diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml
index 00823951dcc..f265a42f9d5 100644
--- a/.buildkite/pipeline.yml
+++ b/.buildkite/pipeline.yml
@@ -1,21 +1,6 @@
-# Document 1: Buildkite loads only this block on first parse. The next step resolves docs-only skip-ci
-# from git diff, then uploads document 2. When docs-only skip applies, image-build still runs if nightly-test
-# / main NIGHTLY so upload-nightly is not skipped together with test-ready/test-merge.
-#
-# Document 2: appended after `---`; same file, read by upload_pipeline_with_skip_ci.sh (not evaluated as a second pipeline by Buildkite).
-steps:
- - label: ":github: Resolve skip-ci & upload pipeline"
- key: upload-ci-pipeline
- commands:
- - "bash .buildkite/scripts/upload_pipeline_with_skip_ci.sh"
- agents:
- queue: "cpu_queue_premerge"
-
----
steps:
- label: ":docker: Build image"
key: image-build
- if: __IMAGE_BUILD_IF__
commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- "docker build --progress=plain --file docker/Dockerfile.ci -t vllm-omni-ci ."
@@ -28,7 +13,7 @@ steps:
- label: "Upload Ready Pipeline"
depends_on: image-build
key: upload-ready-pipeline
- if: __UPLOAD_READY_IF__
+ if: build.branch != "main" && build.pull_request.labels includes "ready"
commands:
- buildkite-agent pipeline upload .buildkite/test-ready.yml
agents:
@@ -38,25 +23,17 @@ steps:
- label: "Upload Merge Pipeline"
depends_on: image-build
key: upload-merge-pipeline
- if: __UPLOAD_MERGE_IF__
+ if: build.branch == "main" && build.env("NIGHTLY") != "1"
commands:
- buildkite-agent pipeline upload .buildkite/test-merge.yml
agents:
queue: "cpu_queue_premerge"
- # L4 Test — main+NIGHTLY=1 (scheduled), or PR with specific label (e.g. add label then Rebuild)
+ # L4 Test — main+NIGHTLY=1 (scheduled), or PR with label nightly-test (e.g. add label then Rebuild)
- label: "Upload Nightly Pipeline"
depends_on: image-build
key: upload-nightly-pipeline
- if: >-
- (build.branch == "main" && build.env("NIGHTLY") == "1") ||
- (build.branch != "main" && (
- build.pull_request.labels includes "nightly-test" ||
- build.pull_request.labels includes "omni-test" ||
- build.pull_request.labels includes "tts-test" ||
- build.pull_request.labels includes "diffusion-x2iat-test" ||
- build.pull_request.labels includes "diffusion-x2v-test"
- ))
+ if: '(build.branch == "main" && build.env("NIGHTLY") == "1") || (build.branch != "main" && build.pull_request.labels includes "nightly-test")'
commands:
- buildkite-agent pipeline upload .buildkite/test-nightly.yml
agents:
diff --git a/.buildkite/scripts/generate-and-upload-nightly-index.sh b/.buildkite/scripts/generate-and-upload-nightly-index.sh
index b09c13f5cf9..6624af32303 100755
--- a/.buildkite/scripts/generate-and-upload-nightly-index.sh
+++ b/.buildkite/scripts/generate-and-upload-nightly-index.sh
@@ -19,7 +19,7 @@ has_new_python=$($PYTHON -c "print(1 if __import__('sys').version_info >= (3,12)
if [[ "$has_new_python" -eq 0 ]]; then
# use new python from docker
docker pull python:3-slim
- PYTHON="docker run --rm --user $(id -u):$(id -g) -v $(pwd):/app -w /app python:3-slim python3"
+ PYTHON="docker run --rm -v $(pwd):/app -w /app python:3-slim python3"
fi
echo "Using python interpreter: $PYTHON"
@@ -36,7 +36,7 @@ mkdir -p "$INDICES_OUTPUT_DIR"
# HACK: we do not need regex module here, but it is required by pre-commit hook
# To avoid any external dependency, we simply replace it back to the stdlib re module
-sed -i.bak 's/import regex as re/import re/g' .buildkite/scripts/generate-nightly-index.py && rm -f .buildkite/scripts/generate-nightly-index.py.bak
+sed -i 's/import regex as re/import re/g' .buildkite/scripts/generate-nightly-index.py
# Generate indices -- the version is just the commit hash (not omni/{commit})
# because relative paths are computed between the index and wheel directories,
@@ -73,16 +73,15 @@ echo "Pure version (without variant): $pure_version"
# re-generate and copy to /omni/{version}/ only if it does not have "dev" in the version
if [[ "$version" != *"dev"* ]]; then
- s3_version="v$pure_version"
- echo "Re-generating indices for /omni/$s3_version/"
+ echo "Re-generating indices for /omni/$pure_version/"
rm -rf "${INDICES_OUTPUT_DIR:?}"
mkdir -p "$INDICES_OUTPUT_DIR"
# wheel-dir is overridden to be the commit directory, so that the indices point to the correct wheel path
$PYTHON .buildkite/scripts/generate-nightly-index.py \
- --version "$s3_version" \
+ --version "$pure_version" \
--wheel-dir "$BUILDKITE_COMMIT" \
--current-objects "$obj_json" \
--output-dir "$INDICES_OUTPUT_DIR" \
--comment "version $pure_version"
- aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/omni/$s3_version/"
+ aws s3 cp --recursive "$INDICES_OUTPUT_DIR/" "s3://$BUCKET/omni/$pure_version/"
fi
diff --git a/.buildkite/scripts/generate-nightly-index.py b/.buildkite/scripts/generate-nightly-index.py
index bb4a74a7044..c616c446b09 100755
--- a/.buildkite/scripts/generate-nightly-index.py
+++ b/.buildkite/scripts/generate-nightly-index.py
@@ -4,7 +4,6 @@
import argparse
import json
-import re
import sys
from dataclasses import asdict, dataclass
from datetime import datetime
@@ -12,6 +11,8 @@
from typing import Any
from urllib.parse import quote
+import regex as re
+
def normalize_package_name(name: str) -> str:
"""Normalize package name per PEP 503."""
diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh
index 96c139c8f7b..f56f23b5deb 100755
--- a/.buildkite/scripts/hardware_ci/run-amd-test.sh
+++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh
@@ -11,6 +11,15 @@ set -o pipefail
export PYTHONPATH=".."
# Print ROCm version
+echo "--- Confirming Clean Initial State"
+while true; do
+ sleep 3
+ if grep -q clean /opt/amdgpu/etc/gpu_state; then
+ echo "GPUs state is \"clean\""
+ break
+ fi
+done
+
echo "--- ROCm info"
rocminfo
@@ -42,14 +51,25 @@ cleanup_docker() {
# Call the cleanup docker function
cleanup_docker
+echo "--- Resetting GPUs"
+
+echo "reset" > /opt/amdgpu/etc/gpu_state
+
+while true; do
+ sleep 3
+ if grep -q clean /opt/amdgpu/etc/gpu_state; then
+ echo "GPUs state is \"clean\""
+ break
+ fi
+done
+
echo "--- Pulling container"
-## Temporary change to use AMD Docker Hub to store the vllm-omni image
+## Temporary change to use AMD Docker Hub to store the vllm-ci image
# to bypass the rate limit issue with ECR Public Gallery.
-# Images are now stored in a separate repository for vllm-omni, instead of vllm-ci.
# TODO: @tjtanaa point back to ECR Public Gallery
# once the amd agents are configured to use ECR Public Gallery.
# image_name="public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:${BUILDKITE_COMMIT}-rocm-omni"
-image_name="rocm/vllm-omni:${BUILDKITE_COMMIT}"
+image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}-rocm-omni"
container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
# TODO: @tjtanaa uncomment this once the amd agents are configured to use ECR Public Gallery.
diff --git a/.buildkite/scripts/upload_pipeline_with_skip_ci.sh b/.buildkite/scripts/upload_pipeline_with_skip_ci.sh
deleted file mode 100644
index 6259d39b290..00000000000
--- a/.buildkite/scripts/upload_pipeline_with_skip_ci.sh
+++ /dev/null
@@ -1,137 +0,0 @@
-#!/usr/bin/env bash
-# Evaluate docs-only skip-ci and upload continuation steps from the same `.buildkite/pipeline.yml`
-# (YAML document after the first `---`). Buildkite `if` is evaluated at upload time.
-set -euo pipefail
-
-ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
-PIPELINE_YML="${ROOT}/.buildkite/pipeline.yml"
-
-# Prints a single digit to stdout: 1 = skip image CI, 0 = run. Logs go to stderr.
-is_docs_only_change() {
- local file_path
- local has_any=0
-
- while IFS= read -r file_path; do
- [[ -z "${file_path}" ]] && continue
- has_any=1
-
- if [[ "${file_path}" == docs/* ]]; then
- continue
- fi
- if [[ "${file_path}" == *.md ]]; then
- continue
- fi
- if [[ "${file_path}" == "mkdocs.yaml" ]]; then
- continue
- fi
- return 1
- done
-
- [[ "${has_any}" -eq 1 ]]
-}
-
-resolve_skip_ci() {
- local is_pr_build=0
- local files
- local base_branch base_ref
-
- if [[ "${BUILDKITE_PULL_REQUEST:-false}" != "false" && -n "${BUILDKITE_PULL_REQUEST:-}" ]]; then
- is_pr_build=1
- fi
-
- if [[ "${is_pr_build}" -eq 1 ]]; then
- base_branch="${BUILDKITE_PULL_REQUEST_BASE_BRANCH:-main}"
- if ! git rev-parse --verify "origin/${base_branch}" >/dev/null 2>&1; then
- echo "resolve_skip_ci: origin/${base_branch} not found locally; trying fetch" >&2
- git fetch --depth=200 origin "${base_branch}" >/dev/null 2>&1 || true
- fi
-
- base_ref=""
- if git rev-parse --verify "origin/${base_branch}" >/dev/null 2>&1; then
- base_ref="origin/${base_branch}"
- elif git rev-parse --verify "${base_branch}" >/dev/null 2>&1; then
- base_ref="${base_branch}"
- else
- echo "resolve_skip_ci: cannot resolve PR base ${base_branch}; skip-ci=0" >&2
- echo -n 0
- return 0
- fi
-
- if ! files="$(git diff --name-only "${base_ref}...${BUILDKITE_COMMIT}" 2>/dev/null)"; then
- echo "resolve_skip_ci: failed to compute PR changed files; skip-ci=0" >&2
- echo -n 0
- return 0
- fi
- elif [[ "${BUILDKITE_BRANCH:-}" == "main" ]]; then
- if ! git rev-parse --verify "${BUILDKITE_COMMIT}^" >/dev/null 2>&1; then
- echo "resolve_skip_ci: commit has no parent on main; skip-ci=0" >&2
- echo -n 0
- return 0
- fi
- if ! files="$(git diff --name-only "${BUILDKITE_COMMIT}^..${BUILDKITE_COMMIT}" 2>/dev/null)"; then
- echo "resolve_skip_ci: failed to compute main changed files; skip-ci=0" >&2
- echo -n 0
- return 0
- fi
- else
- echo "resolve_skip_ci: not PR/main build; skip-ci=0" >&2
- echo -n 0
- return 0
- fi
-
- if is_docs_only_change <<< "${files}"; then
- echo "resolve_skip_ci: docs-only change detected; skip-ci=1" >&2
- echo -n 1
- return 0
- fi
-
- echo "resolve_skip_ci: non-doc changes detected; skip-ci=0" >&2
- echo -n 0
-}
-
-SKIP_CI="$(resolve_skip_ci)"
-
-if [[ ! -f "${PIPELINE_YML}" ]]; then
- echo "upload_pipeline_with_skip_ci: missing ${PIPELINE_YML}" >&2
- exit 1
-fi
-
-export ROOT SKIP_CI PIPELINE_YML
-python3 <<'PY' | buildkite-agent pipeline upload
-import os
-import pathlib
-
-path = pathlib.Path(os.environ["PIPELINE_YML"])
-text = path.read_text(encoding="utf-8")
-sep = "\n---\n"
-if sep not in text:
- raise SystemExit(
- "upload_pipeline_with_skip_ci: .buildkite/pipeline.yml must contain a '\\n---\\n' separator "
- "(document 1 = bootstrap, document 2 = uploaded steps)"
- )
-_, continuation = text.split(sep, 1)
-
-skip = os.environ.get("SKIP_CI") == "1"
-# When docs-only skip-ci: skip default CI image, but still build for L4 nightly (PR label nightly-test or
-# main NIGHTLY=1), otherwise upload-nightly (depends_on image-build) would be skipped too.
-nightly_only = (
- '(build.pull_request.labels includes "nightly-test") '
- '|| (build.branch == "main" && build.env("NIGHTLY") == "1")'
-)
-# Placeholder in pipeline.yml is `if: __IMAGE_BUILD_IF__` (valid YAML); replace value only.
-if skip:
- rep = f"'{nightly_only}'"
- ready_rep = "'false'"
- merge_rep = "'false'"
-else:
- rep = "'true'"
- ready_rep = "'build.branch != \"main\" && build.pull_request.labels includes \"ready\"'"
- merge_rep = "'(build.branch == \"main\" && build.env(\"NIGHTLY\") != \"1\") || (build.branch != \"main\" && build.pull_request.labels includes \"merge-test\")'"
-rendered = (
- continuation
- .replace("__IMAGE_BUILD_IF__", rep)
- .replace("__UPLOAD_READY_IF__", ready_rep)
- .replace("__UPLOAD_MERGE_IF__", merge_rep)
-)
-print(rendered, end="")
-PY
diff --git a/.buildkite/test-amd-merge.yml b/.buildkite/test-amd-merge.yml
index ac52f60b35b..60ba0d9d416 100644
--- a/.buildkite/test-amd-merge.yml
+++ b/.buildkite/test-amd-merge.yml
@@ -32,6 +32,7 @@ steps:
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
+ - export GPU_ARCHS=gfx942
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- |
@@ -54,7 +55,7 @@ steps:
# - export GPU_ARCHS=gfx942
# - export VLLM_LOGGING_LEVEL=DEBUG
# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
-# - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_expansion.py -m "advanced_model and diffusion and L4" --run-level advanced_model
+# - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_model.py
- label: "Diffusion Cache Backend Test"
agent_pool: mi325_1
@@ -62,12 +63,13 @@ steps:
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
+ - export GPU_ARCHS=gfx942
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- timeout 15m pytest -s -v -m "core_model and cache and diffusion and not distributed_cuda and L4"
-- label: "Diffusion Sequence Parallelism Test (Need 4 GPUs)"
- agent_pool: mi325_4
+- label: "Diffusion Sequence Parallelism Test"
+ agent_pool: mi325_2
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
@@ -75,7 +77,6 @@ steps:
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- timeout 20m pytest -s -v tests/e2e/offline_inference/test_sequence_parallel.py
- - timeout 20m pytest -s -v tests/diffusion/distributed/test_ulysses_uaa_perf.py
# merge-only tests
- label: "Diffusion Tensor Parallelism Test"
@@ -94,14 +95,22 @@ steps:
commands:
- timeout 20m pytest -s -v tests/diffusion/test_diffusion_worker.py
-- label: "Engine Test"
- agent_pool: mi325_1
+- label: "Benchmark & Engine Test"
+ agent_pool: mi325_2
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- - timeout 20m pytest -s -v tests/engine/test_async_omni_engine_abort.py
+ - |
+ timeout 20m bash -c '
+ set +e
+ pytest -s -v tests/benchmarks/test_serve_cli.py
+ EXIT1=\$?
+ pytest -s -v tests/engine/test_async_omni_engine_abort.py
+ EXIT2=\$?
+ exit \$((EXIT1 | EXIT2))
+ '
- label: "Omni Model Test Qwen2-5-Omni"
agent_pool: mi325_2
@@ -112,7 +121,6 @@ steps:
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- timeout 20m pytest -s -v tests/e2e/offline_inference/test_qwen2_5_omni.py
- - timeout 20m pytest -s -v tests/e2e/online_serving/test_qwen2_5_omni.py -m "advanced_model" --run-level "advanced_model"
- label: "Omni Model Test Qwen3-Omni"
agent_pool: mi325_2
@@ -123,10 +131,11 @@ steps:
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- export VLLM_TEST_CLEAN_GPU_MEMORY=1
- - timeout 30m pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py tests/e2e/online_serving/test_qwen3_omni.py tests/e2e/online_serving/test_mimo_audio.py -m "advanced_model" --run-level "advanced_model"
+ - timeout 10m pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py
+ - timeout 20m pytest -s -v tests/e2e/online_serving/test_qwen3_omni.py -m "advanced_model" --run-level "advanced_model"
- label: "Qwen3-TTS CustomVoice E2E Test"
- agent_pool: mi325_1
+ agent_pool: mi325_2
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
@@ -136,21 +145,21 @@ steps:
export VLLM_LOGGING_LEVEL=DEBUG
export VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
- pytest -s -v tests/e2e/online_serving/test_qwen3_tts_customvoice.py tests/e2e/offline_inference/test_qwen3_tts_customvoice.py -m "advanced_model" --run-level "advanced_model"
+ pytest -s -v tests/e2e/online_serving/test_qwen3_tts_customvoice.py -m "advanced_model" --run-level "advanced_model" && pytest -s -v tests/e2e/offline_inference/test_qwen3_tts_customvoice.py
'
- label: "Qwen3-TTS Base E2E Test"
- agent_pool: mi325_1
+ agent_pool: mi325_2
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- |
- timeout 30m bash -c '
+ timeout 20m bash -c '
export VLLM_LOGGING_LEVEL=DEBUG
export VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
- pytest -s -v tests/e2e/online_serving/test_qwen3_tts_base.py tests/e2e/offline_inference/test_qwen3_tts_base.py -m "advanced_model" --run-level "advanced_model"
+ pytest -s -v tests/e2e/online_serving/test_qwen3_tts_base.py -m "advanced_model" --run-level "advanced_model" && pytest -s -v tests/e2e/offline_inference/test_qwen3_tts_base.py
'
- label: "Diffusion Image Edit Test"
@@ -164,58 +173,43 @@ steps:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- timeout 20m pytest -s -v tests/e2e/online_serving/test_image_gen_edit.py
-# TODO: Bagel test on ROCm is very unstable. @tjtanaa
-# Need to debug before reneable numerical changes across large PRs
-# # split Bagel Model Test with H100 (Real Weights) into three tests
-# - label: "Bagel Text2Img Model Test (1/3)"
-# agent_pool: mi325_1
-# depends_on: amd-build
-# mirror_hardwares: [amdproduction]
-# grade: Blocking
-# commands:
-# - export GPU_ARCHS=gfx942
-# - export VLLM_TEST_CLEAN_GPU_MEMORY=1
-# - export VLLM_LOGGING_LEVEL=DEBUG
-# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
-# - export VLLM_ROCM_USE_AITER_RMSNORM=0
-# - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_text2img.py -m "advanced_model" --run-level "advanced_model" -k "shared_memory" -k "rocm"
-
-# - label: "Bagel Img2Img Model Test (2/3)"
-# agent_pool: mi325_1
-# depends_on: amd-build
-# mirror_hardwares: [amdproduction]
-# grade: Blocking
-# commands:
-# - export GPU_ARCHS=gfx942
-# - export VLLM_TEST_CLEAN_GPU_MEMORY=1
-# - export VLLM_LOGGING_LEVEL=DEBUG
-# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
-# - export VLLM_ROCM_USE_AITER_RMSNORM=0
-# - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_img2img.py -m "advanced_model" --run-level "advanced_model" -k "rocm"
+# split Bagel Model Test with H100 (Real Weights) into three tests
+- label: "Bagel Text2Img Model Test"
+ agent_pool: mi325_1
+ depends_on: amd-build
+ mirror_hardwares: [amdproduction]
+ grade: Blocking
+ commands:
+ - export GPU_ARCHS=gfx942
+ - export VLLM_TEST_CLEAN_GPU_MEMORY=1
+ - export VLLM_LOGGING_LEVEL=DEBUG
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - export VLLM_ROCM_USE_AITER_RMSNORM=0
+ - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_text2img.py -m "advanced_model" --run-level "advanced_model" -k "shared_memory" -k "rocm"
-# - label: "Bagel Online Serving Test (3/3)"
-# agent_pool: mi325_1
-# depends_on: amd-build
-# mirror_hardwares: [amdproduction]
-# grade: Blocking
-# commands:
-# - export GPU_ARCHS=gfx942
-# - export VLLM_TEST_CLEAN_GPU_MEMORY=1
-# - export VLLM_IMAGE_FETCH_TIMEOUT=60
-# - export VLLM_LOGGING_LEVEL=DEBUG
-# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
-# - export VLLM_ROCM_USE_AITER_RMSNORM=0
-# - timeout 40m pytest -s -v tests/e2e/online_serving/test_bagel_online.py -m "advanced_model" --run-level "advanced_model" -k "rocm"
+- label: "Bagel Img2Img Model Test"
+ agent_pool: mi325_1
+ depends_on: amd-build
+ mirror_hardwares: [amdproduction]
+ grade: Blocking
+ commands:
+ - export GPU_ARCHS=gfx942
+ - export VLLM_TEST_CLEAN_GPU_MEMORY=1
+ - export VLLM_LOGGING_LEVEL=DEBUG
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - export VLLM_ROCM_USE_AITER_RMSNORM=0
+ - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_img2img.py -m "advanced_model" --run-level "advanced_model" -k "rocm"
-- label: "Voxtral-TTS E2E Test"
+- label: "Bagel Online Serving Test"
agent_pool: mi325_1
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- - |
- timeout 20m bash -c '
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v tests/e2e/online_serving/test_voxtral_tts.py tests/e2e/offline_inference/test_voxtral_tts.py -m "advanced_model" --run-level "advanced_model"
- '
+ - export GPU_ARCHS=gfx942
+ - export VLLM_TEST_CLEAN_GPU_MEMORY=1
+ - export VLLM_IMAGE_FETCH_TIMEOUT=60
+ - export VLLM_LOGGING_LEVEL=DEBUG
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - export VLLM_ROCM_USE_AITER_RMSNORM=0
+ - timeout 40m pytest -s -v tests/e2e/online_serving/test_bagel_online.py -m "advanced_model" --run-level "advanced_model" -k "rocm"
diff --git a/.buildkite/test-amd-ready.yaml b/.buildkite/test-amd-ready.yaml
index 30bbc769412..6e31163accb 100644
--- a/.buildkite/test-amd-ready.yaml
+++ b/.buildkite/test-amd-ready.yaml
@@ -9,37 +9,13 @@ steps:
- export VLLM_ROCM_USE_AITER=0
- "timeout 20m pytest -v -s -m 'core_model and cpu' --cov=vllm_omni --cov-branch --cov-report=term-missing --cov-report=html --cov-report=xml"
-- label: "Voxtral TTS CUDA Unit Test"
- agent_pool: mi325_1
- depends_on: amd-build
- mirror_hardwares: [amdproduction]
- grade: Blocking
- commands:
- - timeout 10m pytest -s -v tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py
-
- label: "Diffusion Model Test"
- agent_pool: mi325_1
- depends_on: amd-build
- mirror_hardwares: [amdproduction]
- grade: Blocking
- commands:
- - timeout 30m pytest -s -v tests/e2e/offline_inference/test_t2i_model.py -m "core_model and diffusion" --run-level "core_model"
-
-- label: "Diffusion Batching Test"
- agent_pool: mi325_1
- depends_on: amd-build
- mirror_hardwares: [amdproduction]
- grade: Blocking
- commands:
- - timeout 20m pytest -s -v tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py -m "core_model and diffusion" --run-level "core_model"
-
-- label: "Custom Pipeline Test"
- agent_pool: mi325_1
+ agent_pool: mi325_2
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- - timeout 20m pytest -s -v tests/e2e/offline_inference/custom_pipeline/ -m "core_model"
+ - timeout 20m pytest -s -v tests/e2e/offline_inference/test_t2i_model.py -m "core_model and diffusion" --run-level "core_model"
- label: "Diffusion Model CPU offloading Test"
agent_pool: mi325_1
@@ -47,6 +23,7 @@ steps:
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
+ - export GPU_ARCHS=gfx942
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- |
@@ -69,7 +46,7 @@ steps:
# - export GPU_ARCHS=gfx942
# - export VLLM_LOGGING_LEVEL=DEBUG
# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
-# - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_expansion.py -m "advanced_model and diffusion and L4" --run-level advanced_model
+# - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_model.py
- label: "Diffusion Cache Backend Test"
agent_pool: mi325_1
@@ -100,58 +77,47 @@ steps:
commands:
- timeout 20m pytest -s -v tests/diffusion/test_diffusion_worker.py
-- label: "Engine Test"
- agent_pool: mi325_1
+- label: "Benchmark & Engine Test"
+ agent_pool: mi325_2
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- |
- timeout 15m bash -c '
- pytest -s -v tests/engine/test_async_omni_engine_abort.py
+ timeout 30m bash -c '
+ set +e
+ pytest -s -v tests/benchmarks/test_serve_cli.py
+ EXIT1=\$?
+ pytest -s -v tests/engine/test_async_omni_engine_abort.py
+ EXIT2=\$?
+ exit \$((EXIT1 | EXIT2))
'
+- label: "Omni Model Test Qwen2-5-Omni"
+ agent_pool: mi325_2
+ depends_on: amd-build
+ mirror_hardwares: [amdproduction]
+ grade: Blocking
+ commands:
+ - export VLLM_LOGGING_LEVEL=DEBUG
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - timeout 17m pytest -s -v tests/e2e/offline_inference/test_qwen2_5_omni.py
-# NOTE: This test is not running any thing. It is skipped and deselected.
-# Currently it is = 1 skipped, 1 deselected, 17 warnings in 0.03s ======
-# - label: "Omni Model Test Qwen2-5-Omni"
-# agent_pool: mi325_2
-# depends_on: amd-build
-# mirror_hardwares: [amdproduction]
-# grade: Blocking
-# commands:
-# - export VLLM_LOGGING_LEVEL=DEBUG
-# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
-# - timeout 20m pytest -s -v tests/e2e/offline_inference/test_qwen2_5_omni.py -m "core_model" --run-level "core_model"
-
-# - label: "Omni Model Test Qwen3-Omni"
-# agent_pool: mi325_2
-# depends_on: amd-build
-# mirror_hardwares: [amdproduction]
-# grade: Blocking
-# commands:
-# - export VLLM_LOGGING_LEVEL=DEBUG
-# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
-# - export VLLM_TEST_CLEAN_GPU_MEMORY=1
-# - timeout 10m pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py
-# - timeout 20m pytest -s -v tests/e2e/online_serving/test_qwen3_omni.py -m "core_model" --run-level "core_model"
-
-- label: "MiMo-Audio E2E Test with H100"
- agent_pool: mi325_1
+- label: "Omni Model Test Qwen3-Omni"
+ agent_pool: mi325_2
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- - |
- timeout 30m bash -c '
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v tests/e2e/online_serving/test_mimo_audio.py -m "core_model" --run-level "core_model"
- '
+ - export VLLM_LOGGING_LEVEL=DEBUG
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - export VLLM_TEST_CLEAN_GPU_MEMORY=1
+ - timeout 10m pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py
+ - timeout 10m pytest -s -v tests/e2e/online_serving/test_qwen3_omni.py -m "core_model" --run-level "core_model"
- label: "Qwen3-TTS E2E Test"
- agent_pool: mi325_1
+ agent_pool: mi325_2
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
@@ -159,82 +125,55 @@ steps:
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
- - timeout 30m pytest -s -v tests/e2e/online_serving/test_qwen3_tts_customvoice.py -m "core_model" --run-level "core_model"
+ - timeout 20m pytest -s -v tests/e2e/online_serving/test_qwen3_tts_customvoice.py -m "core_model" --run-level "core_model"
-- label: "Voxtral-TTS E2E Test"
+- label: "Diffusion Image Edit Test"
agent_pool: mi325_1
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- - |
- timeout 20m bash -c '
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v tests/e2e/online_serving/test_voxtral_tts.py -m "advanced_model" --run-level "advanced_model"
- pytest -s -v tests/e2e/offline_inference/test_voxtral_tts.py -m "advanced_model" --run-level "advanced_model"
- '
+ - export GPU_ARCHS=gfx942
+ - export VLLM_LOGGING_LEVEL=DEBUG
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - timeout 20m pytest -s -v tests/e2e/online_serving/test_image_gen_edit.py
-- label: "Diffusion Image Edit Test"
+- label: "Bagel Text2Img Model Test"
agent_pool: mi325_1
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- export GPU_ARCHS=gfx942
+ - export VLLM_TEST_CLEAN_GPU_MEMORY=1
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- - timeout 20m pytest -s -v tests/e2e/online_serving/test_image_gen_edit.py
-
-# TODO: Bagel test on ROCm is very unstable. @tjtanaa
-# Need to debug before reneable numerical changes across large PRs
-# - label: "Bagel Text2Img Model Test"
-# agent_pool: mi325_1
-# depends_on: amd-build
-# mirror_hardwares: [amdproduction]
-# grade: Blocking
-# commands:
-# - export GPU_ARCHS=gfx942
-# - export VLLM_TEST_CLEAN_GPU_MEMORY=1
-# - export VLLM_LOGGING_LEVEL=DEBUG
-# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
-# - export VLLM_ROCM_USE_AITER_RMSNORM=0
-# - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_text2img.py -m "core_model" --run-level "core_model" -k "rocm"
-
-# - label: "Bagel Img2Img Model Test"
-# agent_pool: mi325_1
-# depends_on: amd-build
-# mirror_hardwares: [amdproduction]
-# grade: Blocking
-# commands:
-# - export GPU_ARCHS=gfx942
-# - export VLLM_TEST_CLEAN_GPU_MEMORY=1
-# - export VLLM_LOGGING_LEVEL=DEBUG
-# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
-# - export VLLM_ROCM_USE_AITER_RMSNORM=0
-# - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_img2img.py -m "core_model" --run-level "core_model" -k "rocm"
+ - export VLLM_ROCM_USE_AITER_RMSNORM=0
+ - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_text2img.py -m "core_model" --run-level "core_model" -k "rocm"
-# - label: "Bagel Online Serving Test"
-# agent_pool: mi325_1
-# depends_on: amd-build
-# mirror_hardwares: [amdproduction]
-# grade: Blocking
-# commands:
-# - export GPU_ARCHS=gfx942
-# - export VLLM_TEST_CLEAN_GPU_MEMORY=1
-# - export VLLM_IMAGE_FETCH_TIMEOUT=60
-# - export VLLM_LOGGING_LEVEL=DEBUG
-# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
-# - export VLLM_ROCM_USE_AITER_RMSNORM=0
-# - timeout 40m pytest -s -v tests/e2e/online_serving/test_bagel_online.py -m "core_model" --run-level "core_model" -k "rocm"
+- label: "Bagel Img2Img Model Test"
+ agent_pool: mi325_1
+ depends_on: amd-build
+ mirror_hardwares: [amdproduction]
+ grade: Blocking
+ commands:
+ - export GPU_ARCHS=gfx942
+ - export VLLM_TEST_CLEAN_GPU_MEMORY=1
+ - export VLLM_LOGGING_LEVEL=DEBUG
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - export VLLM_ROCM_USE_AITER_RMSNORM=0
+ - timeout 30m pytest -s -v tests/e2e/offline_inference/test_bagel_img2img.py -m "core_model" --run-level "core_model" -k "rocm"
-- label: "CosyVoice3-TTS E2E Test"
+- label: "Bagel Online Serving Test"
agent_pool: mi325_1
depends_on: amd-build
mirror_hardwares: [amdproduction]
grade: Blocking
commands:
- - |
- timeout 20m bash -c '
- pytest -s -v tests/e2e/online_serving/test_cosyvoice3_tts.py -m "core_model" --run-level "core_model"
- '
+ - export GPU_ARCHS=gfx942
+ - export VLLM_TEST_CLEAN_GPU_MEMORY=1
+ - export VLLM_IMAGE_FETCH_TIMEOUT=60
+ - export VLLM_LOGGING_LEVEL=DEBUG
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - export VLLM_ROCM_USE_AITER_RMSNORM=0
+ - timeout 40m pytest -s -v tests/e2e/online_serving/test_bagel_online.py -m "core_model" --run-level "core_model" -k "rocm"
diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml
index 0b9a3f47aba..e175385ff0d 100644
--- a/.buildkite/test-amd.yaml
+++ b/.buildkite/test-amd.yaml
@@ -117,17 +117,3 @@ steps:
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v tests/e2e/online_serving/test_image_gen_edit.py
-
-
-- label: "Omni Sleep Mode Test"
- timeout_in_minutes: 40
- agent_pool: mi325_2
- depends_on: amd-build
- mirror_hardwares: [amdproduction]
- grade: Blocking
- commands:
- - export GPU_ARCHS=gfx942
- - export VLLM_LOGGING_LEVEL=DEBUG
- - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- - export VLLM_TEST_CLEAN_GPU_MEMORY="1"
- - pytest -s -v tests/e2e/offline_inference/test_omni_sleep_mode.py -m "advanced_model and omni and MI325" --run-level "advanced_model"
diff --git a/.buildkite/test-merge.yml b/.buildkite/test-merge.yml
index 691f3f8764d..b0b5a639618 100644
--- a/.buildkite/test-merge.yml
+++ b/.buildkite/test-merge.yml
@@ -1,8 +1,3 @@
-env:
- VLLM_WORKER_MULTIPROC_METHOD: spawn
- HF_HUB_DOWNLOAD_TIMEOUT: 300
- HF_HUB_ETAG_TIMEOUT: 60
-
steps:
- label: "Simple Unit Test"
depends_on: upload-merge-pipeline
@@ -76,6 +71,24 @@ steps:
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
+ - label: "Audio Generation Model Test"
+ timeout_in_minutes: 20
+ depends_on: upload-merge-pipeline
+ commands:
+ - pytest -s -v tests/e2e/offline_inference/test_stable_audio_model.py
+ agents:
+ queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU
+ plugins:
+ - docker#v5.2.0:
+ image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ always-pull: true
+ propagate-environment: true
+ environment:
+ - "HF_HOME=/fsx/hf_cache"
+ - "HF_TOKEN"
+ volumes:
+ - "/fsx/hf_cache:/fsx/hf_cache"
+
- label: "Diffusion Cache Backend Test"
timeout_in_minutes: 15
depends_on: upload-merge-pipeline
@@ -95,7 +108,7 @@ steps:
- "/fsx/hf_cache:/fsx/hf_cache"
- label: "Diffusion Sequence Parallelism Test"
- timeout_in_minutes: 25
+ timeout_in_minutes: 20
depends_on: upload-merge-pipeline
commands:
- pytest -s -v tests/e2e/offline_inference/test_sequence_parallel.py tests/diffusion/distributed/test_ulysses_uaa_perf.py
@@ -156,6 +169,7 @@ steps:
commands:
- |
timeout 15m bash -c '
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
pytest -s -v tests/engine/test_async_omni_engine_abort.py
'
agents:
@@ -177,6 +191,7 @@ steps:
depends_on: upload-merge-pipeline
commands:
- export VLLM_LOGGING_LEVEL=DEBUG
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v tests/e2e/offline_inference/test_qwen2_5_omni.py tests/e2e/online_serving/test_qwen2_5_omni.py -m "advanced_model" --run-level "advanced_model"
agents:
queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU
@@ -197,6 +212,7 @@ steps:
- |
timeout 20m bash -c '
export VLLM_LOGGING_LEVEL=DEBUG
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
pytest -s -v tests/e2e/online_serving/test_qwen3_tts_customvoice.py tests/e2e/offline_inference/test_qwen3_tts_customvoice.py -m "advanced_model" --run-level "advanced_model"
'
@@ -219,6 +235,7 @@ steps:
- |
timeout 20m bash -c '
export VLLM_LOGGING_LEVEL=DEBUG
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
pytest -s -v tests/e2e/online_serving/test_qwen3_tts_base.py tests/e2e/offline_inference/test_qwen3_tts_base.py -m "advanced_model" --run-level "advanced_model"
'
@@ -239,6 +256,7 @@ steps:
timeout_in_minutes: 30
depends_on: upload-merge-pipeline
commands:
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- export VLLM_TEST_CLEAN_GPU_MEMORY="1"
- pytest -s -v tests/e2e/offline_inference/test_qwen3_omni.py tests/e2e/online_serving/test_qwen3_omni.py tests/e2e/online_serving/test_mimo_audio.py -m "advanced_model" --run-level "advanced_model"
agents:
@@ -275,50 +293,11 @@ steps:
path: /mnt/hf-cache
type: DirectoryOrCreate
- - label: "Audio Streaming Input Test with H100"
- timeout_in_minutes: 30
- depends_on: upload-merge-pipeline
- commands:
- - export VLLM_TEST_CLEAN_GPU_MEMORY="1"
- - pytest -s -v tests/entrypoints/openai_api/test_qwen3_omni_realtime_websocket.py -m "advanced_model" --run-level "advanced_model"
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 2
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
- - name: devshm
- emptyDir:
- medium: Memory
- - name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
-
- label: "Diffusion Image Edit Test with H100 (1 GPU)"
timeout_in_minutes: 20
depends_on: upload-merge-pipeline
commands:
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v tests/e2e/online_serving/test_image_gen_edit.py
agents:
queue: "mithril-h100-pool"
@@ -361,6 +340,7 @@ steps:
- |
timeout 55m bash -c '
set -e
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_TEST_CLEAN_GPU_MEMORY=1
export VLLM_IMAGE_FETCH_TIMEOUT=60
pytest -s -v tests/e2e/offline_inference/test_bagel_text2img.py -m "advanced_model" --run-level "advanced_model" -k "shared_memory"
@@ -400,46 +380,6 @@ steps:
path: /mnt/hf-cache
type: DirectoryOrCreate
- - label: "Omni Sleep Mode Test with H100"
- timeout_in_minutes: 30
- depends_on: upload-merge-pipeline
- commands:
- - export VLLM_TEST_CLEAN_GPU_MEMORY="1"
- - pytest -s -v tests/e2e/offline_inference/test_omni_sleep_mode.py -m "advanced_model and H100 and omni" --run-level "advanced_model"
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 2
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
- - name: devshm
- emptyDir:
- medium: Memory
- - name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
-
- label: "Voxtral-TTS E2E Test"
timeout_in_minutes: 20
depends_on: upload-merge-pipeline
@@ -447,8 +387,19 @@ steps:
- |
timeout 20m bash -c '
export VLLM_LOGGING_LEVEL=DEBUG
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
pytest -s -v tests/e2e/online_serving/test_voxtral_tts.py tests/e2e/offline_inference/test_voxtral_tts.py -m "advanced_model" --run-level "advanced_model"
'
+
+ - label: "CosyVoice3-TTS E2E Test"
+ timeout_in_minutes: 20
+ depends_on: upload-merge-pipeline
+ commands:
+ - |
+ timeout 20m bash -c '
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ pytest -s -v tests/e2e/online_serving/test_cosyvoice3_tts.py -m "advanced_model" --run-level "advanced_model"
+ '
agents:
queue: "mithril-h100-pool"
plugins:
diff --git a/.buildkite/test-nightly.yml b/.buildkite/test-nightly.yml
index ce67b76d921..9dc88850618 100644
--- a/.buildkite/test-nightly.yml
+++ b/.buildkite/test-nightly.yml
@@ -1,693 +1,407 @@
-env:
- VLLM_WORKER_MULTIPROC_METHOD: spawn
- HF_HUB_DOWNLOAD_TIMEOUT: 300
- HF_HUB_ETAG_TIMEOUT: 60
-
steps:
- # Group: collapses under one heading in the Buildkite UI; child steps still run in parallel.
- - group: ":card_index_dividers: Omni Model Test"
- key: nightly-omni-test-group
+ - label: ":full_moon: Omni Model Test with H100"
+ timeout_in_minutes: 90
depends_on: upload-nightly-pipeline
- if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test" || build.pull_request.labels includes "omni-test"
- steps:
- - label: ":full_moon: Omni · Function Test"
- timeout_in_minutes: 90
- commands:
- - pytest -s -v tests/e2e/ -m "full_model and H100 and omni" --run-level "full_model" --ignore=tests/e2e/accuracy
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 2
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
- - name: devshm
- emptyDir:
- medium: Memory
- - name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
-
- - label: ":full_moon: Omni · Doc Test with L4"
- timeout_in_minutes: 90
- commands:
- - export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
- - pytest -s -v tests/examples/ -m "full_model and omni and L4" --run-level "full_model"
- agents:
- queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU
- plugins:
- - docker#v5.2.0:
- image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- always-pull: true
- propagate-environment: true
- shm-size: "8gb"
- environment:
- - "HF_HOME=/fsx/hf_cache"
- - "HF_TOKEN"
- volumes:
- - "/fsx/hf_cache:/fsx/hf_cache"
-
- - label: ":full_moon: Omni · Doc Test with H100"
- timeout_in_minutes: 90
- commands:
- - pytest -s -v tests/examples/ -m "full_model and omni and H100" --run-level "full_model"
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 2
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
- - name: devshm
- emptyDir:
- medium: Memory
- - name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
-
- - label: ":full_moon: Omni · Accuracy Test"
- timeout_in_minutes: 180
- commands:
- - export SEED_TTS_WER_EVAL=1
- - export SEED_TTS_EVAL_DEVICE=cuda:1
- - |
- set +e
- pytest -s -v tests/e2e/accuracy/qwen3_omni/test_qwen3_omni.py -m "full_model" --run-level full_model
- EXIT=$$?
- buildkite-agent artifact upload "tests/e2e/accuracy/qwen3_omni/results/qwen_omni_acc/*.json"
- exit $$EXIT
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 2
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
- - name: devshm
- emptyDir:
- medium: Memory
- - name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
-
- - label: ":full_moon: Omni · Perf Test"
- key: nightly-omni-performance
- timeout_in_minutes: 180
- commands:
- - export BENCHMARK_DIR=tests/dfx/perf/results
- - |
- set +e
- pytest -s -v tests/dfx/perf/scripts/run_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_omni.json
- EXIT=$$?
- buildkite-agent artifact upload "tests/dfx/perf/results/*.json"
- exit $$EXIT
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 2
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
+ if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"
+ commands:
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - |
+ pytest -s -v \
+ tests/examples/ \
+ tests/e2e/online_serving/test_*_expansion.py \
+ -m "advanced_model and H100 and omni" --run-level "advanced_model"
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 2
+ volumeMounts:
- name: devshm
- emptyDir:
- medium: Memory
+ mountPath: /dev/shm
- name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
-
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
- - group: ":card_index_dividers: TTS Model Test"
- key: nightly-tts-test-group
+ - label: ":full_moon: Omni Model Test with L4"
+ timeout_in_minutes: 90
depends_on: upload-nightly-pipeline
- if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test" || build.pull_request.labels includes "tts-test"
- steps:
- - label: ":full_moon: TTS · Function Test"
- timeout_in_minutes: 90
- commands:
- - export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
- - pytest -s -v tests/e2e/ -m "full_model and L4 and omni" --run-level "full_model" --ignore=tests/e2e/accuracy
- agents:
- queue: "gpu_1_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU
- plugins:
- - docker#v5.2.0:
- image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- always-pull: true
- propagate-environment: true
- shm-size: "8gb"
- environment:
- - "HF_HOME=/fsx/hf_cache"
- - "HF_TOKEN"
- volumes:
- - "/fsx/hf_cache:/fsx/hf_cache"
-
- - label: ":full_moon: TTS · Perf Test"
- key: nightly-tts-performance
- timeout_in_minutes: 180
- commands:
- - export BENCHMARK_DIR=tests/dfx/perf/results
- - export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
- - |
- set +e
- pytest -s -v tests/dfx/perf/scripts/run_benchmark.py --test-config-file tests/dfx/perf/tests/test_tts.json
- EXIT=$$?
- buildkite-agent artifact upload "tests/dfx/perf/results/*.json"
- exit $$EXIT
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 1
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
- - name: devshm
- emptyDir:
- medium: Memory
- - name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
+ if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"
+ commands:
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
+ - pytest -s -v tests/examples/ tests/e2e/online_serving/test_*_expansion.py -m "advanced_model and L4 and omni" --run-level "advanced_model"
+ agents:
+ queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU
+ plugins:
+ - docker#v5.2.0:
+ image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ always-pull: true
+ propagate-environment: true
+ shm-size: "8gb"
+ environment:
+ - "HF_HOME=/fsx/hf_cache"
+ - "HF_TOKEN"
+ volumes:
+ - "/fsx/hf_cache:/fsx/hf_cache"
- # Diffusion X2I suite: x2i / x2a / x2t and related non-video paths; x2v is only in "Diffusion X2V Model Test" below.
- - group: ":card_index_dividers: Diffusion X2I(&A&T) Model Test"
- key: nightly-diffusion-x2iat-group
+ - label: ":full_moon: Diffusion Model Test with H100"
+ timeout_in_minutes: 120
depends_on: upload-nightly-pipeline
- if: >-
- build.env("NIGHTLY") == "1" ||
- build.pull_request.labels includes "nightly-test" ||
- build.pull_request.labels includes "diffusion-x2iat-test"
- steps:
- - label: ":full_moon: Diffusion X2I(&A&T) · Function Test with H100"
- timeout_in_minutes: 120
- commands:
- - pytest -sv tests/e2e/ -k "not test_wan and not hunyuan" -m "full_model and diffusion and H100" --run-level "full_model" --ignore=tests/e2e/accuracy
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 2
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
+ if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"
+ commands:
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - pytest -s -v tests/e2e/online_serving/test_*_expansion.py -k "not test_wan22_expansion" -m "advanced_model and diffusion and H100" --run-level "advanced_model"
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 2
+ volumeMounts:
- name: devshm
- emptyDir:
- medium: Memory
+ mountPath: /dev/shm
- name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
-
- - label: ":full_moon: Diffusion X2I(&A&T) · Function Test with L4"
- timeout_in_minutes: 60
- commands:
- - pytest -sv tests/e2e/ -k "not test_wan and not hunyuan" -m "full_model and diffusion and L4" --run-level "full_model" --ignore=tests/e2e/accuracy
- agents:
- queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU
- plugins:
- - docker#v5.2.0:
- image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- always-pull: true
- propagate-environment: true
- shm-size: "8gb"
- environment:
- - "HF_HOME=/fsx/hf_cache"
- - "HF_TOKEN"
- volumes:
- - "/fsx/hf_cache:/fsx/hf_cache"
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
- - label: ":full_moon: Diffusion X2I(&A&T) · Doc Test"
- timeout_in_minutes: 60
- commands:
- - export VLLM_TEST_CLEAN_GPU_MEMORY="1"
- - pytest -s -v tests/examples/*/test_text_to_image.py -m "full_model and example and H100" --run-level "full_model"
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 2
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
+ - label: ":full_moon: Diffusion Model (Wan2.2) Test with H100"
+ timeout_in_minutes: 90
+ depends_on: upload-nightly-pipeline
+ if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"
+ commands:
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - pytest -s -v tests/e2e/online_serving/test_wan22_expansion.py -m "advanced_model" --run-level "advanced_model"
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 2
+ volumeMounts:
- name: devshm
- emptyDir:
- medium: Memory
+ mountPath: /dev/shm
- name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
- - label: ":full_moon: Diffusion X2I(&A&T) · GEBench Accuracy Test"
- timeout_in_minutes: 60
- commands:
- - pytest -s -v tests/e2e/accuracy/test_gebench_h100_smoke.py --run-level full_model --gebench-model Qwen/Qwen-Image-2512 --accuracy-judge-model QuantTrio/Qwen3-VL-30B-A3B-Instruct-AWQ --accuracy-gpu 0 --gebench-port 8093 --accuracy-workers 1
- - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gebench_qwen-image-2512/summary*.json"
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 1
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
- - name: devshm
- emptyDir:
- medium: Memory
- - name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
+ - label: ":full_moon: Diffusion Model Test"
+ timeout_in_minutes: 60
+ depends_on: upload-nightly-pipeline
+ if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"
+ commands:
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - pytest -s -v tests/e2e/online_serving/test_*_expansion.py -m "advanced_model and diffusion and L4" --run-level "advanced_model"
+ agents:
+ queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU
+ plugins:
+ - docker#v5.2.0:
+ image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ always-pull: true
+ propagate-environment: true
+ shm-size: "8gb"
+ environment:
+ - "HF_HOME=/fsx/hf_cache"
+ - "HF_TOKEN"
+ volumes:
+ - "/fsx/hf_cache:/fsx/hf_cache"
- - label: ":full_moon: Diffusion X2I(&A&T) · GEdit-Bench Accuracy Test"
- timeout_in_minutes: 60
- commands:
- - pytest -s -v tests/e2e/accuracy/test_gedit_bench_h100_smoke.py --run-level full_model --gedit-model Qwen/Qwen-Image-Edit --accuracy-judge-model QuantTrio/Qwen3-VL-30B-A3B-Instruct-AWQ --accuracy-gpu 0 --gedit-port 8093 --gedit-samples-per-group 20 --accuracy-workers 1
- - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gedit_scores_qwen-image-edit/qwen-image-edit_all_all_vie_score_*.csv"
- - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gedit_scores_qwen-image-edit/qwen-image-edit_all_all_summary_*.json"
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 1
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: VLLM_HTTP_TIMEOUT_KEEP_ALIVE
- value: "120"
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
- - name: devshm
- emptyDir:
- medium: Memory
- - name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
- - label: ":full_moon: Diffusion X2I(&A&T) · Accuracy Test"
- timeout_in_minutes: 180
- commands:
- - pytest -s -v tests/e2e/accuracy/test_qwen_image*.py --run-level full_model
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 1
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: VLLM_HTTP_TIMEOUT_KEEP_ALIVE
- value: "120"
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
+ - label: ":full_moon: Doc Example Code Test with H100"
+ timeout_in_minutes: 60
+ depends_on: upload-nightly-pipeline
+ if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"
+ commands:
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - export VLLM_TEST_CLEAN_GPU_MEMORY="1"
+ - pytest -s -v tests/examples/online_serving/test_text_to_image.py tests/examples/offline_inference/test_text_to_image.py -m "advanced_model and example and H100" --run-level "advanced_model"
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 2
+ volumeMounts:
- name: devshm
- emptyDir:
- medium: Memory
+ mountPath: /dev/shm
- name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
- - label: ":full_moon: Diffusion X2I(&A&T) · Perf Test"
- key: nightly-diffusion-x2iat-performance
- timeout_in_minutes: 180
- commands:
- - export DIFFUSION_BENCHMARK_DIR=tests/dfx/perf/results
- - export DIFFUSION_ATTENTION_BACKEND=FLASH_ATTN
- - export CACHE_DIT_VERSION=1.3.0
- # [HACK]: run upload in the same command block as pytest.
- # Because `exit` aborts the entire commands list.
- - |
- set +e
- pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
- EXIT1=$$?
- pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_image_edit_vllm_omni.json
- EXIT2=$$?
- pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_image_edit_2509_vllm_omni.json
- EXIT3=$$?
- pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_image_layered_vllm_omni.json
- EXIT4=$$?
- buildkite-agent artifact upload "tests/dfx/perf/results/diffusion_result_*.json"
- buildkite-agent artifact upload "tests/dfx/perf/results/logs/*.log"
- exit $$((EXIT1 | EXIT2 | EXIT3 | EXIT4))
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 4
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
+ - label: ":full_moon: Omni Model Perf Test & Testcase Statistics with H100"
+ key: nightly-omni-performance
+ timeout_in_minutes: 180
+ depends_on: upload-nightly-pipeline
+ if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"
+ commands:
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - export BENCHMARK_DIR=tests/dfx/perf/results
+ - export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
+ - pytest -s -v tests/dfx/perf/scripts/run_benchmark.py
+ - buildkite-agent artifact upload "tests/dfx/perf/results/*.json"
+ - python tools/nightly/buildkite_testcase_statistics.py -o tests/dfx/perf/results/buildkite_testcase_statistics.html
+ - buildkite-agent artifact upload "tests/dfx/perf/results/*.html"
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 2
+ volumeMounts:
- name: devshm
- emptyDir:
- medium: Memory
+ mountPath: /dev/shm
- name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
- # Diffusion x2v only (Wan, HunyuanVideo, …). x2i/x2a/x2t live in the X2I group above, not here.
- - group: ":card_index_dividers: Diffusion X2V Model Test"
- key: nightly-diffusion-x2v-group
+ - label: ":full_moon: GEBench Accuracy Test with H100"
+ key: nightly-gebench-accuracy
+ timeout_in_minutes: 60
depends_on: upload-nightly-pipeline
- if: >-
- build.env("NIGHTLY") == "1" ||
- build.pull_request.labels includes "nightly-test" ||
- build.pull_request.labels includes "diffusion-x2v-test"
- steps:
- - label: ":full_moon: Diffusion X2V · Function Test"
- timeout_in_minutes: 90
- commands:
- - pytest -s -v tests/e2e/online_serving/test_wan22_expansion.py tests/e2e/online_serving/test_wan_2_1_vace_expansion.py tests/e2e/online_serving/test_hunyuan_video_15_expansion.py --run-level "full_model"
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 2
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
+ if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"
+ commands:
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - pytest -s -v tests/e2e/accuracy/test_gebench_h100_smoke.py --run-level advanced_model --gebench-model Qwen/Qwen-Image-2512 --accuracy-judge-model QuantTrio/Qwen3-VL-30B-A3B-Instruct-AWQ --accuracy-gpu 0 --gebench-port 8093 --accuracy-workers 1
+ - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gebench_qwen-image-2512/summary*.json"
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 1
+ volumeMounts:
- name: devshm
- emptyDir:
- medium: Memory
+ mountPath: /dev/shm
- name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
- - label: ":full_moon: Diffusion X2V · Accuracy Test"
- timeout_in_minutes: 180
- commands:
- - pytest -s -v tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py -m full_model --run-level full_model
- - pytest -s -v tests/e2e/accuracy/test_ltx2_3_video_similarity.py -m advanced_model --run-level advanced_model
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 2
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
+ - label: ":full_moon: GEdit-Bench Accuracy Test with H100"
+ key: nightly-gedit-bench-accuracy
+ timeout_in_minutes: 60
+ depends_on: upload-nightly-pipeline
+ if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"
+ commands:
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - pytest -s -v tests/e2e/accuracy/test_gedit_bench_h100_smoke.py --run-level advanced_model --gedit-model Qwen/Qwen-Image-Edit --accuracy-judge-model QuantTrio/Qwen3-VL-30B-A3B-Instruct-AWQ --accuracy-gpu 0 --gedit-port 8093 --gedit-samples-per-group 20 --accuracy-workers 1
+ - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gedit_scores_qwen-image-edit/qwen-image-edit_all_all_vie_score_*.csv"
+ - buildkite-agent artifact upload "tests/e2e/accuracy/artifacts/gedit_scores_qwen-image-edit/qwen-image-edit_all_all_summary_*.json"
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 1
+ volumeMounts:
- name: devshm
- emptyDir:
- medium: Memory
+ mountPath: /dev/shm
- name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: VLLM_HTTP_TIMEOUT_KEEP_ALIVE
+ value: "120"
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
- - label: ":full_moon: Diffusion X2V · Perf Test"
- key: nightly-diffusion-x2v-performance
- timeout_in_minutes: 180
- commands:
- - export DIFFUSION_BENCHMARK_DIR=tests/dfx/perf/results
- - export DIFFUSION_ATTENTION_BACKEND=FLASH_ATTN
- - |
- set +e
- pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --test-config-file tests/dfx/perf/tests/test_wan22_i2v_vllm_omni.json
- EXIT1=$$?
- buildkite-agent artifact upload "tests/dfx/perf/results/diffusion_result_*.json"
- buildkite-agent artifact upload "tests/dfx/perf/results/logs/*.log"
- exit $$EXIT1
- agents:
- queue: "mithril-h100-pool"
- plugins:
- - kubernetes:
- podSpec:
- containers:
- - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- resources:
- limits:
- nvidia.com/gpu: 2
- volumeMounts:
- - name: devshm
- mountPath: /dev/shm
- - name: hf-cache
- mountPath: /root/.cache/huggingface
- env:
- - name: HF_HOME
- value: /root/.cache/huggingface
- - name: HF_TOKEN
- valueFrom:
- secretKeyRef:
- name: hf-token-secret
- key: token
- nodeSelector:
- node.kubernetes.io/instance-type: gpu-h100-sxm
- volumes:
+ - label: ":full_moon: Wan22 I2V Accuracy Test with H100"
+ key: nightly-wan22-i2v-accuracy
+ timeout_in_minutes: 180
+ depends_on: upload-nightly-pipeline
+ if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"
+ commands:
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - pytest -s -v tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py --run-level advanced_model
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 2
+ volumeMounts:
- name: devshm
- emptyDir:
- medium: Memory
+ mountPath: /dev/shm
- name: hf-cache
- hostPath:
- path: /mnt/hf-cache
- type: DirectoryOrCreate
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
- - label: ":bar_chart: Testcase Statistics"
- key: nightly-testcase-statistics
- timeout_in_minutes: 120
+ - label: ":full_moon: Diffusion Perf Test with H100"
+ key: nightly-qwen-image-performance
+ timeout_in_minutes: 180
depends_on: upload-nightly-pipeline
if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"
commands:
- - python tools/nightly/buildkite_testcase_statistics.py -o tests/dfx/perf/results/buildkite_testcase_statistics.html
- - buildkite-agent artifact upload "tests/dfx/perf/results/*.html"
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - export DIFFUSION_BENCHMARK_DIR=tests/dfx/perf/results
+ - export CACHE_DIT_VERSION=1.3.0
+ - pytest -s -v tests/dfx/perf/scripts/run_diffusion_benchmark.py --config-file tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
+ - buildkite-agent artifact upload "tests/dfx/perf/results/benchmark_results_*.json"
+ - buildkite-agent artifact upload "tests/dfx/perf/results/logs/*.log"
agents:
queue: "mithril-h100-pool"
plugins:
@@ -697,7 +411,7 @@ steps:
- image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
resources:
limits:
- nvidia.com/gpu: 2
+ nvidia.com/gpu: 4
volumeMounts:
- name: devshm
mountPath: /dev/shm
@@ -727,20 +441,15 @@ steps:
key: nightly-perf-distribution
depends_on:
- nightly-omni-performance
- - nightly-tts-performance
- - nightly-diffusion-x2iat-performance
- - nightly-diffusion-x2v-performance
- - nightly-testcase-statistics
+ - nightly-qwen-image-performance
if: build.env("NIGHTLY") == "1"
commands:
- pip install openpyxl
- export DEFAULT_INPUT_DIR=tests/dfx/perf/results
- export DEFAULT_OUTPUT_DIR=tests/dfx/perf/results
- - buildkite-agent artifact download "tests/dfx/perf/results/*.json" . --step nightly-tts-performance
- buildkite-agent artifact download "tests/dfx/perf/results/*.json" . --step nightly-omni-performance
- - buildkite-agent artifact download "tests/dfx/perf/results/*.json" . --step nightly-diffusion-x2iat-performance
- - buildkite-agent artifact download "tests/dfx/perf/results/*.json" . --step nightly-diffusion-x2v-performance
- - buildkite-agent artifact download "tests/dfx/perf/results/*.html" . --step nightly-testcase-statistics
+ - buildkite-agent artifact download "tests/dfx/perf/results/*.json" . --step nightly-qwen-image-performance
+ - buildkite-agent artifact download "tests/dfx/perf/results/*.html" . --step nightly-omni-performance
- python tools/nightly/generate_nightly_perf_excel.py
- python tools/nightly/generate_nightly_perf_html.py
- python tools/nightly/send_nightly_email.py --report-file "tests/dfx/perf/results/*.xlsx, tests/dfx/perf/results/*.html"
diff --git a/.buildkite/test-ready.yml b/.buildkite/test-ready.yml
index 080f18885ef..be528b316cd 100644
--- a/.buildkite/test-ready.yml
+++ b/.buildkite/test-ready.yml
@@ -1,8 +1,3 @@
-env:
- VLLM_WORKER_MULTIPROC_METHOD: spawn
- HF_HUB_DOWNLOAD_TIMEOUT: 300
- HF_HUB_ETAG_TIMEOUT: 60
-
steps:
- label: "Simple Unit Test"
depends_on: upload-ready-pipeline
@@ -21,10 +16,11 @@ steps:
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
- - label: "CUDA Unit Test with single card"
+ - label: "Voxtral TTS CUDA Unit Test"
+ timeout_in_minutes: 10
depends_on: upload-ready-pipeline
commands:
- - timeout 10m pytest -v -s -m 'core_model and cuda and L4 and not distributed_cuda' --ignore=tests/e2e --ignore=tests/engine/test_async_omni_engine_abort.py --cov=vllm_omni --cov-branch --cov-report=term-missing --cov-report=html --cov-report=xml
+ - "timeout 10m pytest -s -v tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py"
agents:
queue: "gpu_1_queue"
plugins:
@@ -37,12 +33,12 @@ steps:
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
- - label: "CUDA Unit Test with multi cards"
+ - label: "Diffusion Model Test"
depends_on: upload-ready-pipeline
commands:
- - timeout 10m pytest -v -s -m 'core_model and cuda and L4 and distributed_cuda' --ignore=tests/e2e --cov=vllm_omni --cov-branch --cov-report=term-missing --cov-report=html --cov-report=xml
+ - timeout 30m pytest -s -v tests/e2e/offline_inference/test_t2i_model.py -m "core_model and diffusion" --run-level "core_model"
agents:
- queue: "gpu_4_queue"
+ queue: "gpu_1_queue"
plugins:
- docker#v5.2.0:
image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
@@ -50,15 +46,16 @@ steps:
propagate-environment: true
environment:
- "HF_HOME=/fsx/hf_cache"
+ - "HF_TOKEN"
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
- - label: "Diffusion Model Test"
+ - label: "Diffusion Batching Test"
depends_on: upload-ready-pipeline
commands:
- - timeout 30m pytest -s -v tests/e2e/offline_inference/test_t2i_model.py -m "core_model and diffusion" --run-level "core_model"
+ - timeout 20m pytest -s -v tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py -m "core_model and diffusion" --run-level "core_model"
agents:
- queue: "gpu_1_queue"
+ queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU
plugins:
- docker#v5.2.0:
image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
@@ -70,12 +67,12 @@ steps:
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
- - label: "Diffusion Batching Test"
+ - label: "Custom Pipeline Test"
depends_on: upload-ready-pipeline
commands:
- - timeout 20m pytest -s -v tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py -m "core_model and diffusion" --run-level "core_model"
+ - timeout 20m pytest -s -v tests/e2e/offline_inference/custom_pipeline/ -m "core_model"
agents:
- queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU
+ queue: "gpu_1_queue"
plugins:
- docker#v5.2.0:
image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
@@ -83,16 +80,15 @@ steps:
propagate-environment: true
environment:
- "HF_HOME=/fsx/hf_cache"
- - "HF_TOKEN"
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
- - label: "Custom Pipeline Test"
+ - label: "Diffusion Model CPU offloading Test"
depends_on: upload-ready-pipeline
commands:
- - timeout 20m pytest -s -v tests/e2e/offline_inference/custom_pipeline/ -m "core_model"
+ - timeout 10m pytest -s -v tests/e2e/offline_inference/test_diffusion_cpu_offload.py
agents:
- queue: "gpu_1_queue"
+ queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU
plugins:
- docker#v5.2.0:
image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
@@ -100,13 +96,14 @@ steps:
propagate-environment: true
environment:
- "HF_HOME=/fsx/hf_cache"
+ - "HF_TOKEN"
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
- - label: "Diffusion Model CPU offloading Test"
+ - label: "Audio Generation Model Test"
depends_on: upload-ready-pipeline
commands:
- - timeout 10m pytest -s -v tests/e2e/offline_inference/test_diffusion_cpu_offload.py
+ - timeout 20m pytest -s -v tests/e2e/offline_inference/test_stable_audio_model.py
agents:
queue: "gpu_1_queue" # g6.4xlarge instance on AWS, has 1 L4 GPU
plugins:
@@ -155,12 +152,31 @@ steps:
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
+ - label: "Diffusion GPU Worker Test"
+ depends_on: upload-ready-pipeline
+ commands:
+ - timeout 20m pytest -s -v tests/diffusion/test_diffusion_worker.py
+ agents:
+ queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU
+ plugins:
+ - docker#v5.2.0:
+ image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ always-pull: true
+ propagate-environment: true
+ shm-size: "8gb"
+ environment:
+ - "HF_HOME=/fsx/hf_cache"
+ - "HF_TOKEN"
+ volumes:
+ - "/fsx/hf_cache:/fsx/hf_cache"
+
- label: "Engine Test"
depends_on: upload-ready-pipeline
commands:
- |
timeout 15m bash -c '
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
pytest -s -v tests/engine/test_async_omni_engine_abort.py
'
agents:
@@ -177,11 +193,35 @@ steps:
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
+
+ - label: "Omni Model Test"
+ depends_on: upload-ready-pipeline
+ commands:
+ - |
+ timeout 17m bash -c '
+ export VLLM_LOGGING_LEVEL=DEBUG
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ pytest -s -v tests/e2e/online_serving/test_qwen2_5_omni.py -m "core_model" --run-level "core_model"
+ '
+ agents:
+ queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU
+ plugins:
+ - docker#v5.2.0:
+ image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ always-pull: true
+ propagate-environment: true
+ environment:
+ - "HF_HOME=/fsx/hf_cache"
+ - "HF_TOKEN"
+ volumes:
+ - "/fsx/hf_cache:/fsx/hf_cache"
+
- label: "Omni Model Test with H100"
depends_on: upload-ready-pipeline
commands:
- |
timeout 20m bash -c '
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
pytest -s -v tests/e2e/online_serving/test_qwen3_omni.py -m "core_model" --run-level "core_model"
'
agents:
@@ -219,6 +259,7 @@ steps:
- |
timeout 30m bash -c '
export VLLM_LOGGING_LEVEL=DEBUG
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
pytest -s -v tests/e2e/online_serving/test_mimo_audio.py -m "core_model" --run-level "core_model"
'
agents:
@@ -261,6 +302,7 @@ steps:
- |
timeout 20m bash -c '
export VLLM_LOGGING_LEVEL=DEBUG
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
pytest -s -v tests/e2e/online_serving/test_qwen3_tts_customvoice.py -m "core_model" --run-level "core_model"
'
@@ -278,90 +320,15 @@ steps:
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
- - label: "VoxCPM E2E Test"
- timeout_in_minutes: 20
- depends_on: upload-ready-pipeline
- commands:
- - |
- timeout 20m bash -c '
- pip install voxcpm
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v tests/e2e/offline_inference/test_voxcpm.py -m "core_model" --run-level "core_model"
- '
- agents:
- queue: "gpu_1_queue"
- plugins:
- - docker#v5.2.0:
- image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- always-pull: true
- propagate-environment: true
- shm-size: "8gb"
- environment:
- - "HF_HOME=/fsx/hf_cache"
- - "HF_TOKEN"
- volumes:
- - "/fsx/hf_cache:/fsx/hf_cache"
-
- - label: "VoxCPM2 Native AR E2E Test"
- timeout_in_minutes: 20
- depends_on: upload-ready-pipeline
- commands:
- - |
- timeout 20m bash -c '
- pip install voxcpm
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v tests/e2e/offline_inference/test_voxcpm2.py -m "core_model" --run-level "core_model"
- '
- agents:
- queue: "gpu_1_queue"
- plugins:
- - docker#v5.2.0:
- image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- always-pull: true
- propagate-environment: true
- shm-size: "8gb"
- environment:
- - "HF_HOME=/fsx/hf_cache"
- - "HF_TOKEN"
- volumes:
- - "/fsx/hf_cache:/fsx/hf_cache"
-
- label: "OmniVoice E2E Test"
timeout_in_minutes: 20
depends_on: upload-ready-pipeline
- commands:
- - |
- timeout 20m bash -c '
- export VLLM_LOGGING_LEVEL=DEBUG
- pytest -s -v tests/e2e/online_serving/test_omnivoice.py -m "core_model" --run-level "core_model"
- '
- agents:
- queue: "gpu_1_queue"
- plugins:
- - docker#v5.2.0:
- image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
- always-pull: true
- propagate-environment: true
- shm-size: "8gb"
- environment:
- - "HF_HOME=/fsx/hf_cache"
- volumes:
- - "/fsx/hf_cache:/fsx/hf_cache"
-
- - label: "Qwen3-TTS Base E2E Test (ModelRunner V2)"
- depends_on: upload-ready-pipeline
- soft_fail:
- - exit_status: 1
commands:
- |
timeout 20m bash -c '
export VLLM_LOGGING_LEVEL=DEBUG
export VLLM_WORKER_MULTIPROC_METHOD=spawn
- export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
- export VLLM_OMNI_USE_V2_RUNNER="1"
- pytest -s -v tests/e2e/online_serving/test_qwen3_tts_base.py -m "core_model" --run-level "core_model"
+ pytest -s -v tests/e2e/online_serving/test_omnivoice.py -m "core_model" --run-level "core_model"
'
agents:
queue: "gpu_1_queue"
@@ -373,7 +340,6 @@ steps:
shm-size: "8gb"
environment:
- "HF_HOME=/fsx/hf_cache"
- - "HF_TOKEN"
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
@@ -384,6 +350,7 @@ steps:
- |
timeout 20m bash -c '
export VLLM_LOGGING_LEVEL=DEBUG
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
pytest -s -v tests/e2e/online_serving/test_voxtral_tts.py -m "core_model" --run-level "core_model"
'
agents:
@@ -420,6 +387,7 @@ steps:
# commands:
# - |
# timeout 20m bash -c '
+ # export VLLM_WORKER_MULTIPROC_METHOD=spawn
# pytest -s -v tests/e2e/online_serving/test_image_gen_edit.py
# '
# agents:
@@ -456,6 +424,7 @@ steps:
commands:
- |
timeout 30m bash -c '
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_TEST_CLEAN_GPU_MEMORY=1
pytest -s -v tests/e2e/offline_inference/test_bagel_text2img.py -m "core_model" --run-level "core_model"
'
@@ -498,6 +467,7 @@ steps:
commands:
- |
timeout 30m bash -c '
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_TEST_CLEAN_GPU_MEMORY=1
pytest -s -v tests/e2e/offline_inference/test_bagel_img2img.py -m "core_model" --run-level "core_model"
'
@@ -540,6 +510,7 @@ steps:
commands:
- |
timeout 40m bash -c '
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_TEST_CLEAN_GPU_MEMORY=1
export VLLM_IMAGE_FETCH_TIMEOUT=60
pytest -s -v tests/e2e/online_serving/test_bagel_online.py -m "core_model" --run-level "core_model"
@@ -584,6 +555,7 @@ steps:
commands:
- |
timeout 20m bash -c '
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
pytest -s -v tests/e2e/online_serving/test_cosyvoice3_tts.py -m "core_model" --run-level "core_model"
'
agents:
diff --git a/.buildkite/test-template-amd-omni.j2 b/.buildkite/test-template-amd-omni.j2
index 78f47d1aec0..8dc91a11727 100644
--- a/.buildkite/test-template-amd-omni.j2
+++ b/.buildkite/test-template-amd-omni.j2
@@ -3,7 +3,7 @@
Last synced: 2025-12-15
Modifications: Removed unused CUDA/NVIDIA logic, keeping only AMD tests
#}
-{% set docker_image_amd = "rocm/vllm-omni:$BUILDKITE_COMMIT" %}
+{% set docker_image_amd = "rocm/vllm-ci:$BUILDKITE_COMMIT-rocm-omni" %}
{% set default_working_dir = "/app/vllm-omni" %}
- group: "AMD Tests"
@@ -48,9 +48,6 @@
DOCKER_BUILDKIT: "1"
TEST_COMMAND: |-
(command rocm-smi || true) && cd {{ (step.working_dir or default_working_dir) | safe }}
-{% if "mi250" in step.agent_pool %}
- python3 -m pip uninstall -y amd-aiter
-{% endif %}
{{ indented_cmd | safe }}
priority: 100
{% if step.grade and step.grade == "Blocking" %}
diff --git a/.claude/skills/add-diffusion-model/SKILL.md b/.claude/skills/add-diffusion-model/SKILL.md
deleted file mode 100644
index 0b979e1a984..00000000000
--- a/.claude/skills/add-diffusion-model/SKILL.md
+++ /dev/null
@@ -1,566 +0,0 @@
----
-name: add-diffusion-model
-description: Add a new diffusion model (text-to-image, text-to-video, image-to-video, text-to-audio, image editing) to vLLM-Omni, including Cache-DiT acceleration and parallelism support (TP, SP/USP, CFG-Parallel, HSDP). Use when integrating a new diffusion model, porting a diffusers pipeline or a custom model repo to vllm-omni, creating a new DiT transformer adapter, adding diffusion model support, or enabling multi-GPU parallelism and cache acceleration for an existing model.
----
-
-# Adding a Diffusion Model to vLLM-Omni
-
-## Overview
-
-This skill guides you through adding a new diffusion model to vLLM-Omni. The model may come from HuggingFace Diffusers (structured pipeline) or from a private/custom repo. The workflow differs significantly depending on the source.
-
-## Prerequisites
-
-Before starting, determine:
-
-1. **Model category**: Text-to-Image, Text-to-Video, Image-to-Video, Image Editing, Text-to-Audio, or Omni
-2. **Reference source**: Diffusers pipeline, custom repo, or a combination
-3. **Model HuggingFace ID** or local checkpoint path
-4. **Architecture**: Scheduler, text encoder, VAE, transformer/backbone
-
-## Step 0: Classify the Migration Path
-
-Check the model's HF repo for `model_index.json`. This determines your path:
-
-| Scenario | How to identify | Migration path |
-|----------|----------------|----------------|
-| **Already supported** | `_class_name` in `model_index.json` matches a key in `_DIFFUSION_MODELS` in `registry.py` | Skip to Step 5 (test) and Step 7 (docs) |
-| **Diffusers-based** | Has standard `model_index.json` with `_diffusers_version`, subfolders for `transformer/`, `vae/`, etc. | Follow **Path A** below |
-| **Custom/private repo** | No diffusers `model_index.json`, weights in non-standard format, custom model code in a separate git repo | Follow **Path B** below |
-| **Hybrid** | Has some diffusers components (VAE) but custom transformer/fusion | Mix of Path A and Path B |
-
-## Path A: Diffusers-Based Model
-
-For models with a standard diffusers layout. See [references/transformer-adaptation.md](references/transformer-adaptation.md) for detailed code patterns.
-
-### A1. Analyze `model_index.json`
-
-Identify components: `transformer`, `scheduler`, `vae`, `text_encoder`, `tokenizer`.
-
-### A2. Create model directory
-
-```
-vllm_omni/diffusion/models/your_model_name/
-├── __init__.py
-├── pipeline_your_model.py
-└── your_model_transformer.py
-```
-
-### A3. Adapt transformer
-
-1. Copy from diffusers source. Remove mixins (`ModelMixin`, `ConfigMixin`, `AttentionModuleMixin`).
-2. Replace attention with `vllm_omni.diffusion.attention.layer.Attention` (QKV shape: `[B, seq, heads, head_dim]`).
-3. Add `od_config: OmniDiffusionConfig | None = None` to `__init__`.
-4. Add `load_weights()` method mapping diffusers weight names to vllm-omni names.
-5. Add class attributes: `_repeated_blocks`, `_layerwise_offload_blocks_attr`.
-
-### A4. Adapt pipeline
-
-Inherit from `nn.Module`. The key contract:
-
-```python
-class YourPipeline(nn.Module):
- def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
- # Load VAE, text encoder, tokenizer via from_pretrained()
- # Instantiate transformer (weights loaded later via weights_sources)
- self.weights_sources = [
- DiffusersPipelineLoader.ComponentSource(
- model_or_path=od_config.model, subfolder="transformer",
- prefix="transformer.", fall_back_to_pt=True)]
-
- def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
- # Encode prompt → prepare latents → denoise loop → VAE decode
- return DiffusionOutput(output=output)
-
- def load_weights(self, weights):
- return AutoWeightsLoader(self).load_weights(weights)
-```
-
-Add post/pre-process functions in the same pipeline file. Register them in `registry.py`.
-
-### A5. Register, test, docs → continue at Step 4 below.
-
----
-
-## Path B: Custom/Private Repo Model
-
-For models without a diffusers pipeline — weights in custom format, model code in a private repo. Real examples: DreamID-Omni, BAGEL, HunyuanImage3.
-
-### B1. Understand the reference repo
-
-Study the original model's code to identify:
-- **Model architecture files** (transformers, fusion modules, embeddings)
-- **Weight format** (safetensors, `.pth`, custom checkpoint structure)
-- **Weight loading helpers** (custom init functions, checkpoint loaders)
-- **Pre/post-processing** (image/audio transforms, tokenization, VAE encode/decode)
-- **External dependencies** (packages not on PyPI)
-- **Config format** (JSON config files, hardcoded dicts)
-
-### B2. Decide what lives WHERE
-
-This is the key design decision for custom models. Follow these placement rules:
-
-| Code type | Where to place | Example |
-|-----------|---------------|---------|
-| **Pipeline orchestration** (init, forward, denoise loop) | `vllm_omni/diffusion/models//pipeline_.py` | Always required |
-| **Custom transformer/backbone** (ported and adapted to vllm-omni) | `vllm_omni/diffusion/models//_transformer.py` or similar | `wan2_2.py`, `fusion.py`, `bagel_transformer.py` |
-| **Custom sub-models** (VAE, fusion, autoencoder) | `vllm_omni/diffusion/models//` as separate files | `autoencoder.py`, `fusion.py` |
-| **External dependency code** (original repo utilities) | **External repo**, installed via download script or pip | `dreamid_omni` package via git clone |
-| **Hardcoded model configs** | Module-level dicts in pipeline file | `VIDEO_CONFIG`, `AUDIO_CONFIG` dicts |
-| **Download/setup script** | `examples/offline_inference//download_.py` | `download_dreamid_omni.py` |
-| **Custom `model_index.json`** | Generated by download script, placed at model root | Minimal: `{"_class_name": "YourPipeline", ...}` |
-
-### B3. Handle external dependencies
-
-If the model's code lives in a separate git repo:
-
-**Option 1: Import with graceful fallback** (recommended for models with external utils)
-
-```python
-try:
- from external_model.utils import init_vae, load_checkpoint
-except ImportError:
- raise ImportError(
- "Failed to import from dependency 'external_model'. "
- "Please run the download script first."
- )
-```
-
-**Option 2: Port the code directly** (preferred when feasible)
-
-Copy the essential model files into `vllm_omni/diffusion/models//` and adapt them. This avoids external dependencies. BAGEL does this — `autoencoder.py` and `bagel_transformer.py` are ported directly.
-
-**Decision criteria**: Port if the code is self-contained and won't diverge. Use external deps if the model repo is actively maintained and the code is complex.
-
-### B4. Handle custom weight loading
-
-Custom models have two patterns for weight loading:
-
-**Pattern 1: Bypass standard loader** (DreamID-Omni style)
-
-When the original model has complex custom init functions that load weights in `__init__`:
-
-```python
-class CustomPipeline(nn.Module):
- def __init__(self, *, od_config, prefix=""):
- super().__init__()
- model = od_config.model
- # Load everything eagerly in __init__ using custom helpers
- self.vae = custom_init_vae(model, device=self.device)
- self.text_encoder = custom_init_text_encoder(model, device=self.device)
- self.transformer = CustomFusionModel(CONFIG)
- load_custom_checkpoint(self.transformer,
- checkpoint_path=os.path.join(model, "model.safetensors"))
- # NO weights_sources defined — bypasses standard loader
-
- def load_weights(self, weights):
- pass # No-op — all weights loaded in __init__
-```
-
-**Pattern 2: Use standard loader with custom `load_weights`** (BAGEL style)
-
-When weights are in safetensors format but need name remapping:
-
-```python
-class CustomPipeline(nn.Module):
- def __init__(self, *, od_config, prefix=""):
- super().__init__()
- # Instantiate model architecture without weights
- self.bagel = BagelModel(config)
- self.vae = AutoEncoder(ae_params)
-
- # Point loader at the safetensors in the model root
- self.weights_sources = [
- DiffusersPipelineLoader.ComponentSource(
- model_or_path=od_config.model,
- subfolder=None, # weights at root, not in subfolder
- prefix="",
- fall_back_to_pt=False,
- )
- ]
-
- def load_weights(self, weights):
- # Custom name remapping for non-diffusers weight names
- params = dict(self.named_parameters())
- loaded = set()
- for name, tensor in weights:
- # Remap original weight names to vllm-omni module names
- name = self._remap_weight_name(name)
- if name in params:
- default_weight_loader(params[name], tensor)
- loaded.add(name)
- return loaded
-```
-
-### B5. Create the `model_index.json`
-
-Custom models need a `model_index.json` at the model root for vllm-omni to discover them. For custom models, this is minimal:
-
-```json
-{
- "_class_name": "YourModelPipeline",
- "custom_key": "path/to/custom_weights.safetensors"
-}
-```
-
-The `_class_name` must match a key in `_DIFFUSION_MODELS` in `registry.py`. Additional keys are model-specific (accessed via `od_config.model_config`).
-
-If the model's weights come from multiple HF repos, write a **download script** that:
-1. Downloads from each repo
-2. Assembles into a single directory
-3. Generates `model_index.json`
-4. Installs any external dependencies (git clone + `.pth` file)
-
-Place at: `examples/offline_inference//download_.py`
-
-### B6. Handle multi-modal inputs
-
-If the model accepts images, audio, or other multi-modal inputs, implement the protocol classes from `vllm_omni/diffusion/models/interface.py`:
-
-```python
-from vllm_omni.diffusion.models.interface import SupportImageInput, SupportAudioInput
-
-class MyPipeline(nn.Module, SupportImageInput, SupportAudioInput):
- # Protocol markers — the engine uses these to enable proper input routing
- pass
-```
-
-Preprocessing for custom models is typically done **inside `forward()`** rather than via registered pre-process functions, since the logic is often tightly coupled to the model.
-
-### B7. Continue at Step 4 below.
-
----
-
-## Common Steps (Both Paths)
-
-### Step 4: Register Model in registry.py
-
-Edit `vllm_omni/diffusion/registry.py`:
-
-```python
-_DIFFUSION_MODELS = {
- "YourModelPipeline": ("your_model_name", "pipeline_your_model", "YourModelPipeline"),
-}
-_DIFFUSION_POST_PROCESS_FUNCS = {
- "YourModelPipeline": "get_your_model_post_process_func", # if applicable
-}
-_DIFFUSION_PRE_PROCESS_FUNCS = {
- "YourModelPipeline": "get_your_model_pre_process_func", # if applicable
-}
-```
-
-The registry key is the `_class_name` from `model_index.json`. The tuple is `(folder_name, module_file, class_name)`.
-
-Create `__init__.py` exporting the pipeline class and any factory functions.
-
-### Step 5: Run, Test, Debug
-
-Use the appropriate existing example script:
-
-| Category | Script |
-|----------|--------|
-| Text-to-Image | `examples/offline_inference/text_to_image/text_to_image.py` |
-| Text-to-Video | `examples/offline_inference/text_to_video/text_to_video.py` |
-| Image-to-Video | `examples/offline_inference/image_to_video/image_to_video.py` |
-| Image-to-Image | `examples/offline_inference/image_to_image/image_edit.py` |
-| Text-to-Audio | `examples/offline_inference/text_to_audio/text_to_audio.py` |
-
-For custom/Omni models that don't fit these categories, create a dedicated example script.
-
-**Validation**: No errors, output is meaningful, quality matches reference implementation.
-
-See [references/troubleshooting.md](references/troubleshooting.md) for common errors.
-
-### Step 6: Add Example Scripts
-
-For Omni or custom models, create:
-- `examples/offline_inference/your_model_name/` — offline script + README
-- `examples/online_serving/your_model_name/` — server script + client
-- Download script if weights require assembly from multiple sources
-
-### Step 7: Update Documentation
-
-Required updates:
-1. `docs/user_guide/diffusion/parallelism_acceleration.md` — parallelism support table
-2. `docs/user_guide/diffusion/cpu_offload_diffusion.md` — if CPU offload supported (add to supported models table)
-3. `docs/user_guide/diffusion/teacache.md` — if TeaCache supported
-4. `docs/user_guide/diffusion/cache_dit_acceleration.md` — if Cache-DiT supported
-5. `examples/offline_inference/xxx/README.md` — offline example docs
-6. `examples/online_serve/xxx/README.md` — online serve docs
-
-### Step 8: Add E2E Tests (Recommended)
-
-Create `tests/e2e/online_serving/test_your_model_expansion.py`.
-
-### Step 9: Add Cache-DiT Acceleration
-
-Cache-DiT accelerates inference by caching intermediate computation results across denoising steps. After your model is working correctly on a single GPU, add cache-dit support.
-
-See [references/cache-dit-patterns.md](references/cache-dit-patterns.md) for detailed code patterns.
-
-#### 9a. Determine your model type
-
-| Model Type | Description | Action |
-|------------|-------------|--------|
-| **Standard single-transformer** | One transformer with one `ModuleList` of blocks | No code needed — `CacheDiTBackend` auto-detects via `enable_cache_for_dit()` |
-| **Multi-block-list** | One transformer with multiple block lists (e.g., `transformer_blocks` + `single_transformer_blocks`) | Write custom enabler with `BlockAdapter` |
-| **Dual-transformer** | Two transformers (e.g., high-noise + low-noise) | Write custom enabler with `BlockAdapter` wrapping both |
-
-#### 9b. Standard models — verify automatic support
-
-For standard single-transformer models, test directly:
-
-```python
-omni = Omni(
- model="your-model-name",
- cache_backend="cache_dit",
- cache_config={
- "Fn_compute_blocks": 1,
- "Bn_compute_blocks": 0,
- "max_warmup_steps": 4,
- }
-)
-```
-
-Check logs for "Cache-dit enabled successfully on xxx". If it works, skip to Step 9e.
-
-#### 9c. Custom architectures — write a custom enabler
-
-For multi-block-list or dual-transformer models, write a custom enabler function:
-
-```python
-from cache_dit import BlockAdapter, ForwardPattern, ParamsModifier, DBCacheConfig
-
-def enable_cache_for_your_model(pipeline, cache_config):
- db_cache_config = DBCacheConfig(
- num_inference_steps=None,
- Fn_compute_blocks=cache_config.Fn_compute_blocks,
- Bn_compute_blocks=cache_config.Bn_compute_blocks,
- max_warmup_steps=cache_config.max_warmup_steps,
- max_cached_steps=cache_config.max_cached_steps,
- max_continuous_cached_steps=cache_config.max_continuous_cached_steps,
- residual_diff_threshold=cache_config.residual_diff_threshold,
- )
-
- cache_dit.enable_cache(
- BlockAdapter(
- transformer=pipeline.transformer,
- blocks=[
- pipeline.transformer.transformer_blocks,
- pipeline.transformer.single_transformer_blocks,
- ],
- forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1],
- params_modifiers=[ParamsModifier(...)],
- ),
- cache_config=db_cache_config,
- )
-
- def refresh_cache_context(pipeline, num_inference_steps, verbose=True):
- cache_dit.refresh_context(
- pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose
- )
- return refresh_cache_context
-```
-
-#### 9d. Register the custom enabler
-
-Add your enabler to `CUSTOM_DIT_ENABLERS` in `vllm_omni/diffusion/cache/cache_dit_backend.py`:
-
-```python
-CUSTOM_DIT_ENABLERS = {
- "Wan22Pipeline": enable_cache_for_wan22,
- "LongCatImagePipeline": enable_cache_for_longcat_image,
- "YourModelPipeline": enable_cache_for_your_model, # Add here
-}
-```
-
-#### 9e. Test Cache-DiT
-
-```python
-omni = Omni(
- model="your-model-name",
- cache_backend="cache_dit",
- cache_config={
- "Fn_compute_blocks": 1, "Bn_compute_blocks": 0,
- "max_warmup_steps": 4, "residual_diff_threshold": 0.24,
- }
-)
-images = omni.generate("a beautiful landscape",
- OmniDiffusionSamplingParams(num_inference_steps=50))
-```
-
-**Verify**: 1) logs show cache enabled, 2) 1.5-2x speedup, 3) output quality acceptable vs baseline.
-
-If quality degrades, lower `residual_diff_threshold` (try 0.12-0.18) or increase `max_warmup_steps` (try 6-8).
-
----
-
-### Step 10: Add Parallelism Support
-
-After the model works on a single GPU, add multi-GPU parallelism. Add each type incrementally, testing after each addition.
-
-See [references/parallelism-patterns.md](references/parallelism-patterns.md) for detailed code patterns and API reference.
-
-**Recommended order**: TP → SP/USP → CFG Parallel → HSDP
-
-#### 10a. Tensor Parallelism (TP)
-
-Shards DiT linear layers across GPUs. Requires code changes in the transformer.
-
-**What to change in the transformer**:
-1. Replace `nn.Linear` with `ColumnParallelLinear` / `RowParallelLinear` / `QKVParallelLinear`
-2. Update `load_weights()` to handle QKV fusion with `stacked_params_mapping`
-3. Use `self.to_qkv.num_heads` (local heads) instead of total heads for split sizes
-
-```python
-from vllm.model_executor.layers.linear import (
- QKVParallelLinear, RowParallelLinear, ColumnParallelLinear,
-)
-
-# Attention: QKV → RowParallel output
-self.to_qkv = QKVParallelLinear(dim, head_dim, num_heads, num_kv_heads)
-self.to_out = RowParallelLinear(dim, dim, input_is_parallel=True)
-
-# FFN: ColumnParallel → RowParallel
-self.w1 = ColumnParallelLinear(dim, ffn_dim)
-self.w2 = RowParallelLinear(ffn_dim, dim, input_is_parallel=True)
-```
-
-**Constraints**: `num_heads % tp_size == 0` and `num_kv_heads % tp_size == 0`.
-
-**Test**: `--tensor-parallel-size 2`
-
-#### 10b. Sequence Parallelism (SP / USP)
-
-Splits sequence tokens across GPUs. Non-intrusive via `_sp_plan` on the transformer class — no changes to `forward()`.
-
-**What to change in the transformer**:
-
-Add `_sp_plan` class attribute:
-
-```python
-from vllm_omni.diffusion.distributed.sp_plan import (
- SequenceParallelInput, SequenceParallelOutput,
-)
-
-class YourTransformer(nn.Module):
- _sp_plan = {
- "blocks.0": {
- "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
- },
- "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
- }
-```
-
-If inline tensor ops (e.g., `torch.cat`) exist between shard/gather points, extract them into `nn.Module` submodules so hooks can intercept them.
-
-For RoPE that needs splitting, add an entry for the RoPE module with `split_output=True`.
-
-**Test**: `--ulysses-degree 2` (offline) or `--usp 2` (online serving)
-
-#### 10c. CFG Parallel
-
-Distributes positive/negative CFG branches across 2 GPUs. Requires the pipeline to inherit `CFGParallelMixin`.
-
-**What to change in the pipeline**:
-
-```python
-from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
-
-class YourPipeline(nn.Module, CFGParallelMixin):
- def diffuse(self, ...) -> torch.Tensor:
- for i, t in enumerate(timesteps):
- positive_kwargs = {...}
- negative_kwargs = {...} if do_true_cfg else None
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg, true_cfg_scale=cfg_scale,
- positive_kwargs=positive_kwargs, negative_kwargs=negative_kwargs,
- )
- latents = self.scheduler_step_maybe_with_cfg(
- noise_pred, t, latents, do_true_cfg
- )
- return latents
-```
-
-Override `predict_noise()` if your transformer call is non-standard. Override `combine_cfg_noise()` for multi-output models (e.g., video + audio).
-
-**Constraint**: Exactly 2 GPUs. Only for models using classifier-free guidance.
-
-**Test**: `--cfg-parallel-size 2`
-
-#### 10d. HSDP (Hybrid Sharded Data Parallel)
-
-Shards transformer weights via PyTorch FSDP2 to reduce per-GPU VRAM. No code changes to the forward pass — just add a class attribute.
-
-**What to change in the transformer**:
-
-```python
-class YourTransformer(nn.Module):
- @staticmethod
- def _is_transformer_block(name: str, module) -> bool:
- return "blocks" in name and name.split(".")[-1].isdigit()
-
- _hsdp_shard_conditions = [_is_transformer_block]
-```
-
-**Constraint**: Cannot combine with TP. For standalone HSDP, set `hsdp_shard_size` explicitly.
-
-**Test**: `--use-hsdp` or `DiffusionParallelConfig(use_hsdp=True)`
-
-#### 10e. Update parallelism documentation
-
-After adding parallelism support, update:
-1. `docs/user_guide/diffusion/parallelism_acceleration.md` — add your model to the support table
-2. Record which parallelism methods are supported (USP, Ring, CFG, TP, HSDP, VAE-Patch)
-
-### Step 11: Add CPU Offload Support
-
-Implement `SupportsModuleOffload` on your pipeline class to enable
-`--enable-cpu-offload` and `--enable-layerwise-offload`. The protocol
-declares which submodules the offloader should manage:
-
-```python
-from typing import ClassVar
-from vllm_omni.diffusion.models.interface import SupportsModuleOffload
-
-class YourPipeline(nn.Module, SupportsModuleOffload):
- _dit_modules: ClassVar[list[str]] = ["transformer"]
- _encoder_modules: ClassVar[list[str]] = ["text_encoder"]
- _vae_modules: ClassVar[list[str]] = ["vae"]
- _resident_modules: ClassVar[list[str]] = [] # optional
-```
-
-- `_dit_modules`: denoising submodules (kept on GPU during diffusion loop)
-- `_encoder_modules`: encoder/vision submodules (offloaded to CPU during diffusion loop)
-- `_vae_modules`: VAE(s) (handled by both sequential and layerwise backends)
-- `_resident_modules`: additional modules to pin on GPU during layerwise
- offloading (e.g. embedders, connectors). Only used by the layerwise
- backend. Optional — defaults to `[]`.
-
-All attribute names support dotted paths for nested submodules
-(e.g. `"pipe.transformer"`, `"bagel.time_embedder"`).
-
-Pipelines without `SupportsModuleOffload` fall back to scanning
-well-known attribute names (`transformer`, `text_encoder`, `vae`,
-etc.), which fails for non-standard names.
-
----
-
-## Iterative Development Tips
-
-1. **Start minimal**: Basic generation first, no parallelism/caching
-2. **Use `--enforce-eager`**: Disable torch.compile during debugging
-3. **Use small models**: Test with smaller variants first
-4. **Check tensor shapes**: Most errors are reshape mismatches in attention
-5. **Add features incrementally**: Single GPU → TP → SP → CFG → HSDP → Cache-DiT
-6. **For custom models**: Get the model running with the original code first, then progressively replace components with vllm-omni equivalents
-7. **Cache-DiT before parallelism tuning**: Cache-DiT is lossy — verify quality at baseline before combining with parallelism
-8. **Combine lossless + lossy**: e.g., TP + SP + Cache-DiT for maximum throughput
-
-## Reference Files
-
-- [Transformer Adaptation](references/transformer-adaptation.md) — porting transformers from diffusers
-- [Custom Model Patterns](references/custom-model-patterns.md) — patterns for non-diffusers models
-- [Parallelism Patterns](references/parallelism-patterns.md) — TP, SP/USP, CFG parallel, HSDP implementation details
-- [Cache-DiT Patterns](references/cache-dit-patterns.md) — cache-dit acceleration for standard and custom architectures
-- [Troubleshooting](references/troubleshooting.md) — common errors and fixes
diff --git a/.claude/skills/add-diffusion-model/references/cache-dit-patterns.md b/.claude/skills/add-diffusion-model/references/cache-dit-patterns.md
deleted file mode 100644
index d34ce0e0f43..00000000000
--- a/.claude/skills/add-diffusion-model/references/cache-dit-patterns.md
+++ /dev/null
@@ -1,254 +0,0 @@
-# Cache-DiT Patterns Reference
-
-## Overview
-
-Cache-DiT accelerates Diffusion Transformers by caching intermediate computation results across denoising steps. Adjacent steps produce similar features, so redundant computations can be skipped.
-
-Three caching strategies:
-- **DBCache**: Dynamic block-level caching — selectively computes or caches transformer blocks based on residual differences
-- **TaylorSeer**: Calibration-based prediction using Taylor expansion to estimate block outputs
-- **SCM** (Step Computation Masking): Dynamic step skipping based on configurable policies
-
-**Typical speedup**: 1.5-2.5x depending on model and configuration.
-
-**Official docs**: https://docs.vllm.ai/projects/vllm-omni/en/latest/design/feature/cache_dit
-
-## Architecture
-
-vLLM-Omni integrates cache-dit through `CacheDiTBackend`:
-
-| Component | Purpose |
-|-----------|---------|
-| `CacheDiTBackend` | Unified backend — auto-selects enabler (standard or custom) |
-| `enable_cache_for_dit()` | Default enabler for standard single-transformer models |
-| `CUSTOM_DIT_ENABLERS` dict | Registry of custom enablers keyed by pipeline class name |
-| `BlockAdapter` | Wraps complex architectures (multi-block-list or multi-transformer) |
-| `ForwardPattern` | Specifies block forward signature: `Pattern_0`, `Pattern_1`, `Pattern_2` |
-| `ParamsModifier` | Per-transformer or per-block-list config customization |
-| `DBCacheConfig` | Configuration for DBCache parameters |
-| `cache_dit.refresh_context()` | Updates cache context when `num_inference_steps` changes |
-
-**Source files**:
-- `vllm_omni/diffusion/cache/cache_dit_backend.py` — `CacheDiTBackend`, enablers, `CUSTOM_DIT_ENABLERS`
-- `vllm_omni/diffusion/cache/` — cache backend implementations
-
-## Standard Models: Automatic Support
-
-Most DiT models follow this pattern:
-- Single transformer with one `nn.ModuleList` of blocks
-- Standard forward signature
-- Compatible with cache-dit's automatic detection
-
-**Examples**: Qwen-Image, Z-Image, FLUX
-
-No code changes needed. `CacheDiTBackend` automatically uses `enable_cache_for_dit()`:
-
-```python
-from vllm_omni import Omni
-
-omni = Omni(
- model="Qwen/Qwen-Image",
- cache_backend="cache_dit",
- cache_config={
- "Fn_compute_blocks": 1,
- "Bn_compute_blocks": 0,
- "max_warmup_steps": 4,
- }
-)
-```
-
-What happens automatically:
-
-```python
-def enable_cache_for_dit(pipeline, cache_config):
- db_cache_config = DBCacheConfig(
- num_inference_steps=None,
- Fn_compute_blocks=cache_config.Fn_compute_blocks,
- Bn_compute_blocks=cache_config.Bn_compute_blocks,
- max_warmup_steps=cache_config.max_warmup_steps,
- max_cached_steps=cache_config.max_cached_steps,
- max_continuous_cached_steps=cache_config.max_continuous_cached_steps,
- residual_diff_threshold=cache_config.residual_diff_threshold,
- )
-
- cache_dit.enable_cache(pipeline.transformer, cache_config=db_cache_config)
-
- def refresh_cache_context(pipeline, num_inference_steps, verbose=True):
- cache_dit.refresh_context(
- pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose
- )
- return refresh_cache_context
-```
-
-## Custom Architectures: Writing Custom Enablers
-
-### When you need a custom enabler
-
-- Model has multiple block lists in one transformer (e.g., `transformer_blocks` + `single_transformer_blocks`)
-- Model has two transformers (e.g., high-noise + low-noise like Wan2.2)
-- Model uses non-standard block forward signature
-
-### Pattern 1: Multi-Block-List (LongCat-Image style)
-
-Single transformer with two block lists:
-
-```python
-import cache_dit
-from cache_dit import BlockAdapter, ForwardPattern, ParamsModifier, DBCacheConfig
-
-def enable_cache_for_your_model(pipeline, cache_config):
- db_cache_config = DBCacheConfig(
- num_inference_steps=None,
- Fn_compute_blocks=cache_config.Fn_compute_blocks,
- Bn_compute_blocks=cache_config.Bn_compute_blocks,
- max_warmup_steps=cache_config.max_warmup_steps,
- max_cached_steps=cache_config.max_cached_steps,
- max_continuous_cached_steps=cache_config.max_continuous_cached_steps,
- residual_diff_threshold=cache_config.residual_diff_threshold,
- )
-
- cache_dit.enable_cache(
- BlockAdapter(
- transformer=pipeline.transformer,
- blocks=[
- pipeline.transformer.transformer_blocks,
- pipeline.transformer.single_transformer_blocks,
- ],
- forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1],
- params_modifiers=[ParamsModifier(...)],
- ),
- cache_config=db_cache_config,
- )
-
- def refresh_cache_context(pipeline, num_inference_steps, verbose=True):
- cache_dit.refresh_context(
- pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose
- )
- return refresh_cache_context
-```
-
-For single transformer with multiple block lists, `refresh_context` works the same as standard models — call it once on the transformer.
-
-### Pattern 2: Dual-Transformer (Wan2.2 style)
-
-Two transformers with separate configs:
-
-```python
-def enable_cache_for_dual_transformer(pipeline, cache_config):
- db_cache_config = DBCacheConfig(...)
-
- cache_dit.enable_cache(
- BlockAdapter(
- transformer=[pipeline.transformer, pipeline.transformer_2],
- blocks=[pipeline.transformer.blocks, pipeline.transformer_2.blocks],
- forward_pattern=[ForwardPattern.Pattern_2, ForwardPattern.Pattern_2],
- params_modifiers=[
- ParamsModifier(...), # Config for transformer 1
- ParamsModifier(...), # Config for transformer 2
- ],
- ),
- cache_config=db_cache_config,
- )
-
- def refresh_cache_context(pipeline, num_inference_steps, verbose=True):
- high_steps, low_steps = _split_inference_steps(num_inference_steps)
- cache_dit.refresh_context(
- pipeline.transformer, num_inference_steps=high_steps, verbose=verbose
- )
- cache_dit.refresh_context(
- pipeline.transformer_2, num_inference_steps=low_steps, verbose=verbose
- )
- return refresh_cache_context
-```
-
-Key difference: `refresh_context` must be called on **each transformer separately** with its own step count.
-
-### Choosing the ForwardPattern
-
-| Pattern | Block forward signature | Example models |
-|---------|------------------------|----------------|
-| `Pattern_0` | `block(hidden_states, **kwargs)` → residual added inside block | Default |
-| `Pattern_1` | `block(hidden_states, **kwargs)` → returns `(hidden_states, ...)` tuple | FLUX-style single blocks |
-| `Pattern_2` | `block(hidden_states, **kwargs)` → `(hidden_states, ...)` with different residual pattern | Wan2.2 blocks |
-
-Inspect your block's `forward()` return type and residual connection pattern to choose the right one. See [Cache-DiT API Reference](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/) for details.
-
-## Registering Custom Enablers
-
-Add your enabler to `CUSTOM_DIT_ENABLERS` in `vllm_omni/diffusion/cache/cache_dit_backend.py`:
-
-```python
-CUSTOM_DIT_ENABLERS = {
- "Wan22Pipeline": enable_cache_for_wan22,
- "LongCatImagePipeline": enable_cache_for_longcat_image,
- "YourModelPipeline": enable_cache_for_your_model,
-}
-```
-
-The key must match `pipeline.__class__.__name__`.
-
-## Configuration Parameters
-
-| Parameter | Default | Description |
-|-----------|---------|-------------|
-| `Fn_compute_blocks` | 1 | Number of blocks to always compute at the front |
-| `Bn_compute_blocks` | 0 | Number of blocks to always compute at the back |
-| `max_warmup_steps` | 4 | Steps to run without caching at the beginning |
-| `max_cached_steps` | — | Max total cached steps |
-| `max_continuous_cached_steps` | — | Max consecutive cached steps |
-| `residual_diff_threshold` | 0.24 | Threshold for deciding whether to cache a block |
-
-### Tuning for quality vs speed
-
-| Goal | Adjustments |
-|------|-------------|
-| **More speed, acceptable quality loss** | Higher `residual_diff_threshold` (0.24-0.4), lower `max_warmup_steps` (2-4) |
-| **Better quality, less speed** | Lower `residual_diff_threshold` (0.12-0.18), higher `max_warmup_steps` (6-8), lower `max_continuous_cached_steps` (2) |
-
-## Testing
-
-```python
-from vllm_omni import Omni
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-omni = Omni(
- model="your-model-name",
- cache_backend="cache_dit",
- cache_config={
- "Fn_compute_blocks": 1,
- "Bn_compute_blocks": 0,
- "max_warmup_steps": 4,
- "residual_diff_threshold": 0.24,
- }
-)
-images = omni.generate(
- "a beautiful landscape",
- OmniDiffusionSamplingParams(num_inference_steps=50),
-)
-```
-
-CLI (online serving):
-
-```bash
-vllm serve your-model --omni --port 8098 \
- --cache-backend cache_dit \
- --cache-config '{"Fn_compute_blocks": 1, "Bn_compute_blocks": 0, "max_warmup_steps": 4}'
-```
-
-**Verification checklist**:
-1. Logs show "Cache-dit enabled successfully on xxx"
-2. Performance: 1.5-2x speedup vs no cache
-3. Quality: compare output with `cache_backend=None`
-
-## Excluded Models
-
-Models listed in `_NO_CACHE_ACCELERATION` in `vllm_omni/diffusion/registry.py` do not support cache-dit (e.g., `NextStep11Pipeline`, `StableDiffusionPipeline`). Check this set before attempting to enable cache-dit.
-
-## Reference Implementations
-
-| Model | Path | Notes |
-|-------|------|-------|
-| Standard DiT | `cache_dit_backend.py::enable_cache_for_dit` | Default enabler, automatic |
-| Wan2.2 | `cache_dit_backend.py::enable_cache_for_wan22` | Dual-transformer, auto-detects mode |
-| LongCat | `cache_dit_backend.py::enable_cache_for_longcat_image` | Multi-block-list |
-| BAGEL | `cache_dit_backend.py::enable_cache_for_bagel` | Complex omni model |
diff --git a/.claude/skills/add-diffusion-model/references/custom-model-patterns.md b/.claude/skills/add-diffusion-model/references/custom-model-patterns.md
deleted file mode 100644
index 2434e0b5da0..00000000000
--- a/.claude/skills/add-diffusion-model/references/custom-model-patterns.md
+++ /dev/null
@@ -1,273 +0,0 @@
-# Custom Model Patterns Reference
-
-Patterns for adding models that don't come from the standard diffusers pipeline format.
-
-## Directory Structure Comparison
-
-### Diffusers-based model (e.g., Wan2.2)
-
-```
-vllm_omni/diffusion/models/wan2_2/
-├── __init__.py # Exports pipeline + transformer + helpers
-├── pipeline_wan2_2.py # Pipeline: loads components via from_pretrained()
-├── pipeline_wan2_2_i2v.py # Variant pipeline for image-to-video
-└── wan2_2_transformer.py # Transformer: ported from diffusers, uses Attention layer
-```
-
-The transformer is loaded separately via `weights_sources` + `load_weights()`. Non-transformer components (VAE, text encoder) are loaded in `__init__` via `from_pretrained()`.
-
-### Custom model with external deps (e.g., DreamID-Omni)
-
-```
-vllm_omni/diffusion/models/dreamid_omni/
-├── __init__.py # Exports pipeline only
-├── pipeline_dreamid_omni.py # Pipeline: loads ALL weights in __init__ via custom helpers
-├── fusion.py # Custom fusion architecture (video + audio cross-attention)
-└── wan2_2.py # Re-implemented Wan backbone with split API
-
-examples/offline_inference/x_to_video_audio/
-└── download_dreamid_omni.py # Downloads weights from 3 HF repos + clones code repo
-```
-
-All weights loaded eagerly in `__init__`. `load_weights()` is a no-op. External dependency (`dreamid_omni` package) imported with try/except.
-
-### Custom model with ported code (e.g., BAGEL)
-
-```
-vllm_omni/diffusion/models/bagel/
-├── __init__.py
-├── pipeline_bagel.py # Pipeline: instantiates models, uses weights_sources
-├── bagel_transformer.py # Full LLM backbone (Qwen2-MoT) ported into vllm-omni
-└── autoencoder.py # Custom VAE ported from original repo
-```
-
-Model code is fully ported (no external dependency). Uses `weights_sources` and `load_weights()` with custom name remapping to handle non-diffusers safetensors format.
-
-## Weight Loading Patterns
-
-### Pattern 1: Standard diffusers flow (Wan2.2, Z-Image, FLUX)
-
-```
-init → create transformer (empty) → set weights_sources → [loader calls load_weights()]
-```
-
-- `weights_sources` points to safetensors in HF subfolder (e.g., `transformer/`)
-- `load_weights()` receives `(name, tensor)` pairs from the loader
-- Name remapping handles diffusers→vllm-omni differences (QKV fusion, Sequential index removal)
-
-### Pattern 2: Custom safetensors at root (BAGEL)
-
-```
-init → create all models (empty) → set weights_sources(subfolder=None) → [loader calls load_weights()]
-```
-
-- `weights_sources` points to **root** of model directory, not a subfolder
-- Weights have non-diffusers names (e.g., `bagel.language_model.model.layers.0.self_attn.q_proj.weight`)
-- `load_weights()` does heavy name normalization
-
-```python
-self.weights_sources = [
- DiffusersPipelineLoader.ComponentSource(
- model_or_path=od_config.model,
- subfolder=None, # root directory
- prefix="", # no prefix stripping
- fall_back_to_pt=False,
- )
-]
-```
-
-### Pattern 3: Fully custom loading (DreamID-Omni)
-
-```
-init → load ALL weights eagerly via custom helpers → load_weights() = no-op
-```
-
-- No `weights_sources` attribute — standard loader finds nothing to iterate
-- Custom init functions (e.g., `init_wan_vae_2_2()`, `load_fusion_checkpoint()`) handle downloading and loading
-- `load_weights()` is `pass`
-- Weights may come from multiple HF repos in different formats (`.pth`, `.safetensors`)
-
-Use this when:
-- The original model has complex, well-tested loading code you don't want to rewrite
-- Weights span multiple HF repos
-- Weight format is non-standard (e.g., a single `.pth` file, not sharded safetensors)
-
-## model_index.json for Custom Models
-
-Standard diffusers `model_index.json`:
-```json
-{
- "_class_name": "WanPipeline",
- "_diffusers_version": "0.35.0.dev0",
- "scheduler": ["diffusers", "UniPCMultistepScheduler"],
- "transformer": ["diffusers", "WanTransformer3DModel"],
- "vae": ["diffusers", "AutoencoderKLWan"]
-}
-```
-
-Custom model `model_index.json` (minimal):
-```json
-{
- "_class_name": "DreamIDOmniPipeline",
- "fusion": "DreamID-Omni/dreamid_omni.safetensors"
-}
-```
-
-The only **required** field is `_class_name` — it must match a key in `_DIFFUSION_MODELS` in `registry.py`. Other fields are model-specific and accessible via `od_config.model_config` dict.
-
-## External Dependency Management
-
-### Git clone + .pth injection (DreamID-Omni pattern)
-
-```python
-def download_dependency():
- CACHE_DIR.mkdir(parents=True, exist_ok=True)
- with open(LOCK_FILE, "w") as f:
- fcntl.flock(f, fcntl.LOCK_EX)
- if not DEPENDENCY_DIR.exists():
- subprocess.run([
- "git", "clone", "--depth", "1",
- REPO_URL, "--branch", BRANCH,
- str(DEPENDENCY_DIR)
- ], check=True)
- fcntl.flock(f, fcntl.LOCK_UN)
-
- # Add to Python path via .pth file
- site_packages = Path(site.getsitepackages()[0])
- pth_file = site_packages / "vllm_omni_dependency.pth"
- pth_file.write_text(str(DEPENDENCY_DIR))
-```
-
-### Direct port (BAGEL pattern)
-
-Copy essential files from the original repo into `vllm_omni/diffusion/models//`. Adapt imports to use vllm-omni utilities. Benefits: no external dependency, no git clone step. Drawback: must maintain the ported code.
-
-## Multi-Modal Input/Output Protocols
-
-Custom models that handle images, audio, or video I/O should implement protocol classes:
-
-```python
-from vllm_omni.diffusion.models.interface import (
- SupportImageInput, # Model accepts image input
- SupportAudioInput, # Model accepts audio input
- SupportAudioOutput, # Model produces audio output
-)
-
-class MyPipeline(nn.Module, SupportImageInput, SupportAudioInput, SupportAudioOutput):
- pass # Protocol markers enable proper engine routing
-```
-
-The engine checks `isinstance(pipeline, SupportImageInput)` at startup to configure input validation and warmup behavior.
-
-## Hardcoded Config vs Config Files
-
-Diffusers models use `config.json` in each subfolder. Custom models often use:
-
-**Module-level config dicts** (DreamID-Omni):
-```python
-VIDEO_CONFIG = {
- "patch_size": [1, 2, 2], "model_type": "ti2v",
- "dim": 3072, "ffn_dim": 14336, "num_heads": 24, "num_layers": 30, ...
-}
-```
-
-**Loaded from custom JSON** (BAGEL):
-```python
-cfg_path = os.path.join(model_path, "config.json")
-with open(cfg_path) as f:
- bagel_cfg = json.load(f)
-vae_cfg = bagel_cfg.get("vae_config", {})
-```
-
-## Custom Architecture Patterns
-
-### Split forward API (DreamID-Omni)
-
-When a fusion model needs to interleave blocks from two backbones:
-
-```python
-class WanModel(nn.Module):
- def prepare_transformer_block_kwargs(self, x, t, context, ...):
- # Patch embed, time embed, text embed, RoPE
- return x, e, kwargs
-
- def post_transformer_block_out(self, x, grid_sizes, e):
- # Output projection, unpatchify
- return output
-
- def forward(self, *args, **kwargs):
- raise NotImplementedError # Fusion model handles block iteration
-```
-
-The `FusionModel` then iterates blocks in lock-step:
-```python
-for video_block, audio_block in zip(self.video_model.blocks, self.audio_model.blocks):
- video_out = video_block(video_hidden, ...)
- audio_out = audio_block(audio_hidden, ...)
- # Cross-attend between modalities
- video_out = cross_attention(video_out, audio_out)
- audio_out = cross_attention(audio_out, video_out)
-```
-
-### LLM-as-denoiser (BAGEL)
-
-When the backbone is a language model that also does diffusion:
-
-```python
-class BagelModel(nn.Module):
- def __init__(self):
- self.language_model = Qwen2MoTForCausalLM(config)
- self.vit_model = SiglipVisionModel(vit_config)
-```
-
-The LLM processes both text tokens and latent image tokens in a single forward pass, using KV caching for the text portion.
-
-## Pre/Post Processing for Custom Models
-
-Custom models typically handle pre/post processing **inside `forward()`** rather than via registered functions, because the logic is tightly coupled:
-
-```python
-def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
- # Inline preprocessing
- image = self._load_and_resize_image(req.prompts[0].get("multi_modal_data", {}).get("image"))
- image_latent = self._vae_encode(image)
-
- # ... denoising loop ...
-
- # Inline postprocessing
- pil_image = self._decode_to_pil(latents)
- return DiffusionOutput(output=[pil_image])
-```
-
-If pre/post functions are not registered in `_DIFFUSION_PRE_PROCESS_FUNCS` / `_DIFFUSION_POST_PROCESS_FUNCS`, the engine simply skips those steps.
-
-## Download Script Template
-
-```python
-# examples/offline_inference//download_.py
-from huggingface_hub import snapshot_download
-import json, os
-
-def main(output_dir):
- # Download model weights from HF
- snapshot_download(repo_id="org/model-weights", local_dir=os.path.join(output_dir, "weights"))
-
- # Download additional components if from separate repos
- snapshot_download(repo_id="org/vae-weights", local_dir=os.path.join(output_dir, "vae"),
- allow_patterns=["*.safetensors"])
-
- # Generate model_index.json
- config = {"_class_name": "YourPipeline", "custom_key": "weights/model.safetensors"}
- with open(os.path.join(output_dir, "model_index.json"), "w") as f:
- json.dump(config, f, indent=2)
-
- # Install external code dependency (if needed)
- download_dependency()
-
-if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument("--output-dir", default="./your_model")
- args = parser.parse_args()
- main(args.output_dir)
-```
diff --git a/.claude/skills/add-diffusion-model/references/parallelism-patterns.md b/.claude/skills/add-diffusion-model/references/parallelism-patterns.md
deleted file mode 100644
index 933e2d23204..00000000000
--- a/.claude/skills/add-diffusion-model/references/parallelism-patterns.md
+++ /dev/null
@@ -1,571 +0,0 @@
-# Parallelism Patterns Reference
-
-## Overview
-
-vLLM-Omni supports multiple parallelism strategies for diffusion models. Each targets a different bottleneck:
-
-| Strategy | Splits | Best For | Constraint |
-|----------|--------|----------|------------|
-| Tensor Parallel (TP) | Model layers across GPUs | Latency reduction, large models | Requires fast GPU interconnect, `num_heads % tp == 0` |
-| Sequence Parallel (SP/USP) | Sequence tokens across GPUs | Long sequences (video, high-res) | Near-linear scaling |
-| CFG Parallel | Positive/negative CFG branches | Models using classifier-free guidance | Exactly 2 GPUs |
-| HSDP | Weight shards via FSDP2 | VRAM reduction | Cannot combine with TP |
-| VAE Patch Parallel | VAE decode spatial tiles | Large VAE outputs | Auto-enables tiling |
-
-**Recommended integration order**: TP → SP → CFG Parallel → HSDP
-
-**Official design docs**:
-- TP: https://docs.vllm.ai/projects/vllm-omni/en/latest/design/feature/tensor_parallel
-- SP: https://docs.vllm.ai/projects/vllm-omni/en/latest/design/feature/sequence_parallel
-- CFG: https://docs.vllm.ai/projects/vllm-omni/en/latest/design/feature/cfg_parallel
-- HSDP: https://docs.vllm.ai/projects/vllm-omni/en/latest/design/feature/hsdp
-
----
-
-## Tensor Parallelism (TP)
-
-Replace standard `nn.Linear` with vLLM's parallel linear layers. This is the most invasive change but provides direct VRAM savings and compute speedup.
-
-### Layer replacement rules
-
-| Pattern | vLLM Layer | When to Use |
-|---------|-----------|-------------|
-| Fan-out (first in FFN) | `ColumnParallelLinear` | Projection that splits output across ranks |
-| Fan-in (second in FFN) | `RowParallelLinear` | Projection that gathers across ranks |
-| QKV projection | `QKVParallelLinear` | Fused Q/K/V for self-attention |
-| Single Q or K or V | `ColumnParallelLinear` | Separate projections (cross-attention) |
-| Attention output | `RowParallelLinear` | Output projection after attention |
-| Must not shard | `ReplicatedLinear` | Layers that must stay replicated |
-
-### MLP Block (Up-Down Pattern)
-
-```python
-from vllm.model_executor.layers.linear import (
- ColumnParallelLinear, RowParallelLinear,
-)
-
-class TPFeedForward(nn.Module):
- def __init__(self, dim, ffn_dim):
- super().__init__()
- self.fc1 = ColumnParallelLinear(dim, ffn_dim, bias=False, return_bias=False)
- self.fc2 = RowParallelLinear(
- ffn_dim, dim, bias=False,
- input_is_parallel=True, # Input already sharded from fc1
- return_bias=False,
- )
-
- def forward(self, x):
- x, _ = self.fc1(x)
- x = torch.nn.functional.gelu(x)
- x, _ = self.fc2(x)
- return x
-```
-
-### Attention Block (QKV-Out Pattern)
-
-```python
-from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
-from vllm_omni.diffusion.attention.layer import Attention
-
-class TPSelfAttention(nn.Module):
- def __init__(self, dim, num_heads, num_kv_heads=None):
- super().__init__()
- num_kv_heads = num_kv_heads or num_heads
- self.head_dim = dim // num_heads
-
- self.to_qkv = QKVParallelLinear(
- hidden_size=dim,
- head_size=self.head_dim,
- total_num_heads=num_heads,
- total_num_kv_heads=num_kv_heads,
- bias=False,
- return_bias=False,
- )
- self.to_out = RowParallelLinear(
- dim, dim, bias=False,
- input_is_parallel=True,
- return_bias=False,
- )
- self.attn = Attention(
- num_heads=self.to_qkv.num_heads, # Local heads per GPU
- head_size=self.head_dim,
- softmax_scale=1.0 / (self.head_dim ** 0.5),
- causal=False,
- num_kv_heads=self.to_qkv.num_kv_heads, # Local KV heads per GPU
- )
-
- def forward(self, x):
- qkv, _ = self.to_qkv(x)
- q, k, v = qkv.split(
- [self.to_qkv.num_heads * self.head_dim,
- self.to_qkv.num_kv_heads * self.head_dim,
- self.to_qkv.num_kv_heads * self.head_dim],
- dim=-1,
- )
- B, S, _ = x.shape
- q = q.view(B, S, self.to_qkv.num_heads, self.head_dim)
- k = k.view(B, S, self.to_qkv.num_kv_heads, self.head_dim)
- v = v.view(B, S, self.to_qkv.num_kv_heads, self.head_dim)
- out = self.attn(q, k, v)
- out = out.reshape(B, S, -1)
- out, _ = self.to_out(out)
- return out
-```
-
-### QKV Fusion in load_weights
-
-When you fuse separate Q/K/V into `QKVParallelLinear`, map diffusers' separate weight names:
-
-```python
-stacked_params_mapping = [
- ("to_qkv", "to_q", "q"),
- ("to_qkv", "to_k", "k"),
- ("to_qkv", "to_v", "v"),
-]
-
-def load_weights(self, weights):
- params = dict(self.named_parameters())
- loaded = set()
- for name, tensor in weights:
- for fused_name, orig_name, shard_id in stacked_params_mapping:
- if orig_name in name:
- name = name.replace(orig_name, fused_name)
- param = params[name]
- param.weight_loader(param, tensor, shard_id)
- loaded.add(name)
- break
- else:
- if name in params:
- param = params[name]
- if hasattr(param, "weight_loader"):
- param.weight_loader(param, tensor)
- else:
- default_weight_loader(param, tensor)
- loaded.add(name)
- return loaded
-```
-
-### RMSNorm with TP
-
-When RMSNorm sits between TP-sharded dimensions, use `DistributedRMSNorm` — it computes global RMS via all-reduce across TP ranks. See the Wan2.2 implementation for the pattern.
-
-### TP Constraints
-
-- `num_heads % tp_size == 0`
-- `num_kv_heads % tp_size == 0`
-- Use `self.to_qkv.num_heads` (local per-GPU count), not total heads, for split sizes
-
-### Testing TP
-
-```bash
-python text_to_image.py --model Your-org/your-model \
- --tensor-parallel-size 2 --output "tp_test.png"
-```
-
-**Verify**: speedup, memory reduction proportional to TP size, quality matches single-GPU.
-
-### Reference implementations
-
-| Model | Path |
-|-------|------|
-| Z-Image | `vllm_omni/diffusion/models/z_image/z_image_transformer.py` |
-| FLUX | `vllm_omni/diffusion/models/flux/flux_transformer.py` |
-| Qwen-Image | `vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py` |
-
----
-
-## Sequence Parallelism (SP / USP)
-
-SP splits sequence tokens across GPUs using Ulysses (all-to-all) or Ring (P2P) communication. It is applied non-intrusively via the `_sp_plan` dict — no changes to `forward()` logic.
-
-### Approach 1: Non-Intrusive `_sp_plan` (Recommended)
-
-The framework automatically registers hooks to shard inputs and gather outputs at `nn.Module` boundaries.
-
-#### Step 1: Identify module boundaries
-
-Find where tensors need sharding/gathering:
-
-```python
-class MyTransformer(nn.Module):
- def __init__(self):
- self.patch_embed = PatchEmbed() # Before blocks
- self.pos_embed = RoPE() # RoPE may need splitting
- self.blocks = nn.ModuleList([...]) # Blocks process sharded x
- self.norm_out = LayerNorm()
- self.proj_out = Linear() # Gather after this
-
- def forward(self, x):
- x = self.patch_embed(x)
- pos = self.pos_embed(x)
- for block in self.blocks:
- x = block(x, pos)
- x = self.norm_out(x)
- return self.proj_out(x)
-```
-
-#### Step 2: Handle inline operations
-
-`_sp_plan` hooks only work at `nn.Module` boundaries. Inline ops like `torch.cat()` must be extracted into submodules:
-
-```python
-# BAD: Inline — hooks can't intercept
-unified = torch.cat([x, cap_feats], dim=1)
-
-# GOOD: Extract into submodule
-class UnifiedPrepare(nn.Module):
- def forward(self, x, cap_feats):
- return torch.cat([x, cap_feats], dim=1)
-
-self.unified_prepare = UnifiedPrepare()
-unified = self.unified_prepare(x, cap_feats)
-```
-
-Common cases: `torch.cat()`, `pad_sequence()`, `tensor.reshape()`, complex preprocessing.
-
-#### Step 3: Write `_sp_plan`
-
-**Pattern 1: Shard at first block, gather at output** (most common)
-
-```python
-from vllm_omni.diffusion.distributed.sp_plan import (
- SequenceParallelInput, SequenceParallelOutput,
-)
-
-class StandardTransformer(nn.Module):
- _sp_plan = {
- "blocks.0": {
- "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
- },
- "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
- }
-```
-
-**Pattern 2: Shard RoPE outputs separately**
-
-```python
-class TransformerWithRoPE(nn.Module):
- _sp_plan = {
- "rope": {
- 0: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True),
- 1: SequenceParallelInput(split_dim=1, expected_dims=4, split_output=True),
- },
- "blocks.0": {
- "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
- },
- "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
- }
-```
-
-**Pattern 3: Dual-stream (shard image, replicate text)**
-
-```python
-class DualStreamTransformer(nn.Module):
- _sp_plan = {
- "rope_preparer": {
- 2: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True),
- 3: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True),
- },
- "transformer_blocks.0": {
- "hidden_states": SequenceParallelInput(split_dim=1, expected_dims=3),
- },
- "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
- }
-```
-
-### API Reference
-
-**SequenceParallelInput**:
-
-| Parameter | Type | Description |
-|-----------|------|-------------|
-| `split_dim` | int | Dimension to split (usually 1 for sequence) |
-| `expected_dims` | int/None | Expected tensor rank for validation |
-| `split_output` | bool | `False`: shard input params; `True`: shard output tensors |
-| `auto_pad` | bool | Auto-pad if sequence not divisible by world_size |
-
-**SequenceParallelOutput**:
-
-| Parameter | Type | Description |
-|-----------|------|-------------|
-| `gather_dim` | int | Dimension to gather (usually 1 for sequence) |
-| `expected_dims` | int/None | Expected tensor rank for validation |
-
-**Module naming**:
-
-| Key | Meaning |
-|-----|---------|
-| `"blocks.0"` | First element of ModuleList |
-| `"blocks.*"` | All elements of ModuleList |
-| `"rope"` | Named submodule |
-
-**Dictionary value types**:
-
-| Key type | split_output | Description |
-|----------|-------------|-------------|
-| `"param_name"` (str) | False | Shard input parameter by name |
-| `0, 1, ...` (int) | True | Shard output tuple by index |
-
-### Approach 2: Intrusive Modification (Complex Cases)
-
-For dynamic sharding logic that can't be expressed via `_sp_plan`:
-
-```python
-from vllm_omni.diffusion.distributed.sp_sharding import sp_shard, sp_gather
-
-def forward(self, hidden_states, ...):
- if self.parallel_config.sequence_parallel_size > 1:
- hidden_states = sp_shard(hidden_states, dim=1)
- for block in self.blocks:
- hidden_states = block(hidden_states)
- if self.parallel_config.sequence_parallel_size > 1:
- hidden_states = sp_gather(hidden_states, dim=1)
- return hidden_states
-```
-
-Use intrusive modification as a last resort — `_sp_plan` is preferred for maintainability.
-
-### UAA Mode (Experimental)
-
-`ulysses_mode="advanced_uaa"` handles arbitrary sequence lengths and head counts that aren't divisible by `ulysses_degree`. Uses variable all-to-all split sizes and temporary head padding.
-
-### Combining SP methods
-
-Ulysses and Ring can be combined: `ulysses_degree × ring_degree = total SP GPUs`.
-
-```python
-DiffusionParallelConfig(ulysses_degree=2, ring_degree=2) # 4 GPUs total
-```
-
-### Testing SP
-
-```bash
-# Offline
-python text_to_image.py --model Your-model --ulysses-degree 2
-
-# Online serving
-vllm serve Your-model --omni --usp 2
-```
-
-### Reference implementations
-
-| Model | Path |
-|-------|------|
-| Qwen-Image | `vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py` |
-| Wan2.2 | `vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py` |
-| Z-Image | `vllm_omni/diffusion/models/z_image/z_image_transformer.py` |
-
----
-
-## CFG Parallelism
-
-Distributes positive/negative Classifier-Free Guidance branches across 2 GPUs.
-
-### Implementation
-
-Inherit `CFGParallelMixin` and implement `diffuse()`:
-
-```python
-from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
-
-class YourPipeline(nn.Module, CFGParallelMixin):
- def diffuse(self, latents, timesteps, prompt_embeds, negative_embeds,
- do_true_cfg, true_cfg_scale, **kwargs):
- for i, t in enumerate(timesteps):
- positive_kwargs = {
- "hidden_states": latents,
- "encoder_hidden_states": prompt_embeds,
- "timestep": t,
- }
- negative_kwargs = {
- "hidden_states": latents,
- "encoder_hidden_states": negative_embeds,
- "timestep": t,
- } if do_true_cfg else None
-
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale=true_cfg_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- )
- latents = self.scheduler_step_maybe_with_cfg(
- noise_pred, t, latents, do_true_cfg
- )
- return latents
-```
-
-### Customization hooks
-
-| Method | Override when |
-|--------|-------------|
-| `predict_noise()` | Non-standard transformer call (e.g., dual-transformer like Wan2.2) |
-| `cfg_normalize_function()` | Custom normalization (e.g., LongCat with clamping) |
-| `combine_cfg_noise()` | Multi-output models (e.g., video + audio: CFG on video, positive-only on audio) |
-
-**Custom predict_noise** (Wan2.2 — selects active transformer):
-
-```python
-def predict_noise(self, current_model=None, **kwargs):
- if current_model is None:
- current_model = self.transformer
- return current_model(**kwargs)[0]
-```
-
-**Custom combine_cfg_noise** (multi-output):
-
-```python
-def combine_cfg_noise(self, positive_pred, negative_pred, scale, normalize):
- video_pos, audio_pos = positive_pred
- video_neg, audio_neg = negative_pred
- video_combined = super().combine_cfg_noise(video_pos, video_neg, scale, normalize)
- return (video_combined, audio_pos)
-```
-
-### Composite scheduler for multi-output
-
-When each output has its own schedule:
-
-```python
-class VideoAudioScheduler:
- def __init__(self, video_scheduler, audio_scheduler):
- self.video_scheduler = video_scheduler
- self.audio_scheduler = audio_scheduler
-
- def step(self, noise_pred, t, latents, return_dict=False, generator=None):
- video_out = self.video_scheduler.step(
- noise_pred[0], t[0], latents[0], return_dict=False, generator=generator
- )[0]
- audio_out = self.audio_scheduler.step(
- noise_pred[1], t[1], latents[1], return_dict=False, generator=generator
- )[0]
- return ((video_out, audio_out),)
-```
-
-### Testing CFG Parallel
-
-```bash
-python text_to_image.py --model Your-model \
- --cfg-parallel-size 2 --cfg-scale 4.0 \
- --negative-prompt "ugly, unclear"
-```
-
-**Constraint**: `guidance_scale > 1.0` and negative prompt must be provided.
-
-### Reference implementations
-
-| Model | Path |
-|-------|------|
-| Qwen-Image | `vllm_omni/diffusion/models/qwen_image/cfg_parallel.py` |
-| Wan2.2 | `vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py` |
-| Mixin base | `vllm_omni/diffusion/distributed/cfg_parallel.py` |
-
----
-
-## HSDP (Hybrid Sharded Data Parallel)
-
-Shards model weights across GPUs using PyTorch FSDP2. Reduces per-GPU VRAM without changing computation.
-
-### Implementation
-
-Add `_hsdp_shard_conditions` to the transformer class:
-
-```python
-class YourTransformer(nn.Module):
- @staticmethod
- def _is_transformer_block(name: str, module) -> bool:
- return "blocks" in name and name.split(".")[-1].isdigit()
-
- _hsdp_shard_conditions = [_is_transformer_block]
-```
-
-For MoE models, add additional conditions:
-
-```python
-class MoETransformer(nn.Module):
- @staticmethod
- def _is_transformer_block(name, module):
- return "blocks" in name and name.split(".")[-1].isdigit()
-
- @staticmethod
- def _is_moe_expert(name, module):
- return "experts" in name and name.split(".")[-1].isdigit()
-
- _hsdp_shard_conditions = [_is_transformer_block, _is_moe_expert]
-```
-
-A module is sharded if **any** condition returns `True`.
-
-### Constraints
-
-- Cannot combine with Tensor Parallelism
-- For standalone HSDP (no other parallelism), `hsdp_shard_size` must be specified explicitly
-- Can combine with SP: HSDP reduces memory while SP distributes sequence
-
-### Testing HSDP
-
-```python
-from vllm_omni.diffusion.data import DiffusionParallelConfig
-
-parallel_config = DiffusionParallelConfig(use_hsdp=True, hsdp_shard_size=8)
-omni = Omni(model="your-model", parallel_config=parallel_config)
-```
-
-Or CLI:
-
-```bash
-vllm serve Your-model --omni --use-hsdp
-```
-
-**Verify**: logs show "HSDP Inference: replicate_size=..., shard_size=..." and "Sharded N modules + root". Check VRAM reduction.
-
-### Reference implementations
-
-| Model | Path |
-|-------|------|
-| Wan2.2 | `vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py` |
-| HSDP Core | `vllm_omni/diffusion/distributed/hsdp.py` |
-
----
-
-## VAE Patch Parallelism
-
-Shards VAE decode spatially across ranks using tiling:
-
-```bash
-python text_to_image.py --model Your-model --vae-patch-parallel-size 4
-```
-
-Auto-enables `--vae-use-tiling`. Uses `DistributedAutoencoderKLWan` or similar distributed VAE. Set `vae_patch_parallel_size` in `DiffusionParallelConfig`.
-
----
-
-## Combining Parallelism Methods
-
-Common multi-GPU recipes:
-
-```bash
-# 4 GPUs: CFG (2) × Ulysses (2)
-python text_to_image.py --model Qwen/Qwen-Image \
- --cfg-parallel-size 2 --ulysses-degree 2
-
-# 8 GPUs: Ulysses (4) × Ring (2) + VAE patch (8)
-python text_to_video.py --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \
- --ulysses-degree 4 --ring-degree 2 --vae-patch-parallel-size 8
-
-# 2 GPUs: HSDP + Ulysses (cannot combine HSDP with TP)
-vllm serve Your-model --omni --use-hsdp --usp 2
-```
-
-## Discovering Parallelism Support
-
-Check which parallelism methods a model supports:
-
-| Check | How |
-|-------|-----|
-| **Ulysses / Ring SP** | Transformer defines `_sp_plan`. Search: `grep -r '_sp_plan' vllm_omni/diffusion/models/` |
-| **CFG Parallel** | Pipeline inherits `CFGParallelMixin`. Search: `grep -r 'CFGParallelMixin' vllm_omni/diffusion/models/` |
-| **TP** | Uses `ColumnParallelLinear` / `QKVParallelLinear`. Search: `grep -r 'ParallelLinear\|QKVParallel' vllm_omni/diffusion/models//` |
-| **HSDP** | Transformer defines `_hsdp_shard_conditions`. Search: `grep -r '_hsdp_shard_conditions' vllm_omni/diffusion/models/` |
-
-The canonical per-model support table is in `docs/user_guide/diffusion/parallelism_acceleration.md`.
diff --git a/.claude/skills/add-diffusion-model/references/transformer-adaptation.md b/.claude/skills/add-diffusion-model/references/transformer-adaptation.md
deleted file mode 100644
index 6e344b6a66e..00000000000
--- a/.claude/skills/add-diffusion-model/references/transformer-adaptation.md
+++ /dev/null
@@ -1,218 +0,0 @@
-# Transformer Adaptation Reference
-
-## Adapting a Diffusers Transformer to vLLM-Omni
-
-### Step-by-step Checklist
-
-1. Copy the transformer class from diffusers source
-2. Remove all mixin classes — inherit only from `nn.Module`
-3. Replace attention dispatch with `vllm_omni.diffusion.attention.layer.Attention`
-4. Replace logger with `vllm.logger.init_logger`
-5. Add `od_config: OmniDiffusionConfig | None = None` to `__init__`
-6. Remove training-only code (gradient checkpointing, dropout)
-7. Add `load_weights()` method for weight loading from safetensors
-8. Add class-level attributes for acceleration features
-
-### Mixin Removal
-
-Remove these diffusers mixins (and their imports):
-
-```python
-# Remove all of these:
-from diffusers.models.modeling_utils import ModelMixin
-from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.models.attention_processor import AttentionModuleMixin
-from diffusers.loaders import PeftAdapterMixin, FromOriginalModelMixin
-
-# Replace:
-class MyTransformer(ModelMixin, ConfigMixin, AttentionModuleMixin):
-# With:
-class MyTransformer(nn.Module):
-```
-
-Also remove `@register_to_config` decorators from `__init__`.
-
-### Attention Replacement
-
-The vLLM-Omni `Attention` layer wraps backend selection (FlashAttention, SDPA, SageAttn, etc.) and supports sequence parallelism hooks.
-
-**QKV tensor shape must be `[batch, seq_len, num_heads, head_dim]`.**
-
-#### Self-Attention Pattern
-
-```python
-from vllm_omni.diffusion.attention.layer import Attention
-from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
-
-class SelfAttentionBlock(nn.Module):
- def __init__(self, dim, num_heads):
- super().__init__()
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
-
- self.to_q = nn.Linear(dim, dim)
- self.to_k = nn.Linear(dim, dim)
- self.to_v = nn.Linear(dim, dim)
- self.to_out = nn.Linear(dim, dim)
-
- self.attn = Attention(
- num_heads=num_heads,
- head_size=self.head_dim,
- softmax_scale=1.0 / (self.head_dim ** 0.5),
- causal=False,
- num_kv_heads=num_heads,
- )
-
- def forward(self, x, attn_mask=None):
- B, S, _ = x.shape
- q = self.to_q(x).view(B, S, self.num_heads, self.head_dim)
- k = self.to_k(x).view(B, S, self.num_heads, self.head_dim)
- v = self.to_v(x).view(B, S, self.num_heads, self.head_dim)
-
- attn_metadata = AttentionMetadata(attn_mask=attn_mask)
- out = self.attn(q, k, v, attn_metadata=attn_metadata)
- out = out.reshape(B, S, -1)
- return self.to_out(out)
-```
-
-#### Fused QKV with TP (Advanced)
-
-For tensor parallelism, use vLLM's parallel linear layers:
-
-```python
-from vllm.model_executor.layers.linear import (
- QKVParallelLinear, RowParallelLinear
-)
-
-class TPSelfAttention(nn.Module):
- def __init__(self, dim, num_heads):
- super().__init__()
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
-
- self.to_qkv = QKVParallelLinear(
- hidden_size=dim,
- head_size=self.head_dim,
- total_num_heads=num_heads,
- total_num_kv_heads=num_heads,
- )
- self.to_out = RowParallelLinear(dim, dim)
-
- self.attn = Attention(
- num_heads=num_heads,
- head_size=self.head_dim,
- softmax_scale=1.0 / (self.head_dim ** 0.5),
- causal=False,
- num_kv_heads=num_heads,
- )
-```
-
-### Logger Replacement
-
-```python
-# Replace:
-from diffusers.utils import logging
-logger = logging.get_logger(__name__)
-
-# With:
-from vllm.logger import init_logger
-logger = init_logger(__name__)
-```
-
-### Custom Layers from vLLM-Omni
-
-Available utility layers:
-
-```python
-from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm_omni.diffusion.layers.rope import RotaryEmbedding
-from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNorm
-```
-
-### Config Support
-
-```python
-from vllm_omni.diffusion.data import OmniDiffusionConfig
-
-class MyTransformer(nn.Module):
- def __init__(self, *, od_config=None, num_layers=28, hidden_size=3072, **kwargs):
- super().__init__()
- self.od_config = od_config
- self.parallel_config = od_config.parallel_config if od_config else None
- # ... build layers
-```
-
-The transformer config values come from `model_index.json` → `config.json` in the transformer subfolder. The pipeline uses `get_transformer_config_kwargs(od_config.tf_model_config, TransformerClass)` to filter config keys to match the `__init__` signature.
-
-### Weight Loading
-
-The `load_weights` method receives an iterable of `(name, tensor)` from safetensors files, with the prefix (e.g., `"transformer."`) already stripped by the loader.
-
-```python
-from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-
-class MyTransformer(nn.Module):
- def load_weights(self, weights):
- params = dict(self.named_parameters())
- loaded = set()
- for name, tensor in weights:
- # Optional: remap names from diffusers to vllm-omni naming
- # e.g., "ff.net.0.proj" -> "ff.net_0.proj"
-
- if name in params:
- param = params[name]
- if hasattr(param, "weight_loader"):
- param.weight_loader(param, tensor)
- else:
- default_weight_loader(param, tensor)
- loaded.add(name)
- return loaded
-```
-
-#### QKV Fusion in load_weights
-
-If you fused separate Q/K/V into a `QKVParallelLinear`, you need to map diffusers' separate weight names:
-
-```python
-stacked_params_mapping = [
- ("to_qkv", "to_q", "q"),
- ("to_qkv", "to_k", "k"),
- ("to_qkv", "to_v", "v"),
-]
-
-def load_weights(self, weights):
- params = dict(self.named_parameters())
- loaded = set()
- for name, tensor in weights:
- for fused_name, orig_name, shard_id in stacked_params_mapping:
- if orig_name in name:
- name = name.replace(orig_name, fused_name)
- param = params[name]
- param.weight_loader(param, tensor, shard_id)
- loaded.add(name)
- break
- else:
- # Normal loading
- ...
- return loaded
-```
-
-### Class-Level Attributes for Features
-
-```python
-class MyTransformer(nn.Module):
- # torch.compile: list block class names that repeat and can be compiled
- _repeated_blocks = ["MyTransformerBlock"]
-
- # CPU offload: attribute name of the nn.ModuleList containing blocks
- _layerwise_offload_blocks_attr = "blocks"
-
- # LoRA: mapping of fused param names to original param names
- packed_modules_mapping = {"to_qkv": ["to_q", "to_k", "to_v"]}
-
- # Sequence parallelism plan (advanced — add after basic impl works)
- _sp_plan = {
- "blocks.0": SequenceParallelInput(split_dim=1),
- "proj_out": SequenceParallelOutput(gather_dim=1),
- }
-```
diff --git a/.claude/skills/add-diffusion-model/references/troubleshooting.md b/.claude/skills/add-diffusion-model/references/troubleshooting.md
deleted file mode 100644
index 27acdd8d154..00000000000
--- a/.claude/skills/add-diffusion-model/references/troubleshooting.md
+++ /dev/null
@@ -1,178 +0,0 @@
-# Troubleshooting Reference
-
-## Common Errors When Adding a Diffusion Model
-
-### ImportError / ModuleNotFoundError
-
-**Cause**: Missing or incorrect registration.
-
-**Fix checklist**:
-1. Model registered in `vllm_omni/diffusion/registry.py` `_DIFFUSION_MODELS` dict
-2. `__init__.py` exports the pipeline class
-3. Pipeline file exists at the correct path: `vllm_omni/diffusion/models/{folder}/{file}.py`
-4. Class name in registry matches the actual class name in the file
-
-### Shape Mismatch in Attention
-
-**Symptom**: `RuntimeError: shape mismatch` or `expected 4D tensor`
-
-**Cause**: QKV tensors not reshaped to `[batch, seq_len, num_heads, head_dim]`.
-
-**Fix**: Before calling `self.attn(q, k, v, ...)`, ensure:
-```python
-q = q.view(batch, seq_len, self.num_heads, self.head_dim)
-k = k.view(batch, kv_seq_len, self.num_kv_heads, self.head_dim)
-v = v.view(batch, kv_seq_len, self.num_kv_heads, self.head_dim)
-```
-
-After attention, reshape back:
-```python
-out = out.reshape(batch, seq_len, -1)
-```
-
-### Weight Loading Failures
-
-**Symptom**: `RuntimeError: size mismatch for parameter ...` or missing keys
-
-**Debugging**:
-1. Print diffusers weight names: `safetensors.safe_open(path, "pt").keys()`
-2. Print model parameter names: `dict(model.named_parameters()).keys()`
-3. Compare and add name remappings in `load_weights()`
-
-**Common remappings needed**:
-- `ff.net.0.proj` → `ff.net_0.proj` (PyTorch Sequential indexing)
-- `.to_out.0.` → `.to_out.` (Sequential unwrapping)
-- `scale_shift_table` → moved to a wrapper module
-
-### Black/Blank/Noisy Output
-
-**Possible causes**:
-1. **Wrong latent normalization**: Check VAE expects latents scaled by `vae.config.scaling_factor`
-2. **Wrong scheduler**: Using the wrong scheduler class or wrong `flow_shift`
-3. **Missing CFG**: Some models require `guidance_scale > 1.0` with negative prompt
-4. **Wrong timestep format**: Some schedulers expect float, others expect int/long
-5. **Missing post-processing**: Raw VAE output may need denormalization
-
-**Quick test**: Run with diffusers directly using the same seed and compare latents at each step.
-
-### OOM (Out of Memory)
-
-**Solutions** (in order of preference):
-1. `--enforce-eager` to disable torch.compile (saves compile memory)
-2. `--enable-cpu-offload` for model-level offload
-3. `--enable-layerwise-offload` for block-level offload (better for large models)
-4. `--vae-use-slicing --vae-use-tiling` for VAE memory reduction
-5. Reduce resolution: `--height 480 --width 832`
-6. Use TP: `--tensor-parallel-size 2`
-
-### Different Output vs Diffusers Reference
-
-**Common causes**:
-1. **Attention backend difference**: FlashAttention vs SDPA may produce slightly different results. Set `DIFFUSION_ATTENTION_BACKEND=TORCH_SDPA` to match diffusers
-2. **Float precision**: vLLM-Omni may use bfloat16 where diffusers uses float32 for some operations
-3. **Missing normalization**: Check all LayerNorm/RMSNorm are preserved
-4. **Scheduler rounding**: Some schedulers have numerical sensitivity
-
-### Tensor Parallel Errors
-
-**Symptom**: `AssertionError: not divisible` or incorrect output with TP>1
-
-**Fix**:
-1. Verify `num_heads % tp_size == 0` and `num_kv_heads % tp_size == 0`
-2. Ensure `ColumnParallelLinear` / `RowParallelLinear` are used correctly
-3. Check that norms between parallel layers use distributed norm if needed
-4. Verify `load_weights` handles TP sharding for norm weights
-5. Use `self.to_qkv.num_heads` (local heads per GPU) for QKV split sizes, not total heads
-
-**Missing `input_is_parallel=True`**:
-
-`RowParallelLinear` expects sharded input from `ColumnParallelLinear`:
-```python
-self.w1 = ColumnParallelLinear(dim, hidden_dim, return_bias=False)
-self.w2 = RowParallelLinear(hidden_dim, dim, input_is_parallel=True, return_bias=False)
-```
-
-### Sequence Parallel Errors
-
-**Symptom**: Incorrect output or crashes with `--ulysses-degree N` or `--usp N`
-
-**Possible causes**:
-1. **Inline operations between shard/gather points**: `torch.cat()`, `pad_sequence()` etc. not at `nn.Module` boundaries. Fix: extract into submodule.
-2. **Wrong `split_dim`**: Check the tensor shape at the shard point. Sequence dimension is typically `dim=1` for `[B, S, D]` tensors.
-3. **RoPE not sharded**: If RoPE is computed separately, add it to `_sp_plan` with `split_output=True`.
-4. **Sequence not divisible by SP degree**: Use `auto_pad=True` in `SequenceParallelInput` or switch to `ulysses_mode="advanced_uaa"`.
-
-**Debugging**: Add `expected_dims=N` to `SequenceParallelInput`/`Output` for shape validation at runtime.
-
-### CFG Parallel Errors
-
-**Symptom**: CFG parallel not activating, no speedup
-
-**Fix checklist**:
-1. Pipeline inherits `CFGParallelMixin`
-2. `guidance_scale > 1.0`
-3. Negative prompt provided (even if empty string)
-4. `--cfg-parallel-size 2` specified
-5. `diffuse()` method calls `predict_noise_maybe_with_cfg()` and `scheduler_step_maybe_with_cfg()`
-
-**Symptom**: Different output with CFG parallel vs sequential
-
-**Possible cause**: Non-deterministic scheduler. Fix: pass `generator=torch.Generator(device).manual_seed(seed)` to `scheduler_step_maybe_with_cfg()`.
-
-### HSDP Errors
-
-**Symptom**: HSDP not activating or errors during weight loading
-
-**Fix checklist**:
-1. Transformer defines `_hsdp_shard_conditions` class attribute
-2. Shard condition functions return `True` for correct modules (test with `model.named_modules()`)
-3. Not combining with TP (HSDP and TP are incompatible)
-4. For standalone HSDP, `hsdp_shard_size` is specified explicitly
-
-**Verify**: Check logs for "HSDP Inference: replicate_size=..., shard_size=..." and "Sharded N modules + root".
-
-### Cache-DiT Not Applied
-
-**Symptom**: No speedup, no cache-related log messages
-
-**Fix checklist**:
-1. Model not in `_NO_CACHE_ACCELERATION` in `registry.py`
-2. Pipeline class name matches `CUSTOM_DIT_ENABLERS` key (if using custom enabler)
-3. `cache_backend="cache_dit"` specified
-4. Check logs for "Cache-dit enabled successfully on xxx"
-
-**Verify pipeline name**: `print(pipeline.__class__.__name__)` — must match registry key.
-
-### Cache-DiT Quality Degradation
-
-**Symptom**: Artifacts or lower quality with cache-dit
-
-**Fix**: Reduce aggressiveness:
-```python
-cache_config={
- "residual_diff_threshold": 0.12, # Lower from 0.24
- "max_warmup_steps": 6, # Increase from 4
- "max_continuous_cached_steps": 2, # Reduce if higher
-}
-```
-
-If quality is still poor, the model may need a custom enabler with per-block-list `ParamsModifier` tuning.
-
-### Model Not Detected / Wrong Pipeline Class
-
-**Symptom**: `ValueError: Model class ... not found in diffusion model registry`
-
-**Cause**: The model's `model_index.json` has a `_class_name` for the pipeline that doesn't match registry keys.
-
-**Fix**: The registry key must match the diffusers pipeline class name from `model_index.json`. If using a different name, map it in the registry:
-```python
-"DiffusersPipelineClassName": ("your_folder", "your_file", "YourVllmClassName"),
-```
-
-## Debugging Workflow
-
-1. **Add verbose logging**: Use `logger.info()` to print tensor shapes at each stage
-2. **Compare step-by-step**: Run diffusers and vllm-omni side by side, comparing tensors after each major operation
-3. **Use small configs**: Reduce `num_inference_steps=2`, small resolution for fast iteration
-4. **Test transformer isolation**: Feed the same input to both diffusers and vllm-omni transformers, compare outputs
-5. **Binary search for bugs**: Comment out blocks/layers to isolate where divergence starts
diff --git a/.claude/skills/add-tts-model/SKILL.md b/.claude/skills/add-tts-model/SKILL.md
deleted file mode 100644
index 963ffb4f64d..00000000000
--- a/.claude/skills/add-tts-model/SKILL.md
+++ /dev/null
@@ -1,504 +0,0 @@
----
-name: add-tts-model
-description: "Integrate a new text-to-speech model into vLLM-Omni from HuggingFace reference implementation through production-ready serving with streaming and CUDA graph acceleration. Use when adding a new TTS model, wiring stage separation for speech synthesis, enabling online voice generation serving, debugging TTS integration behavior, or building audio output pipelines."
----
-
-# TTS Model Integration Workflow
-
-## Overview
-
-```
-HF Reference -> Stage Separation -> Online Serving -> Async Chunk -> CUDA Graph -> Pre-commit/DCO
- (Phase 1) (Phase 2) (Phase 3) (Phase 4) (Phase 5) (Phase 6)
-```
-
-Three architecture patterns are supported:
-
-- **Two-stage pipeline** (e.g. Qwen3-TTS, Fish Speech, CosyVoice3): AR
- code-predictor → audio decoder, connected via async_chunk for low-latency
- streaming. Use this for maximum performance.
-- **Single-stage AR via generator** (e.g. MOSS-TTS-Nano): entire model runs
- inside one AR worker, streaming audio chunks through a per-request
- `inference_stream()` generator. Use this when the upstream model bundles AR
- + codec inseparably. See [references/single-stage-ar.md](references/single-stage-ar.md).
-- **Single-stage, vLLM-native base LM + side computation** (e.g. VoxCPM2):
- the base language model runs under vLLM's PagedAttention as a normal AR
- model; diffusion / VAE / side computations run outside vLLM and are
- attached via the runner post-processing hook. This is a distinct pattern
- from the generator approach above — do not confuse the two.
-
-The single-stage variants skip Phase 4 (async_chunk) but Phase 5 (CUDA graph)
-is still encouraged for the inner AR loop.
-
-## Cross-Cutting Invariants
-
-These rules apply to every TTS model regardless of architecture (AR vs AR+diffusion, single-stage vs two-stage, codec-based vs VAE-based). They surface repeatedly across PRs — check them at the end of every phase.
-
-### I1. Streaming output contract
-
-Pick exactly one per-step semantics for `forward()` and document it in the docstring:
-
-- **Delta**: yield only new audio samples produced this step. Preferred — linear cost, low memory.
-- **Cumulative**: re-decode from step 0 every call. O(N²); only acceptable if the codec has no streaming decode path.
-
-If you choose **delta**, verify the full emit→consolidate→consume chain:
-
-1. `forward()` returns `{"model_outputs": , ...}`
-2. `_consolidate_multimodal_tensors()` in `vllm_omni/engine/output_processor.py` concatenates the audio key into one tensor at finish. If it skips the key (`continue`), offline consumers receive only the final chunk. See `output_processor.py` for the concrete list of handled modality keys.
-3. Streaming consumers (SSE, Gradio) receive per-step deltas; offline consumers (`engine.generate()`) receive a single concatenated tensor.
-
-Cumulative-vs-delta mismatch is the most common silent bug — offline RTF benchmarks pass, but users hear replays or truncation.
-
-### I2. Multimodal output consumer hygiene
-
-`outputs[0].outputs[0].multimodal_output[]` can be any of `Tensor`, `list[Tensor]` (pre-consolidation snapshot), `np.ndarray`, or scalar. When writing tests, examples, and benchmarks:
-
-- **Never** use `dict.get("a") or dict.get("b")` on tensor values — Python evaluates the tensor's boolean, raising `RuntimeError: Boolean value of Tensor with more than one value is ambiguous`. Use explicit `if x is None` chains.
-- Always defensively handle the list form: `if isinstance(x, list): x = torch.cat([t.reshape(-1) for t in x], dim=0)`.
-- Assert `shape` / `dtype` / `duration` explicitly; do not rely on truthiness for presence checks.
-
-### I3. Hot-loop GPU discipline
-
-Inside any per-step model loop (AR decode, diffusion solver, CFM Euler, vocoder block loop):
-
-- No `tensor.item()`, `.cpu()`, or `.tolist()` — each triggers a GPU→CPU sync; at 10 steps × 60 frames × 4 ops that is 2400 syncs per request.
-- Prefer `dst.copy_(src)` over `dst.fill_(src.item())` when writing a scalar tensor into a buffer.
-- Prefer `torch.compile(Model.forward, fullgraph=False)` on the whole forward over per-submodule compile — fewer dispatch boundaries, larger fusion regions. Measure before choosing granularity.
-- No Python-side control flow that depends on tensor values; use `torch.where` / masking instead.
-
-Profile first, optimize second. See the profiling docs / project memory for the trace-analysis workflow.
-
-### I4. Validation pyramid
-
-Offline RTF alone is necessary but not sufficient. Every new TTS model must pass all three:
-
-| Layer | Catches | Tool |
-|-------|---------|------|
-| Offline RTF / duration check | Throughput regressions, missing audio, wrong sample rate | `end2end.py`, pytest e2e |
-| Browser streaming playback | Delta/cumulative bugs, chunk boundary glitches, TTFP regressions | Gradio demo over `/v1/audio/speech?stream=true` |
-| Concurrent requests | Per-request state leaks, codec window round-robin gaps | `max_num_seqs>1` smoke test with 4+ parallel prompts |
-
-Declaring a model "done" without all three has shipped regressions more than once.
-
-### I5. Per-request state is owned by the request, not the model
-
-If the model caches *anything* across `forward()` calls (streaming generators, codec buffers, sliding-window pads, CUDA graph state), key it by request ID:
-
-```python
-self._state: dict[str, YourState] = {} # request_key → state
-# fetch: request_key = str(info.get("_omni_req_id", "0"))
-# free on finish: del self._state[request_key]
-```
-
-A shared buffer silently corrupts audio across concurrent requests — the symptom is crosstalk or truncation only under load.
-
-## Phase 1: HuggingFace Reference
-
-**Goal**: Understand the reference implementation and verify it produces correct audio.
-
-### Steps
-
-1. **Run the reference model** end-to-end using the official HuggingFace / GitHub code
-2. **Document the architecture**:
- - What are the sub-models? (AR decoder, codec decoder, vocoder, etc.)
- - What is the token vocabulary? (semantic codes, RVQ codebooks, special tokens)
- - What is the output format? (sample rate, channels, codec type)
-3. **Capture reference outputs** for comparison during integration
-4. **Identify the config structure**: `config.json` fields, `model_type`, sub-model configs
-
-### Key Questions
-
-- How many codebooks? What are the codebook sizes?
-- What special tokens exist? (`<|voice|>`, `<|audio_start|>`, `<|im_end|>`, etc.)
-- What is the token-to-ID mapping for codec codes?
-- What is the hop length / frame rate of the codec?
-- Does the model support voice cloning? How? (reference audio encoding, speaker embeddings, etc.)
-
-### Deliverables
-
-- Working reference script that produces audio
-- Architecture diagram / notes
-- Token vocabulary mapping
-- Reference audio samples for regression testing
-
-## Phase 2: Stage Separation (Offline Inference)
-
-**Goal**: Split the model into vLLM-Omni stages and get offline inference working.
-
-### Steps
-
-1. **Register the model** in `vllm_omni/model_executor/models/registry.py`
-2. **Create config classes** (`configuration_.py`) with `model_type` registration
-3. **Implement Stage 0** (AR model):
- - Subclass appropriate base (e.g., wrap Qwen3 decoder layers)
- - Implement `forward()` for autoregressive token generation
- - Handle special token logic (start/stop tokens, codec token mapping)
- - If dual-AR (like Fish Speech), implement Fast AR as a nested module
-4. **Implement Stage 1** (Decoder):
- - Load codec weights (may need lazy loading from separate checkpoint)
- - Implement `forward()`: codec codes -> audio waveform
- - Return `OmniOutput` with `multimodal_outputs`
-5. **Create stage config YAML** defining both stages, memory allocation, and model paths
-6. **Create stage input processor** for prompt building
-7. **Write end2end.py** test script
-
-### Critical Parameters to Get Right
-
-| Parameter | Impact if Wrong |
-|-----------|----------------|
-| Hop length | Audio duration wrong, streaming noise |
-| Token ID mapping | Garbage codes -> noise output |
-| Codebook count/size | Shape mismatch crashes |
-| Stop token | Generation never stops or stops too early |
-| dtype / autocast | Numerical issues, silent quality degradation |
-| Repetition penalty | Must match reference (often 1.0 for TTS) |
-
-### Debugging Priority (from experience)
-
-When audio output is wrong, check in this order:
-
-1. **RoPE / attention**: Are position encodings correct? Is the attention mask right?
-2. **Normalization**: RMSNorm epsilon, layer norm placement (pre vs post)
-3. **Hop length**: Product of all upsample rates in the codec decoder
-4. **Token mapping**: Are codec IDs correctly offset from the vocabulary base?
-5. **Sampling parameters**: Temperature, top_k, top_p, repetition_penalty
-6. **Tensor layout**: Codebook-major vs frame-major ordering
-7. **dtype**: Float32 for codec decoders (autocast can corrupt audio)
-
-### Streaming Correctness Rules (single-stage and two-stage)
-
-These bugs appear in almost every new TTS PR. Check all before the first push. See also the cross-cutting invariants I1 (output contract) and I5 (per-request state) above — the rules below are the Phase 2-specific instances of those invariants:
-
-- **Accumulate codes across AR steps** — each `forward()` appends new codes; do not reset between steps or audio will be truncated (fish speech: `fix: accumulate audio_codes across steps`)
-- **Emit delta audio, not full waveform** — in streaming mode yield only the new chunk per step, not the re-decoded full waveform from step 0 (fish speech: `fix: emit delta audio not full waveform`)
-- **All return paths must emit `model_outputs`** — if any early-return branch skips setting `model_outputs`, the serving layer silently drops that step's audio (fish speech: `fix: ensure ALL return paths emit model_outputs`)
-- **Per-request state isolation** — for batched concurrent requests, key all state by request ID; a shared buffer corrupts audio across requests (fish speech: `fix: per-request vocode + delta emission`)
-- **Codec tensor device** — move codec codes to the codec decoder's device before calling decode; mismatches cause silent CPU fallback or crashes (fish speech: `fix: use model device for CUDA stream`)
-- **AR stage `max_num_seqs`** — set to at least 4 in production configs; for single-stage models this is the only stage. For two-stage models, Stage 0 (AR) needs `max_num_seqs ≥ 4` to pipeline concurrent requests; Stage 1 (codec decoder) typically uses `max_num_seqs: 1` intentionally. Default of 1 everywhere causes audio gaps under concurrency because the codec window round-robins across requests (RFC #2568)
-
-### Optional Dependency Handling
-
-Patch optional dependencies (`torchaudio` / `torchcodec` / `soundfile`) at
-the top of `load_weights()`, not at module import. Failures to do so cause
-cryptic errors only on environments missing the optional package — after
-the model is already deployed. See
-[references/optional-deps.md](references/optional-deps.md) for the full
-pattern, signature constraints, and MOSS-TTS-Nano reference.
-
-### Single-Stage AR Pattern (alternative to two-stage)
-
-When the upstream model cannot be cleanly split into an AR stage and a
-separate decoder, run the full pipeline inside a single AR worker and
-stream audio through a per-request `inference_stream()` generator keyed by
-`_omni_req_id`. Stage config must set `worker_type: ar`,
-`engine_output_type: audio`, `final_output: true`, `is_comprehension: true`,
-and `async_chunk: false` at the top level. Only extract params from
-`additional_information` that you actually forward, or pre-commit fails
-`ruff F841`.
-
-Full walkthrough with the complete `forward()` / `_create_stream_gen()`
-skeleton and stage-config fields:
-[references/single-stage-ar.md](references/single-stage-ar.md). For an
-in-tree reference, look for any single-stage AR model under
-`vllm_omni/model_executor/models/` — e.g. the MOSS-TTS-Nano integration when
-it lands.
-
-**VoxCPM2 is a different pattern** and should not reuse this skeleton — it
-runs the base LM under vLLM PagedAttention with external side-computation.
-See `plan/voxcpm2_native_ar_design.md`.
-
-### Deliverables
-
-- Model files in `vllm_omni/model_executor/models//`
-- Stage config YAML
-- Working `end2end.py` with correct audio output
-- README.md in the example directory
-
-## Phase 3: Online Serving
-
-**Goal**: Expose the model via `/v1/audio/speech` API endpoint.
-
-### Steps
-
-1. **Register in `serving_speech.py`** — add all 5 points in a **single commit**;
- partial integration causes hard-to-debug failures. This file is modified by every
- model PR and is the most common source of rebase conflicts — see conflict note below.
-
- **Point 1** — stage constant (near the top, alongside the other `_*_TTS_MODEL_STAGES` sets):
- ```python
- _YOUR_MODEL_TTS_MODEL_STAGES = {"your_stage_key"}
- ```
-
- **Point 2** — union into `_TTS_MODEL_STAGES`:
- ```python
- _TTS_MODEL_STAGES: set[str] = (
- ...
- | _YOUR_MODEL_TTS_MODEL_STAGES
- )
- ```
-
- **Point 3** — model type detection in `_detect_tts_model_type()`:
- ```python
- if model_stage in _YOUR_MODEL_TTS_MODEL_STAGES:
- return "your_model"
- ```
-
- **Point 4** — validation dispatch in `_validate_tts_request()`:
- ```python
- if self._tts_model_type == "your_model":
- return self._validate_your_model_request(request)
- ```
-
- **Point 5** — validation + parameter-builder methods:
- ```python
- def _validate_your_model_request(self, request) -> str | None:
- if not request.input or not request.input.strip():
- return "Input text cannot be empty"
- return None
-
- def _build_your_model_params(self, request) -> dict:
- params = {"text": [request.input]}
- if request.voice is not None:
- params["voice"] = [request.voice]
- return params
- ```
- Wire `_build_your_model_params` into `_create_tts_request()` alongside the other
- model-specific param builders.
-
- > **Two dispatch patterns coexist**: Fish Speech uses a `self._is_fish_speech` boolean
- > instance attribute checked before `elif self._is_tts`, while all newer models
- > (CosyVoice3, MOSS-TTS-Nano) use the `_tts_model_type` string returned by
- > `_detect_tts_model_type()`. For new models, always use the `_tts_model_type` string
- > pattern — do not add new `_is_*` flags.
-
- > **Unused variable rule**: only extract fields in `_build_your_model_params` that
- > are actually forwarded to the model. Unused extractions fail `ruff F841`.
- > For voice-cloning fields (`ref_audio` → `prompt_audio_path`, `ref_text` →
- > `prompt_text`), add them to the param builder and verify they reach the model call.
-
- **Rebase conflict note**: when rebasing onto `main` after another model was merged,
- `serving_speech.py` will conflict. Resolution: always keep *both* the upstream
- model's additions and your own — never discard either side.
-
-2. **Handle model-specific parameters**:
- - Voice cloning: `ref_audio` encoding and prompt injection
- - `max_new_tokens` override in sampling params
- - Model-specific default values
-3. **Create client scripts**: `speech_client.py`, `run_server.sh`
-4. **Test all response formats**: wav, mp3, flac, pcm
-5. **Add Gradio demo**: Interactive web UI with streaming support
-
-### Voice Cloning Pattern
-
-```python
-import base64
-from pathlib import Path
-
-def build_voice_clone_prompt(ref_audio_path: str, text: str, codec) -> list:
- """Build prompt with reference audio for voice cloning in serving_speech.py."""
- audio_bytes = Path(ref_audio_path).read_bytes()
- codes = codec.encode(audio_bytes) # Encode on CPU using model's codec (e.g., DAC)
- token_ids = [code + codec.vocab_offset for code in codes.flatten().tolist()]
- return [
- {"role": "system", "content": f"<|voice|>{''.join(chr(t) for t in token_ids)}"},
- {"role": "user", "content": text},
- ]
-```
-
-### Deliverables
-
-- Updated `serving_speech.py` with all 5 integration points (single commit)
-- Client scripts and server launcher
-- Gradio demo with streaming and voice cloning UI
-- E2E online serving test (`tests/e2e/online_serving/test_.py`)
-- Buildkite CI entry in `.buildkite/test-merge.yml`
-- Documentation (offline + online serving docs)
-
-### E2E test pitfalls to avoid
-
-- **One `OmniServerParams` set per file.** `omni_server` is module-scoped; a second
- id in the same file forces mid-module teardown/restart and exposes startup
- races (`APIConnectionError` on the first request post-restart). Split variants
- into separate files instead.
-- **No external URL fetches from the server.** CI and some dev hosts can't
- reach `raw.githubusercontent.com` over TLS. Inline ref audio as
- `data:audio/wav;base64,...`; the serving layer accepts both URL and data URL.
-- **Use the harness readiness gate.** The fixture waits for HTTP 200 on
- `/health`; don't add `time.sleep` in tests. If warmup is incomplete, make
- `/health` return non-200 until you're actually ready.
-- **Mark with `@pytest.mark.core_model` + `hardware_test(res={"cuda": "H100"})`**
- so the test lands in `test-ready.yml` (triggered by the `ready` label) rather
- than only nightly.
-
-## Phase 4: Async Chunk (Streaming)
-
-**Goal**: Enable inter-stage streaming so audio chunks are produced while AR generation continues.
-
-### Steps
-
-1. **Update stage config YAML**:
- ```yaml
- async_chunk: true
- codec_chunk_frames: 25 # frames per chunk
- codec_left_context_frames: 25 # overlap for smooth boundaries
- ```
-2. **Implement chunk handling in Stage 1**:
- - Accept partial input (chunk of codec codes)
- - Handle left context for smooth audio boundaries
- - Return partial audio in `OmniOutput`
-3. **Test streaming**:
- - Verify audio quality matches non-streaming output
- - Check for artifacts at chunk boundaries
- - Measure TTFA (time to first audio)
-4. **Update online serving** to support `stream=true` with PCM output
-
-### Streaming Architecture
-
-```
-Stage 0 (AR) Stage 1 (Decoder)
- | |
- |-- chunk 0 (25 frames) ------> decode -> audio chunk 0 -> client
- |-- chunk 1 (25 frames) ------> decode -> audio chunk 1 -> client
- |-- chunk 2 (25 frames) ------> decode -> audio chunk 2 -> client
- ...
-```
-
-### Key Considerations
-
-- **Left context overlap**: Prevents audible artifacts at chunk boundaries
-- **Hop length matters**: `context_audio_samples = context_frames * hop_length`
-- **First chunk latency**: Can use larger initial chunk for better quality, then smaller chunks
-
-### Deliverables
-
-- Updated stage config with async_chunk enabled
-- Smooth streaming audio without boundary artifacts
-- TTFA metrics
-
-## Phase 5: CUDA Graph Acceleration
-
-**Goal**: Capture the AR loop as a CUDA graph for significant speedup.
-
-### Steps
-
-1. **Identify the hot loop**: The AR decoding loop that runs N steps per token
-2. **Create static buffers**:
- - KV caches with fixed max sequence length
- - Pre-built causal masks and position tensors per step
- - Static input/output tensors
-3. **Implement graph capture**:
- - Warm up with real data
- - Capture the forward pass
- - Replay with updated inputs
-4. **Handle constraints**:
- - Use `torch.argmax` instead of `torch.multinomial` (graph-safe)
- - Fixed batch size (fall back to eager for other sizes)
- - No dynamic control flow inside the graph
-
-See [references/cuda-graph-example.md](references/cuda-graph-example.md) for
-a worked skeleton (Qwen3-TTS code predictor, 16-step AR loop), performance
-expectations (3–5× on the graphed component for fixed batch_size=1), and the
-graph-safety constraints you must honor inside the captured region.
-
-### Deliverables
-
-- CUDA graph implementation for the AR hot loop
-- Benchmark script comparing eager vs graph performance
-- Documentation of constraints and fallback behavior
-
-## Phase 6: Pre-commit and DCO
-
-**Goal**: Every commit passes `pre-commit` lint and carries a DCO
-`Signed-off-by` line that matches the author email.
-
-- Install hooks once: `pre-commit install`.
-- Run `pre-commit run --files ` before every push; accept any
- auto-fixes, stage, re-commit.
-- Sign every commit with `git commit -s`. DCO checks that author email and
- `Signed-off-by` email match — `git config user.email` must match your
- GitHub account email.
-
-Common pre-commit failures, recovery commands for missing sign-off, and the
-full `pre-commit run` invocation for a TTS model:
-[references/precommit-dco.md](references/precommit-dco.md).
-
-## Integration Checklist
-
-Use this checklist when integrating a new TTS model:
-
-### Cross-Cutting Invariants (verify at end of every phase)
-- [ ] I1: `forward()` docstring states cumulative vs delta; consolidation path audited end-to-end
-- [ ] I2: Tests / examples / benchmarks never use `dict.get(a) or dict.get(b)` on tensor values; list form handled
-- [ ] I3: No `.item()` / `.cpu()` / Python branch on tensor values inside per-step loops
-- [ ] I4: Offline RTF, browser streaming playback, and concurrent-request smoke test all pass
-- [ ] I5: Any cross-step cache keyed by `_omni_req_id`; entries freed when the request finishes
-
-### Phase 1: HF Reference
-- [ ] Reference model runs and produces correct audio
-- [ ] Architecture documented (stages, codebooks, tokens, sample rate)
-- [ ] Reference audio samples saved for comparison
-
-### Phase 2: Stage Separation
-- [ ] Model registered in `registry.py`
-- [ ] Config classes created with `model_type` registration
-- [ ] Stage 0 (AR) implemented and generates correct tokens
-- [ ] Stage 1 (Decoder) produces correct audio from tokens — dtype float32 for codec decoder
-- [ ] Stage 1 `max_num_seqs` ≥ 4 in production config (default 1 causes gaps under concurrency)
-- [ ] Optional dependency fallbacks handled at `load_weights()` time (torchaudio/soundfile/etc.)
-- [ ] Streaming: codec codes accumulated across AR steps (not reset per step)
-- [ ] Streaming: delta audio emitted per chunk, not full re-decoded waveform
-- [ ] Streaming: all `forward()` return paths emit `model_outputs`
-- [ ] Streaming: per-request state keyed by request ID (not shared across requests)
-- [ ] Streaming: codec tensors moved to codec decoder device before decode
-- [ ] Stage config YAML created
-- [ ] `end2end.py` produces audio matching reference quality
-- [ ] README.md written
-
-### Phase 3: Online Serving
-- [ ] All 5 `serving_speech.py` integration points added in one commit
-- [ ] Only extract params in `_build_*_params` that are forwarded to the model call (ruff F841)
-- [ ] Prompt builder handles text input correctly
-- [ ] Voice cloning works (if supported)
-- [ ] All response formats work (wav, mp3, flac, pcm)
-- [ ] Client scripts and server launcher created
-- [ ] E2E online serving test written (`tests/e2e/online_serving/test_.py`)
-- [ ] Buildkite CI entry added to `.buildkite/test-merge.yml`
-- [ ] Gradio demo working
-- [ ] Documentation added (offline + online docs, nav, supported models)
-
-### Phase 4: Async Chunk
-- [ ] Stage config updated with `async_chunk: true`
-- [ ] Stage 1 handles partial chunks correctly
-- [ ] No audio artifacts at chunk boundaries
-- [ ] Streaming via API (`stream=true`) works
-- [ ] TTFA measured and acceptable
-
-### Phase 5: CUDA Graph
-- [ ] Hot loop identified and profiled
-- [ ] Static buffers allocated
-- [ ] Graph captured and replays correctly
-- [ ] Benchmark shows meaningful speedup
-- [ ] Fallback to eager works for unsupported configs
-
-### Phase 6: Pre-commit and DCO
-- [ ] `pre-commit run --files ` passes before every push
-- [ ] Every commit has `Signed-off-by` matching the author email (`git commit -s`)
-- [ ] `git config user.email` matches the email registered on your GitHub account
-- [ ] Details and failure-recovery commands: [references/precommit-dco.md](references/precommit-dco.md)
-
-## References
-
-In-skill references (details split out of the main body):
-
-- [references/single-stage-ar.md](references/single-stage-ar.md) — full `forward()` / generator skeleton for the MOSS-TTS-Nano-style pattern
-- [references/optional-deps.md](references/optional-deps.md) — torchaudio / torchcodec fallback pattern
-- [references/cuda-graph-example.md](references/cuda-graph-example.md) — Qwen3-TTS code-predictor CUDA graph skeleton
-- [references/precommit-dco.md](references/precommit-dco.md) — full pre-commit invocation, failure table, DCO recovery
-
-Project docs and adjacent skills:
-
-- [TTS audio skill](../vllm-omni-audio-tts/SKILL.md) — supported models and usage
-- [Fish Speech integration](../vllm-omni-audio-tts/references/fish-speech.md) — complete example of Phases 1–3
-- [Qwen3-TTS reference](../vllm-omni-audio-tts/references/qwen-tts.md) — complete example of all 5 phases
-- [Adding a TTS model (developer guide)](https://github.com/vllm-project/vllm-omni/blob/main/docs/contributing/model/adding_tts_model.md)
-- `plan/voxcpm2_native_ar_design.md` — VoxCPM2's vLLM-native AR + side-computation pattern (distinct from the generator-based single-stage described above)
diff --git a/.claude/skills/add-tts-model/references/cuda-graph-example.md b/.claude/skills/add-tts-model/references/cuda-graph-example.md
deleted file mode 100644
index 6f4993b5c4c..00000000000
--- a/.claude/skills/add-tts-model/references/cuda-graph-example.md
+++ /dev/null
@@ -1,42 +0,0 @@
-# CUDA Graph Example: Qwen3-TTS Code Predictor
-
-Reference sketch for capturing the 16-step code-predictor AR loop as a single
-CUDA graph. Adapt the shapes, number of steps, and KV-head layout to your
-model.
-
-```python
-import torch
-
-class CodePredictorGraph:
- """Captures the 16-step code predictor AR loop as a single CUDA graph."""
-
- def setup_graph(self, device: torch.device, kv_heads: int = 4, head_dim: int = 64):
- self.num_steps = 16
- self.kv_cache = torch.zeros(1, kv_heads, self.num_steps, head_dim, device=device)
- self.positions = torch.arange(self.num_steps, device=device)
- self.causal_mask = torch.tril(torch.ones(self.num_steps, self.num_steps, device=device))
- self.input_buf = torch.zeros(1, 1, kv_heads * head_dim, device=device)
- self.output_buf = torch.zeros(1, self.num_steps, device=device, dtype=torch.long)
- # Warm up, then: self.graph = torch.cuda.CUDAGraph(); self.graph.capture(...)
-
- def run_graph(self, initial_input: torch.Tensor) -> torch.Tensor:
- self.input_buf.copy_(initial_input)
- self.graph.replay()
- return self.output_buf.clone()
-```
-
-## Performance expectations (Qwen3-TTS code predictor)
-
-- **3–5× speedup** on the graphed component.
-- Effective only for fixed batch sizes (typically `batch_size=1`).
-- Fall back to eager for any shape/config that wasn't captured — do not try
- to recapture per request.
-
-## Graph-safety constraints
-
-- `torch.argmax` instead of `torch.multinomial`.
-- Fixed batch size.
-- No Python control flow that branches on tensor values inside the captured
- region (use `torch.where` / masks).
-- No `.item()`, `.cpu()`, `.tolist()` — each would break the capture or
- cause a GPU→CPU sync during replay.
diff --git a/.claude/skills/add-tts-model/references/optional-deps.md b/.claude/skills/add-tts-model/references/optional-deps.md
deleted file mode 100644
index 0a55f30f05c..00000000000
--- a/.claude/skills/add-tts-model/references/optional-deps.md
+++ /dev/null
@@ -1,47 +0,0 @@
-# Optional Dependency Handling
-
-Models that rely on `torchaudio`, `torchcodec`, `soundfile`, or other optional
-packages must handle the missing-package case at import time, not at call
-time. Failing to do this causes cryptic errors only on environments without
-the optional package — after the model is already deployed.
-
-## Pattern (used in MOSS-TTS-Nano)
-
-```python
-def _patch_torchaudio_load() -> None:
- """Fallback torchaudio.load/save to soundfile if torchcodec is unavailable."""
- try:
- import torchcodec # noqa: F401
- return # torchcodec present, torchaudio works as-is
- except ImportError:
- pass
-
- import soundfile as sf
- import torchaudio
-
- def _sf_load(path, **kwargs):
- data, sr = sf.read(str(path), dtype="float32", always_2d=True)
- return torch.from_numpy(data).T, sr
-
- torchaudio.load = _sf_load
- # patch .save similarly if needed
-```
-
-## Rules
-
-- Mirror the full signature of the replaced function. `torchaudio.load`
- accepts `frame_offset`, `num_frames`, `normalize`, `channels_first`,
- `format` — missing any of them causes `TypeError` from calling code.
-- Catch `except Exception`, not just `ImportError`. `import torchaudio`
- itself can fail with non-`ImportError` errors on broken installs. Log the
- exception type and message (`logger.warning("torchaudio probe failed: %s: %s",
- type(exc).__name__, exc)`) before falling back, so unrelated errors are not
- silently swallowed.
-- Call the patch function at the top of `load_weights()` before loading any
- audio assets. Do not call it at module import time.
-
-## Reference implementation
-
-Any in-tree model that patches `torchaudio.load` in its `load_weights()` —
-e.g. MOSS-TTS-Nano's `modeling_moss_tts_nano.py` once that integration
-lands.
diff --git a/.claude/skills/add-tts-model/references/precommit-dco.md b/.claude/skills/add-tts-model/references/precommit-dco.md
deleted file mode 100644
index 86a1f42cefb..00000000000
--- a/.claude/skills/add-tts-model/references/precommit-dco.md
+++ /dev/null
@@ -1,54 +0,0 @@
-# Pre-commit and DCO
-
-Every commit must pass `pre-commit` lint and carry a `Signed-off-by` line
-that matches the commit author email.
-
-## Pre-commit
-
-Install hooks once:
-
-```bash
-pre-commit install
-```
-
-Run before every push on the files you changed:
-
-```bash
-pre-commit run --files \
- vllm_omni/model_executor/models//*.py \
- vllm_omni/entrypoints/openai/serving_speech.py \
- vllm_omni/model_executor/models/registry.py \
- tests/e2e/offline_inference/test_.py \
- tests/e2e/online_serving/test_.py
-```
-
-When pre-commit **modifies files** (ruff format auto-fix), it exits non-zero
-but the changes are correct — stage the modified files and re-commit.
-
-| Failure | Root cause | Fix |
-|---------|-----------|-----|
-| `ruff F841` | Variable extracted but never forwarded to model call | Remove the extraction or wire it through |
-| `ruff E402` | Import added below function definitions | Move to top-level import block |
-| `ruff format` | Line length, spacing, quote style | Accept auto-fix, stage, re-commit |
-
-## DCO sign-off
-
-Every commit must carry `Signed-off-by: Your Name `. Use
-`-s`:
-
-```bash
-git commit -s -m "feat(): add TTS support"
-```
-
-Or set it permanently: `git config format.signOff true`.
-
-The DCO check verifies that the commit author email matches the
-`Signed-off-by` line. Confirm `git config user.email` matches your GitHub
-account email before committing.
-
-Fix a missing or mismatched sign-off on the latest commit:
-
-```bash
-git commit --amend -s --no-edit
-git push origin --force-with-lease
-```
diff --git a/.claude/skills/add-tts-model/references/single-stage-ar.md b/.claude/skills/add-tts-model/references/single-stage-ar.md
deleted file mode 100644
index ed53d30261c..00000000000
--- a/.claude/skills/add-tts-model/references/single-stage-ar.md
+++ /dev/null
@@ -1,108 +0,0 @@
-# Single-Stage AR Pattern
-
-When the upstream model cannot be cleanly split into an AR stage and a separate
-decoder (e.g. MOSS-TTS-Nano, or any model that bundles AR + codec via an
-`inference_stream()` generator), run the whole pipeline inside a single AR
-worker that yields audio chunks per request.
-
-This is distinct from VoxCPM2's pattern, which also runs in a single stage but
-uses vLLM's native PagedAttention on the base language model with diffusion /
-VAE side-computation outside vLLM — see
-`plan/voxcpm2_native_ar_design.md` for that variant.
-
-## Implementation
-
-1. **Single model file** — load both AR LM and codec inside
- `modeling_.py`.
-2. **Load weights in `load_weights()`**, not `__init__()` — vLLM initializes
- distributed state before any CUDA allocations.
-3. **Stream via a per-request generator** stored in `self._stream_gens`:
-
-```python
-class YourModelForCausalLM(nn.Module):
- def __init__(self, *, vllm_config, prefix=""):
- super().__init__()
- self._lm = None # populated in load_weights()
- self._stream_gens: dict = {} # request_key → generator
-
- def load_weights(self, weights):
- # Load self._lm here, after vLLM distributed init
- ...
-
- def forward(
- self,
- input_ids,
- positions,
- intermediate_tensors=None,
- inputs_embeds=None,
- runtime_additional_information: list[dict] | None = None, # one dict per request
- **kwargs,
- ) -> OmniOutput:
- infos = runtime_additional_information or [{}]
- # Skip dummy/profiling calls
- if not runtime_additional_information or all(i.get("_is_dummy") for i in infos):
- self._ar_emit_stop_token = True
- return OmniOutput(...) # return empty outputs
-
- outputs, last_flags = [], []
- for info in infos:
- request_key = str(info.get("_omni_req_id", "0")) # per-request ID from vLLM
- if request_key not in self._stream_gens:
- self._stream_gens[request_key] = self._create_stream_gen(info)
- try:
- chunk, is_last = next(self._stream_gens[request_key])
- except StopIteration:
- chunk, is_last = torch.zeros(0), True
- if is_last:
- del self._stream_gens[request_key]
- outputs.append(chunk)
- last_flags.append(is_last)
-
- self._ar_emit_stop_token = all(last_flags)
- return OmniOutput(multimodal_outputs={"model_outputs": outputs, ...})
-
- def _create_stream_gen(self, info: dict):
- """Yield (waveform_tensor, is_last) tuples from inference_stream()."""
- for event in self._lm.inference_stream(...):
- if event["type"] == "audio":
- yield event["waveform"], False
- elif event["type"] == "result":
- # Fallback: some models emit a single "result" event instead of
- # incremental "audio" events — handle both paths
- yield event.get("waveform", torch.zeros(0)), True
- return
- yield torch.zeros(0), True
-
- def compute_logits(self, hidden_states, sampling_metadata):
- # Emit EOS only after the last chunk so the AR scheduler ends the request
- ...
-```
-
-## Key points
-
-- `runtime_additional_information` is the correct parameter name (not
- `**kwargs`) — it carries one dict per request in the batch.
-- The request ID is `info.get("_omni_req_id")` — set by vLLM, not by user code.
-- Handle both `"audio"` (incremental) and `"result"` (final combined) event
- types from upstream models.
-
-## Stage config
-
-Single stage with `worker_type: ar`, `engine_output_type: audio`,
-`final_output: true`, `is_comprehension: true`, and `async_chunk: false` at
-the top level. Omitting any of these causes silent misclassification in the
-serving layer.
-
-## Lint discipline
-
-Only extract variables from `additional_information` that you actually
-forward to the model call — unused extractions trip `ruff F841` in
-pre-commit.
-
-## Reference implementation
-
-Look for any single-stage AR model under
-`vllm_omni/model_executor/models/` — e.g. `moss_tts_nano/` when its
-integration lands. If none is in tree yet, follow the skeleton above and
-cross-check against the `is_comprehension: true` / `async_chunk: false`
-dispatch in `vllm_omni/entrypoints/openai/serving_speech.py`.
diff --git a/.claude/skills/readme.md b/.claude/skills/readme.md
deleted file mode 100644
index b66f2ecd131..00000000000
--- a/.claude/skills/readme.md
+++ /dev/null
@@ -1,34 +0,0 @@
-# Claude Skills for vLLM-Omni
-
-This directory contains Claude Code skills maintained for the `vllm-omni`
-repository. These skills capture repeatable workflows for common contributor
-tasks such as model integration, pull request review, and release note
-generation.
-
-## Directory Structure
-
-Each skill lives in its own directory under `.claude/skills/`. A skill may
-include:
-
-- `SKILL.md`: the main workflow and operating instructions
-- `references/`: focused reference material used by the skill
-- `scripts/`: small helper scripts used by the skill
-
-## Available Skills
-
-- `add-diffusion-model`: guides integration of a new diffusion model into
- `vllm-omni`
-- `add-omni-model`: covers addition of new omni-modality model support
-- `add-tts-model`: covers integration of new TTS models and related serving
- workflows
-- `generate-release-note`: helps prepare release notes for repository changes
-- `review-pr`: provides a structured workflow for reviewing pull requests
-
-## Maintenance Guidelines
-
-- Keep skill names short and task-oriented.
-- Prefer repository-local paths, commands, and examples.
-- Avoid hardcoding fast-changing support matrices unless the skill is actively
- maintained alongside those changes.
-- Treat skills as contributor tooling: optimize for clarity, actionability, and
- low maintenance overhead.
diff --git a/.claude/skills/vllm-omni-npu-upgrade/SKILL.md b/.claude/skills/vllm-omni-npu-upgrade/SKILL.md
deleted file mode 100644
index 1ef7ab39301..00000000000
--- a/.claude/skills/vllm-omni-npu-upgrade/SKILL.md
+++ /dev/null
@@ -1,300 +0,0 @@
----
-name: vllm-omni-npu-model-runner-upgrade
-description: "Upgrade vllm-omni NPU model runners (OmniNPUModelRunner, NPUARModelRunner, NPUGenerationModelRunner) to align with the latest vllm-ascend NPUModelRunner while preserving omni-specific logic."
----
-
-# vLLM-Omni NPU Model Runner Upgrade Skill
-
-## Overview
-
-This skill guides the process of upgrading vllm-omni's NPU model runners to align with the latest vllm-ascend codebase while preserving omni-specific enhancements. The NPU runners are designed to run omni multimodal models (like Qwen3-Omni, Bagel, MiMoAudio) on Ascend NPUs.
-
-## File Structure
-
-### NPU Model Runner Files
-```
-vllm-omni/vllm_omni/platforms/npu/worker/
-├── __init__.py
-├── npu_model_runner.py # OmniNPUModelRunner (base class)
-├── npu_ar_model_runner.py # NPUARModelRunner (autoregressive)
-├── npu_ar_worker.py # AR worker
-├── npu_generation_model_runner.py # NPUGenerationModelRunner (diffusion/non-AR)
-└── npu_generation_worker.py # Generation worker
-```
-
-### GPU Reference Files (for omni-specific logic sync)
-```
-vllm-omni/vllm_omni/worker/
-├── __init__.py
-├── gpu_model_runner.py # OmniGPUModelRunner
-├── gpu_ar_model_runner.py # GPUARModelRunner
-├── gpu_ar_worker.py
-├── gpu_generation_model_runner.py
-├── gpu_generation_worker.py
-├── mixins.py
-├── base.py
-└── gpu_memory_utils.py
-```
-
-### vllm-ascend Reference Files
-```
-vllm-ascend/vllm_ascend/worker/
-├── model_runner_v1.py # NPUModelRunner (base class to copy from)
-├── npu_input_batch.py
-├── block_table.py
-├── pcp_utils.py
-└── worker.py
-```
-
-## Inheritance Hierarchy
-
-```
- GPUModelRunner (vllm)
- |
- +----------------+----------------+
- | |
- OmniGPUModelRunner NPUModelRunner (vllm-ascend)
- (vllm_omni/worker) (vllm_ascend/worker)
- | |
- +----------- OmniNPUModelRunner --+
- (multiple inheritance)
- |
- +---------------+---------------+
- | |
- NPUARModelRunner NPUGenerationModelRunner
- (autoregressive) (non-autoregressive/diffusion)
-```
-
-## Omni-Specific Comment Markers
-
-Omni-specific logic is marked with comment blocks:
-```python
-# -------------------------------------- Omni-new -------------------------------------------------
-# ... omni-specific code ...
-# -------------------------------------- Omni-new -------------------------------------------------
-```
-
-Or simpler variations:
-```python
-# -------------------------------------- Omni-new -------------------------------------------------
-# ------------------------------------------------------------------------------------------------
-```
-
-**Important**:
-- Always preserve and add these markers when modifying code.
-- **The reference documents (`references/omni-specific-blocks.md`) may not be up-to-date.** Always grep for `Omni-new` in the GPU implementations to find the authoritative list of omni-specific blocks.
-- When you discover new omni-specific code that is not documented in the references, please update the reference files.
-
-## Key Methods Requiring Attention
-
-### OmniNPUModelRunner (npu_model_runner.py)
-
-| Method | Description | Omni-Specific Logic |
-|--------|-------------|---------------------|
-| `load_model` | Load model and initialize talker_mtp | Uses `ACLGraphWrapper` instead of `CUDAGraphWrapper`, initializes talker buffers |
-| `_dummy_run` | Warmup/profiling run | talker_mtp dummy forward, `extract_multimodal_outputs` |
-| `_model_forward` | Forward pass wrapper | Injects `model_kwargs_extra`, wraps with `OmniOutput`, NPU-specific graph updates |
-| `_talker_mtp_forward` | Talker MTP forward for Qwen3-Omni | Uses `set_ascend_forward_context` |
-
-### NPUARModelRunner (npu_ar_model_runner.py)
-
-| Method | Description | Omni-Specific Logic |
-|--------|-------------|---------------------|
-| `__init__` | Initialize with KV transfer manager | `OmniKVTransferManager` setup |
-| `execute_model` | Main inference entry | KV transfer handling, `_update_states` override, `extract_multimodal_outputs` |
-| `sample_tokens` | Token sampling | Hidden states extraction, multimodal outputs processing, `OmniModelRunnerOutput` |
-| `_resolve_global_request_id` | Request ID resolution | For disaggregated inference |
-
-### NPUGenerationModelRunner (npu_generation_model_runner.py)
-
-| Method | Description | Omni-Specific Logic |
-|--------|-------------|---------------------|
-| `_update_request_states` | Update request states for async chunk | async_chunk handling |
-| `execute_model` | Generation forward | async_chunk, `seq_token_counts`, `_run_generation_model` |
-| `sample_tokens` | Output processing | multimodal output packaging to `OmniModelRunnerOutput` |
-| `_dummy_run` | Dummy run override | model_kwargs initialization, multimodal extraction |
-| `_run_generation_model` | Run generation model | Calls `_model_forward` with sampler |
-
-## Upgrade Workflow
-
-### Step 1: Preparation
-
-1. **Identify target versions**(Use gh cli to check):
- - We're using vllm-omni main branch
- - Check the last release of vllm-omni
- - Target vllm-ascend version(Just directly use the local latest vllm-ascend code)
-
-2. **Check GPU-side changes** (since last release):
- ```bash
- cd /root/vllm-workspace/vllm-omni
- git log --oneline --since="" -- vllm_omni/worker/
- ```
-
-3. **Read latest vllm-ascend code**:
- - We don't track vllm-ascend changes - just directly use the latest code from `/root/vllm-workspace/vllm-ascend/vllm_ascend/worker/model_runner_v1.py`
- - Copy the relevant methods and re-insert omni-specific blocks
-
-### Step 2: Analyze Omni-Specific Logic
-
-For each NPU model runner file:
-
-1. **Extract existing omni-specific blocks**:
- ```bash
- grep -n "Omni-new" vllm_omni/platforms/npu/worker/npu_model_runner.py
- ```
-
-2. **Document each omni block**:
- - Which method it belongs to
- - What functionality it provides
- - Dependencies on other omni code
-
-### Step 3: Update Base Class (OmniNPUModelRunner)
-
-**Note**: Always check the GPU implementation `gpu_model_runner.py` for any new omni logic not yet documented in references.
-
-1. **Read the latest vllm-ascend `NPUModelRunner.load_model`**
-2. **Copy the method, keeping the structure**
-3. **Re-insert omni-specific logic** (check GPU `gpu_model_runner.py` for authoritative list):
- - Replace `CUDAGraphWrapper` with `ACLGraphWrapper`
- - Keep talker_mtp initialization
- - Preserve buffer allocations for talker
- - Check for any new omni blocks added since last sync
-
-4. **Update `_dummy_run`**:
- - Copy from vllm-ascend
- - Compare with GPU `_dummy_run` for omni-specific blocks
- - Re-insert all `Omni-new` marked code from GPU version
-
-5. **Update `_model_forward`**:
- - Keep the omni wrapper logic
- - Update NPU-specific parts (graph params, SP all-gather)
- - Check GPU version for any new omni logic
-
-### Step 4: Update AR Model Runner
-
-1. **Compare with GPU `gpu_ar_model_runner.py`** for any new omni features
-2. **Copy `execute_model` from vllm-ascend**
-3. **Re-insert omni blocks** (reference `references/omni-specific-blocks.md`, but note it may be incomplete):
- - **IMPORTANT**: Always check the GPU implementation `gpu_ar_model_runner.py` for all `Omni-new` marked code blocks
- - The reference doc may not include newly added omni logic - treat it as a starting point, not exhaustive
- - When discovering new omni code blocks, please update `references/omni-specific-blocks.md`
- - Common omni blocks include but are not limited to: KV transfer, multimodal outputs, sampling_metadata handling, etc.
-
-4. **Update `sample_tokens`** (also compare with GPU implementation):
- - Compare with `gpu_ar_model_runner.py`'s `sample_tokens` method
- - Identify all `Omni-new` marked code blocks
- - Ensure NPU version includes all omni-specific logic
-
-### Step 5: Update Generation Model Runner
-
-**Note**: Generation model runner may have unique omni logic for diffusion/non-AR models.
-
-1. **Compare with GPU `gpu_generation_model_runner.py`** - grep for all `Omni-new` blocks
-2. **Update `execute_model`**:
- - Check GPU version for all omni-specific blocks
- - Keep async_chunk handling
- - Keep `seq_token_counts` injection
- - Update forward/context setup from vllm-ascend
- - Look for any new omni logic not documented in references
-
-3. **Update `_dummy_run`**:
- - Copy from vllm-ascend base
- - Compare with GPU `_dummy_run` if exists
- - Re-insert all omni-specific logic
-
-### Step 6: Update Imports
-
-Check and update imports at the top of each file:
-
-```python
-# Common vllm-ascend imports
-from vllm_ascend.ascend_forward_context import get_forward_context, set_ascend_forward_context
-from vllm_ascend.attention.attention_v1 import AscendAttentionState
-from vllm_ascend.attention.utils import using_paged_attention
-from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params
-from vllm_ascend.ops.rotary_embedding import update_cos_sin
-from vllm_ascend.utils import enable_sp, lmhead_tp_enable
-from vllm_ascend.worker.model_runner_v1 import SEQ_LEN_WITH_MAX_PA_WORKSPACE, NPUModelRunner
-
-# Omni-specific imports
-from vllm_omni.model_executor.models.output_templates import OmniOutput
-from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner
-from vllm_omni.outputs import OmniModelRunnerOutput
-from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager
-```
-
-### Step 7: Sync GPU-Side Omni Changes
-
-1. **Check recent GPU worker changes**:
- ```bash
- git diff .. -- vllm_omni/worker/gpu_model_runner.py
- git diff .. -- vllm_omni/worker/gpu_ar_model_runner.py
- ```
-
-2. **Identify new omni features** that need to be ported to NPU
-
-3. **Apply corresponding changes** to NPU runners
-
-### Step 8: Validation
-
-1. **Run type checking**:
- ```bash
- cd /root/vllm-workspace/vllm-omni
- python -m py_compile vllm_omni/platforms/npu/worker/npu_model_runner.py
- python -m py_compile vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
- python -m py_compile vllm_omni/platforms/npu/worker/npu_generation_model_runner.py
- ```
-
-2. **Run import test**:
- ```bash
- python -c "from vllm_omni.platforms.npu.worker import *"
- ```
-
-3. **Run model serving test** (if hardware available):
- ```bash
- vllm serve --trust-remote-code
- ```
-
-## Common Pitfalls
-
-### 1. Forward Context Differences
-- GPU uses `set_forward_context`
-- NPU uses `set_ascend_forward_context`
-- Parameters may differ slightly
-
-### 2. Graph Wrapper Differences
-- GPU: `CUDAGraphWrapper`
-- NPU: `ACLGraphWrapper`
-- Constructor parameters may differ
-
-### 3. Buffer Creation
-- GPU: `_make_buffer` returns different structure
-- NPU: May need numpy=True/False parameter
-
-### 4. Attention Metadata
-- GPU: Uses vllm attention metadata builders
-- NPU: Uses `AscendCommonAttentionMetadata`
-
-### 5. Sampling
-- GPU: Uses vllm sampler
-- NPU: Uses `AscendSampler`
-
-## Checklist Before Commit
-
-- [ ] All omni-specific comment markers preserved
-- [ ] New omni logic from GPU side synced
-- [ ] Imports updated to latest vllm-ascend
-- [ ] No `CUDAGraphWrapper` references in NPU code
-- [ ] `set_ascend_forward_context` used instead of `set_forward_context`
-- [ ] `ACLGraphWrapper` used for talker_mtp wrapping
-- [ ] Type hints match vllm-ascend signatures
-- [ ] No duplicate code blocks
-- [ ] Python syntax valid (py_compile passes)
-
-## Reference Files for Comparison
-
-When upgrading, keep these files open for reference:
-
-1. **vllm-ascend NPUModelRunner**: `/root/vllm-workspace/vllm-ascend/vllm_ascend/worker/model_runner_v1.py`
-2. **vllm GPUModelRunner**: `/root/vllm-workspace/vllm/vllm/v1/worker/gpu_model_runner.py`
-3. **vllm-omni OmniGPUModelRunner**: `/root/vllm-workspace/vllm-omni/vllm_omni/worker/gpu_model_runner.py`
diff --git a/.claude/skills/vllm-omni-npu-upgrade/references/gpu-to-npu-translation.md b/.claude/skills/vllm-omni-npu-upgrade/references/gpu-to-npu-translation.md
deleted file mode 100644
index 89067d37b2d..00000000000
--- a/.claude/skills/vllm-omni-npu-upgrade/references/gpu-to-npu-translation.md
+++ /dev/null
@@ -1,335 +0,0 @@
-# GPU to NPU Translation Patterns
-
-This document provides a quick reference for translating GPU code patterns to NPU equivalents when porting omni-specific logic.
-
-## Import Translations
-
-### Forward Context
-```python
-# GPU
-from vllm.forward_context import set_forward_context
-
-# NPU
-from vllm_ascend.ascend_forward_context import set_ascend_forward_context
-```
-
-### Graph Wrapper
-```python
-# GPU
-from vllm.compilation.cuda_graph import CUDAGraphWrapper
-
-# NPU
-from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
-```
-
-### Attention State
-```python
-# GPU (no equivalent - uses FlashAttention states directly)
-
-# NPU
-from vllm_ascend.attention.attention_v1 import AscendAttentionState
-```
-
-### Utilities
-```python
-# GPU
-# (directly use torch.cuda functions)
-
-# NPU
-from vllm_ascend.utils import enable_sp, lmhead_tp_enable
-from vllm_ascend.ops.rotary_embedding import update_cos_sin
-```
-
-## Context Manager Translations
-
-### Forward Context Setup
-```python
-# GPU
-with set_forward_context(
- attn_metadata,
- self.vllm_config,
- num_tokens=num_tokens_padded,
- num_tokens_across_dp=num_tokens_across_dp,
- cudagraph_runtime_mode=cudagraph_mode,
- batch_descriptor=batch_desc,
-):
- # forward pass
-
-# NPU
-with set_ascend_forward_context(
- attn_metadata,
- self.vllm_config,
- num_tokens=num_tokens_padded,
- num_tokens_across_dp=num_tokens_across_dp,
- aclgraph_runtime_mode=cudagraph_mode, # Note: 'aclgraph' not 'cudagraph'
- batch_descriptor=batch_desc,
- num_actual_tokens=scheduler_output.total_num_scheduled_tokens,
- model_instance=self.model,
-):
- # forward pass
-```
-
-### Graph Capture Context
-```python
-# GPU
-from vllm.compilation.cuda_graph import graph_capture as cuda_graph_capture
-with cuda_graph_capture(self.device):
- # capture
-
-# NPU
-from vllm_ascend.worker.model_runner_v1 import graph_capture
-with graph_capture(self.device):
- # capture
-```
-
-## Graph Wrapper Usage
-
-### Creating Graph Wrapper
-```python
-# GPU
-if cudagraph_mode.has_full_cudagraphs() and has_separate_talker:
- self.talker_mtp = CUDAGraphWrapper(
- talker_mtp,
- self.vllm_config,
- runtime_mode=CUDAGraphMode.FULL
- )
-
-# NPU
-if cudagraph_mode.has_full_cudagraphs() and has_separate_talker:
- self.talker_mtp = ACLGraphWrapper(
- talker_mtp,
- self.vllm_config,
- runtime_mode=CUDAGraphMode.FULL
- )
-```
-
-### Checking Graph Wrapper Type
-```python
-# GPU
-if not isinstance(self.talker_mtp, CUDAGraphWrapper):
- _cudagraph_mode = CUDAGraphMode.NONE
-
-# NPU
-if not isinstance(self.talker_mtp, ACLGraphWrapper):
- _cudagraph_mode = CUDAGraphMode.NONE
-```
-
-## Device Operations
-
-### Synchronization
-```python
-# GPU
-torch.cuda.synchronize()
-
-# NPU
-torch.npu.synchronize()
-```
-
-### Stream Operations
-```python
-# GPU
-stream = torch.cuda.Stream(device=device)
-torch.cuda.current_stream()
-
-# NPU
-stream = torch.npu.Stream(device=device)
-torch.npu.current_stream()
-```
-
-## Attention Metadata
-
-### State Setting (NPU-specific)
-```python
-# GPU - handled internally by attention backends
-
-# NPU - explicit state setting required
-self.attn_state = AscendAttentionState.DecodeOnly
-if self.speculative_config and self.speculative_config.method == "mtp":
- if self.vllm_config.model_config.use_mla:
- self.attn_state = AscendAttentionState.SpecDecoding
- else:
- self.attn_state = AscendAttentionState.ChunkedPrefill
-```
-
-### Building Attention Metadata
-```python
-# GPU - uses vllm attention builders
-
-# NPU - may need additional parameters
-(attn_metadata, spec_decode_common_attn_metadata) = self._build_attention_metadata(
- num_tokens=num_tokens_unpadded,
- num_tokens_padded=num_tokens_padded,
- num_reqs=num_reqs,
- num_reqs_padded=num_reqs_padded,
- max_query_len=max_num_scheduled_tokens,
- ubatch_slices=ubatch_slices_attn,
- logits_indices=logits_indices,
- use_spec_decode=use_spec_decode,
- num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
- num_scheduled_tokens_np=num_scheduled_tokens_np,
- cascade_attn_prefix_lens=cascade_attn_prefix_lens,
-)
-```
-
-## Rotary Embedding
-
-### Update Cos/Sin Cache
-```python
-# GPU - typically handled inside attention
-
-# NPU - explicit update required before forward
-from vllm_ascend.ops.rotary_embedding import update_cos_sin
-update_cos_sin(positions)
-```
-
-## Sequence Parallelism
-
-### Enable SP Check
-```python
-# GPU - use vllm distributed utilities
-
-# NPU - use vllm-ascend wrapper
-from vllm_ascend.utils import enable_sp
-
-if enable_sp():
- # sequence parallelism enabled
-```
-
-## Sampler
-
-### Sampler Type
-```python
-# GPU - uses vllm sampler
-self.sampler = Sampler()
-
-# NPU - uses AscendSampler
-from vllm_ascend.sample.sampler import AscendSampler
-self.sampler = AscendSampler()
-```
-
-## Input Batch
-
-### Batch Class
-```python
-# GPU
-from vllm.v1.worker.gpu_input_batch import InputBatch
-
-# NPU
-from vllm_ascend.worker.npu_input_batch import NPUInputBatch
-```
-
-## Graph Parameter Updates
-
-### Full Graph Params Update (NPU-specific)
-```python
-# GPU - not needed
-
-# NPU - required for FULL graph mode
-from vllm_ascend.compilation.acl_graph import update_full_graph_params
-
-forward_context = get_forward_context()
-if (
- forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL
- and not forward_context.capturing
- and not self.use_sparse
-):
- update_full_graph_params(
- self.attn_backend,
- self.update_stream,
- forward_context,
- num_tokens_padded,
- self.vllm_config,
- self.speculative_config,
- positions.shape[0],
- )
-```
-
-## Paged Attention Check
-
-```python
-# GPU - not typically needed
-
-# NPU
-from vllm_ascend.attention.utils import using_paged_attention
-
-if is_graph_capturing and using_paged_attention(num_tokens, self.vllm_config):
- seq_lens = SEQ_LEN_WITH_MAX_PA_WORKSPACE
-```
-
-## Common Method Signature Differences
-
-### _dummy_run Parameters
-```python
-# GPU (v0.17.0)
-def _dummy_run(
- self,
- num_tokens: int,
- cudagraph_runtime_mode: CUDAGraphMode | None = None,
- force_attention: bool = False,
- uniform_decode: bool = False,
- allow_microbatching: bool = True,
- skip_eplb: bool = False,
- is_profile: bool = False,
- create_mixed_batch: bool = False,
- remove_lora: bool = True,
- is_graph_capturing: bool = False,
- num_active_loras: int = 0,
-) -> tuple[torch.Tensor, torch.Tensor]:
-
-# NPU (v0.17.0) - adds with_prefill, activate_lora
-def _dummy_run(
- self,
- num_tokens: int,
- with_prefill: bool = False,
- cudagraph_runtime_mode: CUDAGraphMode | None = None,
- force_attention: bool = False,
- uniform_decode: bool = False,
- is_profile: bool = False,
- create_mixed_batch: bool = False,
- allow_microbatching: bool = True,
- skip_eplb: bool = False,
- remove_lora: bool = True,
- activate_lora: bool = False,
- is_graph_capturing: bool = False,
- num_active_loras: int = 0,
-) -> tuple[torch.Tensor, torch.Tensor]:
-```
-
-### _model_forward Parameters
-```python
-# GPU - no num_tokens_padded
-def _model_forward(
- self,
- input_ids: torch.Tensor | None = None,
- positions: torch.Tensor | None = None,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **model_kwargs: dict[str, Any],
-):
-
-# NPU - has num_tokens_padded as first parameter
-def _model_forward(
- self,
- num_tokens_padded: int,
- input_ids: torch.Tensor | None = None,
- positions: torch.Tensor | None = None,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **model_kwargs: dict[str, Any],
-):
-```
-
-## Quick Reference Table
-
-| Feature | GPU | NPU |
-|---------|-----|-----|
-| Graph wrapper | `CUDAGraphWrapper` | `ACLGraphWrapper` |
-| Forward context | `set_forward_context` | `set_ascend_forward_context` |
-| Runtime mode param | `cudagraph_runtime_mode` | `aclgraph_runtime_mode` |
-| Device sync | `torch.cuda.synchronize()` | `torch.npu.synchronize()` |
-| Stream | `torch.cuda.Stream` | `torch.npu.Stream` |
-| Current stream | `torch.cuda.current_stream()` | `torch.npu.current_stream()` |
-| Input batch | `InputBatch` | `NPUInputBatch` |
-| Sampler | `Sampler` | `AscendSampler` |
-| Attention state | N/A | `AscendAttentionState` |
-| RoPE update | N/A | `update_cos_sin()` |
diff --git a/.claude/skills/vllm-omni-npu-upgrade/references/omni-specific-blocks.md b/.claude/skills/vllm-omni-npu-upgrade/references/omni-specific-blocks.md
deleted file mode 100644
index 8c5d32ab4c1..00000000000
--- a/.claude/skills/vllm-omni-npu-upgrade/references/omni-specific-blocks.md
+++ /dev/null
@@ -1,374 +0,0 @@
-# Omni-Specific Code Blocks Reference
-
-This document catalogs omni-specific code blocks in the NPU model runners, making it easier to identify what needs to be preserved during upgrades.
-
-> **IMPORTANT**: This document may not be complete or up-to-date!
->
-> - Always grep for `Omni-new` in the GPU implementations (`vllm_omni/worker/`) to find the authoritative list
-> - New omni features may be added that are not yet documented here
-> - When you discover new omni-specific blocks during an upgrade, please update this document
-> - Last verified: Check git history for this file
-
-## OmniNPUModelRunner (npu_model_runner.py)
-
-### load_model - Talker MTP Initialization
-
-```python
-def load_model(self, *args, **kwargs) -> None:
- NPUModelRunner.load_model(self, *args, **kwargs)
- # Initialize enable_sp cache to avoid get_current_vllm_config() error
- # in _pad_for_sequence_parallelism during execute_model.
- # This is a workaround for vllm-ascend not passing vllm_config to enable_sp().
- enable_sp(self.vllm_config)
- # TODO move this model specific logic to a separate class
- # TTS model IS the talker (no .talker sub-attr); use getattr to support both Omni and TTS.
- talker_mtp = getattr(self.model, "talker_mtp", None)
- if talker_mtp is not None:
- self.talker_mtp = talker_mtp # type: ignore[assignment]
- cudagraph_mode = self.compilation_config.cudagraph_mode
- assert cudagraph_mode is not None
- # Only wrap talker_mtp in CUDAGraphWrapper for Omni models that
- # have a separate .talker sub-module. TTS models' code predictor
- # has internal AR loops / torch.multinomial — not graph-safe.
- has_separate_talker = getattr(self.model, "talker", None) is not None
- if cudagraph_mode.has_full_cudagraphs() and has_separate_talker:
- # NOTE: Use ACLGraphWrapper on NPU, not CUDAGraphWrapper
- self.talker_mtp = ACLGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL)
- # TTS exposes mtp_hidden_size; Omni uses hf_text_config.hidden_size.
- hidden_size = int(
- getattr(self.model, "mtp_hidden_size", 0) or getattr(self.model_config.hf_text_config, "hidden_size")
- )
- max_batch_size = max(self.max_num_reqs, self.compilation_config.max_cudagraph_capture_size)
- self.talker_mtp_input_ids = self._make_buffer(max_batch_size, dtype=torch.int32)
- self.talker_mtp_inputs_embeds = self._make_buffer(
- max_batch_size, hidden_size, dtype=self.dtype, numpy=False
- )
- self.last_talker_hidden = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False)
- self.text_step = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False)
-```
-
-### _dummy_run - Talker MTP Dummy Forward
-
-Location: Inside `set_ascend_forward_context` block, before main model forward
-
-```python
-# ---------------------------------------Omni-new----------------------------------------------
-if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"):
- num_tokens_padded_talker_mtp = num_tokens_padded
- if num_tokens_padded_talker_mtp == self.max_num_tokens:
- num_tokens_padded_talker_mtp = self.talker_mtp_input_ids.gpu.shape[0]
- outputs = self.talker_mtp(
- self.talker_mtp_input_ids.gpu[:num_tokens_padded_talker_mtp],
- self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded_talker_mtp],
- self.last_talker_hidden.gpu[:num_tokens_padded_talker_mtp],
- self.text_step.gpu[:num_tokens_padded_talker_mtp],
- )
- self.compilation_config.cache_dir = None
-# ---------------------------------------Omni-new----------------------------------------------
-```
-
-### _dummy_run - Extract Multimodal Outputs
-
-Location: After model forward, before dummy_compute_logits
-
-```python
-# ---------------------------------------Omni-new----------------------------------------------
-hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states)
-# ---------------------------------------Omni-new----------------------------------------------
-```
-
-### _model_forward - Omni Output Wrapping
-
-```python
-def _model_forward(
- self,
- num_tokens_padded: int,
- input_ids: torch.Tensor | None = None,
- positions: torch.Tensor | None = None,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **model_kwargs: dict[str, Any],
-):
- """Override to combine NPUModelRunner's signature with OmniGPUModelRunner's logic."""
- # Omni-specific: build and inject extra model kwargs
- model_kwargs_extra = self._build_model_kwargs_extra()
-
- # Call the model forward (same as NPUModelRunner)
- assert self.model is not None
- model_output = self.model(
- input_ids=input_ids,
- positions=positions,
- intermediate_tensors=intermediate_tensors,
- inputs_embeds=inputs_embeds,
- **model_kwargs,
- **model_kwargs_extra,
- )
-
- # Omni-specific: wrap output if needed
- if not isinstance(model_output, OmniOutput) and hasattr(self.model, "make_omni_output"):
- model_output = self.model.make_omni_output(model_output, **model_kwargs_extra)
-
- # Omni-specific: cache model output for later sample_tokens
- self._omni_last_model_output = model_output
-
- # NPU-specific: update full graph params (keep from vllm-ascend)
- forward_context = get_forward_context()
- # ... NPU graph update logic ...
-
- # NPU-specific: all-gather for sequence parallelism (keep from vllm-ascend)
- if get_forward_context().sp_enabled and not isinstance(model_output, IntermediateTensors):
- model_output = self._all_gather_hidden_states_and_aux(model_output)
-
- return model_output
-```
-
----
-
-## NPUARModelRunner (npu_ar_model_runner.py)
-
-### __init__ - KV Transfer Manager
-
-```python
-def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
- # each model stage has their own hidden size
- self.hidden_size = self.model_config.hf_text_config.hidden_size
- self.inputs_embeds = self._make_buffer(self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False)
- # Initialize KV cache manager (preserve vllm_config fallback behavior)
- self.kv_transfer_manager = OmniKVTransferManager.from_vllm_config(self.vllm_config, self.model_config)
-```
-
-### execute_model - KV Transfer Before Update States
-
-Location: At the very beginning of execute_model
-
-```python
-# -------------------------------------- Omni-new -------------------------------------------------
-# [Omni] Handle KV transfer BEFORE updating states (which removes finished requests)
-self.kv_extracted_req_ids = self.kv_transfer_manager.handle_finished_requests_kv_transfer(
- finished_reqs=getattr(scheduler_output, "finished_requests_needing_kv_transfer", {}),
- kv_caches=self.kv_caches,
- block_size=self.cache_config.block_size,
- cache_dtype=str(self.cache_config.cache_dtype),
- request_id_resolver=self._resolve_global_request_id,
-)
-# -------------------------------------- Omni-new -------------------------------------------------
-```
-
-### execute_model - Custom _update_states Call
-
-Location: Inside synchronize_input_prep context
-
-```python
-# -------------------------------------- Omni-new -------------------------------------------------
-self._update_states(scheduler_output)
-# ------------------------------------------------------------------------------------------------
-```
-
-### execute_model - Extract Multimodal Outputs
-
-Location: In post process section, after hidden_states assignment
-
-```python
-# -------------------------------------- Omni-new -------------------------------------------------
-hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states)
-
-if multimodal_outputs is not None:
- keys_or_type = (
- list(multimodal_outputs.keys())
- if isinstance(multimodal_outputs, dict)
- else type(multimodal_outputs)
- )
- logger.debug(f"[AR] execute_model: multimodal_outputs keys = {keys_or_type}")
-else:
- logger.debug("[AR] execute_model: multimodal_outputs is None")
-# -------------------------------------- Omni-new -------------------------------------------------
-```
-
-### execute_model - Compute Logits with sampling_metadata
-
-Location: In both broadcast_pp_output True and False branches
-
-```python
-# -------------------------------------- Omni-new -------------------------------------------------
-# Try with sampling_metadata first; fall back to without for models that don't support it
-try:
- logits = self.model.compute_logits(
- sample_hidden_states, sampling_metadata=self.input_batch.sampling_metadata
- )
-except TypeError:
- logits = self.model.compute_logits(sample_hidden_states)
-# -------------------------------------- Omni-new -------------------------------------------------
-```
-
-### sample_tokens - KV Extracted Req IDs
-
-Location: At the beginning of sample_tokens
-
-```python
-# -------------------------------------- Omni-new -------------------------------------------------
-kv_extracted_req_ids = getattr(self, "kv_extracted_req_ids", None)
-self.kv_extracted_req_ids = None
-# -------------------------------------- Omni-new -------------------------------------------------
-```
-
-### sample_tokens - Process Additional Information and Build Output
-
-Location: After bookkeeping sync, replacing the original output construction
-
-```python
-# -------------------------------------- Omni-new -------------------------------------------------
-hidden_states_cpu = hidden_states.detach().to("cpu").contiguous()
-num_scheduled_tokens_np = getattr(self, "_omni_num_scheduled_tokens_np", None)
-if num_scheduled_tokens_np is None:
- req_ids = self.input_batch.req_ids
- num_scheduled_tokens_np = np.array(
- [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids],
- dtype=np.int32,
- )
-
-self._process_additional_information_updates(
- hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output
-)
-
-pooler_output: list[dict[str, object]] = []
-for rid in req_ids_output_copy:
- idx = req_id_to_index_output_copy[rid]
- start = int(self.query_start_loc.cpu[idx])
- sched = int(num_scheduled_tokens_np[idx])
- end = start + sched
- hidden_slice = hidden_states_cpu[start:end]
- payload: dict[str, object] = {"hidden": hidden_slice}
- if isinstance(multimodal_outputs, dict) and multimodal_outputs:
- # ... multimodal output slicing logic ...
- pooler_output.append(payload)
-
-model_runner_output = OmniModelRunnerOutput(
- req_ids=req_ids_output_copy,
- req_id_to_index=req_id_to_index_output_copy,
- sampled_token_ids=valid_sampled_token_ids,
- logprobs=logprobs_lists,
- prompt_logprobs_dict=prompt_logprobs_dict,
- pooler_output=(pooler_output if self.vllm_config.model_config.engine_output_type != "text" else None),
- kv_connector_output=kv_connector_output,
-)
-model_runner_output.kv_extracted_req_ids = kv_extracted_req_ids
-# -------------------------------------- Omni-new -------------------------------------------------
-```
-
----
-
-## NPUGenerationModelRunner (npu_generation_model_runner.py)
-
-### execute_model - Async Chunk Update
-
-Location: Inside prepare input section, before synchronize_input_prep
-
-```python
-# -------------------------------------- Omni-new -------------------------------------------------
-if self.model_config.async_chunk and num_scheduled_tokens:
- self._update_request_states(scheduler_output)
-# -------------------------------------- Omni-new -------------------------------------------------
-```
-
-### execute_model - Seq Token Counts
-
-Location: After _preprocess call
-
-```python
-# [Omni] Pass token counts per request for code2wav output slicing
-model_kwargs["seq_token_counts"] = tokens
-```
-
-### execute_model - Run Generation Model
-
-Location: Inside forward context
-
-```python
-# -------------------------------------- Omni-new -------------------------------------------------
-outputs = self._run_generation_model(
- num_tokens_padded=num_tokens_padded,
- input_ids=input_ids,
- positions=positions,
- intermediate_tensors=intermediate_tensors,
- inputs_embeds=inputs_embeds,
- model_kwargs=model_kwargs,
- logits_indices=logits_indices,
-)
-_, multimodal_outputs = self.extract_multimodal_outputs(outputs)
-# -------------------------------------- Omni-new -------------------------------------------------
-```
-
-### sample_tokens - Multimodal Output Processing
-
-The entire sample_tokens method body is omni-specific for generation models:
-
-```python
-# -------------------------------------- Omni-new -------------------------------------------------
-pooler_output: list[object] = []
-if isinstance(multimodal_outputs, torch.Tensor):
- # ... tensor handling ...
-elif isinstance(multimodal_outputs, list):
- # ... list handling ...
-elif isinstance(multimodal_outputs, dict):
- # ... dict handling per request ...
-else:
- raise RuntimeError("Unsupported diffusion output type")
-# [Omni] Copy req_id mappings to avoid async scheduling mutation.
-req_ids_output_copy = self.input_batch.req_ids.copy()
-req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
-output = OmniModelRunnerOutput(
- req_ids=req_ids_output_copy,
- req_id_to_index=req_id_to_index_output_copy,
- sampled_token_ids=[],
- logprobs=None,
- prompt_logprobs_dict={},
- pooler_output=pooler_output,
- kv_connector_output=kv_connector_output,
- num_nans_in_logits={},
- ec_connector_output=ec_connector_output if self.supports_mm_inputs else None,
-)
-# -------------------------------------- Omni-new -------------------------------------------------
-```
-
-### _dummy_run - Model Kwargs Init and Multimodal Extract
-
-Location: Before model forward and after
-
-```python
-model_kwargs = self._init_model_kwargs() # Before forward
-
-# ... forward ...
-
-# -------------------------------------- Omni-new -------------------------------------------------
-hidden_states, _ = self.extract_multimodal_outputs(hidden_states)
-# -------------------------------------------------------------------------------------------------
-```
-
----
-
-## ExecuteModelState Extension
-
-The `ExecuteModelState` NamedTuple is extended for omni:
-
-```python
-class ExecuteModelState(NamedTuple):
- """Ephemeral cached state transferred between execute_model() and
- sample_tokens(), after execute_model() returns None."""
-
- scheduler_output: SchedulerOutput
- logits: torch.Tensor
- spec_decode_metadata: SpecDecodeMetadata | None
- spec_decode_common_attn_metadata: AscendCommonAttentionMetadata | None
- hidden_states: torch.Tensor
- sample_hidden_states: torch.Tensor
- aux_hidden_states: list[torch.Tensor] | None
- attn_metadata: PerLayerAttnMetadata
- positions: torch.Tensor
- ec_connector_output: ECConnectorOutput | None
- cudagraph_stats: CUDAGraphStat | None
- multimodal_outputs: Any # <-- Omni extension
-```
-
-This extended state must be imported from `npu_ar_model_runner` in `npu_generation_model_runner`.
diff --git a/.claude/skills/vllm-omni-npu-upgrade/references/workflow-checklist.md b/.claude/skills/vllm-omni-npu-upgrade/references/workflow-checklist.md
deleted file mode 100644
index 4f184df0ecb..00000000000
--- a/.claude/skills/vllm-omni-npu-upgrade/references/workflow-checklist.md
+++ /dev/null
@@ -1,222 +0,0 @@
-# NPU Model Runner Upgrade Workflow Checklist
-
-> **Note**: Reference documents (`omni-specific-blocks.md`) may not be complete. Always grep for `Omni-new` in GPU implementations to find all omni-specific code blocks. Update the reference docs when discovering new blocks.
-
-## Pre-Upgrade Preparation
-
-### 1. Version Information
-- [ ] Identify current vllm-omni version: `_________`
-- [ ] Identify target vllm-ascend version: `_________`
-- [ ] Identify target vllm version: `_________`
-- [ ] Last release date for GPU worker changes: `_________`
-
-### 2. Gather Git History
-```bash
-# GPU-side omni changes since last release
-cd /root/vllm-workspace/vllm-omni
-git log --oneline --since="YYYY-MM-DD" -- vllm_omni/worker/
-
-# vllm-ascend NPUModelRunner changes
-cd /root/vllm-workspace/vllm-ascend
-git log --oneline .. -- vllm_ascend/worker/model_runner_v1.py
-```
-
-### 3. Backup Current Files
-- [ ] Create backup of current NPU runners:
- ```bash
- cp -r vllm_omni/platforms/npu/worker vllm_omni/platforms/npu/worker.backup
- ```
-
----
-
-## OmniNPUModelRunner (npu_model_runner.py)
-
-### Read and Understand
-- [ ] Read current `npu_model_runner.py`
-- [ ] Read latest `vllm_ascend/worker/model_runner_v1.py`
-- [ ] Read latest `vllm_omni/worker/gpu_model_runner.py`
-
-### Method: load_model
-- [ ] Document existing omni-specific logic
-- [ ] Copy latest NPUModelRunner.load_model structure
-- [ ] Re-insert: `enable_sp(self.vllm_config)` call
-- [ ] Re-insert: talker_mtp detection and setup
-- [ ] Replace: `CUDAGraphWrapper` → `ACLGraphWrapper`
-- [ ] Re-insert: Buffer allocations (talker_mtp_input_ids, etc.)
-
-### Method: _dummy_run
-- [ ] Document existing omni-specific logic locations
-- [ ] Copy latest NPUModelRunner._dummy_run
-- [ ] Re-insert: talker_mtp dummy forward block (inside context)
-- [ ] Re-insert: `extract_multimodal_outputs` call
-- [ ] Verify: Comment markers are present
-
-### Method: _model_forward
-- [ ] Copy latest NPUModelRunner._model_forward structure
-- [ ] Re-insert: `_build_model_kwargs_extra()` call
-- [ ] Re-insert: OmniOutput wrapping logic
-- [ ] Re-insert: `_omni_last_model_output` caching
-- [ ] Keep: NPU graph params update
-- [ ] Keep: SP all-gather logic
-
-### Method: _talker_mtp_forward
-- [ ] Verify: Uses `set_ascend_forward_context`
-- [ ] Verify: Uses `ACLGraphWrapper` check
-- [ ] Sync any changes from GPU `_talker_mtp_forward`
-
-### Imports
-- [ ] Update vllm-ascend imports to latest paths
-- [ ] Verify all omni imports are present
-- [ ] Remove any deprecated imports
-
----
-
-## NPUARModelRunner (npu_ar_model_runner.py)
-
-### Read and Understand
-- [ ] Read current `npu_ar_model_runner.py`
-- [ ] Read latest `vllm_ascend/worker/model_runner_v1.py` execute_model
-- [ ] Read latest `vllm_omni/worker/gpu_ar_model_runner.py`
-
-### Method: __init__
-- [ ] Sync any new initialization from GPU side
-- [ ] Keep: `OmniKVTransferManager` setup
-- [ ] Keep: Custom buffer allocations
-
-### Method: execute_model
-- [ ] Document all omni blocks with line numbers
-- [ ] Copy latest NPUModelRunner.execute_model structure
-- [ ] Re-insert: KV transfer handling (beginning)
-- [ ] Re-insert: Custom `_update_states` call
-- [ ] Re-insert: `extract_multimodal_outputs`
-- [ ] Re-insert: `compute_logits` with sampling_metadata try/except
-- [ ] Update: ExecuteModelState to include multimodal_outputs
-
-### Method: sample_tokens
-- [ ] Document all omni blocks
-- [ ] Copy latest NPUModelRunner.sample_tokens structure
-- [ ] Re-insert: `kv_extracted_req_ids` handling
-- [ ] Re-insert: Hidden states CPU copy
-- [ ] Re-insert: `_process_additional_information_updates`
-- [ ] Re-insert: `OmniModelRunnerOutput` construction
-
-### ExecuteModelState
-- [ ] Verify: `multimodal_outputs` field is present
-- [ ] Verify: Imported/used correctly in execute_model
-
-### Imports
-- [ ] Update all vllm-ascend imports
-- [ ] Keep omni-specific imports
-
----
-
-## NPUGenerationModelRunner (npu_generation_model_runner.py)
-
-### Read and Understand
-- [ ] Read current `npu_generation_model_runner.py`
-- [ ] Read latest GPU `gpu_generation_model_runner.py`
-
-### Method: _update_request_states
-- [ ] Verify: async_chunk handling is correct
-- [ ] Sync any changes from GPU side
-
-### Method: execute_model
-- [ ] Document all omni blocks
-- [ ] Copy latest NPUModelRunner.execute_model base structure
-- [ ] Re-insert: async_chunk update logic
-- [ ] Re-insert: `seq_token_counts` injection
-- [ ] Re-insert: `_run_generation_model` call
-- [ ] Re-insert: `extract_multimodal_outputs`
-- [ ] Use: ExecuteModelState from npu_ar_model_runner
-
-### Method: sample_tokens
-- [ ] Keep: Entire omni multimodal output processing
-- [ ] Update: Any new output fields needed
-- [ ] Keep: `OmniModelRunnerOutput` construction
-
-### Method: _run_generation_model
-- [ ] Sync any changes from GPU side
-- [ ] Keep: `_model_forward` call with sampler
-
-### Method: _dummy_run
-- [ ] Copy latest NPUModelRunner._dummy_run
-- [ ] Re-insert: `model_kwargs = self._init_model_kwargs()`
-- [ ] Re-insert: `extract_multimodal_outputs` at end
-
-### Imports
-- [ ] Import ExecuteModelState from npu_ar_model_runner
-- [ ] Update vllm-ascend imports
-
----
-
-## Post-Upgrade Validation
-
-### Syntax Validation
-- [ ] `python -m py_compile vllm_omni/platforms/npu/worker/npu_model_runner.py`
-- [ ] `python -m py_compile vllm_omni/platforms/npu/worker/npu_ar_model_runner.py`
-- [ ] `python -m py_compile vllm_omni/platforms/npu/worker/npu_generation_model_runner.py`
-
-### Import Validation
-- [ ] `python -c "from vllm_omni.platforms.npu.worker.npu_model_runner import OmniNPUModelRunner"`
-- [ ] `python -c "from vllm_omni.platforms.npu.worker.npu_ar_model_runner import NPUARModelRunner"`
-- [ ] `python -c "from vllm_omni.platforms.npu.worker.npu_generation_model_runner import NPUGenerationModelRunner"`
-
-### Comment Markers
-- [ ] Grep for "Omni-new" in all three files
-- [ ] Verify all omni blocks have closing markers
-
-### Code Review
-- [ ] No `CUDAGraphWrapper` references
-- [ ] All `set_forward_context` replaced with `set_ascend_forward_context`
-- [ ] Parameter names correct (`aclgraph_runtime_mode` not `cudagraph_runtime_mode`)
-- [ ] No duplicate code blocks
-- [ ] No missing imports
-
----
-
-## Git Commit
-
-### Commit Message Template
-```
-[NPU] Upgrade model runners to align with vllm-ascend vX.Y.Z
-
-- Update OmniNPUModelRunner with latest NPUModelRunner base
-- Update NPUARModelRunner execute_model and sample_tokens
-- Update NPUGenerationModelRunner for async_chunk changes
-- Sync GPU-side omni changes from vX.Y.Z release
-- Preserve all omni-specific logic (marked with Omni-new comments)
-
-Changes from vllm-ascend:
--
-
-Changes synced from GPU:
--
-```
-
-### Files to Stage
-- [ ] `vllm_omni/platforms/npu/worker/npu_model_runner.py`
-- [ ] `vllm_omni/platforms/npu/worker/npu_ar_model_runner.py`
-- [ ] `vllm_omni/platforms/npu/worker/npu_generation_model_runner.py`
-- [ ] Any other modified files
-
----
-
-## Troubleshooting
-
-### Import Errors
-- Check if vllm-ascend module paths have changed
-- Verify PYTHONPATH includes both vllm-ascend and vllm-omni
-
-### Type Errors
-- Check method signatures match between GPU and NPU
-- Verify NamedTuple fields match expected structure
-
-### Runtime Errors
-- Enable debug logging: `export VLLM_LOGGING_LEVEL=DEBUG`
-- Check graph capture issues: try `--enforce-eager`
-- Check attention issues: verify AscendAttentionState usage
-
-### Performance Regression
-- Compare with previous version on same model
-- Check if graph capture is working: look for ACLGraph logs
-- Verify SP/EP configurations are correct
diff --git a/.github/scripts/pr_reviewer.py b/.github/scripts/pr_reviewer.py
new file mode 100755
index 00000000000..da629a64587
--- /dev/null
+++ b/.github/scripts/pr_reviewer.py
@@ -0,0 +1,629 @@
+#!/usr/bin/env python3
+"""
+PR Reviewer using GLM API for vllm-omni project.
+"""
+
+import json
+import logging
+import os
+import sys
+import time
+from dataclasses import dataclass
+from typing import Any, TypedDict
+
+import requests
+
+
+# Type definitions for API responses
+class PRDetails(TypedDict):
+ """Type definition for GitHub PR details response."""
+
+ title: str
+ body: str
+ number: int
+ state: str
+ user: dict[str, Any]
+
+
+class GLMMessage(TypedDict):
+ """Type definition for GLM API message."""
+
+ role: str
+ content: str
+
+
+class GLMChoice(TypedDict):
+ """Type definition for GLM API choice."""
+
+ message: GLMMessage
+ finish_reason: str
+
+
+class GLMResponse(TypedDict):
+ """Type definition for GLM API response."""
+
+ choices: list[GLMChoice]
+ usage: dict[str, int] | None
+
+
+class GitHubComment(TypedDict):
+ """Type definition for GitHub comment."""
+
+ id: int
+ body: str
+ created_at: str
+ user: dict[str, Any]
+
+
+# Configuration
+TRIGGER_PHRASE: str = "@vllm-omni-reviewer"
+DEFAULT_GLM_API_URL: str = "https://open.bigmodel.cn/api/paas/v4/chat/completions" # noqa: E501
+DEFAULT_GLM_MODEL: str = "glm-5"
+DEFAULT_COOLDOWN_MINUTES: int = 5
+DEFAULT_MAX_RETRIES: int = 3
+DEFAULT_RETRY_DELAY: float = 1.0
+MAX_DIFF_SIZE: int = 100_000 # Maximum diff size in characters
+
+
+@dataclass
+class Config:
+ """Configuration for the PR reviewer."""
+
+ glm_api_url: str
+ glm_model: str
+ cooldown_minutes: int
+ max_retries: int
+ retry_delay: float
+ max_diff_size: int
+
+
+# Setup logging
+logging.basicConfig(
+ level=logging.INFO,
+ format="[PR Reviewer] %(asctime)s - %(levelname)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+)
+logger: logging.Logger = logging.getLogger(__name__)
+
+
+def get_config() -> Config:
+ """Load configuration from environment variables with defaults."""
+ return Config(
+ glm_api_url=os.getenv("GLM_API_URL", DEFAULT_GLM_API_URL),
+ glm_model=os.getenv("GLM_MODEL", DEFAULT_GLM_MODEL),
+ cooldown_minutes=int(
+ os.getenv(
+ "PR_REVIEWER_COOLDOWN_MINUTES",
+ str(DEFAULT_COOLDOWN_MINUTES),
+ )
+ ),
+ max_retries=int(
+ os.getenv(
+ "PR_REVIEWER_MAX_RETRIES",
+ str(DEFAULT_MAX_RETRIES),
+ )
+ ),
+ retry_delay=float(os.getenv("PR_REVIEWER_RETRY_DELAY", str(DEFAULT_RETRY_DELAY))),
+ max_diff_size=int(os.getenv("PR_REVIEWER_MAX_DIFF_SIZE", str(MAX_DIFF_SIZE))), # noqa: E501
+ )
+
+
+def get_env_var(name: str) -> str:
+ """
+ Get an environment variable or raise an error.
+
+ Args:
+ name: Name of the environment variable.
+
+ Returns:
+ The value of the environment variable.
+
+ Raises:
+ SystemExit: If the environment variable is not set.
+ """
+ value = os.environ.get(name)
+ if not value:
+ logger.error(f"Environment variable {name} is not set")
+ sys.exit(1)
+ return value
+
+
+def check_trigger(comment_body: str) -> bool:
+ """
+ Check if the comment contains the trigger phrase.
+
+ Args:
+ comment_body: The body of the comment to check.
+
+ Returns:
+ True if the trigger phrase is found, False otherwise.
+ """
+ return TRIGGER_PHRASE in comment_body
+
+
+def fetch_pr_diff(
+ repo_name: str,
+ pr_number: int,
+ token: str,
+ max_size: int = MAX_DIFF_SIZE,
+) -> str | None:
+ """
+ Fetch the diff for a pull request.
+
+ Args:
+ repo_name: The repository name in format "owner/repo".
+ pr_number: The pull request number.
+ token: GitHub authentication token.
+ max_size: Maximum diff size in characters.
+
+ Returns:
+ The diff content as a string, or None if fetching failed.
+ Returns empty string if diff is larger than max_size.
+ """
+ url: str = f"https://api.github.com/repos/{repo_name}/pulls/{pr_number}"
+ headers: dict[str, str] = {
+ "Authorization": f"Bearer {token}",
+ "Accept": "application/vnd.github.v3.diff",
+ }
+
+ logger.info(f"Fetching PR diff from {url}")
+ response = requests.get(url, headers=headers, timeout=30)
+
+ if response.status_code == 200:
+ diff: str = response.text
+ if len(diff) > max_size:
+ logger.warning(
+ f"Diff size ({len(diff)} bytes) exceeds maximum "
+ f"({max_size} bytes), truncating to first "
+ f"{max_size} bytes"
+ )
+ return diff[:max_size] + "\n\n... [Diff truncated due to size] ..."
+ logger.info(f"Successfully fetched diff ({len(diff)} bytes)")
+ return diff
+ else:
+ logger.error(f"Failed to fetch PR diff: {response.status_code}")
+ logger.error(f"Response: {response.text}")
+ return None
+
+
+def fetch_pr_details(
+ repo_name: str,
+ pr_number: int,
+ token: str,
+) -> PRDetails | None:
+ """
+ Fetch PR details including title and description.
+
+ Args:
+ repo_name: The repository name in format "owner/repo".
+ pr_number: The pull request number.
+ token: GitHub authentication token.
+
+ Returns:
+ A dictionary containing PR details, or None if fetching failed.
+ """
+ url: str = f"https://api.github.com/repos/{repo_name}/pulls/{pr_number}"
+ headers: dict[str, str] = {
+ "Authorization": f"Bearer {token}",
+ "Accept": "application/vnd.github.v3+json",
+ }
+
+ logger.info(f"Fetching PR details from {url}")
+ response = requests.get(url, headers=headers, timeout=30)
+
+ if response.status_code == 200:
+ return response.json()
+ else:
+ logger.error(f"Failed to fetch PR details: {response.status_code}")
+ return None
+
+
+def build_review_prompt(pr_title: str, pr_description: str, diff: str) -> str:
+ """
+ Build the prompt for the GLM-4.7 API.
+
+ Args:
+ pr_title: The title of the pull request.
+ pr_description: The description/body of the pull request.
+ diff: The diff content of the pull request.
+
+ Returns:
+ The formatted prompt string for the API.
+ """
+ return f"""You are an expert code reviewer for the VLLM-Omni project. \
+Please review the following pull request:
+
+## Pull Request Details
+**Title:** {pr_title}
+
+**Description:**
+{pr_description if pr_description else "No description provided."}
+
+## Code Changes (Diff)
+{diff}
+
+## Review Guidelines
+
+Please provide a comprehensive code review with the following sections:
+
+### 1. Overview
+- Brief summary of the changes
+- Overall assessment (positive, neutral, or concerns)
+
+### 2. Code Quality
+- Code style and consistency
+- Potential bugs or edge cases
+- Performance considerations
+- Error handling
+
+### 3. Architecture & Design
+- Integration with existing codebase
+- Design patterns and best practices
+- Potential improvements
+
+### 4. Security & Safety
+- Security concerns (if any)
+- Resource management
+- Input validation
+
+### 5. Testing & Documentation
+- Test coverage considerations
+- Documentation completeness
+- Examples and usage clarity
+
+### 6. Specific Suggestions
+- Line-by-line specific feedback (use `file:line` format)
+- Concrete actionable suggestions
+- Code examples for improvements (if applicable)
+
+### 7. Approval Status
+- **LGTM** (Looks Good To Me) if the PR is ready to merge
+- **LGTM with suggestions** if the PR is good but has minor suggestions
+- **Changes requested** if significant changes are needed
+
+## Important Notes
+- Be constructive and helpful
+- Focus on objective technical feedback
+- Acknowledge good practices when you see them
+- Prioritize critical issues over nitpicks
+- If the diff is empty or minimal, acknowledge this and provide
+ any relevant context-specific guidance
+
+Please format your response in Markdown with clear section headers.
+"""
+
+
+def validate_glm_response(data: dict[str, Any]) -> str | None:
+ """
+ Validate and extract content from GLM API response.
+
+ Args:
+ data: The response data from GLM API.
+
+ Returns:
+ The review content string if valid, None otherwise.
+ """
+ # Check if choices exists and is a non-empty list
+ if "choices" not in data:
+ logger.error("GLM API response missing 'choices' field")
+ logger.error(f"Response structure: {json.dumps(data, indent=2)}")
+ return None
+
+ choices = data["choices"]
+ if not isinstance(choices, list):
+ logger.error(f"GLM API 'choices' is not a list: {type(choices)}")
+ return None
+
+ if len(choices) == 0:
+ logger.error("GLM API 'choices' is an empty list")
+ return None
+
+ # Check if first choice has message
+ try:
+ first_choice = choices[0]
+ if not isinstance(first_choice, dict):
+ logger.error(f"GLM API choice is not a dict: {type(first_choice)}")
+ return None
+
+ if "message" not in first_choice:
+ logger.error("GLM API choice missing 'message' field")
+ logger.error(f"Choice structure: {json.dumps(first_choice, indent=2)}") # noqa: E501
+ return None
+
+ message = first_choice["message"]
+ if not isinstance(message, dict):
+ logger.error(f"GLM API message is not a dict: {type(message)}")
+ return None
+
+ if "content" not in message:
+ logger.error("GLM API message missing 'content' field")
+ logger.error(f"Message structure: {json.dumps(message, indent=2)}")
+ return None
+
+ content = message["content"]
+ if not isinstance(content, str):
+ logger.error(f"GLM API content is not a string: {type(content)}")
+ return None
+
+ return content
+
+ except (KeyError, IndexError, TypeError) as e:
+ logger.error(f"Failed to parse GLM API response: {e}")
+ logger.error(f"Response: {json.dumps(data, indent=2)}")
+ return None
+
+
+def call_glm_api(prompt: str, api_key: str, config: Config) -> str | None:
+ """
+ Call the GLM-4.7 API to get code review with retry logic.
+
+ Args:
+ prompt: The prompt to send to the API.
+ api_key: The GLM API key.
+ config: Configuration object.
+
+ Returns:
+ The review content as a string, or None if all retries failed.
+ """
+ headers: dict[str, str] = {
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json",
+ }
+
+ payload: dict[str, Any] = {
+ "model": config.glm_model,
+ "messages": [{"role": "user", "content": prompt}],
+ "temperature": 0.3,
+ "max_tokens": 32000,
+ "top_p": 0.9,
+ }
+
+ last_error: str | None = None
+
+ for attempt in range(config.max_retries):
+ try:
+ logger.info(f"Calling GLM API ({config.glm_model}) - Attempt {attempt + 1}/{config.max_retries}")
+ response = requests.post(
+ config.glm_api_url,
+ headers=headers,
+ json=payload,
+ timeout=120,
+ )
+
+ if response.status_code == 200:
+ data = response.json()
+ review = validate_glm_response(data)
+ if review:
+ logger.info(f"Successfully received review ({len(review)} chars)") # noqa: E501
+ return review
+ else:
+ last_error = "Failed to validate API response structure"
+ logger.error(last_error)
+ else:
+ last_error = f"GLM API request failed: {response.status_code} - {response.text}"
+ logger.error(last_error)
+
+ except requests.exceptions.Timeout:
+ last_error = f"GLM API request timed out (attempt {attempt + 1})"
+ logger.error(last_error)
+ except requests.exceptions.RequestException as e:
+ last_error = f"GLM API request exception: {e}"
+ logger.error(last_error)
+ except json.JSONDecodeError as e:
+ last_error = f"Failed to decode GLM API response as JSON: {e}"
+ logger.error(last_error)
+
+ # Exponential backoff before retry
+ if attempt < config.max_retries - 1:
+ wait_time: float = config.retry_delay * (2**attempt)
+ logger.info(f"Waiting {wait_time}s before retry...") # noqa: E501
+ time.sleep(wait_time)
+
+ logger.error(
+ f"All {config.max_retries} attempts failed. Last error: {last_error}" # noqa: E501
+ )
+ return None
+
+
+def check_cooldown( # noqa: E501
+ repo_name: str,
+ pr_number: int,
+ token: str,
+ cooldown_minutes: int,
+) -> bool:
+ """
+ Check if the PR is within the cooldown period.
+
+ Args:
+ repo_name: The repository name in format "owner/repo".
+ pr_number: The pull request number.
+ token: GitHub authentication token.
+ cooldown_minutes: Cooldown period in minutes.
+
+ Returns:
+ True if within cooldown period (should skip), False otherwise.
+ """
+ from datetime import datetime, timedelta
+
+ url: str = (
+ f"https://api.github.com/repos/{repo_name}/issues/"
+ f"{pr_number}/comments" # noqa: E501
+ )
+ headers: dict[str, str] = {
+ "Authorization": f"Bearer {token}",
+ "Accept": "application/vnd.github.v3+json",
+ }
+
+ logger.info(f"Checking cooldown period ({cooldown_minutes} minutes)")
+ response = requests.get(url, headers=headers, timeout=30)
+
+ if response.status_code != 200:
+ logger.warning(f"Failed to check cooldown: {response.status_code}, proceeding with review")
+ return False
+
+ comments: list[dict[str, Any]] = response.json()
+ cutoff_time: datetime = datetime.utcnow() - timedelta(minutes=cooldown_minutes) # noqa: E501
+
+ for comment in reversed(comments):
+ # Check if this is a bot comment
+ body: str = comment.get("body", "")
+ if "VLLM-Omni PR Review" in body or "PR Reviewer Bot" in body:
+ created_at_str: str = comment.get("created_at", "")
+ try:
+ # Parse GitHub timestamp format
+ created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
+ created_at = created_at.replace(tzinfo=None)
+ if created_at > cutoff_time:
+ logger.info(f"PR is within cooldown period (last review: {created_at_str})")
+ return True
+ except ValueError:
+ logger.warning(f"Failed to parse comment timestamp: {created_at_str}") # noqa: E501
+ continue
+
+ logger.info("PR is outside cooldown period, proceeding with review")
+ return False
+
+
+def post_review_comment( # noqa: E501
+ repo_name: str,
+ pr_number: int,
+ token: str,
+ review: str,
+) -> bool:
+ """
+ Post the review as a comment on the PR.
+
+ Args:
+ repo_name: The repository name in format "owner/repo".
+ pr_number: The pull request number.
+ token: GitHub authentication token.
+ review: The review content to post.
+
+ Returns:
+ True if posting succeeded, False otherwise.
+ """
+ url: str = (
+ f"https://api.github.com/repos/{repo_name}/issues/"
+ f"{pr_number}/comments" # noqa: E501
+ )
+ headers: dict[str, str] = {
+ "Authorization": f"Bearer {token}",
+ "Accept": "application/vnd.github.v3+json",
+ }
+
+ # Format the review comment
+ comment_body: str = f"""## 🤖 VLLM-Omni PR Review
+
+{review}
+
+---
+*This review was generated automatically by the VLLM-Omni PR Reviewer Bot
+using {os.getenv("GLM_MODEL", DEFAULT_GLM_MODEL)}.*
+"""
+
+ payload: dict[str, str] = {"body": comment_body}
+
+ logger.info(f"Posting review comment to PR #{pr_number}")
+ response = requests.post(url, headers=headers, json=payload, timeout=30)
+
+ if response.status_code == 201:
+ logger.info("Successfully posted review comment")
+ return True
+ else:
+ logger.error(f"Failed to post comment: {response.status_code}")
+ logger.error(f"Response: {response.text}")
+ return False
+
+
+def main() -> int:
+ """
+ Main entry point for the PR reviewer bot.
+
+ Returns:
+ 0 on success, 1 on error.
+ """
+ logger.info("VLLM-Omni PR Reviewer Bot starting...")
+
+ # Load configuration
+ config: Config = get_config()
+ logger.info(
+ f"Configuration: model={config.glm_model}, "
+ f"cooldown={config.cooldown_minutes}min, "
+ f"max_retries={config.max_retries}"
+ )
+
+ # Get environment variables
+ token: str = get_env_var("GITHUB_TOKEN")
+ api_key: str = get_env_var("GLM_API_KEY")
+ repo_name: str = get_env_var("REPO_NAME")
+ pr_number_str: str = get_env_var("PR_NUMBER")
+ comment_body: str = get_env_var("COMMENT_BODY")
+
+ try:
+ pr_number: int = int(pr_number_str)
+ except ValueError:
+ logger.error(f"Invalid PR number: {pr_number_str}")
+ return 1
+
+ logger.info(f"Repository: {repo_name}")
+ logger.info(f"PR Number: {pr_number}")
+
+ # Check if the comment contains the trigger phrase
+ if not check_trigger(comment_body):
+ logger.info(
+ f"Comment does not contain trigger phrase '{TRIGGER_PHRASE}', exiting" # noqa: E501
+ )
+ return 0
+
+ logger.info("Trigger phrase detected! Starting review process...")
+
+ # Check cooldown period
+ if check_cooldown(repo_name, pr_number, token, config.cooldown_minutes):
+ logger.info("Skipping review due to cooldown period")
+ return 0
+
+ # Fetch PR details
+ logger.info("Step 1/4: Fetching PR details...")
+ pr_details: PRDetails | None = fetch_pr_details(repo_name, pr_number, token) # noqa: E501
+ if not pr_details:
+ logger.error("Failed to fetch PR details")
+ return 1
+
+ pr_title: str = pr_details.get("title", "Unknown")
+ pr_description: str = pr_details.get("body", "")
+
+ logger.info(f"PR Title: {pr_title}")
+
+ # Fetch PR diff
+ logger.info("Step 2/4: Fetching PR diff...")
+ diff: str | None = fetch_pr_diff(repo_name, pr_number, token, config.max_diff_size)
+ if diff is None:
+ logger.error("Failed to fetch PR diff")
+ return 1
+
+ if not diff:
+ logger.warning("Warning: Empty diff - this might be a draft PR or no code changes")
+
+ # Build prompt
+ logger.info("Step 3/4: Building review prompt...")
+ prompt: str = build_review_prompt(pr_title, pr_description, diff)
+
+ # Call GLM API
+ logger.info("Step 4/4: Calling GLM API...")
+ review: str | None = call_glm_api(prompt, api_key, config)
+ if not review:
+ logger.error("Failed to get review from GLM API")
+ return 1
+
+ # Post review comment
+ logger.info("Posting review comment...")
+ if not post_review_comment(repo_name, pr_number, token, review):
+ logger.error("Failed to post review comment")
+ return 1
+
+ logger.info("PR review completed successfully!")
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/.github/workflows/pr-reviewer.yml b/.github/workflows/pr-reviewer.yml
new file mode 100644
index 00000000000..8a41c69375e
--- /dev/null
+++ b/.github/workflows/pr-reviewer.yml
@@ -0,0 +1,62 @@
+name: VLLM-Omni PR Reviewer
+
+on:
+ issue_comment:
+ types: [created]
+
+permissions:
+ contents: read
+ pull-requests: write
+ issues: write
+
+jobs:
+ pr-reviewer:
+ name: Review Pull Request
+ runs-on: ubuntu-latest
+ timeout-minutes: 10
+ # Only run when the comment is from a collaborator/owner/member
+ if: |
+ github.event_name == 'issue_comment' &&
+ github.event.issue.pull_request != null &&
+ contains(github.event.comment.body, '@vllm-omni-reviewer') &&
+ (github.event.comment.author_association == 'MEMBER' ||
+ github.event.comment.author_association == 'COLLABORATOR' ||
+ github.event.comment.author_association == 'OWNER')
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v6.0.2
+ with:
+ fetch-depth: 0
+
+ - name: Set up Python
+ uses: actions/setup-python@v6.2.0
+ with:
+ python-version: '3.11'
+ cache: 'pip'
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip==24.0
+ pip install requests==2.31.0 pyyaml==6.0.1
+
+ - name: Run PR Reviewer
+ id: reviewer
+ env:
+ GLM_API_KEY: ${{ secrets.GLM_API_KEY }}
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ PR_NUMBER: ${{ github.event.issue.number || github.event.pull_request.number }}
+ COMMENT_BODY: ${{ github.event.comment.body }}
+ REPO_NAME: ${{ github.repository }}
+ PR_REVIEWER_COOLDOWN_MINUTES: 5
+ PR_REVIEWER_MAX_RETRIES: 3
+ run: |
+ python .github/scripts/pr_reviewer.py 2>&1 | tee "/tmp/pr_review_${PR_NUMBER}.log"
+
+ - name: Upload review logs
+ if: always()
+ uses: actions/upload-artifact@v7.0.0
+ with:
+ name: pr-review-logs-${{ github.event.issue.number || github.event.pull_request.number }}
+ path: /tmp/pr_review_*.log
+ retention-days: 7
+ if-no-files-found: ignore
diff --git a/.gitignore b/.gitignore
index a7cd8f74eb4..28d56e0f6f0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -83,9 +83,6 @@ target/
profile_default/
ipython_config.py
-# uv
-uv.lock
-
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
@@ -158,23 +155,10 @@ cython_debug/
# Claude
CLAUDE.md
-/.claude/*
-!.claude/skills/
-!.claude/skills/readme.md
-!.claude/skills/add-diffusion-model/
-!.claude/skills/add-diffusion-model/SKILL.md
-!.claude/skills/add-diffusion-model/references/
-!.claude/skills/add-diffusion-model/references/*.md
-!.claude/skills/add-tts-model/
-!.claude/skills/add-tts-model/SKILL.md
-!.claude/skills/review-pr/
-!.claude/skills/review-pr/SKILL.md
-!.claude/skills/review-pr/references/
-!.claude/skills/review-pr/references/*.md
+.claude/
# Codex
AGENTS.md
-.codex
.codex/
# cursor
@@ -204,7 +188,6 @@ checkpoints/
# Cache directories
cache/
!vllm_omni/diffusion/cache/
-!tests/diffusion/cache/
.cache/
diffusion_cache/
kv_cache/
@@ -264,6 +247,3 @@ tmp_test
vllm_omni/_version.py
# output files
*.wav
-.worktrees/
-# CI overlay yamls materialized from tests/utils.py:_CI_OVERLAYS at test time
-tests/.ci_generated/
diff --git a/benchmarks/accuracy/README.md b/benchmarks/accuracy/README.md
index dbe20916a77..0d73215b692 100644
--- a/benchmarks/accuracy/README.md
+++ b/benchmarks/accuracy/README.md
@@ -23,5 +23,5 @@ Test guidance:
- Local static/self-checks live in `tests/benchmarks/test_accuracy_bench_utils.py`.
- End-to-end generation/evaluation should be validated in a remote GPU
environment. In the current repo marker system there is `L4` but no `L5`
- marker, so benchmark smoke tests should be wired as `full_model +
- benchmark + L4` for nightly when GPU capacity is available.
+ marker, so benchmark smoke tests should be wired as `advanced_model +
+ benchmark + L4` when GPU capacity is available.
diff --git a/benchmarks/accuracy/image_to_image/README.md b/benchmarks/accuracy/image_to_image/README.md
index 86e7b0cf328..ee1d58f108b 100644
--- a/benchmarks/accuracy/image_to_image/README.md
+++ b/benchmarks/accuracy/image_to_image/README.md
@@ -99,5 +99,5 @@ Notes:
- This flow requires the optional Hugging Face `datasets` package.
- `generate` writes `generation_manifest.json` with local output coverage.
- The current repo marker set exposes `L4` but not `L5`, so if you promote an
- end-to-end smoke test into CI, use the `full_model`, `benchmark`,
- and `L4` markers for nightly (or `advanced_model` for merge) or introduce a new repo-wide marker explicitly first.
+ end-to-end smoke test into CI, use the existing `advanced_model`, `benchmark`,
+ and `L4` markers or introduce a new repo-wide marker explicitly first.
diff --git a/benchmarks/build_dataset/download_process_data_seedtts.md b/benchmarks/build_dataset/download_process_data_seedtts.md
index faf072303b8..ec16f64424a 100644
--- a/benchmarks/build_dataset/download_process_data_seedtts.md
+++ b/benchmarks/build_dataset/download_process_data_seedtts.md
@@ -27,7 +27,7 @@ pip install gdown
Download the dataset from Google Drive:
```bash
-gdown 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP
+gdown --id 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP
```
### 4. Extract the Dataset
@@ -74,7 +74,7 @@ rm meta.lst
# Full setup and benchmark
cd benchmarks/build_dataset
pip install gdown
-gdown 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP
+gdown --id 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP
tar -xf seedtts_testset.tar
cp seedtts_testset/en/meta.lst meta.lst
python extract_tts_prompts.py -i meta.lst -o top100.txt -n 100
diff --git a/benchmarks/build_dataset/seed_tts_design/en/meta.lst b/benchmarks/build_dataset/seed_tts_design/en/meta.lst
deleted file mode 100644
index 7e364c2e517..00000000000
--- a/benchmarks/build_dataset/seed_tts_design/en/meta.lst
+++ /dev/null
@@ -1,20 +0,0 @@
-vd001|||The quick brown fox jumps over the lazy dog.|A warm, friendly female voice with a slight American Midwest accent, speaking at a moderate pace with natural inflection.
-vd002|||Welcome to the future of text-to-speech synthesis.|A deep, authoritative male news anchor voice, clear and professional with a measured cadence.
-vd003|||The sunset painted the sky in brilliant shades of orange and pink.|A gentle elderly female voice, soft and wise, with a slight Southern American accent.
-vd004|||Scientists have discovered a new species of deep-sea creature.|A young male voice with an Australian accent, curious and enthusiastic.
-vd005|||Breaking news: a major climate summit opens today in Geneva.|A crisp female newsreader voice, neutral accent, confident and precise.
-vd006|||In the beginning, there was darkness and silence across the universe.|A rich, dramatic bass male narrator voice, slow and deeply resonant.
-vd007|||Come closer, I have something important to tell you.|A soft, intimate female voice, slightly whispery, warm and gentle.
-vd008|||And they're off! The horses race toward the first turn at incredible speed.|An energetic male sports commentator, fast-paced and excited.
-vd009|||Once upon a time, in a land far away, lived a very clever fox.|A light, playful voice with childlike enthusiasm, bright and clear.
-vd010|||The ancient manuscript reveals secrets hidden for a thousand years.|A wise, measured elderly male voice, slow and deliberate, British English accent.
-vd011|||Good evening, ladies and gentlemen, and welcome to our show.|A sophisticated female voice with a slight French accent speaking English, elegant and refined.
-vd012|||System initialized. Running diagnostics. All systems nominal.|A clear, precise robotic-sounding voice, neutral and monotone with slight synthetic quality.
-vd013|||I hear what you are saying, and it is completely understandable to feel that way.|A warm, empathetic female therapist voice, calm and reassuring, unhurried pace.
-vd014|||Attention all units: proceed to grid reference seven-seven-alpha.|A firm, authoritative military male voice, clipped and commanding.
-vd015|||Oh my goodness, you have to try this amazing new recipe I just found!|An enthusiastic, bubbly female voice, high energy and friendly.
-vd016|||Dude, the waves were totally amazing out there today. Super happy about it!|A relaxed male voice with a California accent, casual and laid-back.
-vd017|||The quarterly results exceed expectations across all major metrics.|A sharp, businesslike female voice, confident and efficient, fast-paced delivery.
-vd018|||Chapter one. The morning sun filtered gently through the forest canopy.|A smooth, rich male audiobook narrator voice, expressive and engaging.
-vd019|||To be or not to be, that is the question.|A theatrical female voice, dramatic and expressive, stage projection quality.
-vd020|||And that is all for tonight. Stay well out there, everyone.|A warm, velvety male late-night radio DJ voice, smooth and intimate.
diff --git a/benchmarks/build_dataset/seed_tts_smoke/en/meta.lst b/benchmarks/build_dataset/seed_tts_smoke/en/meta.lst
deleted file mode 100644
index afe4bc8abcd..00000000000
--- a/benchmarks/build_dataset/seed_tts_smoke/en/meta.lst
+++ /dev/null
@@ -1,20 +0,0 @@
-smoke001|||The quick brown fox jumps over the lazy dog near the riverbank at sunset.
-smoke002|||Welcome to the future of text-to-speech synthesis in production systems.
-smoke003|||Yesterday the team finished rolling out the new authentication flow.
-smoke004|||She walked carefully across the wet cobblestones, careful not to slip.
-smoke005|||The conference call is scheduled for nine in the morning, Pacific time.
-smoke006|||Please remember to save your work before closing the editor.
-smoke007|||Two plus two equals four, but five hundred and forty three digits is long.
-smoke008|||I would like a coffee with oat milk and a chocolate croissant please.
-smoke009|||The library closes at eight on weekdays and six on Saturdays.
-smoke010|||During the Renaissance, art and science flourished in European cities.
-smoke011|||He whispered the secret word so quietly that no one else could hear.
-smoke012|||Our flight departs from gate twenty three at eleven fifteen.
-smoke013|||The storm knocked out power for six hours, but the backup generator kicked in.
-smoke014|||Reading a good book on a rainy afternoon is one of life's great pleasures.
-smoke015|||When the kettle whistled, she poured the hot water over the fresh tea leaves.
-smoke016|||The algorithm runs in linear time, which is a big improvement over the previous approach.
-smoke017|||In the distance, the mountains were shrouded in thick morning fog.
-smoke018|||Our company reported record revenue for the fourth quarter of the fiscal year.
-smoke019|||She explained the new policy in detail during the staff meeting this morning.
-smoke020|||The children laughed and played in the garden until the sun began to set.
diff --git a/benchmarks/diffusion/backends.py b/benchmarks/diffusion/backends.py
index d33160f1377..fa53f87aed7 100644
--- a/benchmarks/diffusion/backends.py
+++ b/benchmarks/diffusion/backends.py
@@ -122,18 +122,6 @@ async def async_request_chat_completions(
output.peak_memory_mb = first_item.get("peak_memory_mb", 0.0)
except (IndexError, TypeError, AttributeError):
pass
-
- if (not output.stage_durations or output.peak_memory_mb == 0.0) and isinstance(
- resp_json.get("metrics"), dict
- ):
- m = resp_json["metrics"]
- if not output.stage_durations and isinstance(m.get("stage_durations"), dict):
- output.stage_durations = m.get("stage_durations") or {}
- if output.peak_memory_mb == 0.0 and m.get("peak_memory_mb") is not None:
- try:
- output.peak_memory_mb = float(m.get("peak_memory_mb") or 0.0)
- except (TypeError, ValueError):
- pass
else:
output.error = f"HTTP {response.status}: {await response.text()}"
output.success = False
@@ -318,8 +306,6 @@ async def async_request_v1_videos(
video_bytes = await content_response.read()
output.response_body = video_bytes
output.success = True
- if "stage_durations" in poll_json:
- output.stage_durations = poll_json["stage_durations"] or {}
if "peak_memory_mb" in poll_json:
output.peak_memory_mb = poll_json["peak_memory_mb"]
elif "peak_memory_mb" in resp_json:
diff --git a/benchmarks/diffusion/diffusion_benchmark_serving.py b/benchmarks/diffusion/diffusion_benchmark_serving.py
index 77b36b3d9c0..aad955b0d1d 100644
--- a/benchmarks/diffusion/diffusion_benchmark_serving.py
+++ b/benchmarks/diffusion/diffusion_benchmark_serving.py
@@ -12,15 +12,15 @@
- v1/videos: Use /v1/videos endpoint
Usage:
- # Video (v1/videos backend)
+ # Video (vllm-omni backend)
t2v:
python3 benchmarks/diffusion/diffusion_benchmark_serving.py \
- --backend v1/videos --dataset vbench --task t2v --num-prompts 10 \
+ --backend vllm-omni --dataset vbench --task t2v --num-prompts 10 \
--height 480 --width 640 --fps 16 --num-frames 80
i2v:
python3 benchmarks/diffusion/diffusion_benchmark_serving.py \
- --backend v1/videos --dataset vbench --task i2v --num-prompts 10
+ --backend vllm-omni --dataset vbench --task i2v --num-prompts 10
# Image (vllm-omni backend)
@@ -49,7 +49,7 @@
--backend openai --dataset vbench --task t2i --num-prompts 10 \
--height 1024 --width 1024 --port 3000
- # Video (v1/videos)
+ # Video (v1/vedeos)
t2v:
python3 benchmarks/diffusion/diffusion_benchmark_serving.py \
--backend v1/videos --dataset random --task t2v --num-prompts 1 \
@@ -558,7 +558,6 @@ def __init__(self, args, api_url: str, model: str, enable_negative_prompt: bool
super().__init__(args, api_url, model)
self.num_prompts = args.num_prompts
self.enable_negative_prompt = enable_negative_prompt
- self.num_input_images = max(1, args.num_input_images)
self.random_request_config = getattr(args, "random_request_config", None)
if self.random_request_config:
self.random_request_config = json.loads(self.random_request_config)
@@ -581,7 +580,11 @@ def __init__(self, args, api_url: str, model: str, enable_negative_prompt: bool
# Random image generate
if self.args.task in ["i2v", "ti2v", "ti2i", "i2i"]:
- self._random_image_path = self._generate_random_image_paths()
+ img = Image.new("RGB", (512, 512), (255, 255, 255))
+
+ image_path = os.path.join(tempfile.gettempdir(), "diffusion_benchmark_random_image.png")
+ self._random_image_path = [image_path]
+ img.save(image_path)
else:
self._random_image_path = None
@@ -616,18 +619,6 @@ def __getitem__(self, idx: int) -> RequestFuncInput:
def get_requests(self) -> list[RequestFuncInput]:
return [self[i] for i in range(len(self))]
- def _generate_random_image_paths(self) -> list[str]:
- image_paths: list[str] = []
- for image_idx in range(self.num_input_images):
- img = Image.new("RGB", (512, 512), (255, 255, 255))
- image_path = os.path.join(
- tempfile.gettempdir(),
- f"diffusion_benchmark_random_image_{image_idx}.png",
- )
- img.save(image_path)
- image_paths.append(image_path)
- return image_paths
-
def _compute_expected_latency_ms_from_base(req: RequestFuncInput, args, base_time_ms: float | None) -> float | None:
"""Compute expected execution time (ms) based on a base per-step-per-frame unit time.
@@ -1124,15 +1115,6 @@ async def limited_request_func(req, session, pbar):
'{"width":768,"height":768,"num_inference_steps":20,"weight":0.85}]'
),
)
- parser.add_argument(
- "--num-input-images",
- type=int,
- default=1,
- help=(
- "Number of synthetic input images to attach for image-conditioned tasks "
- "(i2v, ti2v, ti2i, i2i) when using random dataset."
- ),
- )
args = parser.parse_args()
diff --git a/benchmarks/fish-speech/bench_voice_cache.py b/benchmarks/fish-speech/bench_voice_cache.py
deleted file mode 100644
index 8d465d6489f..00000000000
--- a/benchmarks/fish-speech/bench_voice_cache.py
+++ /dev/null
@@ -1,290 +0,0 @@
-"""Benchmark Fish Speech voice cache: inline ref_audio vs uploaded voice.
-
-Measures TTFP improvement from DAC-code caching when using uploaded voices.
-
-Setup:
- 1. Start vllm-omni with Fish Speech S2 Pro (use our feat branch)
- 2. Provide a reference audio file for voice cloning
-
-Usage:
- python bench_voice_cache.py \
- --ref-audio /path/to/reference.wav \
- --ref-text "Transcript of the reference audio." \
- --num-prompts 20 \
- --port 8091
-
-The script runs two rounds:
- A) Inline ref_audio: every request sends base64 audio (no cache)
- B) Uploaded voice: upload once, then use voice name (cache hits after 1st)
-"""
-
-import argparse
-import asyncio
-import base64
-import json
-import os
-import sys
-import time
-from pathlib import Path
-
-import aiohttp
-
-# Allow imports from benchmarks/fish-speech/
-sys.path.insert(0, str(Path(__file__).resolve().parent))
-
-from fish_bench_utils import ( # noqa: E402
- BenchmarkResult,
- RequestResult,
- compute_stats,
- print_benchmark_results,
- send_streaming_request,
-)
-
-SAMPLE_RATE = 44100
-SAMPLE_WIDTH = 2
-
-PROMPTS = [
- "Hello, welcome to the voice synthesis benchmark test.",
- "She said she would be here by noon, but nobody showed up.",
- "The quick brown fox jumps over the lazy dog near the riverbank.",
- "I can't believe how beautiful the sunset looks from up here.",
- "Please remember to bring your identification documents tomorrow morning.",
- "Have you ever wondered what it would be like to travel through time?",
- "The restaurant on the corner serves the best pasta I have ever tasted.",
- "After the meeting, we should discuss the quarterly results.",
- "Learning a new language takes patience and genuine curiosity.",
- "The train leaves at half past seven, so we need to arrive early.",
- "Could you please turn down the music, I'm trying to concentrate.",
- "It was a dark and stormy night when the keeper heard a knock.",
-]
-
-
-def encode_audio_to_base64(audio_path: str) -> str:
- """Encode a local audio file to base64 data URL."""
- ext = audio_path.lower().rsplit(".", 1)[-1]
- mime_map = {"wav": "audio/wav", "mp3": "audio/mpeg", "flac": "audio/flac"}
- mime_type = mime_map.get(ext, "audio/wav")
- with open(audio_path, "rb") as f:
- audio_b64 = base64.b64encode(f.read()).decode("utf-8")
- return f"data:{mime_type};base64,{audio_b64}"
-
-
-async def upload_voice(
- host: str,
- port: int,
- audio_path: str,
- ref_text: str,
- voice_name: str = "bench_voice",
-) -> dict:
- """Upload a voice via POST /v1/audio/voices."""
- url = f"http://{host}:{port}/v1/audio/voices"
- data = aiohttp.FormData()
- data.add_field("name", voice_name)
- data.add_field("consent", "true")
- if ref_text:
- data.add_field("ref_text", ref_text)
- data.add_field(
- "audio_sample",
- open(audio_path, "rb"),
- filename=os.path.basename(audio_path),
- content_type="audio/wav",
- )
-
- async with aiohttp.ClientSession() as session:
- async with session.post(url, data=data) as resp:
- result = await resp.json()
- print(f" Upload response ({resp.status}): {json.dumps(result, indent=2)}")
- return result
-
-
-async def delete_voice(host: str, port: int, voice_name: str) -> None:
- """Delete an uploaded voice."""
- url = f"http://{host}:{port}/v1/audio/voices/{voice_name}"
- async with aiohttp.ClientSession() as session:
- async with session.delete(url) as resp:
- if resp.status == 200:
- print(f" Deleted voice '{voice_name}'")
-
-
-async def run_round(
- host: str,
- port: int,
- num_prompts: int,
- create_payload_fn,
- label: str,
- num_warmups: int = 2,
- timeout_s: float = 120.0,
-) -> BenchmarkResult:
- """Run one benchmark round and return results."""
- api_url = f"http://{host}:{port}/v1/audio/speech"
- connector = aiohttp.TCPConnector(limit=1, limit_per_host=1)
- session = aiohttp.ClientSession(
- connector=connector,
- timeout=aiohttp.ClientTimeout(total=timeout_s),
- )
-
- try:
- # Warmup.
- if num_warmups > 0:
- print(f" [{label}] Warming up ({num_warmups} requests)...")
- for i in range(num_warmups):
- payload = create_payload_fn(PROMPTS[i % len(PROMPTS)])
- r = await send_streaming_request(
- session,
- api_url,
- payload,
- SAMPLE_RATE,
- SAMPLE_WIDTH,
- )
- status = "OK" if r.success else f"FAIL: {r.error[:80]}"
- print(f" warmup {i + 1}: ttfp={r.ttfp * 1000:.0f}ms {status}")
-
- # Benchmark.
- print(f" [{label}] Running {num_prompts} requests (concurrency=1)...")
- results: list[RequestResult] = []
- start = time.perf_counter()
- for i in range(num_prompts):
- prompt = PROMPTS[i % len(PROMPTS)]
- payload = create_payload_fn(prompt)
- r = await send_streaming_request(
- session,
- api_url,
- payload,
- SAMPLE_RATE,
- SAMPLE_WIDTH,
- )
- results.append(r)
- tag = "HIT" if i > 0 and label == "uploaded_voice" else ""
- print(
- f" req {i + 1:3d}: ttfp={r.ttfp * 1000:7.1f}ms "
- f"e2e={r.e2e * 1000:7.1f}ms "
- f"{'OK' if r.success else 'FAIL'} {tag}"
- )
- wall_time = time.perf_counter() - start
- finally:
- await session.close()
-
- bench = compute_stats(results, wall_time)
- bench.concurrency = 1
- bench.num_prompts = num_prompts
- bench.config_name = label
- return bench
-
-
-async def main():
- parser = argparse.ArgumentParser(
- description="Benchmark Fish Speech voice cache (inline vs uploaded)",
- )
- parser.add_argument("--host", default="127.0.0.1")
- parser.add_argument("--port", type=int, default=8091)
- parser.add_argument("--ref-audio", required=True, help="Path to reference audio file")
- parser.add_argument("--ref-text", required=True, help="Transcript of reference audio")
- parser.add_argument("--num-prompts", type=int, default=20)
- parser.add_argument("--num-warmups", type=int, default=2)
- parser.add_argument("--voice-name", default="bench_voice")
- args = parser.parse_args()
-
- if not os.path.exists(args.ref_audio):
- print(f"Error: ref_audio not found: {args.ref_audio}")
- sys.exit(1)
-
- ref_audio_b64 = encode_audio_to_base64(args.ref_audio)
- print(f"Reference audio: {args.ref_audio} ({len(ref_audio_b64) // 1024}KB base64)")
-
- # ---- Round A: Inline ref_audio (no cache) ----
- print(f"\n{'=' * 60}")
- print("Round A: INLINE ref_audio (every request sends full audio)")
- print(f"{'=' * 60}")
-
- def make_inline_payload(prompt: str) -> dict:
- return {
- "input": prompt,
- "voice": "default",
- "stream": True,
- "response_format": "pcm",
- "ref_audio": ref_audio_b64,
- "ref_text": args.ref_text,
- "max_new_tokens": 2048,
- }
-
- bench_inline = await run_round(
- args.host,
- args.port,
- args.num_prompts,
- make_inline_payload,
- "inline_ref_audio",
- num_warmups=args.num_warmups,
- )
- print_benchmark_results(bench_inline)
-
- # ---- Upload voice ----
- print(f"\n{'=' * 60}")
- print("Uploading voice for cache test...")
- print(f"{'=' * 60}")
- await delete_voice(args.host, args.port, args.voice_name)
- await upload_voice(
- args.host,
- args.port,
- args.ref_audio,
- args.ref_text,
- args.voice_name,
- )
-
- # ---- Round B: Uploaded voice (cache hits after 1st request) ----
- print(f"\n{'=' * 60}")
- print("Round B: UPLOADED VOICE (cache hits after 1st request)")
- print(f"{'=' * 60}")
-
- def make_uploaded_payload(prompt: str) -> dict:
- return {
- "input": prompt,
- "voice": args.voice_name,
- "stream": True,
- "response_format": "pcm",
- "ref_text": args.ref_text,
- "max_new_tokens": 2048,
- }
-
- bench_cached = await run_round(
- args.host,
- args.port,
- args.num_prompts,
- make_uploaded_payload,
- "uploaded_voice",
- num_warmups=args.num_warmups,
- )
- print_benchmark_results(bench_cached)
-
- # ---- Comparison ----
- print(f"\n{'=' * 60}")
- print("COMPARISON: Inline ref_audio vs Uploaded voice (cached)")
- print(f"{'=' * 60}")
- print(f"{'Metric':<30} {'Inline':>12} {'Cached':>12} {'Speedup':>10}")
- print(f"{'-' * 64}")
-
- def fmt_speedup(inline_val: float, cached_val: float) -> str:
- if cached_val > 0 and inline_val > 0:
- ratio = inline_val / cached_val
- return f"{ratio:.2f}x"
- return "N/A"
-
- rows = [
- ("Mean TTFP (ms)", bench_inline.mean_ttfp_ms, bench_cached.mean_ttfp_ms),
- ("Median TTFP (ms)", bench_inline.median_ttfp_ms, bench_cached.median_ttfp_ms),
- ("P99 TTFP (ms)", bench_inline.p99_ttfp_ms, bench_cached.p99_ttfp_ms),
- ("Mean E2E (ms)", bench_inline.mean_e2e_ms, bench_cached.mean_e2e_ms),
- ("Median E2E (ms)", bench_inline.median_e2e_ms, bench_cached.median_e2e_ms),
- ("Mean RTF", bench_inline.mean_rtf, bench_cached.mean_rtf),
- ]
- for label, a, b in rows:
- print(f"{label:<30} {a:>12.1f} {b:>12.1f} {fmt_speedup(a, b):>10}")
-
- print("\nNote: Round B request #1 is a cache MISS (cold start).")
- print(" Requests #2+ are cache HITs (skip DAC encoding).")
-
- # Cleanup.
- await delete_voice(args.host, args.port, args.voice_name)
-
-
-if __name__ == "__main__":
- asyncio.run(main())
diff --git a/benchmarks/fish-speech/fish_bench_utils.py b/benchmarks/fish-speech/fish_bench_utils.py
deleted file mode 100644
index cc84c4037fe..00000000000
--- a/benchmarks/fish-speech/fish_bench_utils.py
+++ /dev/null
@@ -1,501 +0,0 @@
-"""Shared benchmark infrastructure for Fish Speech serving benchmarks.
-
-Provides common dataclasses, metrics computation, streaming HTTP client,
-and result formatting used by model-specific benchmark scripts.
-
-Model-specific scripts supply a ``create_payload_fn(prompt) -> dict``
-callback and audio parameters; everything else is handled here.
-"""
-
-import asyncio
-import base64
-import json
-import time
-from collections.abc import Callable
-from dataclasses import asdict, dataclass, field
-from datetime import datetime
-from pathlib import Path
-
-import aiohttp
-import numpy as np
-from tqdm.asyncio import tqdm
-
-# ---------------------------------------------------------------------------
-# Shared test prompts (varying length for realistic workload)
-# ---------------------------------------------------------------------------
-PROMPTS = [
- "Hello, welcome to the voice synthesis benchmark test.",
- "She said she would be here by noon, but nobody showed up.",
- "The quick brown fox jumps over the lazy dog near the riverbank.",
- "I can't believe how beautiful the sunset looks from up here on the mountain.",
- "Please remember to bring your identification documents to the appointment tomorrow morning.",
- "Have you ever wondered what it would be like to travel through time and visit ancient civilizations?",
- "The restaurant on the corner serves the best pasta I have ever tasted in my entire life.",
- "After the meeting, we should discuss the quarterly results and plan for the next phase.",
- "Learning a new language takes patience, practice, and a genuine curiosity about other cultures.",
- "The train leaves at half past seven, so we need to arrive at the station before then.",
- "Could you please turn down the music a little bit, I'm trying to concentrate on my work.",
- "It was a dark and stormy night when the old lighthouse keeper heard a knock at the door.",
-]
-
-
-# ---------------------------------------------------------------------------
-# Dataclasses
-# ---------------------------------------------------------------------------
-@dataclass
-class RequestResult:
- success: bool = False
- ttfp: float = 0.0 # Time to first audio packet (seconds)
- e2e: float = 0.0 # End-to-end latency (seconds)
- audio_bytes: int = 0 # Total audio bytes received
- audio_duration: float = 0.0 # Audio duration in seconds
- rtf: float = 0.0 # Real-time factor = e2e / audio_duration
- prompt: str = ""
- error: str = ""
-
-
-@dataclass
-class BenchmarkResult:
- config_name: str = ""
- concurrency: int = 0
- num_prompts: int = 0
- completed: int = 0
- failed: int = 0
- duration_s: float = 0.0
- # TTFP stats (ms)
- mean_ttfp_ms: float = 0.0
- median_ttfp_ms: float = 0.0
- std_ttfp_ms: float = 0.0
- p90_ttfp_ms: float = 0.0
- p95_ttfp_ms: float = 0.0
- p99_ttfp_ms: float = 0.0
- # E2E stats (ms)
- mean_e2e_ms: float = 0.0
- median_e2e_ms: float = 0.0
- std_e2e_ms: float = 0.0
- p90_e2e_ms: float = 0.0
- p95_e2e_ms: float = 0.0
- p99_e2e_ms: float = 0.0
- # RTF stats
- mean_rtf: float = 0.0
- median_rtf: float = 0.0
- std_rtf: float = 0.0
- p99_rtf: float = 0.0
- # Audio stats
- mean_audio_duration_s: float = 0.0
- total_audio_duration_s: float = 0.0
- audio_throughput: float = 0.0 # audio_duration / wall_time
- request_throughput: float = 0.0 # requests / second
- # Per-request details
- per_request: list = field(default_factory=list)
-
-
-# ---------------------------------------------------------------------------
-# Audio helpers
-# ---------------------------------------------------------------------------
-def pcm_bytes_to_duration(
- num_bytes: int,
- sample_rate: int = 24000,
- sample_width: int = 2,
-) -> float:
- """Convert raw PCM byte count to duration in seconds."""
- return num_bytes / sample_width / sample_rate
-
-
-def _is_sse_response(response: aiohttp.ClientResponse) -> bool:
- content_type = (response.headers.get("Content-Type") or "").lower()
- return "text/event-stream" in content_type
-
-
-async def _read_raw_audio_stream(
- response: aiohttp.ClientResponse,
- *,
- start_time: float,
-) -> tuple[int, float]:
- first_audio_at = 0.0
- total_bytes = 0
-
- async for chunk in response.content.iter_any():
- if chunk and first_audio_at <= 0:
- first_audio_at = time.perf_counter() - start_time
- total_bytes += len(chunk)
-
- return total_bytes, first_audio_at
-
-
-def _extract_sse_payload(raw_event: bytes) -> bytes | None:
- data_lines: list[bytes] = []
- for raw_line in raw_event.splitlines():
- line = raw_line.rstrip(b"\r")
- if line.startswith(b"data: "):
- data_lines.append(line[6:])
- elif line.startswith(b"data:"):
- data_lines.append(line[5:].lstrip())
-
- if not data_lines:
- return None
- return b"\n".join(data_lines).strip()
-
-
-async def _read_sse_audio_stream(
- response: aiohttp.ClientResponse,
- *,
- start_time: float,
-) -> tuple[int, float]:
- """Decode SSE events and count raw audio bytes from base64 payloads."""
- first_audio_at = 0.0
- total_bytes = 0
- pending = b""
-
- async for chunk in response.content.iter_any():
- if not chunk:
- continue
- pending += chunk
- pending = pending.replace(b"\r\n", b"\n")
-
- while b"\n\n" in pending:
- raw_event, pending = pending.split(b"\n\n", 1)
- payload_bytes = _extract_sse_payload(raw_event)
- if payload_bytes is None:
- continue
- if payload_bytes == b"[DONE]":
- return total_bytes, first_audio_at
-
- try:
- payload = json.loads(payload_bytes)
- except json.JSONDecodeError as exc:
- raise ValueError(f"Invalid SSE JSON payload: {exc}") from exc
-
- audio = payload.get("audio")
- if not isinstance(audio, dict):
- continue
-
- audio_b64 = audio.get("data")
- if not audio_b64:
- continue
-
- try:
- audio_bytes = base64.b64decode(audio_b64)
- except Exception as exc:
- raise ValueError(f"Invalid base64 audio chunk: {exc}") from exc
-
- if audio_bytes and first_audio_at <= 0:
- first_audio_at = time.perf_counter() - start_time
- total_bytes += len(audio_bytes)
-
- return total_bytes, first_audio_at
-
-
-# ---------------------------------------------------------------------------
-# Metrics
-# ---------------------------------------------------------------------------
-def compute_stats(
- results: list[RequestResult],
- wall_time: float,
-) -> BenchmarkResult:
- """Compute aggregate statistics from per-request results."""
- successful = [r for r in results if r.success]
- failed = [r for r in results if not r.success]
-
- bench = BenchmarkResult(
- completed=len(successful),
- failed=len(failed),
- duration_s=wall_time,
- )
-
- if not successful:
- return bench
-
- ttfps = [r.ttfp * 1000 for r in successful]
- e2es = [r.e2e * 1000 for r in successful]
- rtfs = [r.rtf for r in successful]
- audio_durs = [r.audio_duration for r in successful]
-
- bench.mean_ttfp_ms = float(np.mean(ttfps))
- bench.median_ttfp_ms = float(np.median(ttfps))
- bench.std_ttfp_ms = float(np.std(ttfps))
- bench.p90_ttfp_ms = float(np.percentile(ttfps, 90))
- bench.p95_ttfp_ms = float(np.percentile(ttfps, 95))
- bench.p99_ttfp_ms = float(np.percentile(ttfps, 99))
-
- bench.mean_e2e_ms = float(np.mean(e2es))
- bench.median_e2e_ms = float(np.median(e2es))
- bench.std_e2e_ms = float(np.std(e2es))
- bench.p90_e2e_ms = float(np.percentile(e2es, 90))
- bench.p95_e2e_ms = float(np.percentile(e2es, 95))
- bench.p99_e2e_ms = float(np.percentile(e2es, 99))
-
- bench.mean_rtf = float(np.mean(rtfs))
- bench.median_rtf = float(np.median(rtfs))
- bench.std_rtf = float(np.std(rtfs))
- bench.p99_rtf = float(np.percentile(rtfs, 99))
-
- bench.mean_audio_duration_s = float(np.mean(audio_durs))
- bench.total_audio_duration_s = float(np.sum(audio_durs))
- bench.audio_throughput = bench.total_audio_duration_s / wall_time
- bench.request_throughput = len(successful) / wall_time
-
- bench.per_request = [
- {
- "ttfp_ms": r.ttfp * 1000,
- "e2e_ms": r.e2e * 1000,
- "rtf": r.rtf,
- "audio_duration_s": r.audio_duration,
- "prompt": r.prompt,
- }
- for r in successful
- ]
-
- return bench
-
-
-# ---------------------------------------------------------------------------
-# Output formatting
-# ---------------------------------------------------------------------------
-def print_benchmark_results(bench: BenchmarkResult) -> None:
- """Print benchmark results in standardized format."""
- W = 50
- print("")
- print(f"{'=' * W}")
- print(f"{'Serving Benchmark Result':^{W}}")
- print(f"{'=' * W}")
- print(f"{'Successful requests:':<40}{bench.completed:<10}")
- print(f"{'Failed requests:':<40}{bench.failed:<10}")
- print(f"{'Maximum request concurrency:':<40}{bench.concurrency:<10}")
- print(f"{'Benchmark duration (s):':<40}{bench.duration_s:<10.2f}")
- print(f"{'Request throughput (req/s):':<40}{bench.request_throughput:<10.2f}")
- print(f"{'-' * W}")
- print(f"{'End-to-end Latency':^{W}}")
- print(f"{'-' * W}")
- print(f"{'Mean E2EL (ms):':<40}{bench.mean_e2e_ms:<10.2f}")
- print(f"{'Median E2EL (ms):':<40}{bench.median_e2e_ms:<10.2f}")
- print(f"{'P99 E2EL (ms):':<40}{bench.p99_e2e_ms:<10.2f}")
- print(f"{'=' * W}")
- print(f"{'Audio Result':^{W}}")
- print(f"{'=' * W}")
- print(f"{'Total audio duration generated (s):':<40}{bench.total_audio_duration_s:<10.2f}")
- print(f"{'Audio throughput (audio duration/s):':<40}{bench.audio_throughput:<10.2f}")
- print(f"{'-' * W}")
- print(f"{'Time to First Packet':^{W}}")
- print(f"{'-' * W}")
- print(f"{'Mean AUDIO_TTFP (ms):':<40}{bench.mean_ttfp_ms:<10.2f}")
- print(f"{'Median AUDIO_TTFP (ms):':<40}{bench.median_ttfp_ms:<10.2f}")
- print(f"{'P99 AUDIO_TTFP (ms):':<40}{bench.p99_ttfp_ms:<10.2f}")
- print(f"{'-' * W}")
- print(f"{'Real Time Factor':^{W}}")
- print(f"{'-' * W}")
- print(f"{'Mean AUDIO_RTF:':<40}{bench.mean_rtf:<10.3f}")
- print(f"{'Median AUDIO_RTF:':<40}{bench.median_rtf:<10.3f}")
- print(f"{'P99 AUDIO_RTF:':<40}{bench.p99_rtf:<10.3f}")
- print(f"{'=' * W}")
- print("")
-
-
-def save_results(
- all_results: list[dict],
- result_dir: str,
- config_name: str,
-) -> Path:
- """Save benchmark results as JSON and return the file path."""
- out = Path(result_dir)
- out.mkdir(parents=True, exist_ok=True)
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- result_file = out / f"bench_{config_name}_{timestamp}.json"
-
- with open(result_file, "w") as f:
- json.dump(all_results, f, indent=2)
- print(f"Results saved to {result_file}")
- return result_file
-
-
-# ---------------------------------------------------------------------------
-# Streaming HTTP client
-# ---------------------------------------------------------------------------
-async def send_streaming_request(
- session: aiohttp.ClientSession,
- api_url: str,
- payload: dict,
- sample_rate: int,
- sample_width: int,
- pbar: tqdm | None = None,
-) -> RequestResult:
- """Send a streaming TTS request and measure latency metrics."""
- result = RequestResult(prompt=payload.get("input", ""))
- st = time.perf_counter()
-
- try:
- async with session.post(api_url, json=payload) as response:
- if response.status != 200:
- result.error = f"HTTP {response.status}: {await response.text()}"
- else:
- if _is_sse_response(response):
- total_bytes, result.ttfp = await _read_sse_audio_stream(
- response,
- start_time=st,
- )
- else:
- total_bytes, result.ttfp = await _read_raw_audio_stream(
- response,
- start_time=st,
- )
-
- result.e2e = time.perf_counter() - st
- result.audio_bytes = total_bytes
- result.audio_duration = pcm_bytes_to_duration(total_bytes, sample_rate, sample_width)
-
- if total_bytes <= 0 or result.ttfp <= 0:
- result.error = "HTTP 200 but no audio bytes were received"
- else:
- if result.audio_duration > 0:
- result.rtf = result.e2e / result.audio_duration
- result.success = True
-
- except Exception as e:
- result.error = str(e)
- result.e2e = time.perf_counter() - st
-
- finally:
- if pbar:
- pbar.update(1)
- return result
-
-
-# ---------------------------------------------------------------------------
-# Benchmark runner
-# ---------------------------------------------------------------------------
-async def run_benchmark(
- host: str,
- port: int,
- num_prompts: int,
- max_concurrency: int,
- create_payload_fn: Callable[[str], dict],
- sample_rate: int,
- sample_width: int = 2,
- num_warmups: int = 3,
- request_timeout_s: float = 120.0,
-) -> BenchmarkResult:
- """Run a TTS streaming benchmark at a given concurrency level.
-
- Args:
- create_payload_fn: Model-specific function that takes a prompt string
- and returns the request JSON payload dict.
- sample_rate: PCM sample rate for audio duration calculation.
- sample_width: PCM sample width in bytes (default 2 for 16-bit).
- """
- api_url = f"http://{host}:{port}/v1/audio/speech"
-
- connector = aiohttp.TCPConnector(
- limit=max_concurrency,
- limit_per_host=max_concurrency,
- keepalive_timeout=60,
- )
- session = aiohttp.ClientSession(
- connector=connector,
- timeout=aiohttp.ClientTimeout(
- total=request_timeout_s,
- connect=min(10.0, request_timeout_s),
- sock_connect=min(10.0, request_timeout_s),
- sock_read=request_timeout_s,
- ),
- )
-
- try:
- # Warmup
- if num_warmups > 0:
- print(f" Warming up with {num_warmups} requests...")
- warmup_tasks = [
- send_streaming_request(
- session,
- api_url,
- create_payload_fn(PROMPTS[i % len(PROMPTS)]),
- sample_rate,
- sample_width,
- )
- for i in range(num_warmups)
- ]
- warmup_results = await asyncio.gather(*warmup_tasks)
- warmup_ok = sum(1 for r in warmup_results if r.success)
- if warmup_ok == 0:
- print(" WARNING: All warmup requests failed!")
- for r in warmup_results:
- if r.error:
- print(f" {r.error[:200]}")
- print(f" Warmup done ({warmup_ok}/{num_warmups} succeeded).")
-
- # Build request list
- request_prompts = [PROMPTS[i % len(PROMPTS)] for i in range(num_prompts)]
-
- # Run
- print(f" Running {num_prompts} requests with concurrency={max_concurrency}...")
- semaphore = asyncio.Semaphore(max_concurrency)
- pbar = tqdm(total=num_prompts, desc=f" concurrency={max_concurrency}")
-
- async def limited_request(prompt: str) -> RequestResult:
- async with semaphore:
- return await send_streaming_request(
- session,
- api_url,
- create_payload_fn(prompt),
- sample_rate,
- sample_width,
- pbar,
- )
-
- start_time = time.perf_counter()
- tasks = [asyncio.create_task(limited_request(p)) for p in request_prompts]
- results: list[RequestResult] = await asyncio.gather(*tasks)
- wall_time = time.perf_counter() - start_time
- pbar.close()
-
- finally:
- await session.close()
-
- # Compute stats
- bench = compute_stats(results, wall_time)
- bench.concurrency = max_concurrency
- bench.num_prompts = num_prompts
-
- print_benchmark_results(bench)
-
- # Print sample errors
- failed = [r for r in results if not r.success]
- if failed:
- for r in failed[:3]:
- print(f" [ERROR] {r.error[:200]}")
-
- return bench
-
-
-async def run_benchmark_sweep(
- host: str,
- port: int,
- num_prompts: int,
- concurrency_levels: list[int],
- create_payload_fn: Callable[[str], dict],
- sample_rate: int,
- sample_width: int = 2,
- num_warmups: int = 3,
- request_timeout_s: float = 120.0,
- config_name: str = "benchmark",
- result_dir: str = "results",
-) -> list[dict]:
- """Run benchmarks across multiple concurrency levels and save results."""
- all_results = []
-
- for concurrency in concurrency_levels:
- result = await run_benchmark(
- host=host,
- port=port,
- num_prompts=num_prompts,
- max_concurrency=concurrency,
- create_payload_fn=create_payload_fn,
- sample_rate=sample_rate,
- sample_width=sample_width,
- num_warmups=num_warmups,
- request_timeout_s=request_timeout_s,
- )
- result.config_name = config_name
- all_results.append(asdict(result))
-
- save_results(all_results, result_dir, config_name)
- return all_results
diff --git a/benchmarks/glm_image/README.md b/benchmarks/glm_image/README.md
deleted file mode 100644
index 485e081426f..00000000000
--- a/benchmarks/glm_image/README.md
+++ /dev/null
@@ -1,157 +0,0 @@
-# GLM-Image Benchmarks
-
-Benchmark GLM-Image T2I (text-to-image) and I2I (image-to-image) performance across three backends: HuggingFace baseline, vLLM-Omni offline, and vLLM-Omni online serving.
-
-## Benchmarks
-
-| Benchmark | Script | Description |
-|-----------|--------|-------------|
-| HuggingFace Baseline | `huggingface/inference.py` | Single-GPU transformers + diffusers pipeline |
-| vLLM-Omni Offline | `vllm-omni/inference.py` | Offline inference with continuous batching |
-| vLLM-Omni Online | `benchmark_glm_image.py` | Online serving via `/v1/chat/completions` |
-
-## HuggingFace Baseline
-
-Single-request sequential inference using the reference HuggingFace pipeline.
-
-```bash
-# T2I
-CUDA_VISIBLE_DEVICES=0 python benchmarks/glm_image/huggingface/inference.py \
- --model-path /path/to/GLM-Image --mode t2i --num-prompts 10
-
-# I2I
-CUDA_VISIBLE_DEVICES=0 python benchmarks/glm_image/huggingface/inference.py \
- --model-path /path/to/GLM-Image --mode i2i --num-prompts 10
-```
-
-### Options
-
-| Flag | Default | Description |
-|------|---------|-------------|
-| `--model-path` | `zai-org/GLM-Image` | Model path |
-| `--mode` | `t2i` | `t2i` or `i2i` |
-| `--dataset-path` | `prompt/prompt.json` | Path to prompt.json |
-| `--num-prompts` | `10` | Number of images to generate |
-| `--width` / `--height` | `1024` | Output image size |
-| `--num-inference-steps` | `50` | Diffusion denoising steps |
-| `--output-dir` | `benchmarks/glm_image/huggingface/outputs` | Output directory |
-| `--output-file` | - | JSON file for metrics |
-
-## vLLM-Omni Offline
-
-Multi-GPU offline inference with pipeline parallelism and continuous batching.
-
-```bash
-# T2I
-CUDA_VISIBLE_DEVICES=0,1 python benchmarks/glm_image/vllm-omni/inference.py \
- --model-path /path/to/GLM-Image --mode t2i --num-prompts 10
-
-# I2I
-CUDA_VISIBLE_DEVICES=0,1 python benchmarks/glm_image/vllm-omni/inference.py \
- --model-path /path/to/GLM-Image --mode i2i --num-prompts 10
-```
-
-### Options
-
-| Flag | Default | Description |
-|------|---------|-------------|
-| `--model-path` | `zai-org/GLM-Image` | Model path |
-| `--deploy-config` | - | Deploy config YAML |
-| `--mode` | `t2i` | `t2i` or `i2i` |
-| `--dataset-path` | `prompt/prompt.json` | Path to prompt.json |
-| `--num-prompts` | `10` | Number of images to generate |
-| `--width` / `--height` | `1024` | Output image size |
-| `--num-inference-steps` | `50` | Diffusion denoising steps |
-| `--output-dir` | `benchmarks/glm_image/vllm-omni/outputs` | Output directory |
-| `--output-file` | - | JSON file for metrics |
-| `--stage-init-timeout` | `600` | Stage initialization timeout (s) |
-
-### Latency Computation
-
-In offline mode all requests are submitted simultaneously and processed with continuous batching. The per-request latency is computed by summing the actual per-stage times (with `stage_0_gen_ms` diffed against the previous request to remove accumulated queue/scheduling wait).
-
-## vLLM-Omni Online Serving
-
-### Start the server
-
-```bash
-CUDA_VISIBLE_DEVICES=0,1 vllm serve /path/to/GLM-Image \
- --omni --port 8091 --host 0.0.0.0 \
- --served-model-name glm-image
-```
-
-### Run the benchmark
-
-```bash
-# T2I
-python benchmarks/glm_image/benchmark_glm_image.py \
- --mode t2i --num-prompts 10 --model glm-image
-
-# I2I
-python benchmarks/glm_image/benchmark_glm_image.py \
- --mode i2i --num-prompts 10 --model glm-image
-
-# Custom dataset
-python benchmarks/glm_image/benchmark_glm_image.py \
- --mode i2i --dataset custom \
- --dataset-path prompts.json --num-prompts 5
-```
-
-### Options
-
-| Flag | Default | Description |
-|------|---------|-------------|
-| `--mode` | `t2i` | `t2i` or `i2i` |
-| `--dataset` | `prompt` | `prompt`, `random`, or `custom` |
-| `--dataset-path` | - | JSON file path (required for `custom`) |
-| `--num-prompts` | `10` | Number of benchmark requests |
-| `--max-concurrency` | `1` | Max concurrent requests |
-| `--request-rate` | `inf` | Requests per second (Poisson arrival) |
-| `--warmup-requests` | `1` | Warmup requests before measurement |
-| `--width` / `--height` | `1024` | Output image size |
-| `--num-inference-steps` | `50` | Diffusion denoising steps |
-| `--seed` | - | Random seed |
-| `--model` | `default` | Model name (must match `--served-model-name`) |
-| `--host` | `localhost` | Server host |
-| `--port` | `8091` | Server port |
-| `--output-file` | - | JSON output file for metrics |
-| `--num-input-images` | `1` | Number of input images for random I2I |
-
-## Dataset
-
-The default dataset is hosted on [HuggingFace](https://huggingface.co/datasets/JaredforReal/glm-image-bench) (`prompt.json`). It is automatically downloaded and cached to `prompt/prompt.json` on first run. No manual setup needed.
-
-Each entry contains:
-
-- `t2i_prompt`: Text prompt for text-to-image generation
-- `i2i_prompt`: Text prompt for image-to-image editing
-- `image_url`: Source image URL for I2I (downloaded and cached on first use)
-
-Custom datasets use the same JSON format and can be provided via `--dataset-path`.
-
-## Pipeline Timings
-
-All three benchmarks report per-stage pipeline timings (in milliseconds):
-
-| Key | Description |
-|-----|-------------|
-| `preprocess_ms` | Input preprocessing (tokenization, multimodal encoding) |
-| `stage_0_gen_ms` | AR (autoregressive) model generation time |
-| `ar2diffusion_ms` | AR output to diffusion input conversion |
-| `stage_1_gen_ms` | Diffusion model denoising time |
-| `queue_wait_ms` | Queue wait time before processing |
-
-The stages are ordered by execution: `preprocess → stage_0 (AR) → ar2diffusion → stage_1 (Diffusion)`.
-
-## Sample Results
-
-Tested on 2x GPU with 10 prompts, 1024x1024, 50 denoising steps:
-
-| Backend | Mode | Latency Mean (s) | Throughput (img/s) |
-|---------|------|-------------------|--------------------|
-| HuggingFace | T2I | 72.6 | 0.014 |
-| HuggingFace | I2I | 70.9 | 0.014 |
-| vLLM-Omni Offline | T2I | 35.0 | 0.044 |
-| vLLM-Omni Offline | I2I | 31.0 | 0.053 |
-| vLLM-Omni Online | T2I | 38.8 | 0.026 |
-| vLLM-Omni Online | I2I | 34.7 | 0.029 |
diff --git a/benchmarks/glm_image/__init__.py b/benchmarks/glm_image/__init__.py
deleted file mode 100644
index e69de29bb2d..00000000000
diff --git a/benchmarks/glm_image/benchmark_glm_image.py b/benchmarks/glm_image/benchmark_glm_image.py
deleted file mode 100644
index 9f8df3f1986..00000000000
--- a/benchmarks/glm_image/benchmark_glm_image.py
+++ /dev/null
@@ -1,464 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""
-Online serving benchmark for GLM-Image (T2I and I2I modes).
-
-Sends requests to the /v1/chat/completions endpoint and reports end-to-end
-latency, throughput, and per-stage durations (when the server is started with
---enable-diffusion-pipeline-profiler and/or --enable-ar-profiler).
-
-Supports three dataset types:
- - prompt: Use prompt.json (default). T2I uses t2i_prompt, I2I uses i2i_prompt
- and sends source images from image_url.
- - random: Generate synthetic prompts (and random images for I2I).
- - custom: Load from a user-specified JSON file.
-
-Usage:
- # T2I with prompt.json (default)
- python benchmarks/glm_image/benchmark_glm_image.py \
- --mode t2i --num-prompts 10
-
- # I2I with prompt.json (downloads source images automatically)
- python benchmarks/glm_image/benchmark_glm_image.py \
- --mode i2i --num-prompts 10
-
- # Random dataset
- python benchmarks/glm_image/benchmark_glm_image.py \
- --mode t2i --dataset random --num-prompts 20
-
- # Custom dataset
- python benchmarks/glm_image/benchmark_glm_image.py \
- --mode i2i --dataset custom \
- --dataset-path my_prompts.json --num-prompts 5
-"""
-
-import argparse
-import asyncio
-import base64
-import json
-import os
-import sys
-import tempfile
-import time
-from dataclasses import dataclass
-from pathlib import Path
-from typing import Any
-
-import aiohttp
-import numpy as np
-import requests as sync_requests
-from PIL import Image
-from tqdm.asyncio import tqdm
-
-# Import backends from the diffusion benchmark (add parent dirs to path)
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "diffusion"))
-from backends import RequestFuncOutput
-
-BENCHMARK_DIR = Path(__file__).resolve().parent
-DEFAULT_PROMPT_JSON = BENCHMARK_DIR / "prompt" / "prompt.json"
-IMAGE_CACHE_DIR = BENCHMARK_DIR / "prompt" / "images"
-
-DATASET_REPO = "JaredforReal/glm-image-bench"
-DATASET_FILE = "prompt.json"
-
-
-def _ensure_prompt_json(dataset_path: str | None) -> str:
- """Return path to prompt.json, downloading from HuggingFace if needed."""
- if dataset_path:
- return dataset_path
- local = DEFAULT_PROMPT_JSON
- if local.exists():
- return str(local)
- print(f"Downloading {DATASET_FILE} from {DATASET_REPO} ...")
- try:
- from huggingface_hub import hf_hub_download
-
- downloaded = hf_hub_download(
- repo_id=DATASET_REPO,
- filename=DATASET_FILE,
- repo_type="dataset",
- )
- local.parent.mkdir(parents=True, exist_ok=True)
- import shutil
-
- shutil.copy2(downloaded, local)
- print(f"Saved to {local}")
- except ImportError:
- url = f"https://huggingface.co/datasets/{DATASET_REPO}/resolve/main/{DATASET_FILE}"
- import urllib.request
-
- local.parent.mkdir(parents=True, exist_ok=True)
- urllib.request.urlretrieve(url, local)
- print(f"Saved to {local}")
- return str(local)
-
-
-# ---------------------------------------------------------------------------
-# Helpers
-# ---------------------------------------------------------------------------
-
-
-@dataclass
-class GLMImageRequest:
- prompt: str
- image_path: str | None = None # Only for I2I mode
-
-
-def download_image(url: str) -> str:
- """Download an image to cache and return the local path."""
- IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
- fname = url.rsplit("/", 1)[-1]
- local_path = IMAGE_CACHE_DIR / fname
- if local_path.exists():
- return str(local_path)
- resp = sync_requests.get(url, timeout=30)
- resp.raise_for_status()
- local_path.write_bytes(resp.content)
- return str(local_path)
-
-
-def encode_image_as_data_url(path: str) -> str:
- """Encode a local image file as a base64 data URL."""
- with open(path, "rb") as f:
- encoded = base64.b64encode(f.read()).decode("utf-8")
- ext = Path(path).suffix.lower()
- mime = {"png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg"}.get(ext, "image/png")
- return f"data:{mime};base64,{encoded}"
-
-
-# ---------------------------------------------------------------------------
-# Datasets
-# ---------------------------------------------------------------------------
-
-
-class PromptDataset:
- """Load from prompt.json. T2I uses t2i_prompt, I2I uses i2i_prompt + image_url."""
-
- def __init__(self, args: argparse.Namespace):
- path = _ensure_prompt_json(args.dataset_path)
- with open(path, encoding="utf-8") as f:
- raw = json.load(f)
-
- prompt_key = "t2i_prompt" if args.mode == "t2i" else "i2i_prompt"
- self.items: list[GLMImageRequest] = []
-
- for entry in raw:
- prompt = entry.get(prompt_key, "").strip()
- if not prompt:
- continue
- image_path = None
- if args.mode == "i2i":
- url = entry.get("image_url", "")
- if url:
- image_path = download_image(url)
- self.items.append(GLMImageRequest(prompt=prompt, image_path=image_path))
-
- if args.num_prompts and len(self.items) > args.num_prompts:
- self.items = self.items[: args.num_prompts]
-
- def __len__(self) -> int:
- return len(self.items)
-
- def __getitem__(self, idx: int) -> GLMImageRequest:
- return self.items[idx]
-
- def get_requests(self) -> list[GLMImageRequest]:
- return list(self.items)
-
-
-class RandomDataset:
- """Generate synthetic prompts (and optional random images for I2I)."""
-
- def __init__(self, args: argparse.Namespace):
- self.args = args
- self.num_prompts = args.num_prompts
- self._random_image_paths: list[str] | None = None
- if args.mode == "i2i":
- self._random_image_paths = self._generate_random_images()
-
- def _generate_random_images(self) -> list[str]:
- paths: list[str] = []
- for i in range(self.args.num_input_images):
- img = Image.new("RGB", (512, 512), (128 + i * 30 % 128, 64, 192))
- path = os.path.join(tempfile.gettempdir(), f"glm_image_bench_input_{i}.png")
- img.save(path)
- paths.append(path)
- return paths
-
- def __len__(self) -> int:
- return self.num_prompts
-
- def __getitem__(self, idx: int) -> GLMImageRequest:
- image_path = None
- if self._random_image_paths is not None:
- image_path = self._random_image_paths[idx % len(self._random_image_paths)]
- return GLMImageRequest(
- prompt=f"A beautiful scene with vivid colors and intricate details, prompt {idx}",
- image_path=image_path,
- )
-
- def get_requests(self) -> list[GLMImageRequest]:
- return [self[i] for i in range(len(self))]
-
-
-class CustomDataset:
- """Load from a user-specified JSON file.
-
- Expected format:
- [
- {"prompt": "A cat sitting on a windowsill"},
- {"prompt": "Make it look like winter", "image_path": "/path/to/img.png"}
- ]
- """
-
- def __init__(self, args: argparse.Namespace):
- if not args.dataset_path:
- raise ValueError("--dataset-path is required for custom dataset")
- with open(args.dataset_path, encoding="utf-8") as f:
- raw = json.load(f)
- self.items: list[GLMImageRequest] = []
- for item in raw:
- self.items.append(
- GLMImageRequest(
- prompt=item.get("prompt", ""),
- image_path=item.get("image_path"),
- )
- )
- if args.num_prompts and len(self.items) > args.num_prompts:
- self.items = self.items[: args.num_prompts]
-
- def __len__(self) -> int:
- return len(self.items)
-
- def __getitem__(self, idx: int) -> GLMImageRequest:
- return self.items[idx]
-
- def get_requests(self) -> list[GLMImageRequest]:
- return list(self.items)
-
-
-# ---------------------------------------------------------------------------
-# Async request for GLM-Image (chat completions with image support)
-# ---------------------------------------------------------------------------
-
-
-async def async_glm_image_request(
- req: GLMImageRequest,
- api_url: str,
- model: str,
- session: aiohttp.ClientSession,
- pbar: Any,
- args: argparse.Namespace,
-) -> RequestFuncOutput:
- """Send a single T2I or I2I request via chat completions endpoint."""
- output = RequestFuncOutput()
- output.start_time = time.perf_counter()
-
- # Build messages
- if req.image_path and args.mode == "i2i":
- data_url = encode_image_as_data_url(req.image_path)
- content = [
- {"type": "text", "text": req.prompt},
- {"type": "image_url", "image_url": {"url": data_url}},
- ]
- else:
- content = req.prompt
-
- messages = [{"role": "user", "content": content}]
-
- extra_body: dict[str, Any] = {}
- if args.height:
- extra_body["height"] = args.height
- if args.width:
- extra_body["width"] = args.width
- if args.num_inference_steps:
- extra_body["num_inference_steps"] = args.num_inference_steps
- if args.seed is not None:
- extra_body["seed"] = args.seed
-
- payload: dict[str, Any] = {
- "model": model,
- "messages": messages,
- }
- if extra_body:
- payload["extra_body"] = extra_body
-
- try:
- async with session.post(api_url, json=payload) as response:
- if response.status == 200:
- resp_json = await response.json()
- output.response_body = resp_json
- output.success = True
- try:
- choices = resp_json.get("choices", [])
- if choices and isinstance(choices, list):
- msg = choices[0].get("message", {})
- if isinstance(msg, dict):
- resp_content = msg.get("content", [])
- if resp_content and isinstance(resp_content, list) and len(resp_content) > 0:
- first_item = resp_content[0]
- if isinstance(first_item, dict):
- output.stage_durations = first_item.get("stage_durations") or {}
- output.peak_memory_mb = first_item.get("peak_memory_mb", 0.0)
- except (IndexError, TypeError, AttributeError):
- pass
- else:
- output.error = f"HTTP {response.status}: {await response.text()}"
- output.success = False
- except Exception as e:
- output.error = str(e)
- output.success = False
-
- output.latency = time.perf_counter() - output.start_time
- if pbar:
- pbar.update(1)
- return output
-
-
-# ---------------------------------------------------------------------------
-# Benchmark
-# ---------------------------------------------------------------------------
-
-
-async def iter_requests(n: int, request_rate: float) -> Any:
- import random as _random
-
- for i in range(n):
- if request_rate != float("inf") and i > 0:
- await asyncio.sleep(_random.expovariate(request_rate))
- yield i
-
-
-def calculate_metrics(outputs: list[RequestFuncOutput], total_duration: float) -> dict[str, Any]:
- success = [o for o in outputs if o.success]
- errors = [o for o in outputs if not o.success]
- latencies = [o.latency for o in success]
- peak_mems = [o.peak_memory_mb for o in success if o.peak_memory_mb > 0]
-
- stage_duration_lists: dict[str, list[float]] = {}
- for o in success:
- for stage, dur in (o.stage_durations or {}).items():
- stage_duration_lists.setdefault(stage, []).append(dur)
-
- return {
- "duration": total_duration,
- "completed_requests": len(success),
- "failed_requests": len(errors),
- "throughput_qps": len(success) / total_duration if total_duration > 0 else 0,
- "latency_mean": float(np.mean(latencies)) if latencies else 0,
- "latency_median": float(np.median(latencies)) if latencies else 0,
- "latency_p99": float(np.percentile(latencies, 99)) if latencies else 0,
- "latency_p95": float(np.percentile(latencies, 95)) if latencies else 0,
- "peak_memory_mb_max": max(peak_mems) if peak_mems else 0,
- "stage_durations_mean": {s: float(np.mean(v)) for s, v in stage_duration_lists.items()},
- "stage_durations_p50": {s: float(np.percentile(v, 50)) for s, v in stage_duration_lists.items()},
- }
-
-
-async def benchmark(args: argparse.Namespace) -> None:
- api_url = f"http://{args.host}:{args.port}/v1/chat/completions"
-
- # Load dataset
- if args.dataset == "prompt":
- dataset = PromptDataset(args)
- elif args.dataset == "random":
- dataset = RandomDataset(args)
- elif args.dataset == "custom":
- dataset = CustomDataset(args)
- else:
- raise ValueError(f"Unknown dataset: {args.dataset}")
-
- glm_requests = dataset.get_requests()
- print(f"Prepared {len(glm_requests)} requests (mode={args.mode}, dataset={args.dataset})")
-
- semaphore = asyncio.Semaphore(args.max_concurrency) if args.max_concurrency else None
-
- async def limited_request(idx: int, req: GLMImageRequest, session: aiohttp.ClientSession, pbar: Any):
- if semaphore:
- async with semaphore:
- return await async_glm_image_request(req, api_url, args.model, session, pbar, args)
- return await async_glm_image_request(req, api_url, args.model, session, pbar, args)
-
- async with aiohttp.ClientSession() as session:
- # Warmup
- if args.warmup_requests and glm_requests:
- print(f"Running {args.warmup_requests} warmup request(s)...")
- for i in range(args.warmup_requests):
- await limited_request(i, glm_requests[i % len(glm_requests)], session, None)
-
- # Main benchmark
- pbar = tqdm(total=len(glm_requests), disable=args.disable_tqdm)
- start_time = time.perf_counter()
- tasks = []
- async for idx in iter_requests(len(glm_requests), args.request_rate):
- tasks.append(asyncio.create_task(limited_request(idx, glm_requests[idx], session, pbar)))
- outputs = await asyncio.gather(*tasks)
- total_duration = time.perf_counter() - start_time
- pbar.close()
-
- # Metrics
- metrics = calculate_metrics(outputs, total_duration)
- metrics["mode"] = args.mode
- metrics["model"] = args.model
- metrics["dataset"] = args.dataset
-
- print(f"\n{' GLM-Image Online Benchmark Result ':=^60}")
- print(f"{'Mode:':<40} {args.mode}")
- print(f"{'Model:':<40} {args.model}")
- print(f"{'Dataset:':<40} {args.dataset}")
- print("-" * 50)
- print(f"{'Benchmark duration (s):':<40} {metrics['duration']:.2f}")
- print(f"{'Request rate:':<40} {args.request_rate}")
- print(f"{'Max concurrency:':<40} {args.max_concurrency}")
- print(f"{'Successful requests:':<40} {metrics['completed_requests']}/{len(glm_requests)}")
- print("-" * 50)
- print(f"{'Throughput (req/s):':<40} {metrics['throughput_qps']:.2f}")
- print(f"{'Latency Mean (s):':<40} {metrics['latency_mean']:.4f}")
- print(f"{'Latency Median (s):':<40} {metrics['latency_median']:.4f}")
- print(f"{'Latency P95 (s):':<40} {metrics['latency_p95']:.4f}")
- print(f"{'Latency P99 (s):':<40} {metrics['latency_p99']:.4f}")
-
- if metrics["peak_memory_mb_max"] > 0:
- print("-" * 50)
- print(f"{'Peak Memory Max (MB):':<40} {metrics['peak_memory_mb_max']:.2f}")
-
- if metrics["stage_durations_mean"]:
- print("-" * 50)
- print("Stage Durations Mean:")
- for stage, val in sorted(metrics["stage_durations_mean"].items()):
- unit = "ms" if stage.endswith("_ms") else "s"
- print(f" {stage + ':':<38} {val:.4f} ({unit})")
-
- print("=" * 60)
-
- if args.output_file:
- with open(args.output_file, "w") as f:
- json.dump(metrics, f, indent=2)
- print(f"Metrics saved to {args.output_file}")
-
-
-def main() -> None:
- parser = argparse.ArgumentParser(description="Benchmark GLM-Image T2I/I2I online serving.")
- parser.add_argument("--mode", type=str, default="t2i", choices=["t2i", "i2i"])
- parser.add_argument("--dataset", type=str, default="prompt", choices=["prompt", "random", "custom"])
- parser.add_argument("--dataset-path", type=str, default=None)
- parser.add_argument("--num-prompts", type=int, default=10)
- parser.add_argument("--max-concurrency", type=int, default=1)
- parser.add_argument("--request-rate", type=float, default=float("inf"))
- parser.add_argument("--warmup-requests", type=int, default=1)
- parser.add_argument("--width", type=int, default=1024)
- parser.add_argument("--height", type=int, default=1024)
- parser.add_argument("--num-inference-steps", type=int, default=50)
- parser.add_argument("--seed", type=int, default=None)
- parser.add_argument("--model", type=str, default="default")
- parser.add_argument("--host", type=str, default="localhost")
- parser.add_argument("--port", type=int, default=8091)
- parser.add_argument("--output-file", type=str, default=None)
- parser.add_argument("--disable-tqdm", action="store_true")
- parser.add_argument("--num-input-images", type=int, default=1, help="For random I2I dataset.")
- args = parser.parse_args()
- asyncio.run(benchmark(args))
-
-
-if __name__ == "__main__":
- main()
diff --git a/benchmarks/glm_image/huggingface/inference.py b/benchmarks/glm_image/huggingface/inference.py
deleted file mode 100644
index ff826080e8c..00000000000
--- a/benchmarks/glm_image/huggingface/inference.py
+++ /dev/null
@@ -1,291 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""
-HuggingFace (transformers + diffusers) baseline benchmark for GLM-Image.
-
-Supports T2I and I2I modes with the prompt.json dataset.
-Downloads source images for I2I from image_url on first run and caches locally.
-
-Usage:
- # T2I mode (text-to-image, no source images needed)
- python benchmarks/glm_image/huggingface/inference.py \
- --model-path zai-org/GLM-Image \
- --mode t2i --num-prompts 10
-
- # I2I mode (image-to-image, downloads source images)
- python benchmarks/glm_image/huggingface/inference.py \
- --model-path zai-org/GLM-Image \
- --mode i2i --num-prompts 10
-
- # With custom prompt.json
- python benchmarks/glm_image/huggingface/inference.py \
- --model-path zai-org/GLM-Image \
- --mode i2i --dataset-path prompts.json --num-prompts 5
-"""
-
-import argparse
-import json
-import os
-import time
-from pathlib import Path
-
-import numpy as np
-import requests
-import torch
-from PIL import Image
-
-BENCHMARK_DIR = Path(__file__).resolve().parent.parent
-DEFAULT_PROMPT_JSON = BENCHMARK_DIR / "prompt" / "prompt.json"
-IMAGE_CACHE_DIR = BENCHMARK_DIR / "prompt" / "images"
-
-DATASET_REPO = "JaredforReal/glm-image-bench"
-DATASET_FILE = "prompt.json"
-
-
-def _ensure_prompt_json(dataset_path: str | None) -> str:
- """Return path to prompt.json, downloading from HuggingFace if needed."""
- if dataset_path:
- return dataset_path
- local = DEFAULT_PROMPT_JSON
- if local.exists():
- return str(local)
- print(f"Downloading {DATASET_FILE} from {DATASET_REPO} ...")
- try:
- from huggingface_hub import hf_hub_download
-
- downloaded = hf_hub_download(
- repo_id=DATASET_REPO,
- filename=DATASET_FILE,
- repo_type="dataset",
- )
- local.parent.mkdir(parents=True, exist_ok=True)
- import shutil
-
- shutil.copy2(downloaded, local)
- print(f"Saved to {local}")
- except ImportError:
- url = f"https://huggingface.co/datasets/{DATASET_REPO}/resolve/main/{DATASET_FILE}"
- import urllib.request
-
- local.parent.mkdir(parents=True, exist_ok=True)
- urllib.request.urlretrieve(url, local)
- print(f"Saved to {local}")
- return str(local)
-
-
-HEIGHT = 1024
-WIDTH = 1024
-SEED = 42
-NUM_INFERENCE_STEPS = 50
-GUIDANCE_SCALE = 1.5
-
-
-# ---------------------------------------------------------------------------
-# Dataset
-# ---------------------------------------------------------------------------
-
-
-def load_dataset(
- dataset_path: str | None,
- mode: str,
- num_prompts: int,
-) -> list[dict]:
- """Load prompts from prompt.json and prepare per-request data."""
- path = _ensure_prompt_json(dataset_path)
- with open(path, encoding="utf-8") as f:
- raw = json.load(f)
-
- items = []
- for entry in raw:
- if mode == "t2i":
- prompt_key = "t2i_prompt"
- else:
- prompt_key = "i2i_prompt"
-
- prompt_text = entry.get(prompt_key, "").strip()
- if not prompt_text:
- continue
-
- item = {"prompt": prompt_text}
- if mode == "i2i":
- item["image_url"] = entry.get("image_url", "")
- items.append(item)
-
- if num_prompts and len(items) > num_prompts:
- items = items[:num_prompts]
- return items
-
-
-def download_image(url: str, cache_dir: Path) -> str:
- """Download an image to cache_dir and return the local path."""
- cache_dir.mkdir(parents=True, exist_ok=True)
- fname = url.rsplit("/", 1)[-1]
- local_path = cache_dir / fname
- if local_path.exists():
- return str(local_path)
- print(f" Downloading {url} ...")
- resp = requests.get(url, timeout=30)
- resp.raise_for_status()
- local_path.write_bytes(resp.content)
- return str(local_path)
-
-
-# ---------------------------------------------------------------------------
-# Benchmark
-# ---------------------------------------------------------------------------
-
-
-def benchmark(args: argparse.Namespace) -> None:
- from diffusers.pipelines.glm_image import GlmImagePipeline
-
- print("=" * 60)
- print("GLM-Image HuggingFace Baseline Benchmark")
- print(f"Mode: {args.mode} | Model: {args.model_path}")
- print(f"Size: {args.height}x{args.width} | Steps: {args.num_inference_steps}")
- print("=" * 60)
-
- # Load dataset
- items = load_dataset(args.dataset_path, args.mode, args.num_prompts)
- if not items:
- print("No prompts loaded. Exiting.")
- return
- print(f"Loaded {len(items)} prompts for {args.mode} mode")
-
- # Download I2I source images
- if args.mode == "i2i":
- print("Preparing source images...")
- for item in items:
- url = item.get("image_url", "")
- if url:
- item["image_path"] = download_image(url, IMAGE_CACHE_DIR)
- else:
- item["image_path"] = None
-
- # Load pipeline
- print(f"\nLoading pipeline from {args.model_path} ...")
- t0 = time.perf_counter()
- pipe = GlmImagePipeline.from_pretrained(
- args.model_path,
- torch_dtype=torch.bfloat16,
- device_map="cuda",
- )
- init_time = time.perf_counter() - t0
- print(f"Pipeline loaded in {init_time:.2f}s")
-
- # Create output dir
- os.makedirs(args.output_dir, exist_ok=True)
-
- # Run benchmark
- generator = torch.Generator(device="cuda").manual_seed(args.seed)
- latencies = []
- success = 0
- failed = 0
-
- print(f"\nRunning {len(items)} requests sequentially...")
- print("-" * 60)
-
- for i, item in enumerate(items):
- prompt = item["prompt"]
- gen_kwargs: dict = {
- "prompt": prompt,
- "height": args.height,
- "width": args.width,
- "num_inference_steps": args.num_inference_steps,
- "guidance_scale": args.guidance_scale,
- "generator": generator,
- }
-
- if args.mode == "i2i":
- img_path = item.get("image_path")
- if img_path and os.path.exists(img_path):
- gen_kwargs["image"] = [Image.open(img_path).convert("RGB")]
- else:
- print(f" [{i + 1}] SKIP: no source image")
- failed += 1
- continue
-
- t_start = time.perf_counter()
- try:
- result = pipe(**gen_kwargs)
- image = result.images[0]
- elapsed = time.perf_counter() - t_start
- latencies.append(elapsed)
- success += 1
-
- out_path = os.path.join(args.output_dir, f"{i:04d}.png")
- image.save(out_path)
- print(f" [{i + 1}/{len(items)}] {elapsed:.3f}s -> {out_path}")
- except Exception as e:
- elapsed = time.perf_counter() - t_start
- failed += 1
- print(f" [{i + 1}/{len(items)}] FAILED ({elapsed:.3f}s): {e}")
-
- # Report
- total_gen_time = sum(latencies) if latencies else 0
- print("\n" + "=" * 60)
- print("HuggingFace Baseline Results")
- print("=" * 60)
- print(f"{'Mode:':<40} {args.mode}")
- print(f"{'Model:':<40} {args.model_path}")
- print(f"{'Image size:':<40} {args.height}x{args.width}")
- print(f"{'Num inference steps:':<40} {args.num_inference_steps}")
- print("-" * 50)
- print(f"{'Pipeline init time (s):':<40} {init_time:.2f}")
- print(f"{'Successful:':<40} {success}/{len(items)}")
- print(f"{'Failed:':<40} {failed}")
- print("-" * 50)
- if latencies:
- arr = np.array(latencies)
- print(f"{'Total generation time (s):':<40} {total_gen_time:.2f}")
- print(f"{'Throughput (img/s):':<40} {success / total_gen_time:.4f}")
- print(f"{'Latency Mean (s):':<40} {arr.mean():.4f}")
- print(f"{'Latency Median (s):':<40} {np.median(arr):.4f}")
- print(f"{'Latency P95 (s):':<40} {np.percentile(arr, 95):.4f}")
- print(f"{'Latency P99 (s):':<40} {np.percentile(arr, 99):.4f}")
-
- print(f"\n{'Output dir:':<40} {args.output_dir}")
- print("=" * 60)
-
- # Save metrics JSON
- metrics = {
- "backend": "huggingface",
- "mode": args.mode,
- "model": args.model_path,
- "height": args.height,
- "width": args.width,
- "num_inference_steps": args.num_inference_steps,
- "init_time_s": init_time,
- "completed_requests": success,
- "failed_requests": failed,
- "total_gen_time_s": total_gen_time,
- "throughput_qps": success / total_gen_time if total_gen_time > 0 else 0,
- "latency_mean": float(np.mean(latencies)) if latencies else 0,
- "latency_median": float(np.median(latencies)) if latencies else 0,
- "latency_p95": float(np.percentile(latencies, 95)) if latencies else 0,
- "latency_p99": float(np.percentile(latencies, 99)) if latencies else 0,
- }
- if args.output_file:
- with open(args.output_file, "w") as f:
- json.dump(metrics, f, indent=2)
- print(f"Metrics saved to {args.output_file}")
-
-
-def main() -> None:
- parser = argparse.ArgumentParser(description="GLM-Image HuggingFace baseline benchmark")
- parser.add_argument("--model-path", type=str, default="zai-org/GLM-Image")
- parser.add_argument("--mode", type=str, default="t2i", choices=["t2i", "i2i"])
- parser.add_argument("--dataset-path", type=str, default=None, help="Path to prompt.json")
- parser.add_argument("--num-prompts", type=int, default=10)
- parser.add_argument("--height", type=int, default=HEIGHT)
- parser.add_argument("--width", type=int, default=WIDTH)
- parser.add_argument("--num-inference-steps", type=int, default=NUM_INFERENCE_STEPS)
- parser.add_argument("--guidance-scale", type=float, default=GUIDANCE_SCALE)
- parser.add_argument("--seed", type=int, default=SEED)
- parser.add_argument("--output-dir", type=str, default="benchmarks/glm_image/huggingface/outputs")
- parser.add_argument("--output-file", type=str, default=None, help="JSON file for metrics")
- args = parser.parse_args()
- benchmark(args)
-
-
-if __name__ == "__main__":
- main()
diff --git a/benchmarks/glm_image/vllm-omni/inference.py b/benchmarks/glm_image/vllm-omni/inference.py
deleted file mode 100644
index 5729da07174..00000000000
--- a/benchmarks/glm_image/vllm-omni/inference.py
+++ /dev/null
@@ -1,505 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""
-vLLM-Omni offline benchmark for GLM-Image.
-
-Supports T2I and I2I modes with the prompt.json dataset.
-Downloads source images for I2I from image_url on first run and caches locally.
-
-Usage:
- # T2I mode
- python benchmarks/glm_image/vllm-omni/inference.py \
- --model-path zai-org/GLM-Image \
- --mode t2i --num-prompts 10
-
- # I2I mode (downloads source images)
- python benchmarks/glm_image/vllm-omni/inference.py \
- --model-path zai-org/GLM-Image \
- --mode i2i --num-prompts 10
-"""
-
-import argparse
-import json
-import math
-import os
-import time
-from pathlib import Path
-
-import numpy as np
-import requests
-from PIL import Image
-from vllm import SamplingParams
-
-from vllm_omni.entrypoints.omni import Omni
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-BENCHMARK_DIR = Path(__file__).resolve().parent.parent
-DEFAULT_PROMPT_JSON = BENCHMARK_DIR / "prompt" / "prompt.json"
-IMAGE_CACHE_DIR = BENCHMARK_DIR / "prompt" / "images"
-DEFAULT_DEPLOY_CONFIG = "vllm_omni/deploy/glm_image.yaml"
-
-DATASET_REPO = "JaredforReal/glm-image-bench"
-DATASET_FILE = "prompt.json"
-
-
-def _ensure_prompt_json(dataset_path: str | None) -> str:
- """Return path to prompt.json, downloading from HuggingFace if needed."""
- if dataset_path:
- return dataset_path
- local = DEFAULT_PROMPT_JSON
- if local.exists():
- return str(local)
- print(f"Downloading {DATASET_FILE} from {DATASET_REPO} ...")
- try:
- from huggingface_hub import hf_hub_download
-
- downloaded = hf_hub_download(
- repo_id=DATASET_REPO,
- filename=DATASET_FILE,
- repo_type="dataset",
- )
- local.parent.mkdir(parents=True, exist_ok=True)
- import shutil
-
- shutil.copy2(downloaded, local)
- print(f"Saved to {local}")
- except ImportError:
- url = f"https://huggingface.co/datasets/{DATASET_REPO}/resolve/main/{DATASET_FILE}"
- import urllib.request
-
- local.parent.mkdir(parents=True, exist_ok=True)
- urllib.request.urlretrieve(url, local)
- print(f"Saved to {local}")
- return str(local)
-
-
-SEED = 42
-HEIGHT = 1024
-WIDTH = 1024
-NUM_INFERENCE_STEPS = 50
-GUIDANCE_SCALE = 1.5
-
-GLM_IMAGE_EOS_TOKEN_ID = 16385
-GLM_IMAGE_VISION_VOCAB_SIZE = 16512
-
-
-# ---------------------------------------------------------------------------
-# Dataset
-# ---------------------------------------------------------------------------
-
-
-def load_dataset(
- dataset_path: str | None,
- mode: str,
- num_prompts: int,
-) -> list[dict]:
- path = _ensure_prompt_json(dataset_path)
- with open(path, encoding="utf-8") as f:
- raw = json.load(f)
-
- items = []
- for entry in raw:
- prompt_key = "t2i_prompt" if mode == "t2i" else "i2i_prompt"
- prompt_text = entry.get(prompt_key, "").strip()
- if not prompt_text:
- continue
-
- item = {"prompt": prompt_text}
- if mode == "i2i":
- item["image_url"] = entry.get("image_url", "")
- items.append(item)
-
- if num_prompts and len(items) > num_prompts:
- items = items[:num_prompts]
- return items
-
-
-def download_image(url: str, cache_dir: Path) -> str:
- cache_dir.mkdir(parents=True, exist_ok=True)
- fname = url.rsplit("/", 1)[-1]
- local_path = cache_dir / fname
- if local_path.exists():
- return str(local_path)
- print(f" Downloading {url} ...")
- resp = requests.get(url, timeout=30)
- resp.raise_for_status()
- local_path.write_bytes(resp.content)
- return str(local_path)
-
-
-# ---------------------------------------------------------------------------
-# Helpers
-# ---------------------------------------------------------------------------
-
-
-def compute_max_tokens(height: int, width: int, is_i2i: bool = False) -> int:
- factor = 32
- token_h = height // factor
- token_w = width // factor
- large_tokens = token_h * token_w
-
- # Small preview tokens (half resolution in each dimension)
-
- ratio = token_h / token_w if token_w > 0 else 1.0
- small_token_h = max(1, int(math.sqrt(ratio) * (factor // 2)))
- small_token_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2)))
- small_tokens = small_token_h * small_token_w
-
- # Mode-dependent totals:
- # - t2i: small + large + EOS
- # - i2i: large + EOS
- if is_i2i:
- return large_tokens + 1
- return small_tokens + large_tokens + 1
-
-
-def build_prompt_t2i(prompt: str, height: int, width: int, **gen_kw) -> dict:
- return {
- "prompt": prompt,
- "height": height,
- "width": width,
- "mm_processor_kwargs": {"target_h": height, "target_w": width},
- **gen_kw,
- }
-
-
-def build_prompt_i2i(prompt: str, image_path: str, height: int, width: int, **gen_kw) -> dict:
- return {
- "prompt": prompt,
- "height": height,
- "width": width,
- "mm_processor_kwargs": {"target_h": height, "target_w": width},
- "multi_modal_data": {"image": Image.open(image_path).convert("RGB")},
- **gen_kw,
- }
-
-
-def resolve_deploy_config(args: argparse.Namespace) -> str:
- if args.deploy_config:
- return args.deploy_config
- if os.path.exists(DEFAULT_DEPLOY_CONFIG):
- return DEFAULT_DEPLOY_CONFIG
- fallback = Path(__file__).resolve().parents[3] / DEFAULT_DEPLOY_CONFIG
- if fallback.exists():
- return str(fallback)
- raise FileNotFoundError("Deploy config not found. Specify --deploy-config.")
-
-
-# ---------------------------------------------------------------------------
-# Benchmark
-# ---------------------------------------------------------------------------
-
-
-def benchmark(args: argparse.Namespace) -> None:
- is_i2i = args.mode == "i2i"
-
- print("=" * 60)
- print("GLM-Image vLLM-Omni Benchmark")
- print(f"Mode: {args.mode} | Model: {args.model_path}")
- print(f"Size: {args.height}x{args.width} | Steps: {args.num_inference_steps}")
- print("=" * 60)
-
- # Load dataset
- items = load_dataset(args.dataset_path, args.mode, args.num_prompts)
- if not items:
- print("No prompts loaded. Exiting.")
- return
- print(f"Loaded {len(items)} prompts for {args.mode} mode")
-
- # Download I2I source images
- if is_i2i:
- print("Preparing source images...")
- for item in items:
- url = item.get("image_url", "")
- if url:
- item["image_path"] = download_image(url, IMAGE_CACHE_DIR)
- else:
- item["image_path"] = None
-
- # Init Omni
- deploy_config = resolve_deploy_config(args)
- print(f"\nInitializing vLLM-Omni (deploy config: {deploy_config}) ...")
- t0 = time.perf_counter()
-
- omni = Omni(
- model=args.model_path,
- deploy_config=deploy_config,
- log_stats=args.log_stats,
- stage_init_timeout=args.stage_init_timeout,
- enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
- enable_ar_profiler=args.enable_ar_profiler,
- )
-
- init_time = time.perf_counter() - t0
- print(f"Initialized in {init_time:.2f}s")
-
- # Sampling params
- max_tokens = compute_max_tokens(args.height, args.width, is_i2i=is_i2i)
- ar_params = SamplingParams(
- temperature=0.9,
- top_p=0.75,
- top_k=GLM_IMAGE_VISION_VOCAB_SIZE,
- max_tokens=max_tokens,
- stop_token_ids=[GLM_IMAGE_EOS_TOKEN_ID],
- seed=args.seed,
- detokenize=False,
- extra_args={"target_h": args.height, "target_w": args.width},
- )
- diff_params = OmniDiffusionSamplingParams(
- num_inference_steps=args.num_inference_steps,
- guidance_scale=args.guidance_scale,
- height=args.height,
- width=args.width,
- seed=args.seed,
- )
- sampling_params_list = [ar_params, diff_params]
-
- # Build all prompts
- gen_kw = {
- "seed": args.seed,
- "num_inference_steps": args.num_inference_steps,
- "guidance_scale": args.guidance_scale,
- }
- all_prompts = []
- for item in items:
- if is_i2i:
- img_path = item.get("image_path")
- if not img_path or not os.path.exists(img_path):
- continue
- all_prompts.append(build_prompt_i2i(item["prompt"], img_path, args.height, args.width, **gen_kw))
- else:
- all_prompts.append(build_prompt_t2i(item["prompt"], args.height, args.width, **gen_kw))
-
- valid = len(all_prompts)
- print(f"Valid prompts: {valid}")
-
- # Create output dir
- os.makedirs(args.output_dir, exist_ok=True)
-
- # Warmup: run 1 request to prime caches, CUDA graphs, etc.
- if all_prompts:
- print("Running warmup request...")
- try:
- warmup_prompt = [all_prompts[0]]
- omni.generate(warmup_prompt, sampling_params_list, py_generator=False)
- print("Warmup done.\n")
- except Exception as e:
- print(f"Warmup failed (continuing): {e}")
-
- # Run
- print(f"\nRunning {valid} requests...")
- print("-" * 60)
-
- latencies = []
- all_stage_durations: list[dict[str, float]] = []
- success = 0
- failed = 0
- wall_start = time.perf_counter()
-
- try:
- output_idx = 0
- for stage_outputs in omni.generate(all_prompts, sampling_params_list, py_generator=True):
- if stage_outputs.final_output_type == "image":
- request_output = stage_outputs.request_output
- request_id = getattr(request_output, "request_id", "")
-
- images = getattr(request_output, "images", [])
- if not images and hasattr(request_output, "multimodal_output"):
- mm = request_output.multimodal_output
- if isinstance(mm, dict):
- images = mm.get("images", [])
-
- elapsed = time.perf_counter() - wall_start
- if images:
- for img in images:
- if isinstance(img, Image.Image):
- out_path = os.path.join(args.output_dir, f"{output_idx:04d}.png")
- img.save(out_path)
- success += 1
- latencies.append(elapsed)
- stage_durations = getattr(stage_outputs, "stage_durations", {})
- if stage_durations:
- all_stage_durations.append(stage_durations)
- # Show wall-clock elapsed and pipeline breakdown if available
- preprocess_str = ""
- if "preprocess_ms" in stage_durations:
- preprocess_str = f" preprocess={stage_durations['preprocess_ms'] / 1000.0:.2f}s"
- print(f" [{success}/{valid}] id={request_id[:8]} {elapsed:.2f}s{preprocess_str}")
- output_idx += 1
- else:
- failed += 1
- except Exception as e:
- print(f"Error: {e}")
- failed = valid - success
-
- total_gen_time = time.perf_counter() - wall_start
-
- # Diff stage_0_gen_ms with previous request to remove accumulated wait time.
- # stage_0_gen_ms is measured from submit_ts (same for all requests submitted
- # at once), so it accumulates queue/scheduling overhead across requests.
- # Other stages and pipeline timings are per-request already.
- _TIMING_ORDER = [
- "preprocess_ms",
- "stage_0_gen_ms",
- "ar2diffusion_ms",
- "stage_1_gen_ms",
- "queue_wait_ms",
- ]
-
- per_request_actual: list[dict[str, float]] = []
- prev_stage_0_ms = 0.0
- for sd in all_stage_durations:
- actual = dict(sd)
- s0 = sd.get("stage_0_gen_ms", 0.0)
- actual["stage_0_gen_ms"] = s0 - prev_stage_0_ms
- prev_stage_0_ms = s0
- per_request_actual.append(actual)
-
- per_request_e2e_ms: list[float] = []
- for actual in per_request_actual:
- e2e_ms = sum(v for k, v in actual.items() if k in _TIMING_ORDER)
- if e2e_ms > 0:
- per_request_e2e_ms.append(e2e_ms)
-
- # Report
- print("\n" + "=" * 60)
- print("vLLM-Omni Benchmark Results")
- print("=" * 60)
- print(f"{'Mode:':<40} {args.mode}")
- print(f"{'Model:':<40} {args.model_path}")
- print(f"{'Image size:':<40} {args.height}x{args.width}")
- print(f"{'Num inference steps:':<40} {args.num_inference_steps}")
- print("-" * 50)
- print(f"{'Init time (s):':<40} {init_time:.2f}")
- print(f"{'Successful:':<40} {success}/{valid}")
- print(f"{'Failed:':<40} {failed}")
- print("-" * 50)
-
- if per_request_e2e_ms:
- per_request_s = np.array(per_request_e2e_ms) / 1000.0
- print(f"{'Total generation time (s):':<40} {total_gen_time:.2f}")
- print(f"{'Throughput (img/s):':<40} {success / total_gen_time:.4f}")
- print(f"{'Latency Mean (s):':<40} {per_request_s.mean():.4f}")
- print(f"{'Latency Median (s):':<40} {np.median(per_request_s):.4f}")
- print(f"{'Latency P95 (s):':<40} {np.percentile(per_request_s, 95):.4f}")
- print(f"{'Latency P99 (s):':<40} {np.percentile(per_request_s, 99):.4f}")
- print(f"{'Latency Min (s):':<40} {per_request_s.min():.4f}")
- print(f"{'Latency Max (s):':<40} {per_request_s.max():.4f}")
- elif latencies:
- per_request = np.diff([0.0] + list(latencies))
- print(f"{'Total generation time (s):':<40} {total_gen_time:.2f}")
- print(f"{'Throughput (img/s):':<40} {success / total_gen_time:.4f}")
- print(f"{'Latency Mean (s) [wall-clock]:':<40} {per_request.mean():.4f}")
- print(f"{'Latency Median (s) [wall-clock]:':<40} {np.median(per_request):.4f}")
- print(f"{'Latency P95 (s) [wall-clock]:':<40} {np.percentile(per_request, 95):.4f}")
- print(f"{'Latency P99 (s) [wall-clock]:':<40} {np.percentile(per_request, 99):.4f}")
- print(f"{'Latency Min (s) [wall-clock]:':<40} {per_request.min():.4f}")
- print(f"{'Latency Max (s) [wall-clock]:':<40} {per_request.max():.4f}")
-
- if per_request_actual:
- print("-" * 50)
- print("Pipeline Timings Mean:")
- for key in _TIMING_ORDER:
- vals = [d.get(key, 0.0) for d in per_request_actual]
- if any(v != 0 for v in vals):
- unit = "ms" if key.endswith("_ms") else "s"
- print(f" {key + ':':<38} {np.mean(vals):.4f} ({unit})")
- # Show any extra keys not in the ordered list
- ordered_set = set(_TIMING_ORDER)
- extra_keys = sorted(k for k in per_request_actual[0].keys() if k not in ordered_set)
- for key in extra_keys:
- vals = [d.get(key, 0.0) for d in per_request_actual]
- if any(v != 0 for v in vals):
- unit = "ms" if key.endswith("_ms") else "s"
- print(f" {key + ':':<38} {np.mean(vals):.4f} ({unit})")
-
- print(f"\n{'Output dir:':<40} {args.output_dir}")
- print("=" * 60)
-
- # Metrics JSON
- metrics = {
- "backend": "vllm-omni",
- "mode": args.mode,
- "model": args.model_path,
- "height": args.height,
- "width": args.width,
- "num_inference_steps": args.num_inference_steps,
- "init_time_s": init_time,
- "completed_requests": success,
- "failed_requests": failed,
- "total_gen_time_s": total_gen_time,
- "throughput_qps": success / total_gen_time if total_gen_time > 0 else 0,
- }
- if per_request_e2e_ms:
- per_request_s = np.array(per_request_e2e_ms) / 1000.0
- metrics["latency_mean"] = float(per_request_s.mean())
- metrics["latency_median"] = float(np.median(per_request_s))
- metrics["latency_p95"] = float(np.percentile(per_request_s, 95))
- metrics["latency_p99"] = float(np.percentile(per_request_s, 99))
- elif latencies:
- per_request = np.diff([0.0] + list(latencies))
- metrics["latency_mean"] = float(per_request.mean())
- metrics["latency_median"] = float(np.median(per_request))
- metrics["latency_p95"] = float(np.percentile(per_request, 95))
- metrics["latency_p99"] = float(np.percentile(per_request, 99))
- else:
- metrics["latency_mean"] = 0
- metrics["latency_median"] = 0
- metrics["latency_p95"] = 0
- metrics["latency_p99"] = 0
- if per_request_actual:
- all_keys = list(_TIMING_ORDER) + sorted(k for k in per_request_actual[0].keys() if k not in set(_TIMING_ORDER))
- stage_metrics = {}
- for key in all_keys:
- vals = [d.get(key, 0.0) for d in per_request_actual]
- stage_metrics[key] = {
- "mean": float(np.mean(vals)),
- "median": float(np.median(vals)),
- "p95": float(np.percentile(vals, 95)),
- }
- metrics["stage_durations"] = stage_metrics
- if args.output_file:
- with open(args.output_file, "w") as f:
- json.dump(metrics, f, indent=2)
- print(f"Metrics saved to {args.output_file}")
-
- omni.close()
- print("Done!")
-
-
-def main() -> None:
- parser = argparse.ArgumentParser(description="GLM-Image vLLM-Omni offline benchmark")
- parser.add_argument("--model-path", type=str, default="zai-org/GLM-Image")
- parser.add_argument("--deploy-config", type=str, default=None, help="Deploy config YAML")
- parser.add_argument("--mode", type=str, default="t2i", choices=["t2i", "i2i"])
- parser.add_argument("--dataset-path", type=str, default=None, help="Path to prompt.json")
- parser.add_argument("--num-prompts", type=int, default=10)
- parser.add_argument("--height", type=int, default=HEIGHT)
- parser.add_argument("--width", type=int, default=WIDTH)
- parser.add_argument("--num-inference-steps", type=int, default=NUM_INFERENCE_STEPS)
- parser.add_argument("--guidance-scale", type=float, default=GUIDANCE_SCALE)
- parser.add_argument("--seed", type=int, default=SEED)
- parser.add_argument("--output-dir", type=str, default="benchmarks/glm_image/vllm-omni/outputs")
- parser.add_argument("--output-file", type=str, default=None, help="JSON file for metrics")
- parser.add_argument("--stage-init-timeout", type=int, default=600)
- parser.add_argument(
- "--enable-diffusion-pipeline-profiler",
- action="store_true",
- help="Enable diffusion pipeline profiler for stage-level timing",
- )
- parser.add_argument(
- "--enable-ar-profiler",
- action="store_true",
- help="Enable AR stage profiler to include AR timing in stage_durations",
- )
- parser.add_argument(
- "--log-stats",
- action="store_true",
- help="Enable detailed per-request pipeline stats logging",
- )
- args = parser.parse_args()
- benchmark(args)
-
-
-if __name__ == "__main__":
- main()
diff --git a/benchmarks/qwen3-omni/README.md b/benchmarks/qwen3-omni/README.md
new file mode 100644
index 00000000000..de27c05c2c4
--- /dev/null
+++ b/benchmarks/qwen3-omni/README.md
@@ -0,0 +1,86 @@
+# Benchmarks Guide
+
+This README explains how to (1) prepare benchmark datasets and (2) run the provided Qwen3-Omni benchmarks.
+
+## 1) Prepare the dataset (SeedTTS top100)
+
+```bash
+cd benchmarks/build_dataset
+pip install gdown
+
+# Download SeedTTS test set from Google Drive
+gdown --id 1GlSjVfSHkW3-leKKBlfrjuuTGqQ_xaLP
+
+# Extract
+tar -xf seedtts_testset.tar
+
+# Copy metadata and extract top-100 prompts
+cp seedtts_testset/en/meta.lst meta.lst
+python extract_prompts.py -i meta.lst -o top100.txt -n 100
+
+# (Optional) clean up to save space
+rm -rf seedtts_testset seedtts_testset.tar meta.lst
+```
+
+Artifacts:
+- `benchmarks/build_dataset/top100.txt` — 100 text prompts (one per line).
+
+## 2) Run benchmarks
+
+All commands assume repo root (`vllm-omni`).
+
+### A. Transformers benchmark (offline, HF Transformers)
+
+```
+bash benchmarks/qwen3-omni/transformers/eval_qwen3_moe_omni_transformers.sh
+```
+
+What it does:
+- Runs `qwen3_omni_moe_transformers.py` over `top100.txt` with `--num-prompts 100`.
+- Outputs to `benchmarks/qwen3-omni/transformers/benchmark_results/`:
+ - `perf_stats.json` — aggregated & per-prompt TPS/latency (thinker/talker/code2wav/overall).
+ - `results.json` — per-prompt outputs and audio paths.
+ - `audio/` — ~100 generated `.wav` files.
+
+Key checks:
+- `overall_tps` and `*_tps_avg` should be non-zero and reasonably stable.
+- Investigate any 0/NaN or unusually low TPS / long-tail latency.
+
+### B. vLLM Omni end-to-end benchmark (pipeline)
+
+```
+bash benchmarks/qwen3-omni/vllm_omni/eval_qwen3_moe_omni.sh
+```
+
+What it does:
+- Runs `examples/offline_inference/qwen3_omni/end2end.py` with `--log-stats`.
+- Uses `benchmarks/build_dataset/top100.txt` and writes to:
+ - Logs: `benchmarks/qwen3-omni/vllm_omni/logs/`
+ - `omni_pipeline_text.orchestrator.stats.jsonl` — per-stage latency stats.
+ - `omni_pipeline_text.overall.stats.jsonl` — end-to-end latency/TPS.
+ - `omni_pipeline_text.stage{0,1,2}.log` — per-stage detailed logs/errors.
+ - Outputs: `benchmarks/qwen3-omni/vllm_omni/outputs/` — ~100 text and `.wav` files.
+
+Key checks:
+- Overall stats: end-to-end latency/TPS should be reasonable.
+- Orchestrator stats: per-stage latency should be stable; investigate long tails.
+- Stage logs: ensure no errors and no unusually slow stages.
+
+
+## Performance snapshot
+
+The chart below summarizes our measured Qwen3-Omni MoE end-to-end benchmark, comparing vLLM-Omni against HF Transformers. It shows the overall throughput advantage for vLLM-Omni. These are actual experiment results—please refer to this performance when evaluating or reproducing the benchmark.
+
+
+
+## Directory layout
+- `benchmarks/build_dataset/` — dataset prep utilities (e.g., SeedTTS top100).
+- `benchmarks//vllm_omni/` — vLLM-Omni pipeline benchmarks, logs, outputs.
+- Add new tasks under `benchmarks//...` with the same pattern: `transformers/`, `vllm_omni/`, task-specific README, and (optionally) dataset prep notes.
+- `benchmarks//vllm-omni-vs-hf.png` — current performance snapshot (overall throughput comparison).
+- `benchmarks//transformers/` — HF Transformers benchmarks (offline reference).
+
+## Troubleshooting
+- Make sure GPU/driver/FlashAttention2 requirements are met for the chosen model.
+- If downloads fail, confirm network access to Google Drive (`gdown`) and Hugging Face.
+- If audio files are missing, check for errors in stage logs or model generation.***
diff --git a/benchmarks/qwen3-omni/transformers/eval_qwen3_moe_omni_transformers.sh b/benchmarks/qwen3-omni/transformers/eval_qwen3_moe_omni_transformers.sh
new file mode 100644
index 00000000000..bae514fab28
--- /dev/null
+++ b/benchmarks/qwen3-omni/transformers/eval_qwen3_moe_omni_transformers.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+# Qwen3-Omni Transformers Benchmark Evaluation Script
+# This script must be run from the vllm-omni root directory
+
+# Get the directory where this script is located
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Navigate to vllm-omni root directory (4 levels up from script location)
+VLLM_OMNI_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)"
+cd "$VLLM_OMNI_ROOT" || { echo "Error: Failed to navigate to vllm-omni directory"; exit 1; }
+
+echo "Working directory: $(pwd)"
+# Verify we're in the correct directory and run benchmark
+if [[ ! -f "benchmarks/qwen3-omni/transformers/qwen3_omni_moe_transformers.py" ]]; then
+ echo "Error: Not in vllm-omni root directory. Please run from vllm-omni folder."
+else
+ cd benchmarks/qwen3-omni/transformers
+
+ python qwen3_omni_moe_transformers.py --prompts-file ../../build_dataset/top100.txt --num-prompts 100
+
+ echo "Logs and outputs are saved to $(pwd)/benchmark_results:"
+ echo " - perf_stats.json Aggregated/per-prompt TPS and latency (thinker/talker/code2wav/overall)"
+ echo " - results.json Per-prompt outputs and audio paths"
+ echo " - audio/ Generated wav files, there should be 100 wav file generated"
+ echo "Key checks: overall_tps and *_tps_avg should be non-zero and stable; investigate 0/NaN or unusually low TPS/long-tail latency."
+fi
diff --git a/benchmarks/qwen3-omni/transformers/qwen3_omni_moe_model.py b/benchmarks/qwen3-omni/transformers/qwen3_omni_moe_model.py
new file mode 100644
index 00000000000..43b56f3e995
--- /dev/null
+++ b/benchmarks/qwen3-omni/transformers/qwen3_omni_moe_model.py
@@ -0,0 +1,265 @@
+import time
+
+import torch
+from transformers import Qwen3OmniMoeForConditionalGeneration
+
+
+class Qwen3OmniMoeForConditionalGenerationWithLogging(Qwen3OmniMoeForConditionalGeneration):
+ @torch.no_grad()
+ def generate(
+ self,
+ input_ids: torch.Tensor | None = None,
+ speaker: str = "Ethan",
+ use_audio_in_video: bool = False,
+ return_audio: bool | None = None,
+ thinker_max_new_tokens: int = 1024,
+ thinker_eos_token_id: int = 151645,
+ talker_max_new_tokens: int = 4096,
+ talker_do_sample: bool = True,
+ talker_top_k: int = 50,
+ talker_top_p: float = 1.0,
+ talker_temperature: float = 0.9,
+ talker_repetition_penalty: float = 1.05,
+ **kwargs,
+ ):
+ total_t0 = time.time()
+ perf_stats = {
+ "thinker_tokens": 0,
+ "thinker_time_s": 0.0,
+ "thinker_tps": 0.0,
+ "talker_tokens": 0,
+ "talker_time_s": 0.0,
+ "talker_tps": 0.0,
+ "code2wav_tokens": 0,
+ "code2wav_time_s": 0.0,
+ "code2wav_tps": 0.0,
+ "total_tokens": 0,
+ "total_time_s": 0.0,
+ "total_tps": 0.0,
+ }
+ if return_audio and not self.has_talker:
+ raise ValueError(
+ "Cannot use talker when talker module not initialized. "
+ "Use `enable_talker` method or set enable_talker in config "
+ "to enable talker."
+ )
+ if return_audio is None:
+ return_audio = self.has_talker
+
+ shared_kwargs = {"use_audio_in_video": use_audio_in_video}
+ thinker_kwargs = {
+ "max_new_tokens": thinker_max_new_tokens,
+ "eos_token_id": thinker_eos_token_id,
+ }
+
+ talker_kwargs = {}
+ token2wav_kwargs = {}
+ if return_audio:
+ speaker_id = self.config.talker_config.speaker_id.get(speaker.lower())
+ if speaker_id is None:
+ raise NotImplementedError(f"Speaker {speaker} not implemented")
+ if input_ids.shape[0] != 1:
+ raise NotImplementedError("Qwen3-Omni currently does not support batched inference with audio output")
+ talker_suppressed_tokens = [
+ i
+ for i in range(
+ self.config.talker_config.text_config.vocab_size - 1024,
+ self.config.talker_config.text_config.vocab_size,
+ )
+ if i != self.config.talker_config.codec_eos_token_id
+ ] # Suppress additional special tokens, should not be predicted
+ talker_kwargs = {
+ "max_new_tokens": talker_max_new_tokens,
+ "do_sample": talker_do_sample,
+ "top_k": talker_top_k,
+ "top_p": talker_top_p,
+ "temperature": talker_temperature,
+ "eos_token_id": self.config.talker_config.codec_eos_token_id,
+ "repetition_penalty": talker_repetition_penalty,
+ "suppress_tokens": talker_suppressed_tokens,
+ "output_hidden_states": True,
+ "return_dict_in_generate": True,
+ }
+ token2wav_kwargs = {}
+
+ for key, value in kwargs.items():
+ if key.startswith("thinker_"):
+ thinker_kwargs[key[len("thinker_") :]] = value
+ elif key.startswith("talker_"):
+ talker_kwargs[key[len("talker_") :]] = value
+ elif key.startswith("token2wav_"):
+ token2wav_kwargs[key[len("token2wav_") :]] = value
+ # Process special input values
+ elif key == "feature_attention_mask":
+ thinker_kwargs[key] = value
+ talker_kwargs["audio_feature_lengths"] = torch.sum(value, dim=1)
+ elif key in ("input_features", "attention_mask"):
+ thinker_kwargs[key] = value
+ # Put other key to shared kwargs
+ else:
+ shared_kwargs[key] = value
+
+ # Merge kwargs
+ for key, value in shared_kwargs.items():
+ if key not in thinker_kwargs:
+ thinker_kwargs[key] = value
+ if key not in talker_kwargs and key in ["image_grid_thw", "video_grid_thw", "video_second_per_grid"]:
+ talker_kwargs[key] = value
+ if key not in token2wav_kwargs:
+ token2wav_kwargs[key] = value
+
+ # 1. Generate from thinker module
+ generate_audio = return_audio and self.has_talker
+ if generate_audio:
+ thinker_kwargs["output_hidden_states"] = True
+ thinker_kwargs["return_dict_in_generate"] = True
+
+ t0 = time.time()
+ thinker_result = self.thinker.generate(input_ids=input_ids, **thinker_kwargs)
+ t1 = time.time()
+ perf_stats["thinker_time_s"] = max(0.0, t1 - t0)
+ try:
+ prompt_len = int(input_ids.shape[1]) if input_ids is not None else 0
+ total_len = int(thinker_result.sequences.shape[-1])
+ thinker_out_len = max(0, total_len - prompt_len)
+ except Exception:
+ thinker_out_len = 0
+ perf_stats["thinker_tokens"] = thinker_out_len
+ perf_stats["thinker_tps"] = (
+ (thinker_out_len / perf_stats["thinker_time_s"]) if perf_stats["thinker_time_s"] > 0 else 0.0
+ )
+
+ if not generate_audio:
+ perf_stats["total_tokens"] = perf_stats["thinker_tokens"]
+ perf_stats["total_time_s"] = time.time() - total_t0
+ perf_stats["total_tps"] = (
+ (perf_stats["total_tokens"] / perf_stats["total_time_s"]) if perf_stats["total_time_s"] > 0 else 0.0
+ )
+ # attach stats to self
+ setattr(self, "_perf_stats_last", perf_stats)
+ if not hasattr(self, "_perf_stats_history"):
+ setattr(self, "_perf_stats_history", [])
+ self._perf_stats_history.append(perf_stats)
+ return thinker_result, None
+
+ # 2. Prepare talker input
+ thinker_embed = torch.cat([hidden_states[0] for hidden_states in thinker_result.hidden_states], dim=1).to(
+ self.talker.device
+ ) # [1 t d]
+ thinker_hidden = torch.cat(
+ [
+ hidden_states[self.config.talker_config.accept_hidden_layer]
+ for hidden_states in thinker_result.hidden_states
+ ],
+ dim=1,
+ ).to(self.talker.device) # [1 t d]
+
+ im_start_indexes = torch.cat(
+ (
+ torch.nonzero(input_ids[0] == self.config.im_start_token_id).squeeze(),
+ torch.tensor([thinker_result.sequences.shape[-1]], device=input_ids.device, dtype=input_ids.dtype),
+ ),
+ dim=-1,
+ ).to(self.talker.device) # Shape [n_starts + 1]; Take batch 0 since batched inference is not supported here.
+ multimodal_mask = (
+ (thinker_result.sequences == self.config.thinker_config.audio_token_id) |
+ (thinker_result.sequences == self.config.thinker_config.image_token_id) |
+ (thinker_result.sequences == self.config.thinker_config.video_token_id)
+ ).to(self.talker.device) # [1 t] # fmt: skip
+
+ talker_special_tokens = torch.tensor(
+ [[self.config.tts_bos_token_id, self.config.tts_eos_token_id, self.config.tts_pad_token_id]],
+ device=self.thinker.device,
+ dtype=input_ids.dtype,
+ )
+ tts_bos_embed, tts_eos_embed, tts_pad_embed = (
+ self.talker.text_projection(self.thinker.get_input_embeddings()(talker_special_tokens))
+ .to(self.talker.device)
+ .chunk(3, dim=1)
+ ) # 3 * [1 1 d]
+
+ talker_input_embeds = [] # [1 t d]
+ talker_input_ids = []
+ # For every chatml parts
+ for i in range(len(im_start_indexes) - 1):
+ im_start_index = im_start_indexes[i]
+ segment_end_index = im_start_indexes[i + 1]
+ role_token = input_ids[0][im_start_index + 1]
+ # Talker should ignore thinker system prompt
+ if role_token == self.config.system_token_id:
+ continue
+ # Talker takes word embeddings for tokens and hidden state from `accept_hidden_layer` for multimodal inputs
+ elif role_token == self.config.user_token_id:
+ talker_user_part = self._get_talker_user_parts(
+ im_start_index, segment_end_index, multimodal_mask, thinker_hidden, thinker_embed
+ )
+ talker_input_embeds.append(talker_user_part)
+ talker_input_ids.append(thinker_result.sequences[:, im_start_index:segment_end_index])
+ # Take assistant output (for now)
+ elif role_token == self.config.assistant_token_id and i == len(im_start_indexes) - 2:
+ talker_assistant_embeds, talker_assistant_ids, trailing_text_hidden = self._get_talker_assistant_parts(
+ im_start_index,
+ segment_end_index,
+ speaker_id,
+ thinker_embed,
+ tts_pad_embed,
+ tts_bos_embed,
+ tts_eos_embed,
+ )
+ talker_input_embeds.append(talker_assistant_embeds)
+ talker_input_ids.append(talker_assistant_ids)
+ # History assistant output (ignore for now)
+ elif role_token == self.config.assistant_token_id and i != len(im_start_indexes) - 2:
+ continue
+ else:
+ raise AssertionError("Expect role id after <|im_start|> (assistant, user, system)")
+ talker_input_embed = torch.cat([embed.to(self.talker.device) for embed in talker_input_embeds], dim=1)
+ talker_input_id = torch.cat([embed.to(self.talker.device) for embed in talker_input_ids], dim=1)
+ t2 = time.time()
+ talker_result = self.talker.generate(
+ inputs_embeds=talker_input_embed,
+ trailing_text_hidden=trailing_text_hidden,
+ tts_pad_embed=tts_pad_embed,
+ talker_input_ids=talker_input_id, # Not use input_ids to prevent repetition penalty out of bound
+ **talker_kwargs,
+ )
+ t3 = time.time()
+ perf_stats["talker_time_s"] = max(0.0, t3 - t2)
+ talker_codes = (
+ torch.stack([hid[-1] for hid in talker_result.hidden_states if hid[-1] is not None], dim=1)
+ .transpose(1, 2)
+ .to(self.code2wav.device)
+ )
+ try:
+ # codes shape: (B, num_quantizers, T). We log T as token length.
+ perf_stats["talker_tokens"] = int(talker_codes.shape[-1])
+ except Exception:
+ perf_stats["talker_tokens"] = 0
+ perf_stats["talker_tps"] = (
+ (perf_stats["talker_tokens"] / perf_stats["talker_time_s"]) if perf_stats["talker_time_s"] > 0 else 0.0
+ )
+ t4 = time.time()
+ talker_wavs = self.code2wav.chunked_decode(talker_codes, chunk_size=300, left_context_size=25).float()
+ t5 = time.time()
+ perf_stats["code2wav_time_s"] = max(0.0, t5 - t4)
+ perf_stats["code2wav_tokens"] = perf_stats["talker_tokens"] # same T, not times 16
+ perf_stats["code2wav_tps"] = (
+ (perf_stats["code2wav_tokens"] / perf_stats["code2wav_time_s"])
+ if perf_stats["code2wav_time_s"] > 0
+ else 0.0
+ )
+ perf_stats["total_tokens"] = perf_stats["thinker_tokens"] + perf_stats["talker_tokens"]
+ perf_stats["total_time_s"] = time.time() - total_t0
+ perf_stats["total_tps"] = (
+ (perf_stats["total_tokens"] / perf_stats["total_time_s"]) if perf_stats["total_time_s"] > 0 else 0.0
+ )
+ setattr(self, "_perf_stats_last", perf_stats)
+ if not hasattr(self, "_perf_stats_history"):
+ setattr(self, "_perf_stats_history", [])
+ self._perf_stats_history.append(perf_stats)
+ return thinker_result, talker_wavs.float()
+
+
+__all__ = [
+ "Qwen3OmniMoeForConditionalGenerationWithLogging",
+]
diff --git a/benchmarks/qwen3-omni/transformers/qwen3_omni_moe_transformers.py b/benchmarks/qwen3-omni/transformers/qwen3_omni_moe_transformers.py
new file mode 100644
index 00000000000..87d3de797b8
--- /dev/null
+++ b/benchmarks/qwen3-omni/transformers/qwen3_omni_moe_transformers.py
@@ -0,0 +1,275 @@
+import argparse
+import json
+import os
+
+import soundfile as sf
+from qwen3_omni_moe_model import Qwen3OmniMoeForConditionalGenerationWithLogging
+from qwen_omni_utils import process_mm_info
+from tqdm import tqdm
+from transformers import Qwen3OmniMoeProcessor
+
+MODEL_PATH = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
+# MODEL_PATH = "Qwen/Qwen3-Omni-30B-A3B-Thinking"
+
+
+def load_prompts(prompts_file: str) -> list[str]:
+ """Load prompts from a text file, one prompt per line."""
+ prompts = []
+ with open(prompts_file, encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ prompts.append(line)
+ return prompts
+
+
+def run_benchmark(
+ model,
+ processor,
+ prompts: list[str],
+ output_dir: str = "benchmark_results",
+ speaker: str = "Ethan",
+ use_audio_in_video: bool = True,
+):
+ """
+ Run benchmark on a list of prompts and collect performance stats.
+
+ Args:
+ model: The Qwen3OmniMoe model
+ processor: The Qwen3OmniMoe processor
+ prompts: List of text prompts to process
+ output_dir: Directory to save results
+ speaker: Speaker voice for audio output
+ use_audio_in_video: Whether to use audio in video
+
+ Returns:
+ tuple: (aggregated_stats, results, audio_outputs)
+ - aggregated_stats: dict with aggregated performance statistics
+ - results: list of dicts with per-prompt results
+ - audio_outputs: list of audio tensors/arrays (or None if no audio)
+ """
+ os.makedirs(output_dir, exist_ok=True)
+ audio_dir = os.path.join(output_dir, "audio")
+ os.makedirs(audio_dir, exist_ok=True)
+
+ all_stats = []
+ results = []
+ audio_outputs = []
+
+ for idx, prompt in enumerate(tqdm(prompts, desc="Processing prompts")):
+ conversation = [
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": prompt}],
+ },
+ ]
+
+ # Preparation for inference
+ text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
+ audios, images, videos = process_mm_info(conversation, use_audio_in_video=use_audio_in_video)
+ inputs = processor(
+ text=text,
+ audio=audios,
+ images=images,
+ videos=videos,
+ return_tensors="pt",
+ padding=True,
+ use_audio_in_video=use_audio_in_video,
+ )
+ inputs = inputs.to(model.device).to(model.dtype)
+
+ # Inference: Generation of the output text and audio
+ text_ids, audio = model.generate(
+ **inputs, speaker=speaker, thinker_return_dict_in_generate=True, use_audio_in_video=use_audio_in_video
+ )
+
+ # Decode output text
+ output_text = processor.batch_decode(
+ text_ids.sequences[:, inputs["input_ids"].shape[1] :],
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ )[0]
+
+ # Collect performance stats
+ perf_stats = None
+ if hasattr(model, "_perf_stats_last"):
+ perf_stats = model._perf_stats_last.copy()
+ perf_stats["prompt_idx"] = idx
+ perf_stats["prompt"] = prompt
+ all_stats.append(perf_stats)
+
+ # Save audio and collect audio output
+ audio_path = None
+ audio_data = None
+ if audio is not None:
+ audio_data = audio.reshape(-1).detach().cpu().numpy()
+ audio_path = os.path.join(audio_dir, f"output_{idx:04d}.wav")
+ sf.write(
+ audio_path,
+ audio_data,
+ samplerate=24000,
+ )
+ audio_outputs.append(audio_data)
+ else:
+ audio_outputs.append(None)
+
+ # Save result
+ result = {
+ "idx": idx,
+ "prompt": prompt,
+ "output": output_text,
+ "audio_path": audio_path,
+ "perf_stats": perf_stats,
+ }
+ results.append(result)
+
+ # Aggregate statistics
+ aggregated_stats = aggregate_stats(all_stats)
+
+ # Save all results
+ results_path = os.path.join(output_dir, "results.json")
+ with open(results_path, "w", encoding="utf-8") as f:
+ json.dump(results, f, ensure_ascii=False, indent=2)
+
+ # Save aggregated stats
+ stats_path = os.path.join(output_dir, "perf_stats.json")
+ with open(stats_path, "w", encoding="utf-8") as f:
+ json.dump({"aggregated": aggregated_stats, "per_prompt": all_stats}, f, ensure_ascii=False, indent=2)
+
+ # Count saved audio files
+ num_audio_saved = sum(1 for a in audio_outputs if a is not None)
+ print(f"\nSaved {num_audio_saved} audio files to {audio_dir}/")
+
+ return aggregated_stats, results, audio_outputs
+
+
+def aggregate_stats(all_stats: list[dict]) -> dict:
+ """Aggregate performance statistics from multiple runs."""
+ if not all_stats:
+ return {}
+
+ keys = [
+ "thinker_tokens",
+ "thinker_time_s",
+ "thinker_tps",
+ "talker_tokens",
+ "talker_time_s",
+ "talker_tps",
+ "code2wav_tokens",
+ "code2wav_time_s",
+ "code2wav_tps",
+ "total_tokens",
+ "total_time_s",
+ "total_tps",
+ ]
+
+ aggregated = {
+ "num_samples": len(all_stats),
+ }
+
+ for key in keys:
+ values = [s.get(key, 0) for s in all_stats if key in s]
+ if values:
+ aggregated[f"{key}_sum"] = sum(values)
+ aggregated[f"{key}_avg"] = sum(values) / len(values)
+ aggregated[f"{key}_min"] = min(values)
+ aggregated[f"{key}_max"] = max(values)
+
+ # Calculate overall throughput
+ total_tokens = aggregated.get("total_tokens_sum", 0)
+ total_time = aggregated.get("total_time_s_sum", 0)
+ if total_time > 0:
+ aggregated["overall_tps"] = total_tokens / total_time
+
+ return aggregated
+
+
+def print_stats(stats: dict):
+ """Print performance statistics in a formatted way."""
+ print("\n" + "=" * 60)
+ print("Performance Statistics Summary")
+ print("=" * 60)
+
+ print(f"\nNumber of samples: {stats.get('num_samples', 0)}")
+
+ print("\n--- Thinker ---")
+ print(f" Total tokens: {stats.get('thinker_tokens_sum', 0):.0f}")
+ print(f" Total time: {stats.get('thinker_time_s_sum', 0):.2f}s")
+ print(f" Avg TPS: {stats.get('thinker_tps_avg', 0):.2f}")
+ print(f" Min TPS: {stats.get('thinker_tps_min', 0):.2f}")
+ print(f" Max TPS: {stats.get('thinker_tps_max', 0):.2f}")
+
+ print("\n--- Talker ---")
+ print(f" Total tokens: {stats.get('talker_tokens_sum', 0):.0f}")
+ print(f" Total time: {stats.get('talker_time_s_sum', 0):.2f}s")
+ print(f" Avg TPS: {stats.get('talker_tps_avg', 0):.2f}")
+ print(f" Min TPS: {stats.get('talker_tps_min', 0):.2f}")
+ print(f" Max TPS: {stats.get('talker_tps_max', 0):.2f}")
+
+ print("\n--- Code2Wav ---")
+ print(f" Total tokens: {stats.get('code2wav_tokens_sum', 0):.0f}")
+ print(f" Total time: {stats.get('code2wav_time_s_sum', 0):.2f}s")
+ print(f" Avg TPS: {stats.get('code2wav_tps_avg', 0):.2f}")
+ print(f" Min TPS: {stats.get('code2wav_tps_min', 0):.2f}")
+ print(f" Max TPS: {stats.get('code2wav_tps_max', 0):.2f}")
+
+ print("\n--- Overall ---")
+ print(f" Total tokens: {stats.get('total_tokens_sum', 0):.0f}")
+ print(f" Total time: {stats.get('total_time_s_sum', 0):.2f}s")
+ print(f" Overall TPS: {stats.get('overall_tps', 0):.2f}")
+ print(f" Avg TPS: {stats.get('total_tps_avg', 0):.2f}")
+ print(f" Min TPS: {stats.get('total_tps_min', 0):.2f}")
+ print(f" Max TPS: {stats.get('total_tps_max', 0):.2f}")
+
+ print("=" * 60 + "\n")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Qwen3-Omni Benchmark Script")
+ parser.add_argument(
+ "--prompts-file",
+ type=str,
+ default="benchmark/build_dataset/top100.txt",
+ help="Path to the prompts file (one prompt per line)",
+ )
+ parser.add_argument(
+ "--output-dir", type=str, default="benchmark_results", help="Directory to save benchmark results"
+ )
+ parser.add_argument("--model-path", type=str, default=MODEL_PATH, help="Path to the model")
+ parser.add_argument("--speaker", type=str, default="Ethan", help="Speaker voice for audio output")
+ parser.add_argument("--num-prompts", type=int, default=None, help="Number of prompts to process (default: all)")
+ args = parser.parse_args()
+
+ # Load model and processor
+ print(f"Loading model from {args.model_path}...")
+ model = Qwen3OmniMoeForConditionalGenerationWithLogging.from_pretrained(
+ args.model_path,
+ dtype="auto",
+ device_map="auto",
+ attn_implementation="flash_attention_2",
+ )
+ processor = Qwen3OmniMoeProcessor.from_pretrained(args.model_path)
+
+ # Benchmark mode
+ print(f"Loading prompts from {args.prompts_file}...")
+ prompts = load_prompts(args.prompts_file)
+
+ if args.num_prompts:
+ prompts = prompts[: args.num_prompts]
+
+ print(f"Running benchmark on {len(prompts)} prompts...")
+
+ aggregated_stats, results, audio_outputs = run_benchmark(
+ model=model,
+ processor=processor,
+ prompts=prompts,
+ output_dir=args.output_dir,
+ speaker=args.speaker,
+ )
+
+ print_stats(aggregated_stats)
+ print(f"\nResults saved to {args.output_dir}/")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmarks/qwen3-omni/vllm-omni-vs-hf.png b/benchmarks/qwen3-omni/vllm-omni-vs-hf.png
new file mode 100644
index 00000000000..e47079335be
Binary files /dev/null and b/benchmarks/qwen3-omni/vllm-omni-vs-hf.png differ
diff --git a/benchmarks/qwen3-omni/vllm_omni/eval_qwen3_moe_omni.sh b/benchmarks/qwen3-omni/vllm_omni/eval_qwen3_moe_omni.sh
new file mode 100644
index 00000000000..e4c83e97510
--- /dev/null
+++ b/benchmarks/qwen3-omni/vllm_omni/eval_qwen3_moe_omni.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+# Qwen3-Omni Benchmark Evaluation Script
+# This script must be run from the vllm-omni root directory
+
+# Get the directory where this script is located
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Navigate to vllm-omni root directory (4 levels up from script location)
+VLLM_OMNI_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)"
+cd "$VLLM_OMNI_ROOT" || { echo "Error: Failed to navigate to vllm-omni directory"; exit 1; }
+
+echo "Working directory: $(pwd)"
+
+# Verify we're in the correct directory and run benchmark
+if [[ ! -d "benchmarks/qwen3-omni/vllm_omni" ]]; then
+ echo "Error: Not in vllm-omni root directory. Please run from vllm-omni folder."
+else
+ log_dir=benchmarks/qwen3-omni/vllm_omni/logs
+ outputs_dir=benchmarks/qwen3-omni/vllm_omni/outputs
+ end2end_script_path=examples/offline_inference/qwen3_omni/end2end.py
+ build_dataset_path=benchmarks/build_dataset/top100.txt
+
+ python $end2end_script_path --output-wav $outputs_dir \
+ --query-type text \
+ --txt-prompts $build_dataset_path \
+ --log-stats \
+ --log-dir $log_dir
+ echo "Logs and outputs are saved in ${log_dir} and ${outputs_dir} respectively:"
+ echo " - omni_pipeline_text run dir/base name"
+ echo " - omni_pipeline_text.orchestrator.stats.jsonl orchestrator-stage latency stats"
+ echo " - omni_pipeline_text.overall.stats.jsonl overall latency/TPS stats"
+ echo " - omni_pipeline_text.stage0.log per-stage detailed logs"
+ echo " - omni_pipeline_text.stage1.log"
+ echo " - omni_pipeline_text.stage2.log"
+ echo "Key checks: overall.stats.jsonl for end-to-end latency/TPS; orchestrator.stats.jsonl for stable per-stage latency; stage*.log for errors or long tails."
+ echo " - outputs/ Generated txt and wav files, there should be 100 text and wav files generated respectively"
+fi
diff --git a/benchmarks/qwen3-tts/README.md b/benchmarks/qwen3-tts/README.md
new file mode 100644
index 00000000000..9c01f29aa9f
--- /dev/null
+++ b/benchmarks/qwen3-tts/README.md
@@ -0,0 +1,103 @@
+# Qwen3-TTS Benchmark
+
+Benchmarks for Qwen3-TTS text-to-speech models, comparing vLLM-Omni streaming serving against HuggingFace Transformers offline inference.
+
+## Prerequisites
+
+```bash
+pip install matplotlib aiohttp soundfile numpy tqdm
+pip install qwen_tts # for HF baseline
+```
+
+## Quick Start
+
+Run the full benchmark (vllm-omni + HF baseline) with a single command:
+
+```bash
+cd benchmarks/qwen3-tts
+bash run_benchmark.sh
+```
+
+Results (JSON + PNG plots) are saved to `results/`.
+
+### Common options
+
+```bash
+# Only vllm-omni (skip HF baseline)
+bash run_benchmark.sh --async-only
+
+# Only HF baseline
+bash run_benchmark.sh --hf-only
+
+# Use a different model (e.g. 1.7B)
+MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice bash run_benchmark.sh --async-only
+
+# Use a Voice Clone model
+MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-Base TASK_TYPE=Base bash run_benchmark.sh --async-only
+
+# Use bs16 config for higher throughput
+STAGE_CONFIG=vllm_omni/configs/qwen3_tts_bs16.yaml bash run_benchmark.sh --async-only
+
+# Custom GPU, prompt count, concurrency levels
+GPU_DEVICE=1 NUM_PROMPTS=20 CONCURRENCY="1 4" bash run_benchmark.sh
+```
+
+## Manual Steps
+
+### 1) Start the vLLM-Omni server
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python -m vllm_omni.entrypoints.cli.main serve \
+ "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice" \
+ --omni --host 127.0.0.1 --port 8000 \
+ --stage-configs-path benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml \
+ --trust-remote-code
+```
+
+### 2) Run online serving benchmark
+
+```bash
+python benchmarks/qwen3-tts/vllm_omni/bench_tts_serve.py \
+ --port 8000 \
+ --num-prompts 50 \
+ --max-concurrency 1 4 10 \
+ --config-name "async_chunk" \
+ --result-dir results/
+```
+
+### 3) Run HuggingFace baseline
+
+```bash
+python benchmarks/qwen3-tts/transformers/bench_tts_hf.py \
+ --model "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice" \
+ --num-prompts 50 \
+ --gpu-device 0 \
+ --result-dir results/
+```
+
+### 4) Generate comparison plots
+
+```bash
+python benchmarks/qwen3-tts/plot_results.py \
+ --results results/bench_async_chunk_*.json results/bench_hf_transformers_*.json \
+ --labels "vllm-omni" "hf_transformers" \
+ --output results/comparison.png
+```
+
+## Stage Configs
+
+| Config | max_num_seqs | Description |
+|--------|:------------:|-------------|
+| `vllm_omni/configs/qwen3_tts_bs1.yaml` | 1 | Single-request processing (lowest latency) |
+| `vllm_omni/configs/qwen3_tts_bs16.yaml` | 16 | High-throughput concurrent processing |
+
+All configs use a 2-stage pipeline (Talker -> Code2Wav) with `async_chunk` streaming enabled. The `SharedMemoryConnector` streams codec frames (25-frame chunks with 25-frame context overlap) between stages.
+
+The model is specified via the CLI `--model` flag (or `MODEL` env var), so the same configs work for both the 0.6B and 1.7B model variants.
+
+## Metrics
+
+- **TTFP (Time to First Audio Packet)**: Time from request to first audio chunk (streaming latency)
+- **E2E (End-to-End Latency)**: Total time from request to complete audio response
+- **RTF (Real-Time Factor)**: E2E latency / audio duration. RTF < 1.0 means faster-than-real-time synthesis
+- **Throughput**: Total audio seconds generated per wall-clock second
diff --git a/benchmarks/qwen3-tts/plot_results.py b/benchmarks/qwen3-tts/plot_results.py
new file mode 100644
index 00000000000..e750101e324
--- /dev/null
+++ b/benchmarks/qwen3-tts/plot_results.py
@@ -0,0 +1,254 @@
+"""Plot Qwen3-TTS benchmark results.
+
+Generates comparison bar charts similar to the async_chunk design doc:
+- TTFP (Time-to-First-Packet) across concurrency levels
+- E2E latency across concurrency levels
+- RTF (Real-Time Factor) across concurrency levels
+
+Usage:
+ # Compare two configs (async_chunk vs no_async_chunk):
+ python plot_results.py \
+ --results results/bench_async_chunk_*.json results/bench_no_async_chunk_*.json \
+ --labels "async_chunk" "no_async_chunk" \
+ --output results/qwen3_tts_benchmark.png
+
+ # Single config:
+ python plot_results.py \
+ --results results/bench_async_chunk_*.json \
+ --labels "async_chunk" \
+ --output results/qwen3_tts_benchmark.png
+"""
+
+import argparse
+import json
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+def load_results(result_files: list[str]) -> list[list[dict]]:
+ """Load benchmark results from JSON files."""
+ all_results = []
+ for f in result_files:
+ with open(f) as fh:
+ data = json.load(fh)
+ all_results.append(data)
+ return all_results
+
+
+def plot_comparison(
+ all_results: list[list[dict]],
+ labels: list[str],
+ output_path: str,
+ title_prefix: str = "Qwen3-TTS",
+):
+ """Generate comparison bar charts."""
+ n_configs = len(all_results)
+
+ # Collect concurrency levels present in ALL configs (skip missing data)
+ all_concurrencies = [set(r["concurrency"] for r in results) for results in all_results]
+ concurrencies = sorted(set.union(*all_concurrencies))
+
+ # Build data arrays, using None for missing concurrency levels
+ ttfp_data = {label: [] for label in labels}
+ e2e_data = {label: [] for label in labels}
+ rtf_data = {label: [] for label in labels}
+ throughput_data = {label: [] for label in labels}
+
+ for results, label in zip(all_results, labels):
+ conc_map = {r["concurrency"]: r for r in results}
+ for c in concurrencies:
+ r = conc_map.get(c)
+ ttfp_data[label].append(r["mean_ttfp_ms"] if r else None)
+ e2e_data[label].append(r["mean_e2e_ms"] if r else None)
+ rtf_data[label].append(r["mean_rtf"] if r else None)
+ throughput_data[label].append(r["audio_throughput"] if r else None)
+
+ fig, axes = plt.subplots(2, 2, figsize=(14, 10))
+ fig.suptitle(f"{title_prefix} Performance Benchmark", fontsize=16, fontweight="bold")
+
+ x = np.arange(len(concurrencies))
+ width = 0.35 if n_configs == 2 else 0.5
+ if n_configs > 1:
+ offsets = np.linspace(-width / 2 * (n_configs - 1), width / 2 * (n_configs - 1), n_configs)
+ else:
+ offsets = [0]
+
+ colors = ["#2196F3", "#FF5722", "#4CAF50", "#FFC107"]
+
+ def plot_metric(ax, data_dict, ylabel, title, fmt=".1f"):
+ bars = []
+ for i, (label, values) in enumerate(data_dict.items()):
+ # Replace None with 0 for plotting, but track which are missing
+ plot_values = [v if v is not None else 0 for v in values]
+ color = colors[i % len(colors)]
+ bar = ax.bar(x + offsets[i], plot_values, width, label=label, color=color, alpha=0.85)
+ bars.append(bar)
+ # Add value labels on bars (skip None/missing data)
+ max_val = max((v for v in values if v is not None), default=1)
+ for rect, val in zip(bar, values):
+ if val is not None and val > 0:
+ ax.text(
+ rect.get_x() + rect.get_width() / 2,
+ rect.get_height() + max_val * 0.02,
+ f"{val:{fmt}}",
+ ha="center",
+ va="bottom",
+ fontsize=9,
+ fontweight="bold",
+ )
+ ax.set_xlabel("Concurrency", fontsize=12)
+ ax.set_ylabel(ylabel, fontsize=12)
+ ax.set_title(title, fontsize=13, fontweight="bold")
+ ax.set_xticks(x)
+ ax.set_xticklabels([str(c) for c in concurrencies])
+ ax.legend(fontsize=10)
+ ax.grid(axis="y", alpha=0.3)
+ ax.set_axisbelow(True)
+
+ plot_metric(axes[0, 0], ttfp_data, "TTFP (ms)", "Time to First Audio Packet (TTFP)")
+ plot_metric(axes[0, 1], e2e_data, "E2E Latency (ms)", "End-to-End Latency (E2E)")
+ plot_metric(axes[1, 0], rtf_data, "RTF", "Real-Time Factor (RTF)", fmt=".3f")
+ plot_metric(axes[1, 1], throughput_data, "Audio-sec / Wall-sec", "Audio Throughput", fmt=".2f")
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
+ print(f"Plot saved to {output_path}")
+ plt.close()
+
+
+def plot_single_summary(results: list[dict], label: str, output_path: str):
+ """Generate a single-config summary with percentile breakdown."""
+ concurrencies = [r["concurrency"] for r in results]
+
+ fig, axes = plt.subplots(1, 3, figsize=(16, 5))
+ fig.suptitle(f"Qwen3-TTS Benchmark - {label}", fontsize=15, fontweight="bold")
+
+ # TTFP breakdown
+ ax = axes[0]
+ means = [r["mean_ttfp_ms"] for r in results]
+ medians = [r["median_ttfp_ms"] for r in results]
+ p90s = [r["p90_ttfp_ms"] for r in results]
+ p99s = [r["p99_ttfp_ms"] for r in results]
+ x = np.arange(len(concurrencies))
+ w = 0.2
+ ax.bar(x - 1.5 * w, means, w, label="mean", color="#2196F3")
+ ax.bar(x - 0.5 * w, medians, w, label="median", color="#4CAF50")
+ ax.bar(x + 0.5 * w, p90s, w, label="p90", color="#FF9800")
+ ax.bar(x + 1.5 * w, p99s, w, label="p99", color="#F44336")
+ ax.set_xticks(x)
+ ax.set_xticklabels([str(c) for c in concurrencies])
+ ax.set_xlabel("Concurrency")
+ ax.set_ylabel("TTFP (ms)")
+ ax.set_title("Time to First Audio Packet")
+ ax.legend(fontsize=9)
+ ax.grid(axis="y", alpha=0.3)
+
+ # E2E breakdown
+ ax = axes[1]
+ means = [r["mean_e2e_ms"] for r in results]
+ medians = [r["median_e2e_ms"] for r in results]
+ p90s = [r["p90_e2e_ms"] for r in results]
+ p99s = [r["p99_e2e_ms"] for r in results]
+ ax.bar(x - 1.5 * w, means, w, label="mean", color="#2196F3")
+ ax.bar(x - 0.5 * w, medians, w, label="median", color="#4CAF50")
+ ax.bar(x + 0.5 * w, p90s, w, label="p90", color="#FF9800")
+ ax.bar(x + 1.5 * w, p99s, w, label="p99", color="#F44336")
+ ax.set_xticks(x)
+ ax.set_xticklabels([str(c) for c in concurrencies])
+ ax.set_xlabel("Concurrency")
+ ax.set_ylabel("E2E Latency (ms)")
+ ax.set_title("End-to-End Latency")
+ ax.legend(fontsize=9)
+ ax.grid(axis="y", alpha=0.3)
+
+ # RTF
+ ax = axes[2]
+ means = [r["mean_rtf"] for r in results]
+ medians = [r["median_rtf"] for r in results]
+ ax.bar(x - 0.15, means, 0.3, label="mean", color="#2196F3")
+ ax.bar(x + 0.15, medians, 0.3, label="median", color="#4CAF50")
+ ax.set_xticks(x)
+ ax.set_xticklabels([str(c) for c in concurrencies])
+ ax.set_xlabel("Concurrency")
+ ax.set_ylabel("RTF")
+ ax.set_title("Real-Time Factor")
+ ax.legend(fontsize=9)
+ ax.grid(axis="y", alpha=0.3)
+
+ plt.tight_layout()
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
+ print(f"Plot saved to {output_path}")
+ plt.close()
+
+
+def print_comparison_table(all_results: list[list[dict]], labels: list[str]):
+ """Print a markdown-formatted comparison table."""
+ concurrencies = sorted(set(r["concurrency"] for r in all_results[0]))
+
+ print("\n## Benchmark Results\n")
+ header = "| Metric | Concurrency |"
+ sep = "| --- | --- |"
+ for label in labels:
+ header += f" {label} |"
+ sep += " --- |"
+ print(header)
+ print(sep)
+
+ for metric, key, fmt in [
+ ("TTFP (ms)", "mean_ttfp_ms", ".1f"),
+ ("E2E (ms)", "mean_e2e_ms", ".1f"),
+ ("RTF", "mean_rtf", ".3f"),
+ ("Throughput (audio-s/s)", "audio_throughput", ".2f"),
+ ]:
+ for c in concurrencies:
+ row = f"| {metric} | {c} |"
+ for results in all_results:
+ conc_map = {r["concurrency"]: r for r in results}
+ val = conc_map.get(c, {}).get(key, 0)
+ row += f" {val:{fmt}} |"
+ print(row)
+
+ # Improvement calculation (only if 2 configs)
+ if len(all_results) == 2:
+ print(f"\n## Improvement ({labels[0]} vs {labels[1]})\n")
+ print("| Metric | Concurrency | Improvement |")
+ print("| --- | --- | --- |")
+ for metric, key in [("TTFP", "mean_ttfp_ms"), ("E2E", "mean_e2e_ms"), ("RTF", "mean_rtf")]:
+ for c in concurrencies:
+ m0 = {r["concurrency"]: r for r in all_results[0]}
+ m1 = {r["concurrency"]: r for r in all_results[1]}
+ v0 = m0.get(c, {}).get(key, 0)
+ v1 = m1.get(c, {}).get(key, 0)
+ if v1 > 0:
+ pct = (v1 - v0) / v1 * 100
+ print(f"| {metric} | {c} | {pct:+.1f}% |")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Plot Qwen3-TTS benchmark results")
+ parser.add_argument(
+ "--results", type=str, nargs="+", required=True, help="Path(s) to result JSON files (one per config)"
+ )
+ parser.add_argument(
+ "--labels", type=str, nargs="+", required=True, help="Labels for each config (must match --results count)"
+ )
+ parser.add_argument("--output", type=str, default="results/qwen3_tts_benchmark.png", help="Output image path")
+ parser.add_argument("--title", type=str, default="Qwen3-TTS", help="Title prefix for the plot")
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ assert len(args.results) == len(args.labels), "--results and --labels must have the same count"
+
+ all_results = load_results(args.results)
+ print_comparison_table(all_results, args.labels)
+
+ Path(args.output).parent.mkdir(parents=True, exist_ok=True)
+
+ if len(all_results) == 1:
+ plot_single_summary(all_results[0], args.labels[0], args.output)
+ else:
+ plot_comparison(all_results, args.labels, args.output, title_prefix=args.title)
diff --git a/benchmarks/qwen3-tts/results/.gitignore b/benchmarks/qwen3-tts/results/.gitignore
new file mode 100644
index 00000000000..5b6759ef717
--- /dev/null
+++ b/benchmarks/qwen3-tts/results/.gitignore
@@ -0,0 +1,3 @@
+# Benchmark results are machine-specific - do not commit
+*
+!.gitignore
diff --git a/benchmarks/qwen3-tts/run_benchmark.sh b/benchmarks/qwen3-tts/run_benchmark.sh
new file mode 100755
index 00000000000..283b6b844c1
--- /dev/null
+++ b/benchmarks/qwen3-tts/run_benchmark.sh
@@ -0,0 +1,280 @@
+#!/bin/bash
+# Qwen3-TTS Benchmark Runner
+#
+# Compares vllm-omni streaming serving vs HuggingFace transformers offline inference.
+# Produces JSON results and comparison plots.
+#
+# Usage:
+# # Full comparison (vllm-omni + HF):
+# bash run_benchmark.sh
+#
+# # Only vllm-omni async_chunk config:
+# bash run_benchmark.sh --async-only
+#
+# # Only HuggingFace baseline:
+# bash run_benchmark.sh --hf-only
+#
+# # vllm-omni only (skip HF):
+# bash run_benchmark.sh --skip-hf
+#
+# # Custom settings:
+# GPU_DEVICE=1 NUM_PROMPTS=20 CONCURRENCY="1 4" bash run_benchmark.sh
+#
+# # Use 1.7B model:
+# MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice bash run_benchmark.sh --async-only
+#
+# # Use Voice Clone model
+# MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-Base TASK_TYPE=Base bash run_benchmark.sh --async-only
+#
+# # Use batch_size=4 config:
+# STAGE_CONFIG=vllm_omni/configs/qwen3_tts_bs4.yaml bash run_benchmark.sh --async-only
+#
+# Environment variables:
+# GPU_DEVICE - GPU index to use (default: 0)
+# NUM_PROMPTS - Number of prompts per concurrency level (default: 50)
+# CONCURRENCY - Space-separated concurrency levels (default: "1 4 10")
+# MODEL - Model name (default: Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice)
+# PORT - Server port (default: 8000)
+# GPU_MEM_TALKER - gpu_memory_utilization for talker stage (default: 0.3)
+# GPU_MEM_CODE2WAV - gpu_memory_utilization for code2wav stage (default: 0.2)
+# STAGE_CONFIG - Path to stage config YAML (default: configs/qwen3_tts_bs1.yaml)
+# TASK_TYPE - Task type: CustomVoice, VoiceDesign, Base (default: CustomVoice)
+
+set -euo pipefail
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+PROJECT_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"
+
+# Defaults
+GPU_DEVICE="${GPU_DEVICE:-0}"
+NUM_PROMPTS="${NUM_PROMPTS:-50}"
+CONCURRENCY="${CONCURRENCY:-1 4 10}"
+MODEL="${MODEL:-Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice}"
+PORT="${PORT:-8000}"
+GPU_MEM_TALKER="${GPU_MEM_TALKER:-0.3}"
+GPU_MEM_CODE2WAV="${GPU_MEM_CODE2WAV:-0.2}"
+NUM_WARMUPS="${NUM_WARMUPS:-3}"
+STAGE_CONFIG="${STAGE_CONFIG:-vllm_omni/configs/qwen3_tts_bs1.yaml}"
+RESULT_DIR="${SCRIPT_DIR}/results"
+TIMESTAMP="$(date +%Y%m%d_%H%M%S)"
+TASK_TYPE="${TASK_TYPE:-CustomVoice}"
+
+# Parse args
+RUN_ASYNC=true
+RUN_HF=true
+for arg in "$@"; do
+ case "$arg" in
+ --async-only) RUN_HF=false ;;
+ --hf-only) RUN_ASYNC=false ;;
+ --skip-hf) RUN_HF=false ;;
+ esac
+done
+
+mkdir -p "${RESULT_DIR}"
+
+echo "============================================================"
+echo " Qwen3-TTS Benchmark"
+echo "============================================================"
+echo " GPU: ${GPU_DEVICE}"
+echo " Model: ${MODEL}"
+echo " Prompts: ${NUM_PROMPTS}"
+echo " Concurrency: ${CONCURRENCY}"
+echo " Port: ${PORT}"
+echo " Stage config: ${STAGE_CONFIG}"
+echo " Results: ${RESULT_DIR}"
+echo " Task type: ${TASK_TYPE}"
+echo "============================================================"
+
+# Prepare stage config with correct GPU device and memory settings
+prepare_config() {
+ local config_template="$1"
+ local config_name="$2"
+ local output_path="${RESULT_DIR}/${config_name}_stage_config.yaml"
+
+ # Use sed to patch GPU device and memory utilization
+ sed \
+ -e "s/devices: \"0\"/devices: \"${GPU_DEVICE}\"/g" \
+ -e "s/gpu_memory_utilization: 0.3/gpu_memory_utilization: ${GPU_MEM_TALKER}/g" \
+ -e "s/gpu_memory_utilization: 0.2/gpu_memory_utilization: ${GPU_MEM_CODE2WAV}/g" \
+ "${config_template}" > "${output_path}"
+
+ echo "${output_path}"
+}
+
+# Start server and wait for it to be ready
+start_server() {
+ local stage_config="$1"
+ local config_name="$2"
+ local log_file="${RESULT_DIR}/server_${config_name}_${TIMESTAMP}.log"
+
+ echo ""
+ echo "Starting server with config: ${config_name}"
+ echo " Stage config: ${stage_config}"
+ echo " Log file: ${log_file}"
+
+ VLLM_WORKER_MULTIPROC_METHOD=spawn \
+ CUDA_VISIBLE_DEVICES="${GPU_DEVICE}" \
+ python -m vllm_omni.entrypoints.cli.main serve "${MODEL}" \
+ --omni \
+ --host 127.0.0.1 \
+ --port "${PORT}" \
+ --stage-configs-path "${stage_config}" \
+ --stage-init-timeout 120 \
+ --trust-remote-code \
+ --disable-log-stats \
+ > "${log_file}" 2>&1 &
+
+ SERVER_PID=$!
+ echo " Server PID: ${SERVER_PID}"
+
+ # Wait for server to be ready
+ echo " Waiting for server to be ready..."
+ local max_wait=300
+ local waited=0
+ while [ ${waited} -lt ${max_wait} ]; do
+ if curl -sf "http://127.0.0.1:${PORT}/v1/models" > /dev/null 2>&1; then
+ echo " Server is ready! (waited ${waited}s)"
+ return 0
+ fi
+ # Check if process is still alive
+ if ! kill -0 ${SERVER_PID} 2>/dev/null; then
+ echo " ERROR: Server process died. Check log: ${log_file}"
+ tail -20 "${log_file}"
+ return 1
+ fi
+ sleep 2
+ waited=$((waited + 2))
+ done
+
+ echo " ERROR: Server did not start within ${max_wait}s. Check log: ${log_file}"
+ kill ${SERVER_PID} 2>/dev/null || true
+ return 1
+}
+
+# Stop the server
+stop_server() {
+ if [ -n "${SERVER_PID:-}" ]; then
+ echo " Stopping server (PID: ${SERVER_PID})..."
+ kill ${SERVER_PID} 2>/dev/null || true
+ wait ${SERVER_PID} 2>/dev/null || true
+ # Kill any remaining child processes on the port
+ local pids
+ pids=$(lsof -ti:${PORT} 2>/dev/null || true)
+ if [ -n "${pids}" ]; then
+ echo " Cleaning up remaining processes on port ${PORT}..."
+ echo "${pids}" | xargs kill -9 2>/dev/null || true
+ fi
+ echo " Server stopped."
+ SERVER_PID=""
+ fi
+}
+
+# Cleanup on exit
+trap 'stop_server' EXIT
+
+# Run benchmark for a given config
+run_bench() {
+ local config_name="$1"
+ local config_template="$2"
+
+ echo ""
+ echo "============================================================"
+ echo " Benchmarking: ${config_name}"
+ echo "============================================================"
+
+ local stage_config
+ stage_config=$(prepare_config "${config_template}" "${config_name}")
+
+ start_server "${stage_config}" "${config_name}"
+
+ # Convert concurrency string to args
+ local conc_args=""
+ for c in ${CONCURRENCY}; do
+ conc_args="${conc_args} ${c}"
+ done
+
+ cd "${PROJECT_ROOT}"
+ python "${SCRIPT_DIR}/vllm_omni/bench_tts_serve.py" \
+ --host 127.0.0.1 \
+ --port "${PORT}" \
+ --num-prompts "${NUM_PROMPTS}" \
+ --max-concurrency ${conc_args} \
+ --num-warmups "${NUM_WARMUPS}" \
+ --config-name "${config_name}" \
+ --result-dir "${RESULT_DIR}" \
+ --task-type "${TASK_TYPE}"
+
+ stop_server
+
+ # Allow GPU memory to settle
+ sleep 5
+}
+
+# Run vllm-omni benchmark
+if [ "${RUN_ASYNC}" = true ]; then
+ run_bench "async_chunk" "${SCRIPT_DIR}/${STAGE_CONFIG}"
+fi
+
+# Run HuggingFace baseline benchmark
+if [ "${RUN_HF}" = true ]; then
+ echo ""
+ echo "============================================================"
+ echo " Benchmarking: HuggingFace transformers (offline)"
+ echo "============================================================"
+
+ cd "${PROJECT_ROOT}"
+ python "${SCRIPT_DIR}/transformers/bench_tts_hf.py" \
+ --model "${MODEL}" \
+ --num-prompts "${NUM_PROMPTS}" \
+ --num-warmups "${NUM_WARMUPS}" \
+ --gpu-device "${GPU_DEVICE}" \
+ --config-name "hf_transformers" \
+ --result-dir "${RESULT_DIR}" \
+ --task-type "${TASK_TYPE}"
+
+ # Allow GPU memory to settle
+ sleep 5
+fi
+
+# Plot results
+echo ""
+echo "============================================================"
+echo " Generating plots..."
+echo "============================================================"
+
+RESULT_FILES=""
+LABELS=""
+
+if [ "${RUN_ASYNC}" = true ]; then
+ ASYNC_FILE=$(ls -t "${RESULT_DIR}"/bench_async_chunk_*.json 2>/dev/null | head -1)
+ if [ -n "${ASYNC_FILE}" ]; then
+ RESULT_FILES="${ASYNC_FILE}"
+ LABELS="async_chunk"
+ fi
+fi
+
+if [ "${RUN_HF}" = true ]; then
+ HF_FILE=$(ls -t "${RESULT_DIR}"/bench_hf_transformers_*.json 2>/dev/null | head -1)
+ if [ -n "${HF_FILE}" ]; then
+ if [ -n "${RESULT_FILES}" ]; then
+ RESULT_FILES="${RESULT_FILES} ${HF_FILE}"
+ LABELS="${LABELS} hf_transformers"
+ else
+ RESULT_FILES="${HF_FILE}"
+ LABELS="hf_transformers"
+ fi
+ fi
+fi
+
+if [ -n "${RESULT_FILES}" ]; then
+ python "${SCRIPT_DIR}/plot_results.py" \
+ --results ${RESULT_FILES} \
+ --labels ${LABELS} \
+ --output "${RESULT_DIR}/qwen3_tts_benchmark_${TIMESTAMP}.png"
+fi
+
+echo ""
+echo "============================================================"
+echo " Benchmark complete!"
+echo " Results: ${RESULT_DIR}"
+echo "============================================================"
diff --git a/benchmarks/qwen3-tts/transformers/bench_tts_hf.py b/benchmarks/qwen3-tts/transformers/bench_tts_hf.py
new file mode 100644
index 00000000000..ed04ee264c4
--- /dev/null
+++ b/benchmarks/qwen3-tts/transformers/bench_tts_hf.py
@@ -0,0 +1,301 @@
+"""Benchmark Qwen3-TTS using HuggingFace transformers (qwen_tts library).
+
+Measures E2E latency, RTF, and audio duration for offline (non-serving) inference.
+Results are saved in the same JSON format as bench_tts_serve.py for unified plotting.
+
+Usage:
+ python bench_tts_hf.py \
+ --model Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice \
+ --num-prompts 50 \
+ --num-warmups 3 \
+ --gpu-device 0 \
+ --result-dir results/
+"""
+
+import argparse
+import json
+import time
+from dataclasses import asdict, dataclass, field
+from datetime import datetime
+from pathlib import Path
+
+import numpy as np
+import soundfile as sf
+import torch
+
+PROMPTS = [
+ "Hello, welcome to the voice synthesis benchmark test.",
+ "She said she would be here by noon, but nobody showed up.",
+ "The quick brown fox jumps over the lazy dog near the riverbank.",
+ "I can't believe how beautiful the sunset looks from up here on the mountain.",
+ "Please remember to bring your identification documents to the appointment tomorrow morning.",
+ "Have you ever wondered what it would be like to travel through time and visit ancient civilizations?",
+ "The restaurant on the corner serves the best pasta I have ever tasted in my entire life.",
+ "After the meeting, we should discuss the quarterly results and plan for the next phase.",
+ "Learning a new language takes patience, practice, and a genuine curiosity about other cultures.",
+ "The train leaves at half past seven, so we need to arrive at the station before then.",
+ "Could you please turn down the music a little bit, I'm trying to concentrate on my work.",
+ "It was a dark and stormy night when the old lighthouse keeper heard a knock at the door.",
+]
+
+REF_AUDIO = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav"
+REF_TEXT = "Okay. Yeah. I resent you. I love you. I respect you. But you know what? You blew it! And thanks to you."
+INSTRUCT = "Speak in an incredulous tone, but with a hint of panic beginning to creep into your voice."
+
+
+@dataclass
+class BenchmarkResult:
+ config_name: str = ""
+ concurrency: int = 1 # always 1 for offline
+ num_prompts: int = 0
+ completed: int = 0
+ failed: int = 0
+ duration_s: float = 0.0
+ # TTFP stats - not applicable for HF offline, set to E2E for compatibility
+ mean_ttfp_ms: float = 0.0
+ median_ttfp_ms: float = 0.0
+ std_ttfp_ms: float = 0.0
+ p90_ttfp_ms: float = 0.0
+ p95_ttfp_ms: float = 0.0
+ p99_ttfp_ms: float = 0.0
+ # E2E stats (ms)
+ mean_e2e_ms: float = 0.0
+ median_e2e_ms: float = 0.0
+ std_e2e_ms: float = 0.0
+ p90_e2e_ms: float = 0.0
+ p95_e2e_ms: float = 0.0
+ p99_e2e_ms: float = 0.0
+ # RTF stats
+ mean_rtf: float = 0.0
+ median_rtf: float = 0.0
+ std_rtf: float = 0.0
+ p99_rtf: float = 0.0
+ # Audio stats
+ mean_audio_duration_s: float = 0.0
+ total_audio_duration_s: float = 0.0
+ audio_throughput: float = 0.0
+ request_throughput: float = 0.0
+ # Per-request details
+ per_request: list = field(default_factory=list)
+
+
+def generate_audio(model, prompt: str, args):
+ if args.task_type == "Base":
+ return model.generate_voice_clone(
+ text=prompt,
+ language=args.language,
+ ref_audio=REF_AUDIO,
+ ref_text=REF_TEXT,
+ )
+
+ if args.task_type == "VoiceDesign":
+ return model.generate_voice_design(
+ text=prompt,
+ language=args.language,
+ instruct=INSTRUCT,
+ )
+
+ return model.generate_custom_voice(
+ text=prompt,
+ language=args.language,
+ speaker=args.voice,
+ )
+
+
+def run_benchmark(args):
+ from qwen_tts import Qwen3TTSModel
+
+ device = f"cuda:{args.gpu_device}"
+ print(f"Loading model: {args.model} on {device}")
+ model = Qwen3TTSModel.from_pretrained(
+ args.model,
+ device_map=device,
+ dtype=torch.bfloat16,
+ )
+ print("Model loaded.")
+
+ # Build prompt list
+ prompts = [PROMPTS[i % len(PROMPTS)] for i in range(args.num_prompts)]
+
+ # Warmup
+ if args.num_warmups > 0:
+ print(f"Warming up with {args.num_warmups} requests...")
+ for i in range(args.num_warmups):
+ p = PROMPTS[i % len(PROMPTS)]
+ wavs, sr = generate_audio(model, p, args)
+ # Sync GPU
+ torch.cuda.synchronize(device)
+ print("Warmup done.")
+
+ # Benchmark
+ print(f"Running {args.num_prompts} requests sequentially...")
+ e2e_times = []
+ rtfs = []
+ audio_durations = []
+ per_request = []
+ failed = 0
+
+ audio_dir = None
+ if args.save_audio:
+ audio_dir = Path(args.result_dir) / "audio_hf"
+ audio_dir.mkdir(parents=True, exist_ok=True)
+
+ total_start = time.perf_counter()
+
+ for i, prompt in enumerate(prompts):
+ try:
+ torch.cuda.synchronize(device)
+ st = time.perf_counter()
+
+ wavs, sr = generate_audio(model, prompt, args)
+
+ torch.cuda.synchronize(device)
+ elapsed = time.perf_counter() - st
+
+ # Compute audio duration
+ audio_samples = wavs[0]
+ if isinstance(audio_samples, torch.Tensor):
+ audio_samples = audio_samples.cpu().numpy()
+ audio_dur = len(audio_samples) / sr
+
+ rtf = elapsed / audio_dur if audio_dur > 0 else 0.0
+
+ e2e_times.append(elapsed)
+ rtfs.append(rtf)
+ audio_durations.append(audio_dur)
+ per_request.append(
+ {
+ "e2e_ms": elapsed * 1000,
+ "ttfp_ms": elapsed * 1000, # no streaming, TTFP = E2E
+ "rtf": rtf,
+ "audio_duration_s": audio_dur,
+ "prompt": prompt,
+ }
+ )
+
+ if audio_dir:
+ sf.write(str(audio_dir / f"output_{i:04d}.wav"), audio_samples, sr)
+
+ if (i + 1) % 10 == 0 or i == 0:
+ print(
+ f" [{i + 1}/{args.num_prompts}] e2e={elapsed * 1000:.0f}ms rtf={rtf:.3f} audio={audio_dur:.2f}s"
+ )
+
+ except Exception as e:
+ print(f" [{i + 1}/{args.num_prompts}] FAILED: {e}")
+ failed += 1
+
+ total_duration = time.perf_counter() - total_start
+ completed = len(e2e_times)
+
+ # Compute stats
+ result = BenchmarkResult(
+ config_name=args.config_name,
+ concurrency=1,
+ num_prompts=args.num_prompts,
+ completed=completed,
+ failed=failed,
+ duration_s=total_duration,
+ )
+
+ if e2e_times:
+ e2e_ms = [t * 1000 for t in e2e_times]
+
+ result.mean_e2e_ms = float(np.mean(e2e_ms))
+ result.median_e2e_ms = float(np.median(e2e_ms))
+ result.std_e2e_ms = float(np.std(e2e_ms))
+ result.p90_e2e_ms = float(np.percentile(e2e_ms, 90))
+ result.p95_e2e_ms = float(np.percentile(e2e_ms, 95))
+ result.p99_e2e_ms = float(np.percentile(e2e_ms, 99))
+
+ # For HF offline, TTFP = E2E (no streaming)
+ result.mean_ttfp_ms = result.mean_e2e_ms
+ result.median_ttfp_ms = result.median_e2e_ms
+ result.std_ttfp_ms = result.std_e2e_ms
+ result.p90_ttfp_ms = result.p90_e2e_ms
+ result.p95_ttfp_ms = result.p95_e2e_ms
+ result.p99_ttfp_ms = result.p99_e2e_ms
+
+ result.mean_rtf = float(np.mean(rtfs))
+ result.median_rtf = float(np.median(rtfs))
+ result.std_rtf = float(np.std(rtfs))
+ result.p99_rtf = float(np.percentile(rtfs, 99))
+
+ result.mean_audio_duration_s = float(np.mean(audio_durations))
+ result.total_audio_duration_s = float(np.sum(audio_durations))
+ result.audio_throughput = result.total_audio_duration_s / total_duration
+ result.request_throughput = completed / total_duration
+ result.per_request = per_request
+
+ # Print summary in standardized performance template
+ W = 50
+ print("")
+ print(f"{'=' * W}")
+ print(f"{'Serving Benchmark Result':^{W}}")
+ print(f"{'=' * W}")
+ print(f"{'Successful requests:':<40}{completed:<10}")
+ print(f"{'Failed requests:':<40}{failed:<10}")
+ print(f"{'Maximum request concurrency:':<40}{1:<10}")
+ print(f"{'Benchmark duration (s):':<40}{total_duration:<10.2f}")
+ print(f"{'Request throughput (req/s):':<40}{result.request_throughput:<10.2f}")
+ print(f"{'-' * W}")
+ print(f"{'End-to-end Latency':^{W}}")
+ print(f"{'-' * W}")
+ print(f"{'Mean E2EL (ms):':<40}{result.mean_e2e_ms:<10.2f}")
+ print(f"{'Median E2EL (ms):':<40}{result.median_e2e_ms:<10.2f}")
+ print(f"{'P99 E2EL (ms):':<40}{result.p99_e2e_ms:<10.2f}")
+ print(f"{'=' * W}")
+ print(f"{'Audio Result':^{W}}")
+ print(f"{'=' * W}")
+ print(f"{'Total audio duration generated (s):':<40}{result.total_audio_duration_s:<10.2f}")
+ print(f"{'Audio throughput (audio duration/s):':<40}{result.audio_throughput:<10.2f}")
+ print(f"{'-' * W}")
+ print(f"{'Time to First Packet':^{W}}")
+ print(f"{'-' * W}")
+ print(f"{'Mean AUDIO_TTFP (ms):':<40}{result.mean_ttfp_ms:<10.2f}")
+ print(f"{'Median AUDIO_TTFP (ms):':<40}{result.median_ttfp_ms:<10.2f}")
+ print(f"{'P99 AUDIO_TTFP (ms):':<40}{result.p99_ttfp_ms:<10.2f}")
+ print(f"{'-' * W}")
+ print(f"{'Real Time Factor':^{W}}")
+ print(f"{'-' * W}")
+ print(f"{'Mean AUDIO_RTF:':<40}{result.mean_rtf:<10.3f}")
+ print(f"{'Median AUDIO_RTF:':<40}{result.median_rtf:<10.3f}")
+ print(f"{'P99 AUDIO_RTF:':<40}{result.p99_rtf:<10.3f}")
+ print(f"{'=' * W}")
+ print("")
+
+ # Save results (as a list with single concurrency=1 entry, matching serve format)
+ result_dir = Path(args.result_dir)
+ result_dir.mkdir(parents=True, exist_ok=True)
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ result_file = result_dir / f"bench_{args.config_name}_{timestamp}.json"
+
+ with open(result_file, "w") as f:
+ json.dump([asdict(result)], f, indent=2)
+ print(f"Results saved to {result_file}")
+
+ return result
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Qwen3-TTS HuggingFace Benchmark")
+ parser.add_argument(
+ "--model", type=str, default="Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice", help="HuggingFace model name or path"
+ )
+ parser.add_argument("--num-prompts", type=int, default=50)
+ parser.add_argument("--num-warmups", type=int, default=3)
+ parser.add_argument("--gpu-device", type=int, default=0)
+ parser.add_argument("--voice", type=str, default="Vivian")
+ parser.add_argument("--language", type=str, default="English")
+ parser.add_argument("--task-type", type=str, default="CustomVoice", choices=["CustomVoice", "VoiceDesign", "Base"])
+ parser.add_argument(
+ "--config-name", type=str, default="hf_transformers", help="Label for this config (used in filenames)"
+ )
+ parser.add_argument("--result-dir", type=str, default="results")
+ parser.add_argument("--save-audio", action="store_true", help="Save generated audio files")
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ run_benchmark(args)
diff --git a/benchmarks/qwen3-tts/vllm_omni/bench_async_chunk.py b/benchmarks/qwen3-tts/vllm_omni/bench_async_chunk.py
new file mode 100644
index 00000000000..3497ae82152
--- /dev/null
+++ b/benchmarks/qwen3-tts/vllm_omni/bench_async_chunk.py
@@ -0,0 +1,301 @@
+"""Benchmark comparing async_chunk on vs off for Qwen3-TTS.
+
+Measures TTFP (Time-to-First-Packet), E2E latency, and RTF across
+concurrency levels for both async_chunk modes. Saves results as JSON.
+
+Usage:
+ # Run against a server already serving with a given config:
+ python bench_async_chunk.py \
+ --host 127.0.0.1 --port 8000 \
+ --config-name async_chunk_on \
+ --num-prompts 50 \
+ --max-concurrency 1 10 \
+ --result-dir results/
+"""
+
+import argparse
+import asyncio
+import json
+import time
+from dataclasses import asdict, dataclass, field
+from datetime import datetime
+from pathlib import Path
+
+import aiohttp
+import numpy as np
+from tqdm.asyncio import tqdm
+
+PROMPTS = [
+ "Hello, welcome to the voice synthesis benchmark test.",
+ "She said she would be here by noon, but nobody showed up.",
+ "The quick brown fox jumps over the lazy dog near the riverbank.",
+ "I can't believe how beautiful the sunset looks from up here on the mountain.",
+ "Please remember to bring your identification documents to the appointment tomorrow morning.",
+ "Have you ever wondered what it would be like to travel through time and visit ancient civilizations?",
+ "The restaurant on the corner serves the best pasta I have ever tasted in my entire life.",
+ "After the meeting, we should discuss the quarterly results and plan for the next phase.",
+ "Learning a new language takes patience, practice, and a genuine curiosity about other cultures.",
+ "The train leaves at half past seven, so we need to arrive at the station before then.",
+ "Could you please turn down the music a little bit, I'm trying to concentrate on my work.",
+ "It was a dark and stormy night when the old lighthouse keeper heard a knock at the door.",
+]
+
+
+@dataclass
+class RequestResult:
+ success: bool = False
+ ttfp: float = 0.0
+ e2e: float = 0.0
+ audio_bytes: int = 0
+ audio_duration: float = 0.0
+ rtf: float = 0.0
+ prompt: str = ""
+ error: str = ""
+
+
+@dataclass
+class BenchmarkResult:
+ config_name: str = ""
+ concurrency: int = 0
+ num_prompts: int = 0
+ completed: int = 0
+ failed: int = 0
+ duration_s: float = 0.0
+ mean_ttfp_ms: float = 0.0
+ median_ttfp_ms: float = 0.0
+ std_ttfp_ms: float = 0.0
+ p90_ttfp_ms: float = 0.0
+ p95_ttfp_ms: float = 0.0
+ p99_ttfp_ms: float = 0.0
+ mean_e2e_ms: float = 0.0
+ median_e2e_ms: float = 0.0
+ std_e2e_ms: float = 0.0
+ p90_e2e_ms: float = 0.0
+ p95_e2e_ms: float = 0.0
+ p99_e2e_ms: float = 0.0
+ mean_rtf: float = 0.0
+ median_rtf: float = 0.0
+ std_rtf: float = 0.0
+ mean_audio_duration_s: float = 0.0
+ total_audio_duration_s: float = 0.0
+ audio_throughput: float = 0.0
+ request_throughput: float = 0.0
+ per_request: list = field(default_factory=list)
+
+
+def pcm_bytes_to_duration(num_bytes: int, sample_rate: int = 24000, sample_width: int = 2) -> float:
+ return num_bytes / sample_width / sample_rate
+
+
+async def send_tts_request(
+ session: aiohttp.ClientSession,
+ api_url: str,
+ prompt: str,
+ voice: str = "vivian",
+ language: str = "English",
+ stream: bool = True,
+ pbar: tqdm | None = None,
+) -> RequestResult:
+ payload = {
+ "input": prompt,
+ "voice": voice,
+ "language": language,
+ "stream": stream,
+ "response_format": "pcm",
+ }
+
+ result = RequestResult(prompt=prompt)
+ st = time.perf_counter()
+
+ try:
+ async with session.post(api_url, json=payload) as response:
+ if response.status != 200:
+ result.error = f"HTTP {response.status}: {await response.text()}"
+ return result
+
+ first_chunk = True
+ total_bytes = 0
+
+ async for chunk in response.content.iter_any():
+ if first_chunk and len(chunk) > 0:
+ result.ttfp = time.perf_counter() - st
+ first_chunk = False
+ total_bytes += len(chunk)
+
+ result.e2e = time.perf_counter() - st
+ result.audio_bytes = total_bytes
+ result.audio_duration = pcm_bytes_to_duration(total_bytes)
+ if result.audio_duration > 0:
+ result.rtf = result.e2e / result.audio_duration
+ result.success = True
+
+ except Exception as e:
+ result.error = str(e)
+ result.e2e = time.perf_counter() - st
+
+ if pbar:
+ pbar.update(1)
+ return result
+
+
+async def run_benchmark(
+ host: str,
+ port: int,
+ num_prompts: int,
+ max_concurrency: int,
+ num_warmups: int = 3,
+ voice: str = "vivian",
+ language: str = "English",
+ stream: bool = True,
+) -> BenchmarkResult:
+ api_url = f"http://{host}:{port}/v1/audio/speech"
+
+ connector = aiohttp.TCPConnector(limit=max_concurrency, limit_per_host=max_concurrency, keepalive_timeout=60)
+ session = aiohttp.ClientSession(connector=connector, timeout=aiohttp.ClientTimeout(total=600))
+
+ if num_warmups > 0:
+ print(f" Warming up with {num_warmups} requests...")
+ warmup_tasks = [
+ send_tts_request(session, api_url, PROMPTS[i % len(PROMPTS)], voice, language, stream)
+ for i in range(num_warmups)
+ ]
+ await asyncio.gather(*warmup_tasks)
+ print(" Warmup done.")
+
+ request_prompts = [PROMPTS[i % len(PROMPTS)] for i in range(num_prompts)]
+
+ print(f" Running {num_prompts} requests with concurrency={max_concurrency}...")
+ semaphore = asyncio.Semaphore(max_concurrency)
+ pbar = tqdm(total=num_prompts, desc=f" concurrency={max_concurrency}")
+
+ async def limited_request(prompt):
+ async with semaphore:
+ return await send_tts_request(session, api_url, prompt, voice, language, stream, pbar)
+
+ start_time = time.perf_counter()
+ tasks = [asyncio.create_task(limited_request(p)) for p in request_prompts]
+ results: list[RequestResult] = await asyncio.gather(*tasks)
+ duration = time.perf_counter() - start_time
+ pbar.close()
+
+ await session.close()
+
+ successful = [r for r in results if r.success]
+ failed = [r for r in results if not r.success]
+
+ bench = BenchmarkResult(
+ concurrency=max_concurrency,
+ num_prompts=num_prompts,
+ completed=len(successful),
+ failed=len(failed),
+ duration_s=duration,
+ )
+
+ if successful:
+ ttfps = [r.ttfp * 1000 for r in successful]
+ e2es = [r.e2e * 1000 for r in successful]
+ rtfs = [r.rtf for r in successful]
+ audio_durs = [r.audio_duration for r in successful]
+
+ bench.mean_ttfp_ms = float(np.mean(ttfps))
+ bench.median_ttfp_ms = float(np.median(ttfps))
+ bench.std_ttfp_ms = float(np.std(ttfps))
+ bench.p90_ttfp_ms = float(np.percentile(ttfps, 90))
+ bench.p95_ttfp_ms = float(np.percentile(ttfps, 95))
+ bench.p99_ttfp_ms = float(np.percentile(ttfps, 99))
+
+ bench.mean_e2e_ms = float(np.mean(e2es))
+ bench.median_e2e_ms = float(np.median(e2es))
+ bench.std_e2e_ms = float(np.std(e2es))
+ bench.p90_e2e_ms = float(np.percentile(e2es, 90))
+ bench.p95_e2e_ms = float(np.percentile(e2es, 95))
+ bench.p99_e2e_ms = float(np.percentile(e2es, 99))
+
+ bench.mean_rtf = float(np.mean(rtfs))
+ bench.median_rtf = float(np.median(rtfs))
+ bench.std_rtf = float(np.std(rtfs))
+
+ bench.mean_audio_duration_s = float(np.mean(audio_durs))
+ bench.total_audio_duration_s = float(np.sum(audio_durs))
+ bench.audio_throughput = bench.total_audio_duration_s / duration
+ bench.request_throughput = len(successful) / duration
+
+ bench.per_request = [
+ {
+ "ttfp_ms": r.ttfp * 1000,
+ "e2e_ms": r.e2e * 1000,
+ "rtf": r.rtf,
+ "audio_duration_s": r.audio_duration,
+ "prompt": r.prompt,
+ }
+ for r in successful
+ ]
+
+ print(f"\n{'=' * 60}")
+ print(f" Concurrency: {max_concurrency} | Completed: {bench.completed} | Failed: {bench.failed}")
+ print(f" Duration: {duration:.2f}s | Throughput: {bench.request_throughput:.2f} req/s")
+ print(
+ f" TTFP (ms): mean={bench.mean_ttfp_ms:.1f} median={bench.median_ttfp_ms:.1f}"
+ f" p90={bench.p90_ttfp_ms:.1f} p99={bench.p99_ttfp_ms:.1f}"
+ )
+ print(
+ f" E2E (ms): mean={bench.mean_e2e_ms:.1f} median={bench.median_e2e_ms:.1f}"
+ f" p90={bench.p90_e2e_ms:.1f} p99={bench.p99_e2e_ms:.1f}"
+ )
+ print(f" RTF: mean={bench.mean_rtf:.3f} median={bench.median_rtf:.3f}")
+ print(f" Throughput: {bench.audio_throughput:.2f} audio-sec/wall-sec")
+ print(f"{'=' * 60}\n")
+
+ if failed:
+ for r in failed[:3]:
+ print(f" [ERROR] {r.error[:200]}")
+
+ return bench
+
+
+async def main(args):
+ all_results = []
+
+ for concurrency in args.max_concurrency:
+ result = await run_benchmark(
+ host=args.host,
+ port=args.port,
+ num_prompts=args.num_prompts,
+ max_concurrency=concurrency,
+ num_warmups=args.num_warmups,
+ voice=args.voice,
+ language=args.language,
+ stream=args.stream,
+ )
+ result.config_name = args.config_name
+ all_results.append(asdict(result))
+
+ result_dir = Path(args.result_dir)
+ result_dir.mkdir(parents=True, exist_ok=True)
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ result_file = result_dir / f"bench_{args.config_name}_{timestamp}.json"
+
+ with open(result_file, "w") as f:
+ json.dump(all_results, f, indent=2)
+ print(f"Results saved to {result_file}")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Qwen3-TTS async_chunk benchmark client")
+ parser.add_argument("--host", type=str, default="127.0.0.1")
+ parser.add_argument("--port", type=int, default=8000)
+ parser.add_argument("--num-prompts", type=int, default=50)
+ parser.add_argument("--max-concurrency", type=int, nargs="+", default=[1, 10])
+ parser.add_argument("--num-warmups", type=int, default=3)
+ parser.add_argument("--voice", type=str, default="vivian")
+ parser.add_argument("--language", type=str, default="English")
+ parser.add_argument("--stream", action="store_true", default=True)
+ parser.add_argument("--no-stream", dest="stream", action="store_false")
+ parser.add_argument("--config-name", type=str, default="async_chunk_on")
+ parser.add_argument("--result-dir", type=str, default="results")
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ asyncio.run(main(args))
diff --git a/benchmarks/qwen3-tts/vllm_omni/bench_tts_serve.py b/benchmarks/qwen3-tts/vllm_omni/bench_tts_serve.py
new file mode 100644
index 00000000000..96b904b0174
--- /dev/null
+++ b/benchmarks/qwen3-tts/vllm_omni/bench_tts_serve.py
@@ -0,0 +1,371 @@
+"""Benchmark client for Qwen3-TTS via /v1/audio/speech endpoint.
+
+Measures TTFP (Time-to-First-Packet), E2E latency, and RTF (Real-Time Factor)
+across configurable concurrency levels. Saves results as JSON for plotting.
+
+Usage:
+ python bench_tts_serve.py \
+ --host 127.0.0.1 --port 8000 \
+ --num-prompts 50 \
+ --max-concurrency 1 4 10 \
+ --result-dir results/
+"""
+
+import argparse
+import asyncio
+import json
+import time
+from dataclasses import asdict, dataclass, field
+from datetime import datetime
+from pathlib import Path
+
+import aiohttp
+import numpy as np
+from tqdm.asyncio import tqdm
+
+PROMPTS = [
+ "Hello, welcome to the voice synthesis benchmark test.",
+ "She said she would be here by noon, but nobody showed up.",
+ "The quick brown fox jumps over the lazy dog near the riverbank.",
+ "I can't believe how beautiful the sunset looks from up here on the mountain.",
+ "Please remember to bring your identification documents to the appointment tomorrow morning.",
+ "Have you ever wondered what it would be like to travel through time and visit ancient civilizations?",
+ "The restaurant on the corner serves the best pasta I have ever tasted in my entire life.",
+ "After the meeting, we should discuss the quarterly results and plan for the next phase.",
+ "Learning a new language takes patience, practice, and a genuine curiosity about other cultures.",
+ "The train leaves at half past seven, so we need to arrive at the station before then.",
+ "Could you please turn down the music a little bit, I'm trying to concentrate on my work.",
+ "It was a dark and stormy night when the old lighthouse keeper heard a knock at the door.",
+]
+REF_AUDIO = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav"
+REF_TEXT = "Okay. Yeah. I resent you. I love you. I respect you. But you know what? You blew it! And thanks to you."
+INSTRUCT = "Speak in an incredulous tone, but with a hint of panic beginning to creep into your voice."
+
+
+@dataclass
+class RequestResult:
+ success: bool = False
+ ttfp: float = 0.0 # Time to first audio packet (seconds)
+ e2e: float = 0.0 # End-to-end latency (seconds)
+ audio_bytes: int = 0 # Total audio bytes received
+ audio_duration: float = 0.0 # Audio duration in seconds (estimated from PCM)
+ rtf: float = 0.0 # Real-time factor = e2e / audio_duration
+ prompt: str = ""
+ error: str = ""
+
+
+@dataclass
+class BenchmarkResult:
+ config_name: str = ""
+ concurrency: int = 0
+ num_prompts: int = 0
+ completed: int = 0
+ failed: int = 0
+ duration_s: float = 0.0
+ # TTFP stats (ms)
+ mean_ttfp_ms: float = 0.0
+ median_ttfp_ms: float = 0.0
+ std_ttfp_ms: float = 0.0
+ p90_ttfp_ms: float = 0.0
+ p95_ttfp_ms: float = 0.0
+ p99_ttfp_ms: float = 0.0
+ # E2E stats (ms)
+ mean_e2e_ms: float = 0.0
+ median_e2e_ms: float = 0.0
+ std_e2e_ms: float = 0.0
+ p90_e2e_ms: float = 0.0
+ p95_e2e_ms: float = 0.0
+ p99_e2e_ms: float = 0.0
+ # RTF stats
+ mean_rtf: float = 0.0
+ median_rtf: float = 0.0
+ std_rtf: float = 0.0
+ p99_rtf: float = 0.0
+ # Audio stats
+ mean_audio_duration_s: float = 0.0
+ total_audio_duration_s: float = 0.0
+ audio_throughput: float = 0.0 # audio_duration / wall_time
+ request_throughput: float = 0.0 # requests / second
+ # Per-request details
+ per_request: list = field(default_factory=list)
+
+
+def pcm_bytes_to_duration(num_bytes: int, sample_rate: int = 24000, sample_width: int = 2) -> float:
+ """Convert raw PCM byte count to duration in seconds."""
+ num_samples = num_bytes / sample_width
+ return num_samples / sample_rate
+
+
+def create_payload(
+ prompt: str, task_type: str = "CustomVoice", voice: str = "vivian", language: str = "English"
+) -> dict:
+ payload = {
+ "input": prompt,
+ "language": language,
+ "stream": True,
+ "response_format": "pcm",
+ "task_type": task_type,
+ }
+
+ if task_type == "Base":
+ payload["ref_audio"] = REF_AUDIO
+ payload["ref_text"] = REF_TEXT
+ elif task_type == "CustomVoice":
+ payload["voice"] = voice
+ elif task_type == "VoiceDesign":
+ payload["instructions"] = INSTRUCT
+
+ return payload
+
+
+async def send_tts_request(
+ session: aiohttp.ClientSession,
+ api_url: str,
+ prompt: str,
+ task_type: str = "CustomVoice",
+ voice: str = "vivian",
+ language: str = "English",
+ pbar: tqdm | None = None,
+) -> RequestResult:
+ """Send a streaming TTS request and measure latency metrics."""
+ payload = create_payload(prompt, task_type, voice, language)
+
+ result = RequestResult(prompt=prompt)
+ st = time.perf_counter()
+
+ try:
+ async with session.post(api_url, json=payload) as response:
+ if response.status != 200:
+ result.error = f"HTTP {response.status}: {await response.text()}"
+ result.success = False
+ return result
+
+ first_chunk = True
+ total_bytes = 0
+
+ async for chunk in response.content.iter_any():
+ if first_chunk and len(chunk) > 0:
+ result.ttfp = time.perf_counter() - st
+ first_chunk = False
+ total_bytes += len(chunk)
+
+ result.e2e = time.perf_counter() - st
+ result.audio_bytes = total_bytes
+ result.audio_duration = pcm_bytes_to_duration(total_bytes)
+
+ if result.audio_duration > 0:
+ result.rtf = result.e2e / result.audio_duration
+ result.success = True
+
+ except Exception as e:
+ result.error = str(e)
+ result.success = False
+ result.e2e = time.perf_counter() - st
+
+ if pbar:
+ pbar.update(1)
+ return result
+
+
+async def run_benchmark(
+ host: str,
+ port: int,
+ num_prompts: int,
+ max_concurrency: int,
+ num_warmups: int = 3,
+ task_type: str = "CustomVoice",
+ voice: str = "vivian",
+ language: str = "English",
+) -> BenchmarkResult:
+ """Run benchmark at a given concurrency level."""
+ api_url = f"http://{host}:{port}/v1/audio/speech"
+
+ connector = aiohttp.TCPConnector(
+ limit=max_concurrency,
+ limit_per_host=max_concurrency,
+ keepalive_timeout=60,
+ )
+ session = aiohttp.ClientSession(
+ connector=connector,
+ timeout=aiohttp.ClientTimeout(total=600),
+ )
+
+ # Warmup
+ if num_warmups > 0:
+ print(f" Warming up with {num_warmups} requests...")
+ warmup_tasks = []
+ for i in range(num_warmups):
+ prompt = PROMPTS[i % len(PROMPTS)]
+ warmup_tasks.append(send_tts_request(session, api_url, prompt, task_type, voice, language))
+ await asyncio.gather(*warmup_tasks)
+ print(" Warmup done.")
+
+ # Build request list
+ request_prompts = [PROMPTS[i % len(PROMPTS)] for i in range(num_prompts)]
+
+ # Run benchmark
+ print(f" Running {num_prompts} requests with concurrency={max_concurrency}...")
+ semaphore = asyncio.Semaphore(max_concurrency)
+ pbar = tqdm(total=num_prompts, desc=f" concurrency={max_concurrency}")
+
+ async def limited_request(prompt):
+ async with semaphore:
+ return await send_tts_request(session, api_url, prompt, task_type, voice, language, pbar)
+
+ start_time = time.perf_counter()
+ tasks = [asyncio.create_task(limited_request(p)) for p in request_prompts]
+ results: list[RequestResult] = await asyncio.gather(*tasks)
+ duration = time.perf_counter() - start_time
+ pbar.close()
+
+ await session.close()
+
+ # Compute stats
+ successful = [r for r in results if r.success]
+ failed = [r for r in results if not r.success]
+
+ bench = BenchmarkResult(
+ concurrency=max_concurrency,
+ num_prompts=num_prompts,
+ completed=len(successful),
+ failed=len(failed),
+ duration_s=duration,
+ )
+
+ if successful:
+ ttfps = [r.ttfp * 1000 for r in successful] # convert to ms
+ e2es = [r.e2e * 1000 for r in successful]
+ rtfs = [r.rtf for r in successful]
+ audio_durs = [r.audio_duration for r in successful]
+
+ bench.mean_ttfp_ms = float(np.mean(ttfps))
+ bench.median_ttfp_ms = float(np.median(ttfps))
+ bench.std_ttfp_ms = float(np.std(ttfps))
+ bench.p90_ttfp_ms = float(np.percentile(ttfps, 90))
+ bench.p95_ttfp_ms = float(np.percentile(ttfps, 95))
+ bench.p99_ttfp_ms = float(np.percentile(ttfps, 99))
+
+ bench.mean_e2e_ms = float(np.mean(e2es))
+ bench.median_e2e_ms = float(np.median(e2es))
+ bench.std_e2e_ms = float(np.std(e2es))
+ bench.p90_e2e_ms = float(np.percentile(e2es, 90))
+ bench.p95_e2e_ms = float(np.percentile(e2es, 95))
+ bench.p99_e2e_ms = float(np.percentile(e2es, 99))
+
+ bench.mean_rtf = float(np.mean(rtfs))
+ bench.median_rtf = float(np.median(rtfs))
+ bench.std_rtf = float(np.std(rtfs))
+ bench.p99_rtf = float(np.percentile(rtfs, 99))
+
+ bench.mean_audio_duration_s = float(np.mean(audio_durs))
+ bench.total_audio_duration_s = float(np.sum(audio_durs))
+ bench.audio_throughput = bench.total_audio_duration_s / duration
+ bench.request_throughput = len(successful) / duration
+
+ bench.per_request = [
+ {
+ "ttfp_ms": r.ttfp * 1000,
+ "e2e_ms": r.e2e * 1000,
+ "rtf": r.rtf,
+ "audio_duration_s": r.audio_duration,
+ "prompt": r.prompt,
+ }
+ for r in successful
+ ]
+
+ # Print summary in standardized performance template
+ W = 50
+ print("")
+ print(f"{'=' * W}")
+ print(f"{'Serving Benchmark Result':^{W}}")
+ print(f"{'=' * W}")
+ print(f"{'Successful requests:':<40}{bench.completed:<10}")
+ print(f"{'Failed requests:':<40}{bench.failed:<10}")
+ print(f"{'Maximum request concurrency:':<40}{max_concurrency:<10}")
+ print(f"{'Benchmark duration (s):':<40}{duration:<10.2f}")
+ print(f"{'Request throughput (req/s):':<40}{bench.request_throughput:<10.2f}")
+ print(f"{'-' * W}")
+ print(f"{'End-to-end Latency':^{W}}")
+ print(f"{'-' * W}")
+ print(f"{'Mean E2EL (ms):':<40}{bench.mean_e2e_ms:<10.2f}")
+ print(f"{'Median E2EL (ms):':<40}{bench.median_e2e_ms:<10.2f}")
+ print(f"{'P99 E2EL (ms):':<40}{bench.p99_e2e_ms:<10.2f}")
+ print(f"{'=' * W}")
+ print(f"{'Audio Result':^{W}}")
+ print(f"{'=' * W}")
+ print(f"{'Total audio duration generated (s):':<40}{bench.total_audio_duration_s:<10.2f}")
+ print(f"{'Audio throughput (audio duration/s):':<40}{bench.audio_throughput:<10.2f}")
+ print(f"{'-' * W}")
+ print(f"{'Time to First Packet':^{W}}")
+ print(f"{'-' * W}")
+ print(f"{'Mean AUDIO_TTFP (ms):':<40}{bench.mean_ttfp_ms:<10.2f}")
+ print(f"{'Median AUDIO_TTFP (ms):':<40}{bench.median_ttfp_ms:<10.2f}")
+ print(f"{'P99 AUDIO_TTFP (ms):':<40}{bench.p99_ttfp_ms:<10.2f}")
+ print(f"{'-' * W}")
+ print(f"{'Real Time Factor':^{W}}")
+ print(f"{'-' * W}")
+ print(f"{'Mean AUDIO_RTF:':<40}{bench.mean_rtf:<10.3f}")
+ print(f"{'Median AUDIO_RTF:':<40}{bench.median_rtf:<10.3f}")
+ print(f"{'P99 AUDIO_RTF:':<40}{bench.p99_rtf:<10.3f}")
+ print(f"{'=' * W}")
+ print("")
+
+ if failed:
+ for r in failed[:3]:
+ print(f" [ERROR] {r.error[:200]}")
+
+ return bench
+
+
+async def main(args):
+ all_results = []
+
+ for concurrency in args.max_concurrency:
+ result = await run_benchmark(
+ host=args.host,
+ port=args.port,
+ num_prompts=args.num_prompts,
+ max_concurrency=concurrency,
+ num_warmups=args.num_warmups,
+ task_type=args.task_type,
+ voice=args.voice,
+ language=args.language,
+ )
+ result.config_name = args.config_name
+ all_results.append(asdict(result))
+
+ # Save results
+ result_dir = Path(args.result_dir)
+ result_dir.mkdir(parents=True, exist_ok=True)
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ result_file = result_dir / f"bench_{args.config_name}_{timestamp}.json"
+
+ with open(result_file, "w") as f:
+ json.dump(all_results, f, indent=2)
+ print(f"Results saved to {result_file}")
+
+ return all_results
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Qwen3-TTS Benchmark Client")
+ parser.add_argument("--host", type=str, default="127.0.0.1")
+ parser.add_argument("--port", type=int, default=8000)
+ parser.add_argument("--num-prompts", type=int, default=50, help="Number of prompts per concurrency level")
+ parser.add_argument( # noqa: E501
+ "--max-concurrency", type=int, nargs="+", default=[1, 4, 10], help="Concurrency levels to test"
+ )
+ parser.add_argument("--num-warmups", type=int, default=3)
+ parser.add_argument("--task-type", type=str, default="CustomVoice", choices=["CustomVoice", "VoiceDesign", "Base"])
+ parser.add_argument("--voice", type=str, default="vivian")
+ parser.add_argument("--language", type=str, default="English")
+ parser.add_argument(
+ "--config-name", type=str, default="async_chunk", help="Label for this config (used in filenames)"
+ )
+ parser.add_argument("--result-dir", type=str, default="results")
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ asyncio.run(main(args))
diff --git a/vllm_omni/platforms/npu/stage_configs/voxcpm_async_chunk.yaml b/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml
similarity index 64%
rename from vllm_omni/platforms/npu/stage_configs/voxcpm_async_chunk.yaml
rename to benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml
index 87843634cb7..ca441d286dd 100644
--- a/vllm_omni/platforms/npu/stage_configs/voxcpm_async_chunk.yaml
+++ b/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml
@@ -1,3 +1,5 @@
+# Qwen3-TTS batch_size=1 config (streaming with async_chunk)
+# 2-stage pipeline: Talker -> Code2Wav
async_chunk: true
stage_args:
- stage_id: 0
@@ -5,85 +7,87 @@ stage_args:
is_comprehension: true
runtime:
devices: "0"
- max_batch_size: 1
engine_args:
- dtype: bfloat16
- model_stage: latent_generator
- model_arch: VoxCPMForConditionalGeneration
- # Optional persistent HF-compatible config dir for native VoxCPM models.
- hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,}
+ max_num_seqs: 1
+ model_stage: qwen3_tts
+ model_arch: Qwen3TTSTalkerForConditionalGeneration
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: true
+ enforce_eager: false
trust_remote_code: true
- async_scheduling: false
+ async_scheduling: true
enable_prefix_caching: false
engine_output_type: latent
- gpu_memory_utilization: 0.75
+ gpu_memory_utilization: 0.3
distributed_executor_backend: "mp"
- max_num_batched_tokens: 4096
+ max_num_batched_tokens: 512
max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.voxcpm.latent2vae_async_chunk
+ custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
output_connectors:
to_stage_1: connector_of_shared_memory
default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
+ temperature: 0.9
+ top_k: 50
max_tokens: 4096
seed: 42
detokenize: false
- repetition_penalty: 1.0
- final_output: false
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
- stage_id: 1
stage_type: llm
runtime:
devices: "0"
- max_batch_size: 1
engine_args:
- dtype: float32
- model_stage: vae
- model_arch: VoxCPMForConditionalGeneration
- hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,}
+ max_num_seqs: 1
+ model_stage: code2wav
+ model_arch: Qwen3TTSCode2Wav
worker_type: generation
scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager: true
trust_remote_code: true
- async_scheduling: false
+ async_scheduling: true
enable_prefix_caching: false
engine_output_type: audio
- gpu_memory_utilization: 0.1
+ gpu_memory_utilization: 0.3
distributed_executor_backend: "mp"
max_num_batched_tokens: 8192
- max_model_len: 4096
+ max_model_len: 32768
engine_input_source: [0]
- input_connectors:
- from_stage_0: connector_of_shared_memory
final_output: true
final_output_type: audio
+ input_connectors:
+ from_stage_0: connector_of_shared_memory
+ tts_args:
+ max_instructions_length: 500
default_sampling_params:
temperature: 0.0
top_p: 1.0
top_k: -1
- max_tokens: 1
+ max_tokens: 65536
seed: 42
detokenize: true
repetition_penalty: 1.0
runtime:
enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
connectors:
connector_of_shared_memory:
name: SharedMemoryConnector
extra:
shm_threshold_bytes: 65536
- codec_streaming: false
+ codec_streaming: true
connector_get_sleep_s: 0.01
connector_get_max_wait_first_chunk: 3000
connector_get_max_wait: 300
+ codec_chunk_frames: 25
+ codec_left_context_frames: 25
edges:
- from: 0
to: 1
+ window_size: -1
diff --git a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs16.yaml b/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs16.yaml
new file mode 100644
index 00000000000..2cc5cf53532
--- /dev/null
+++ b/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs16.yaml
@@ -0,0 +1,94 @@
+# Qwen3-TTS max_num_seqs=16 config (streaming with async_chunk)
+# High-throughput concurrent request processing
+# 2-stage pipeline: Talker -> Code2Wav
+async_chunk: true
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ is_comprehension: true
+ runtime:
+ devices: "0"
+ engine_args:
+ max_num_seqs: 16
+ model_stage: qwen3_tts
+ model_arch: Qwen3TTSTalkerForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ enforce_eager: false
+ trust_remote_code: true
+ async_scheduling: true
+ enable_prefix_caching: false
+ engine_output_type: latent
+ gpu_memory_utilization: 0.3
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 4096
+ max_model_len: 4096
+ custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
+ output_connectors:
+ to_stage_1: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ detokenize: false
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+
+ - stage_id: 1
+ stage_type: llm
+ runtime:
+ devices: "0"
+ engine_args:
+ max_num_seqs: 16
+ model_stage: code2wav
+ model_arch: Qwen3TTSCode2Wav
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ async_scheduling: true
+ enable_prefix_caching: false
+ engine_output_type: audio
+ gpu_memory_utilization: 0.2
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 16384
+ max_model_len: 32768
+ engine_input_source: [0]
+ final_output: true
+ final_output_type: audio
+ input_connectors:
+ from_stage_0: connector_of_shared_memory
+ tts_args:
+ max_instructions_length: 500
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ detokenize: true
+ repetition_penalty: 1.0
+
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 16
+
+ connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536
+ codec_streaming: true
+ connector_get_sleep_s: 0.01
+ connector_get_max_wait_first_chunk: 3000
+ connector_get_max_wait: 300
+ codec_chunk_frames: 25
+ codec_left_context_frames: 25
+
+ edges:
+ - from: 0
+ to: 1
+ window_size: -1
diff --git a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs4.yaml b/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs4.yaml
new file mode 100644
index 00000000000..5de107d4976
--- /dev/null
+++ b/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs4.yaml
@@ -0,0 +1,94 @@
+# Qwen3-TTS batch_size=4 config (streaming with async_chunk)
+# Enables concurrent request processing
+# 2-stage pipeline: Talker -> Code2Wav
+async_chunk: true
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ is_comprehension: true
+ runtime:
+ devices: "0"
+ engine_args:
+ max_num_seqs: 4
+ model_stage: qwen3_tts
+ model_arch: Qwen3TTSTalkerForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ enforce_eager: false
+ trust_remote_code: true
+ async_scheduling: true
+ enable_prefix_caching: false
+ engine_output_type: latent
+ gpu_memory_utilization: 0.3
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 512
+ max_model_len: 4096
+ custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
+ output_connectors:
+ to_stage_1: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ detokenize: false
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+
+ - stage_id: 1
+ stage_type: llm
+ runtime:
+ devices: "0"
+ engine_args:
+ max_num_seqs: 4
+ model_stage: code2wav
+ model_arch: Qwen3TTSCode2Wav
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ async_scheduling: true
+ enable_prefix_caching: false
+ engine_output_type: audio
+ gpu_memory_utilization: 0.2
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 8192
+ max_model_len: 32768
+ engine_input_source: [0]
+ final_output: true
+ final_output_type: audio
+ input_connectors:
+ from_stage_0: connector_of_shared_memory
+ tts_args:
+ max_instructions_length: 500
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ detokenize: true
+ repetition_penalty: 1.0
+
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 4
+
+ connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536
+ codec_streaming: true
+ connector_get_sleep_s: 0.01
+ connector_get_max_wait_first_chunk: 3000
+ connector_get_max_wait: 300
+ codec_chunk_frames: 25
+ codec_left_context_frames: 25
+
+ edges:
+ - from: 0
+ to: 1
+ window_size: -1
diff --git a/benchmarks/qwen3-tts/vllm_omni/plot_async_chunk.py b/benchmarks/qwen3-tts/vllm_omni/plot_async_chunk.py
new file mode 100644
index 00000000000..dd03d9626d9
--- /dev/null
+++ b/benchmarks/qwen3-tts/vllm_omni/plot_async_chunk.py
@@ -0,0 +1,249 @@
+"""Plot TTFP comparison: async_chunk off vs on.
+
+Generates a bar chart with improvement arrows, matching the Qwen3-Omni
+async_chunk benchmark figure style.
+
+Usage:
+ python plot_async_chunk.py \
+ --off results/bench_async_chunk_off_*.json \
+ --on results/bench_async_chunk_on_*.json \
+ --output results/qwen3_tts_async_chunk_ttfp.png
+
+ # Also supports E2E and RTF metrics:
+ python plot_async_chunk.py \
+ --off results/bench_async_chunk_off_*.json \
+ --on results/bench_async_chunk_on_*.json \
+ --metric e2e \
+ --output results/qwen3_tts_async_chunk_e2e.png
+"""
+
+import argparse
+import json
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+METRIC_CONFIG = {
+ "ttfp": {
+ "key": "mean_ttfp_ms",
+ "ylabel": "TTFP (s)",
+ "title": "TTFP (Time to First Audio Packet) - Qwen3-TTS, by concurrency",
+ "to_seconds": True,
+ },
+ "e2e": {
+ "key": "mean_e2e_ms",
+ "ylabel": "E2E (s)",
+ "title": "E2E Latency - Qwen3-TTS, by concurrency",
+ "to_seconds": True,
+ },
+ "rtf": {
+ "key": "mean_rtf",
+ "ylabel": "RTF",
+ "title": "Real-Time Factor - Qwen3-TTS, by concurrency",
+ "to_seconds": False,
+ },
+}
+
+
+def load_results(path: str) -> list[dict]:
+ with open(path) as f:
+ return json.load(f)
+
+
+def plot_ttfp_comparison(
+ off_results: list[dict],
+ on_results: list[dict],
+ metric: str,
+ output_path: str,
+ title_override: str | None = None,
+):
+ cfg = METRIC_CONFIG[metric]
+ key = cfg["key"]
+ to_seconds = cfg["to_seconds"]
+
+ off_map = {r["concurrency"]: r for r in off_results}
+ on_map = {r["concurrency"]: r for r in on_results}
+ concurrencies = sorted(set(off_map.keys()) & set(on_map.keys()))
+
+ off_vals = []
+ on_vals = []
+ for c in concurrencies:
+ v_off = off_map[c][key]
+ v_on = on_map[c][key]
+ if to_seconds:
+ v_off /= 1000.0
+ v_on /= 1000.0
+ off_vals.append(v_off)
+ on_vals.append(v_on)
+
+ fig, ax = plt.subplots(figsize=(8, 6))
+
+ x = np.arange(len(concurrencies))
+ width = 0.3
+
+ ax.bar(x - width / 2, off_vals, width, label="async_chunk off", color="#87CEEB", edgecolor="none")
+ ax.bar(x + width / 2, on_vals, width, label="async_chunk on", color="#FFF8DC", edgecolor="#DDD8B8")
+
+ # Draw improvement arrows and labels
+ for i in range(len(concurrencies)):
+ v_off = off_vals[i]
+ v_on = on_vals[i]
+ if v_on > 0:
+ improvement = v_off / v_on
+ else:
+ improvement = float("inf")
+
+ # Arrow from top of off-bar to top of on-bar
+ arrow_start_x = x[i] - width / 2
+ arrow_start_y = v_off * 0.95
+ arrow_end_x = x[i] + width / 2
+ arrow_end_y = v_on * 1.05
+
+ ax.annotate(
+ "",
+ xy=(arrow_end_x, arrow_end_y),
+ xytext=(arrow_start_x, arrow_start_y),
+ arrowprops=dict(arrowstyle="->", color="red", lw=1.5),
+ )
+
+ # Improvement label
+ label_x = (arrow_start_x + arrow_end_x) / 2
+ label_y = arrow_start_y + (v_off - v_on) * 0.15
+ ax.text(
+ label_x,
+ label_y,
+ f"{improvement:.1f}x improvement",
+ ha="center",
+ va="bottom",
+ fontsize=10,
+ color="red",
+ fontweight="bold",
+ )
+
+ title = title_override or cfg["title"]
+ ax.set_title(title, fontsize=13, fontweight="bold")
+ ax.set_ylabel(cfg["ylabel"], fontsize=12)
+ ax.set_xlabel("Max concurrency", fontsize=12)
+ ax.set_xticks(x)
+ ax.set_xticklabels([str(c) for c in concurrencies])
+ ax.set_yscale("log")
+ ax.legend(loc="upper left", fontsize=11)
+ ax.grid(axis="y", alpha=0.3, linestyle="--")
+ ax.set_axisbelow(True)
+
+ plt.tight_layout()
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
+ print(f"Plot saved to {output_path}")
+ plt.close()
+
+
+def plot_all_metrics(off_results: list[dict], on_results: list[dict], output_path: str):
+ """Generate a 1x3 subplot with TTFP, E2E, and RTF comparisons."""
+ off_map = {r["concurrency"]: r for r in off_results}
+ on_map = {r["concurrency"]: r for r in on_results}
+ concurrencies = sorted(set(off_map.keys()) & set(on_map.keys()))
+
+ fig, axes = plt.subplots(1, 3, figsize=(18, 6))
+ fig.suptitle("Qwen3-TTS: async_chunk on vs off", fontsize=15, fontweight="bold")
+
+ for ax, metric in zip(axes, ["ttfp", "e2e", "rtf"]):
+ cfg = METRIC_CONFIG[metric]
+ key = cfg["key"]
+ to_seconds = cfg["to_seconds"]
+
+ off_vals = []
+ on_vals = []
+ for c in concurrencies:
+ v_off = off_map[c][key]
+ v_on = on_map[c][key]
+ if to_seconds:
+ v_off /= 1000.0
+ v_on /= 1000.0
+ off_vals.append(v_off)
+ on_vals.append(v_on)
+
+ x = np.arange(len(concurrencies))
+ width = 0.3
+ ax.bar(x - width / 2, off_vals, width, label="async_chunk off", color="#87CEEB")
+ ax.bar(x + width / 2, on_vals, width, label="async_chunk on", color="#FFF8DC", edgecolor="#DDD8B8")
+
+ for i in range(len(concurrencies)):
+ if on_vals[i] > 0:
+ improvement = off_vals[i] / on_vals[i]
+ ax.annotate(
+ "",
+ xy=(x[i] + width / 2, on_vals[i] * 1.05),
+ xytext=(x[i] - width / 2, off_vals[i] * 0.95),
+ arrowprops=dict(arrowstyle="->", color="red", lw=1.5),
+ )
+ label_y = off_vals[i] * 0.85
+ ax.text(x[i], label_y, f"{improvement:.1f}x", ha="center", fontsize=10, color="red", fontweight="bold")
+
+ ax.set_title(cfg["title"].split(" - ")[0], fontsize=12, fontweight="bold")
+ ax.set_ylabel(cfg["ylabel"], fontsize=11)
+ ax.set_xlabel("Max concurrency", fontsize=11)
+ ax.set_xticks(x)
+ ax.set_xticklabels([str(c) for c in concurrencies])
+ if metric != "rtf":
+ ax.set_yscale("log")
+ ax.legend(fontsize=9)
+ ax.grid(axis="y", alpha=0.3, linestyle="--")
+ ax.set_axisbelow(True)
+
+ plt.tight_layout()
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
+ print(f"Plot saved to {output_path}")
+ plt.close()
+
+
+def print_table(off_results: list[dict], on_results: list[dict]):
+ off_map = {r["concurrency"]: r for r in off_results}
+ on_map = {r["concurrency"]: r for r in on_results}
+ concurrencies = sorted(set(off_map.keys()) & set(on_map.keys()))
+
+ print("\n## Benchmark Results: async_chunk off vs on\n")
+ print("| Metric | Concurrency | async_chunk off | async_chunk on | Improvement |")
+ print("| --- | --- | --- | --- | --- |")
+
+ for name, key, fmt in [
+ ("TTFP (ms)", "mean_ttfp_ms", ".1f"),
+ ("E2E (ms)", "mean_e2e_ms", ".1f"),
+ ("RTF", "mean_rtf", ".3f"),
+ ("Throughput", "audio_throughput", ".2f"),
+ ]:
+ for c in concurrencies:
+ v_off = off_map[c].get(key, 0)
+ v_on = on_map[c].get(key, 0)
+ if v_on > 0 and key != "audio_throughput":
+ ratio = f"{v_off / v_on:.1f}x"
+ elif v_off > 0 and key == "audio_throughput":
+ ratio = f"{v_on / v_off:.1f}x"
+ else:
+ ratio = "N/A"
+ print(f"| {name} | {c} | {v_off:{fmt}} | {v_on:{fmt}} | {ratio} |")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Plot async_chunk comparison for Qwen3-TTS")
+ parser.add_argument("--off", type=str, required=True, help="JSON results for async_chunk off")
+ parser.add_argument("--on", type=str, required=True, help="JSON results for async_chunk on")
+ parser.add_argument("--metric", type=str, default="ttfp", choices=["ttfp", "e2e", "rtf", "all"])
+ parser.add_argument("--output", type=str, default="results/qwen3_tts_async_chunk.png")
+ parser.add_argument("--title", type=str, default=None, help="Custom title override")
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ off_results = load_results(args.off)
+ on_results = load_results(args.on)
+
+ print_table(off_results, on_results)
+
+ if args.metric == "all":
+ plot_all_metrics(off_results, on_results, args.output)
+ else:
+ plot_ttfp_comparison(off_results, on_results, args.metric, args.output, args.title)
diff --git a/benchmarks/qwen3-tts/vllm_omni/run_async_chunk_benchmark.sh b/benchmarks/qwen3-tts/vllm_omni/run_async_chunk_benchmark.sh
new file mode 100755
index 00000000000..61cf7757a9b
--- /dev/null
+++ b/benchmarks/qwen3-tts/vllm_omni/run_async_chunk_benchmark.sh
@@ -0,0 +1,167 @@
+#!/bin/bash
+# Qwen3-TTS async_chunk on vs off Benchmark
+#
+# Starts two servers (async_chunk on and off), benchmarks both,
+# and generates comparison plots.
+#
+# Usage:
+# bash run_async_chunk_benchmark.sh
+#
+# Environment variables:
+# GPU_DEVICE - GPU index (default: 0)
+# MODEL - Model path (default: Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice)
+# NUM_PROMPTS - Prompts per concurrency level (default: 50)
+# CONCURRENCY - Space-separated concurrency levels (default: "1 10")
+# PORT_ON - Port for async_chunk on server (default: 8000)
+# PORT_OFF - Port for async_chunk off server (default: 8001)
+
+set -euo pipefail
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+PROJECT_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
+cd "$PROJECT_ROOT"
+
+GPU_DEVICE="${GPU_DEVICE:-0}"
+MODEL="${MODEL:-Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice}"
+NUM_PROMPTS="${NUM_PROMPTS:-50}"
+CONCURRENCY="${CONCURRENCY:-1 10}"
+NUM_WARMUPS="${NUM_WARMUPS:-3}"
+PORT_ON="${PORT_ON:-8000}"
+PORT_OFF="${PORT_OFF:-8001}"
+RESULT_DIR="${SCRIPT_DIR}/results"
+TIMESTAMP="$(date +%Y%m%d_%H%M%S)"
+
+STAGE_CONFIG_ON="vllm_omni/model_executor/stage_configs/qwen3_tts.yaml"
+STAGE_CONFIG_OFF="vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml"
+
+mkdir -p "${RESULT_DIR}"
+
+echo "============================================================"
+echo " Qwen3-TTS async_chunk Benchmark"
+echo "============================================================"
+echo " GPU: ${GPU_DEVICE}"
+echo " Model: ${MODEL}"
+echo " Prompts: ${NUM_PROMPTS}"
+echo " Concurrency: ${CONCURRENCY}"
+echo " Port (on/off): ${PORT_ON} / ${PORT_OFF}"
+echo " Results: ${RESULT_DIR}"
+echo "============================================================"
+
+cleanup() {
+ echo "Cleaning up servers..."
+ kill "$PID_ON" 2>/dev/null || true
+ kill "$PID_OFF" 2>/dev/null || true
+ wait "$PID_ON" 2>/dev/null || true
+ wait "$PID_OFF" 2>/dev/null || true
+}
+trap cleanup EXIT
+
+wait_for_server() {
+ local port=$1
+ local name=$2
+ local max_wait=300
+ local elapsed=0
+ echo "Waiting for ${name} server on port ${port}..."
+ while ! curl -s "http://localhost:${port}/health" >/dev/null 2>&1; do
+ sleep 5
+ elapsed=$((elapsed + 5))
+ if [ $elapsed -ge $max_wait ]; then
+ echo "ERROR: ${name} server failed to start within ${max_wait}s"
+ exit 1
+ fi
+ done
+ echo "${name} server ready (${elapsed}s)"
+}
+
+# ---- Phase 1: Start async_chunk ON server ----
+echo ""
+echo "[Phase 1] Starting async_chunk ON server on port ${PORT_ON}..."
+CUDA_VISIBLE_DEVICES=${GPU_DEVICE} vllm-omni serve "${MODEL}" \
+ --stage-configs-path "${STAGE_CONFIG_ON}" \
+ --host 0.0.0.0 --port "${PORT_ON}" \
+ --trust-remote-code --enforce-eager --omni \
+ > "${RESULT_DIR}/server_on_${TIMESTAMP}.log" 2>&1 &
+PID_ON=$!
+
+wait_for_server "${PORT_ON}" "async_chunk_on"
+
+echo "[Phase 1] Benchmarking async_chunk ON..."
+# shellcheck disable=SC2086
+python "${SCRIPT_DIR}/bench_async_chunk.py" \
+ --host 127.0.0.1 --port "${PORT_ON}" \
+ --config-name "async_chunk_on" \
+ --num-prompts "${NUM_PROMPTS}" \
+ --max-concurrency ${CONCURRENCY} \
+ --num-warmups "${NUM_WARMUPS}" \
+ --result-dir "${RESULT_DIR}"
+
+echo "[Phase 1] Stopping async_chunk ON server..."
+kill "$PID_ON" 2>/dev/null || true
+wait "$PID_ON" 2>/dev/null || true
+sleep 5
+
+# ---- Phase 2: Start async_chunk OFF server ----
+echo ""
+echo "[Phase 2] Starting async_chunk OFF server on port ${PORT_OFF}..."
+CUDA_VISIBLE_DEVICES=${GPU_DEVICE} vllm-omni serve "${MODEL}" \
+ --stage-configs-path "${STAGE_CONFIG_OFF}" \
+ --host 0.0.0.0 --port "${PORT_OFF}" \
+ --trust-remote-code --enforce-eager --omni \
+ > "${RESULT_DIR}/server_off_${TIMESTAMP}.log" 2>&1 &
+PID_OFF=$!
+
+wait_for_server "${PORT_OFF}" "async_chunk_off"
+
+echo "[Phase 2] Benchmarking async_chunk OFF (non-streaming)..."
+# shellcheck disable=SC2086
+python "${SCRIPT_DIR}/bench_async_chunk.py" \
+ --host 127.0.0.1 --port "${PORT_OFF}" \
+ --config-name "async_chunk_off" \
+ --num-prompts "${NUM_PROMPTS}" \
+ --max-concurrency ${CONCURRENCY} \
+ --num-warmups "${NUM_WARMUPS}" \
+ --no-stream \
+ --result-dir "${RESULT_DIR}"
+
+echo "[Phase 2] Stopping async_chunk OFF server..."
+kill "$PID_OFF" 2>/dev/null || true
+wait "$PID_OFF" 2>/dev/null || true
+
+# ---- Phase 3: Plot results ----
+echo ""
+echo "[Phase 3] Generating plots..."
+
+# Find the latest result files
+RESULT_ON=$(ls -t "${RESULT_DIR}"/bench_async_chunk_on_*.json 2>/dev/null | head -1)
+RESULT_OFF=$(ls -t "${RESULT_DIR}"/bench_async_chunk_off_*.json 2>/dev/null | head -1)
+
+if [ -z "$RESULT_ON" ] || [ -z "$RESULT_OFF" ]; then
+ echo "ERROR: Could not find result files. Check logs in ${RESULT_DIR}/"
+ exit 1
+fi
+
+echo " ON results: ${RESULT_ON}"
+echo " OFF results: ${RESULT_OFF}"
+
+# TTFP comparison (main figure)
+python "${SCRIPT_DIR}/plot_async_chunk.py" \
+ --off "${RESULT_OFF}" \
+ --on "${RESULT_ON}" \
+ --metric ttfp \
+ --output "${RESULT_DIR}/qwen3_tts_async_chunk_ttfp.png"
+
+# All metrics comparison
+python "${SCRIPT_DIR}/plot_async_chunk.py" \
+ --off "${RESULT_OFF}" \
+ --on "${RESULT_ON}" \
+ --metric all \
+ --output "${RESULT_DIR}/qwen3_tts_async_chunk_all.png"
+
+echo ""
+echo "============================================================"
+echo " Benchmark complete!"
+echo " Results: ${RESULT_DIR}/"
+echo " Plots:"
+echo " - ${RESULT_DIR}/qwen3_tts_async_chunk_ttfp.png"
+echo " - ${RESULT_DIR}/qwen3_tts_async_chunk_all.png"
+echo "============================================================"
diff --git a/benchmarks/tts/README.md b/benchmarks/tts/README.md
deleted file mode 100644
index 9e2fd35b1a5..00000000000
--- a/benchmarks/tts/README.md
+++ /dev/null
@@ -1,227 +0,0 @@
-# TTS Universal Benchmark
-
-A model-agnostic serving benchmark for TTS models in vllm-omni. One CLI
-(`bench_tts.py`) + one YAML registry (`model_configs.yaml`) drive perf and
-quality runs for every registered checkpoint: **Qwen3-TTS** (Base / CustomVoice)
-and **VoxCPM2** today, more to come.
-
-The same three task types — `voice_clone`, `default_voice`, `voice_design` —
-are wired into both the manual CLI and the DFX nightly CI matrix
-(`tests/dfx/perf/tests/test_tts.json`).
-
-## Quick start
-
-### 1. Start the server
-
-```bash
-vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-Base --omni --port 8000
-```
-
-The server auto-loads its Deploy YAML from `vllm_omni/deploy/qwen3_tts.yaml`
-(Pipeline + Deploy schema introduced in #2383). No `--stage-configs-path` or
-`--deploy-config` flag is needed for any registered model.
-
-### 2. Run the benchmark (`vllm bench serve --omni`)
-
-The primary, directly-controllable path. Copy-paste one of these and tweak
-any bench flag (sampling params, endpoint, extra body, warmups, etc.):
-
-#### voice_clone (Qwen3-TTS-Base, seed-tts dataset)
-
-```bash
-vllm bench serve --omni \
- --host 127.0.0.1 --port 8000 \
- --model Qwen/Qwen3-TTS-12Hz-1.7B-Base \
- --backend openai-audio-speech \
- --endpoint /v1/audio/speech \
- --dataset-name seed-tts \
- --dataset-path /path/to/seed-tts-eval \
- --seed-tts-locale en \
- --num-prompts 20 --num-warmups 2 \
- --extra-body '{"task_type":"Base"}' \
- --max-concurrency 1 --request-rate inf \
- --percentile-metrics ttft,e2el,audio_rtf,audio_ttfp,audio_duration \
- --save-result --result-dir ./results
-```
-
-#### default_voice (Qwen3-TTS-CustomVoice, bundled seed_tts_smoke)
-
-```bash
-vllm bench serve --omni \
- --host 127.0.0.1 --port 8000 \
- --model Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --backend openai-audio-speech \
- --endpoint /v1/audio/speech \
- --dataset-name seed-tts-text \
- --dataset-path benchmarks/build_dataset/seed_tts_smoke \
- --seed-tts-locale en \
- --num-prompts 20 --num-warmups 2 \
- --extra-body '{"voice":"Vivian","language":"English","task_type":"CustomVoice"}' \
- --max-concurrency 1 --request-rate inf \
- --percentile-metrics ttft,e2el,audio_rtf,audio_ttfp,audio_duration \
- --save-result --result-dir ./results
-```
-
-#### voice_design (Qwen3-TTS-CustomVoice, bundled seed_tts_design)
-
-```bash
-vllm bench serve --omni \
- --host 127.0.0.1 --port 8000 \
- --model Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --backend openai-audio-speech \
- --endpoint /v1/audio/speech \
- --dataset-name seed-tts-design \
- --dataset-path benchmarks/build_dataset/seed_tts_design \
- --seed-tts-locale en \
- --num-prompts 20 --num-warmups 2 \
- --extra-body '{"task_type":"VoiceDesign","language":"English"}' \
- --max-concurrency 1 --request-rate inf \
- --percentile-metrics ttft,e2el,audio_rtf,audio_ttfp,audio_duration \
- --save-result --result-dir ./results
-```
-
-#### Add WER / SIM / UTMOS to any of the above
-
-Append `--seed-tts-wer-eval` (and optionally `SEED_TTS_EVAL_DEVICE=cuda:0`
-in the env, per PR #2558). This triggers the seed-tts-eval protocol:
-Whisper-large-v3 ASR → WER, WavLM embeddings → SIM, balacoon/utmos → UTMOS.
-
-### 3. Convenience wrapper (`bench_tts.py`)
-
-If you're running the **canonical** configuration for a registered model,
-`bench_tts.py` loads the right defaults from `model_configs.yaml` and
-emits the exact `vllm bench serve --omni` command above — useful for
-concurrency sweeps and multi-task runs:
-
-```bash
-# Smallest smoke — 5 prompts, concurrency=1
-python benchmarks/tts/bench_tts.py \
- --model Qwen/Qwen3-TTS-12Hz-1.7B-Base \
- --task voice_clone \
- --dataset-path /path/to/seed-tts-eval \
- --concurrency 1 --num-prompts 5 \
- --output-dir ./results
-
-# Full concurrency sweep
-python benchmarks/tts/bench_tts.py \
- --model Qwen/Qwen3-TTS-12Hz-1.7B-Base \
- --task voice_clone \
- --dataset-path /path/to/seed-tts-eval \
- --concurrency 1 2 4 8 16 32 \
- --num-prompts 20 \
- --output-dir ./results
-
-# With WER / SIM / UTMOS quality eval (adds ASR + embedding compute)
-python benchmarks/tts/bench_tts.py \
- --model Qwen/Qwen3-TTS-12Hz-1.7B-Base \
- --task voice_clone \
- --dataset-path /path/to/seed-tts-eval \
- --wer-eval \
- --concurrency 4 --num-prompts 200 \
- --output-dir ./results
-```
-
-### 4. Plot a sweep
-
-```bash
-python benchmarks/tts/plot_results.py \
- --results ./results/*.json \
- --output ./results/curve.png
-```
-
-Outputs TTFP / RTF / throughput curves (and a markdown table) for every
-`(task, concurrency)` combination in the result set.
-
-## Task types
-
-| Task | Dataset | Request body | Checkpoints that support it |
-|-----------------|-------------------|-----------------------------------------------------|------------------------------------------|
-| `voice_clone` | `seed-tts` | `ref_audio` + `ref_text` + `task_type=Base` | `Qwen3-TTS-*-Base`, `VoxCPM2` |
-| `default_voice` | `seed-tts-text` | `voice=Vivian` + `task_type=CustomVoice` | `Qwen3-TTS-*-CustomVoice` |
-| `voice_design` | `seed-tts-design` | `instructions=` + `task_type=VoiceDesign` | `Qwen3-TTS-*-CustomVoice` |
-
-**`-CustomVoice` checkpoints do NOT ship `speaker_encoder` weights**, so
-voice_clone requests raise `ValueError` at model runtime. Use `-Base` for
-voice_clone.
-
-## Adding a new TTS model
-
-Drop an entry into `model_configs.yaml` — no Python changes required:
-
-```yaml
-models:
- /:
- supported_tasks: [voice_clone] # or default_voice / voice_design
- backend: openai-audio-speech # vllm bench serve backend
- endpoint: /v1/audio/speech # OpenAI-compatible endpoint
- task_extra_body: # merged into every request's body
- voice_clone:
- task_type: Base
-```
-
-Then add the model's Deploy YAML under `vllm_omni/deploy/.yaml`
-(Pipeline + Deploy schema) and it's immediately benchable.
-
-## Datasets
-
-| Dataset | Bundled? | Format | Source |
-|--------------------|----------|-------------------|----------------------------------------------------------------|
-| `seed-tts-design` | ✅ | 5-field meta.lst | `benchmarks/build_dataset/seed_tts_design/en/meta.lst` (20 prompts) |
-| `seed_tts_smoke` | ✅ | 4-field meta.lst | `benchmarks/build_dataset/seed_tts_smoke/en/meta.lst` (20 text-only) |
-| `seed-tts` | ❌ | 4-field meta.lst + WAVs | Google-Drive: [BytedanceSpeech/seed-tts-eval][seedtts] (~1.2 GB) |
-| `seed-tts-text` | ❌ | 4-field meta.lst | Same archive as `seed-tts` (wav column unused) |
-
-[seedtts]: https://github.com/BytedanceSpeech/seed-tts-eval
-
-For manual voice_clone / default_voice runs against the full corpus, follow
-`benchmarks/build_dataset/download_process_data_seedtts.md` and point
-`--dataset-path` at the extracted `seedtts_testset` directory.
-
-## DFX nightly CI
-
-`tests/dfx/perf/tests/test_tts.json` wires three perf regimes plus quality:
-
-| eval_phase | concurrency | purpose | Baseline metrics |
-|---------------|-------------|---------------------------------------------------------|-----------------------------------------|
-| `latency` | 1 | Single-request TTFP / RTF SLO | `median_audio_ttfp_ms`, `median_audio_rtf` |
-| `throughput` | 8 | Codec-batching cliff sentinel (PDF #272 concurrency≥8) | `median_audio_ttfp_ms`, `median_audio_rtf` |
-| `quality` | 4 | WER / SIM / UTMOS regression (disabled in CI by default)| `mean_audio_rtf` |
-
-Why `median_*` for latency/throughput and `mean_*` for quality: latency
-distributions have cold-start tails that drag the mean; quality aggregates
-over 200 prompts so single-request outliers don't matter.
-
-Quality entries are `enabled: false` in CI because seed-tts-eval is not
-staged in the Buildkite container (matches the precedent in
-PR #2558 — quality runs are manual / release-validation, not nightly).
-
-## Concurrency cliff regression sentinel
-
-Observed on H20-3e, Qwen3-TTS-1.7B (measured pre-merge on this branch):
-
-| Task | Model | c=1 | c=4 | **c=8** | c=16 | c=32 |
-|---------------|---------------|--------|--------|------------|--------|--------|
-| voice_clone | 1.7B-Base | RTF 0.15 / TTFP 165ms | 0.28 / 412ms | **0.49 / 1701ms** | 0.72 / 3355ms | 0.77 / 3772ms |
-| voice_design | 1.7B-CustomVoice | RTF 0.08 / TTFP 53ms | 0.11 / 154ms | **0.21 / 872ms** | 0.33 / 1801ms | 0.38 / 1989ms |
-
-Both models show a **4–6× TTFP jump from c=4 to c=8** while audio throughput
-saturates around c=4–8 — the codec-bs=1 bottleneck documented in
-vllm-project/vllm-omni#272. The `throughput` CI regime at c=8 is the
-sentinel for regressions in this area.
-
-## File layout
-
-```
-benchmarks/tts/
-├── README.md (this file)
-├── bench_tts.py CLI — serve-mode benchmark driver
-├── bench_voxcpm_offline.py CLI — offline VoxCPM benchmark (sync + streaming)
-├── plot_results.py Generate per-task / per-concurrency curves
-└── model_configs.yaml Model registry (supported tasks + extra body)
-```
-
-## Related
-
-- Upstream seed-tts-eval integration: vllm-project/vllm-omni#2558
-- Pipeline + Deploy schema: vllm-project/vllm-omni#2383
-- Concurrency cliff RFC: vllm-project/vllm-omni#272
diff --git a/benchmarks/tts/bench_tts.py b/benchmarks/tts/bench_tts.py
deleted file mode 100644
index ba82b1c9b7b..00000000000
--- a/benchmarks/tts/bench_tts.py
+++ /dev/null
@@ -1,308 +0,0 @@
-#!/usr/bin/env python3
-"""Universal TTS benchmark CLI for vllm-omni.
-
-Runs ``vllm bench serve --omni`` with model-aware defaults loaded from
-``model_configs.yaml``. Supports Qwen3-TTS, VoxCPM2, and any future TTS
-model registered in the config file -- no code changes needed to add models.
-
-Usage::
-
- python benchmarks/tts/bench_tts.py \\
- --model Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \\
- --task voice_clone \\
- --locale en \\
- --concurrency 1 4 \\
- --num-prompts 20 \\
- --dataset-path /path/to/seed-tts-eval \\
- --host localhost --port 8000
-
-See ``--help`` for full option list.
-"""
-
-from __future__ import annotations
-
-import argparse
-import json
-import math
-import os
-import subprocess
-import sys
-from datetime import datetime
-from pathlib import Path
-from typing import Any
-
-import yaml
-
-
-def _vllm_omni_bin() -> str:
- """Return the vllm-omni (or vllm) binary co-located with the current Python."""
- bin_dir = Path(sys.executable).parent
- for candidate in ("vllm-omni", "vllm"):
- p = bin_dir / candidate
- if p.is_file():
- return str(p)
- return "vllm-omni" # fall back and let the shell resolve it
-
-
-_REPO_ROOT = Path(__file__).resolve().parent.parent.parent
-_SCRIPT_DIR = Path(__file__).resolve().parent
-_DEFAULT_MODEL_CONFIGS = _SCRIPT_DIR / "model_configs.yaml"
-
-# Maps task name to the dataset_name used with vllm bench serve
-_TASK_TO_DATASET: dict[str, str] = {
- "voice_clone": "seed-tts",
- "default_voice": "seed-tts-text",
- "voice_design": "seed-tts-design",
-}
-
-# Default design dataset path (bundled with the repo)
-_DEFAULT_DESIGN_DATASET_PATH = str(_REPO_ROOT / "benchmarks" / "build_dataset" / "seed_tts_design")
-
-
-def load_model_configs(path: Path) -> dict[str, Any]:
- """Load model registry from YAML file."""
- with open(path, encoding="utf-8") as f:
- data = yaml.safe_load(f)
- return data.get("models", {})
-
-
-def build_bench_args(
- *,
- host: str,
- port: int,
- model: str,
- task: str,
- model_cfg: dict[str, Any],
- locale: str,
- num_prompts: int,
- concurrency: int | None,
- dataset_path: str | None,
- wer_eval: bool,
- output_dir: str | None,
- result_filename: str | None,
- extra_cli_args: list[str],
-) -> list[str]:
- """Build the ``vllm bench serve --omni`` command for one (task, concurrency) run."""
- dataset_name = _TASK_TO_DATASET[task]
- backend: str = model_cfg["backend"]
- endpoint: str = model_cfg["endpoint"]
- task_extra_body: dict[str, Any] = (model_cfg.get("task_extra_body") or {}).get(task) or {}
-
- # Resolve dataset path
- if dataset_path:
- resolved_dataset_path = dataset_path
- elif task == "voice_design":
- resolved_dataset_path = _DEFAULT_DESIGN_DATASET_PATH
- else:
- resolved_dataset_path = None
-
- cmd = [
- _vllm_omni_bin(),
- "bench",
- "serve",
- "--omni",
- "--host",
- host,
- "--port",
- str(port),
- "--model",
- model,
- "--backend",
- backend,
- "--endpoint",
- endpoint,
- "--dataset-name",
- dataset_name,
- "--num-prompts",
- str(num_prompts),
- "--num-warmups",
- "2",
- "--percentile-metrics",
- "ttft,e2el,audio_rtf,audio_ttfp,audio_duration",
- ]
-
- if resolved_dataset_path:
- cmd += ["--dataset-path", resolved_dataset_path]
-
- if locale:
- cmd += ["--seed-tts-locale", locale]
-
- if task_extra_body:
- cmd += ["--extra-body", json.dumps(task_extra_body, separators=(",", ":"))]
-
- if concurrency is not None:
- cmd += ["--max-concurrency", str(concurrency), "--request-rate", "inf"]
-
- if wer_eval:
- cmd.append("--seed-tts-wer-eval")
-
- if output_dir or result_filename:
- out_dir = output_dir or "."
- os.makedirs(out_dir, exist_ok=True)
- cmd += ["--save-result", "--result-dir", out_dir]
- if result_filename:
- cmd += ["--result-filename", result_filename]
-
- cmd += extra_cli_args
- return cmd
-
-
-def run_one_benchmark(cmd: list[str]) -> dict[str, Any] | None:
- """Run a single benchmark subprocess and return parsed JSON result if available."""
- print(f"\n{'=' * 60}")
- print("Running:", " ".join(cmd))
- print("=" * 60)
- result = subprocess.run(cmd, check=False)
- if result.returncode != 0:
- print(f"[bench_tts] WARNING: benchmark exited with code {result.returncode}")
- return None
- # If --save-result was used, find the result file
- try:
- result_dir_idx = cmd.index("--result-dir")
- result_dir = Path(cmd[result_dir_idx + 1])
- if "--result-filename" in cmd:
- fname_idx = cmd.index("--result-filename")
- result_file = result_dir / cmd[fname_idx + 1]
- else:
- # find most recently modified json
- jsons = sorted(result_dir.glob("result_*.json"), key=lambda p: p.stat().st_mtime)
- result_file = jsons[-1] if jsons else None
- if result_file and result_file.is_file():
- return json.loads(result_file.read_text(encoding="utf-8"))
- except (ValueError, IndexError, OSError):
- pass
- return None
-
-
-def print_summary_table(results: list[dict[str, Any]]) -> None:
- """Print a unified metrics table across all (task, concurrency) runs."""
- if not results:
- return
- header = (
- f"{'Task':<16} {'Concurrency':>11} {'RTF mean':>10} "
- f"{'TTFP (ms)':>10} {'Throughput':>12} {'WER':>7} {'SIM':>7} {'UTMOS':>7}"
- )
- print(f"\n{'=' * len(header)}")
- print("BENCHMARK SUMMARY")
- print("=" * len(header))
- print(header)
- print("-" * len(header))
- for r in results:
- task = r.get("_task", "?")
- conc = r.get("_concurrency", "?")
- rtf = r.get("mean_audio_rtf", float("nan"))
- ttfp = r.get("mean_audio_ttfp_ms", float("nan"))
- throughput = r.get("audio_throughput", float("nan"))
- wer = r.get("seed_tts_mean_wer", float("nan"))
- sim = r.get("seed_tts_mean_sim", float("nan"))
- utmos = r.get("seed_tts_mean_utmos", float("nan"))
-
- def fmt(v: float, digits: int = 3) -> str:
- return f"{v:.{digits}f}" if not math.isnan(v) else " n/a"
-
- print(
- f"{task:<16} {str(conc):>11} {fmt(rtf):>10} {fmt(ttfp, 0):>10} "
- f"{fmt(throughput):>12} {fmt(wer):>7} {fmt(sim):>7} {fmt(utmos):>7}"
- )
- print("=" * len(header))
-
-
-def main() -> None:
- """Entry point for the universal TTS benchmark CLI."""
- parser = argparse.ArgumentParser(
- description=__doc__,
- formatter_class=argparse.RawDescriptionHelpFormatter,
- )
- parser.add_argument(
- "--model", required=True, help="HuggingFace model ID (e.g. Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice)"
- )
- parser.add_argument("--task", default="all", help="Task type: voice_clone | default_voice | voice_design | all")
- parser.add_argument("--locale", default="en", choices=["en", "zh"])
- parser.add_argument("--concurrency", type=int, nargs="+", default=[1, 4], metavar="N")
- parser.add_argument(
- "--num-prompts",
- type=int,
- nargs="+",
- default=[20],
- metavar="N",
- help="Number of prompts per run. If one value, applied to all concurrency levels.",
- )
- parser.add_argument(
- "--dataset-path", default=None, help="Root of seed-tts-eval dataset (required for voice_clone/default_voice)"
- )
- parser.add_argument("--wer-eval", action="store_true", help="Enable WER/SIM/UTMOS quality eval")
- parser.add_argument("--output-dir", default=None, help="Directory to save result JSON files")
- parser.add_argument("--host", default="localhost")
- parser.add_argument("--port", type=int, default=8000)
- parser.add_argument("--model-configs", default=str(_DEFAULT_MODEL_CONFIGS), help="Path to model_configs.yaml")
- parser.add_argument("extra", nargs=argparse.REMAINDER, help="Extra args passed directly to vllm bench serve")
- args = parser.parse_args()
-
- model_configs = load_model_configs(Path(args.model_configs))
- if args.model not in model_configs:
- known = "\n ".join(model_configs.keys())
- print(f"[bench_tts] ERROR: model '{args.model}' not in model_configs.yaml.\nKnown models:\n {known}")
- sys.exit(1)
-
- model_cfg = model_configs[args.model]
- supported_tasks: list[str] = model_cfg.get("supported_tasks", [])
-
- tasks_to_run: list[str]
- if args.task == "all":
- tasks_to_run = supported_tasks
- elif args.task in supported_tasks:
- tasks_to_run = [args.task]
- else:
- print(
- f"[bench_tts] ERROR: task '{args.task}' not supported by {args.model}.\nSupported tasks: {supported_tasks}"
- )
- sys.exit(1)
-
- # Align num_prompts list with concurrency list
- num_prompts_list: list[int] = args.num_prompts
- if len(num_prompts_list) == 1:
- num_prompts_list = num_prompts_list * len(args.concurrency)
- elif len(num_prompts_list) != len(args.concurrency):
- print(
- f"[bench_tts] ERROR: --num-prompts ({len(num_prompts_list)} values) must be "
- f"length 1 or match --concurrency ({len(args.concurrency)} values)."
- )
- sys.exit(1)
-
- all_results: list[dict[str, Any]] = []
-
- for task in tasks_to_run:
- for concurrency, num_prompts in zip(args.concurrency, num_prompts_list):
- ts = datetime.now().strftime("%Y%m%d-%H%M%S")
- result_filename = f"bench_tts_{args.model.replace('/', '_')}_{task}_c{concurrency}_{ts}.json"
- cmd = build_bench_args(
- host=args.host,
- port=args.port,
- model=args.model,
- task=task,
- model_cfg=model_cfg,
- locale=args.locale,
- num_prompts=num_prompts,
- concurrency=concurrency,
- dataset_path=args.dataset_path,
- wer_eval=args.wer_eval,
- output_dir=args.output_dir,
- result_filename=result_filename,
- extra_cli_args=args.extra or [],
- )
- result = run_one_benchmark(cmd)
- if result is not None:
- result["_task"] = task
- result["_concurrency"] = concurrency
- all_results.append(result)
- # Persist the metadata so plot_results.py can pick it up.
- if args.output_dir and result_filename:
- result_path = Path(args.output_dir) / result_filename
- if result_path.is_file():
- result_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
-
- print_summary_table(all_results)
-
-
-if __name__ == "__main__":
- main()
diff --git a/benchmarks/tts/bench_voxcpm_offline.py b/benchmarks/tts/bench_voxcpm_offline.py
deleted file mode 100644
index 672b77f1495..00000000000
--- a/benchmarks/tts/bench_voxcpm_offline.py
+++ /dev/null
@@ -1,922 +0,0 @@
-"""Offline VoxCPM benchmark for vLLM Omni.
-
-Supports both:
-- sync one-shot (Omni.generate)
-- streaming (AsyncOmni.generate with async_chunk config)
-- text-only synthesis
-- voice cloning
-- text/clone batch inputs from txt or jsonl
-
-Usage::
-
- # Sync (default voice)
- python benchmarks/tts/bench_voxcpm_offline.py \\
- --model /path/to/VoxCPM \\
- --text "Hello world" \\
- --output-dir results/audio/
-
- # Streaming (async_chunk)
- python benchmarks/tts/bench_voxcpm_offline.py \\
- --model /path/to/VoxCPM \\
- --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml \\
- --txt-prompts prompts.txt \\
- --output-dir results/audio/
-
- # Voice cloning batch via JSONL
- python benchmarks/tts/bench_voxcpm_offline.py \\
- --model /path/to/VoxCPM \\
- --jsonl-prompts prompts.jsonl \\
- --output-dir results/audio/
-"""
-
-from __future__ import annotations
-
-import asyncio
-import json
-import logging
-import os
-import tempfile
-import time
-import uuid
-from dataclasses import dataclass
-from pathlib import Path
-from typing import Any
-
-import torch
-from vllm.utils.argparse_utils import FlexibleArgumentParser
-
-from vllm_omni import AsyncOmni, Omni
-
-
-def _find_repo_root(start: Path) -> Path:
- """Walk up from ``start`` until a repo marker is found.
-
- Falls back to ``parents[2]`` for backwards compatibility if no marker hits
- (which can only happen in unusual checkouts — the tree should always have
- pyproject.toml + vllm_omni/ at the top level).
- """
- for candidate in [start, *start.parents]:
- if (candidate / "pyproject.toml").is_file() and (candidate / "vllm_omni").is_dir():
- return candidate
- return start.parents[2]
-
-
-REPO_ROOT = _find_repo_root(Path(__file__).resolve())
-DEFAULT_STAGE_ASYNC = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm_async_chunk.yaml"
-DEFAULT_STAGE_SYNC = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm.yaml"
-
-logger = logging.getLogger(__name__)
-
-
-@dataclass(frozen=True, slots=True)
-class PromptSpec:
- text: str
- label: str
- ref_audio: str | None = None
- ref_text: str | None = None
-
-
-def _require_soundfile():
- try:
- import soundfile as sf # type: ignore
- except ModuleNotFoundError as exc:
- raise RuntimeError(
- "soundfile is required to write VoxCPM benchmark WAV outputs. Install it with: pip install soundfile"
- ) from exc
- return sf
-
-
-def _build_prompt(
- args,
- *,
- text: str,
- ref_audio: str | None = None,
- ref_text: str | None = None,
- global_request_id: str | None = None,
-) -> dict[str, Any]:
- additional_information: dict[str, list[Any]] = {
- "text": [text],
- "cfg_value": [args.cfg_value],
- "inference_timesteps": [args.inference_timesteps],
- "min_len": [args.min_len],
- "max_new_tokens": [args.max_new_tokens],
- }
- if args.streaming_prefix_len is not None:
- additional_information["streaming_prefix_len"] = [args.streaming_prefix_len]
-
- if ref_audio:
- additional_information["ref_audio"] = [ref_audio]
- if ref_text:
- additional_information["ref_text"] = [ref_text]
- if global_request_id is not None:
- additional_information["global_request_id"] = [global_request_id]
-
- return {
- "prompt_token_ids": [1],
- "additional_information": additional_information,
- }
-
-
-def _extract_audio_tensor(mm: dict[str, Any]) -> torch.Tensor:
- audio = mm.get("audio", mm.get("model_outputs"))
- if audio is None:
- raise ValueError("No audio output found in multimodal output.")
- if isinstance(audio, list):
- parts = [torch.as_tensor(a).float().cpu().reshape(-1) for a in audio]
- audio = torch.cat(parts, dim=-1) if parts else torch.zeros(0)
- if not isinstance(audio, torch.Tensor):
- audio = torch.as_tensor(audio)
- return audio.float().cpu().reshape(-1)
-
-
-def _extract_sample_rate(mm: dict[str, Any]) -> int:
- sr_raw = mm.get("sr", 24000)
- if isinstance(sr_raw, list) and sr_raw:
- sr_raw = sr_raw[-1]
- if hasattr(sr_raw, "item"):
- return int(sr_raw.item())
- return int(sr_raw)
-
-
-def _emit_offline_metrics(
- *,
- request_id: str,
- elapsed_s: float,
- first_audio_elapsed: float | None,
- audio_duration_s: float,
-) -> None:
- metrics = {
- "request_id": request_id,
- "ttfp_ms": round(first_audio_elapsed * 1000.0, 3) if first_audio_elapsed is not None else None,
- "audio_duration_s": round(audio_duration_s, 6),
- "rtf": round(elapsed_s / audio_duration_s, 6) if audio_duration_s > 0 else None,
- }
- print(f"[OfflineMetrics] {metrics}")
-
-
-def _write_audio_tensor(output_path: Path, audio_tensor: Any, sample_rate: int) -> None:
- sf = _require_soundfile()
- if isinstance(audio_tensor, torch.Tensor):
- audio_np = audio_tensor.float().cpu().clamp(-1.0, 1.0).numpy()
- else:
- audio_np = torch.as_tensor(audio_tensor).float().cpu().clamp(-1.0, 1.0).numpy()
- sf.write(
- output_path,
- audio_np,
- sample_rate,
- format="WAV",
- subtype="PCM_16",
- )
-
-
-def _save_wav(mm: dict[str, Any], output_dir: Path, request_id: str) -> Path:
- output_dir.mkdir(parents=True, exist_ok=True)
- output_path = output_dir / f"output_{request_id}.wav"
- _write_audio_tensor(output_path, _extract_audio_tensor(mm), _extract_sample_rate(mm))
- return output_path
-
-
-def _iter_request_multimodal_outputs(request_output: Any):
- outputs = getattr(request_output, "outputs", None)
- if outputs:
- for output in outputs:
- mm = getattr(output, "multimodal_output", None)
- if isinstance(mm, dict):
- yield mm
-
- mm = getattr(request_output, "multimodal_output", None)
- if isinstance(mm, dict):
- yield mm
-
-
-def _read_non_empty_lines(path: str) -> list[str]:
- with open(path, encoding="utf-8") as f:
- return [line.strip() for line in f if line.strip()]
-
-
-def _load_prompt_specs(args) -> list[PromptSpec]:
- specs: list[PromptSpec] = []
-
- if args.txt_prompts is not None:
- texts = _read_non_empty_lines(args.txt_prompts)
- if not texts:
- raise ValueError(f"No prompts found in {args.txt_prompts}")
- for idx, text in enumerate(texts, start=1):
- specs.append(
- PromptSpec(
- text=text,
- label=f"item{idx:03d}",
- ref_audio=args.ref_audio,
- ref_text=args.ref_text,
- )
- )
- return specs
-
- if args.jsonl_prompts is not None:
- with open(args.jsonl_prompts, encoding="utf-8") as f:
- for line_no, raw_line in enumerate(f, start=1):
- line = raw_line.strip()
- if not line:
- continue
- try:
- item = json.loads(line)
- except json.JSONDecodeError as exc:
- raise ValueError(f"{args.jsonl_prompts}:{line_no} is not valid JSON: {exc}") from exc
- if not isinstance(item, dict):
- raise ValueError(f"{args.jsonl_prompts}:{line_no} must be a JSON object")
-
- text = item.get("text")
- if not isinstance(text, str) or not text.strip():
- raise ValueError(f"{args.jsonl_prompts}:{line_no} requires non-empty string field 'text'")
-
- ref_audio = item.get("ref_audio", args.ref_audio)
- ref_text = item.get("ref_text", args.ref_text)
- if (ref_audio is None) != (ref_text is None):
- raise ValueError(
- f"{args.jsonl_prompts}:{line_no} must provide both 'ref_audio' and 'ref_text' together"
- )
-
- specs.append(
- PromptSpec(
- text=text.strip(),
- label=f"item{len(specs) + 1:03d}",
- ref_audio=ref_audio,
- ref_text=ref_text,
- )
- )
-
- if not specs:
- raise ValueError(f"No prompts found in {args.jsonl_prompts}")
- return specs
-
- specs.append(
- PromptSpec(
- text=args.text,
- label="item001",
- ref_audio=args.ref_audio,
- ref_text=args.ref_text,
- )
- )
- return specs
-
-
-def _build_prompt_for_spec(args, spec: PromptSpec, *, global_request_id: str | None = None) -> dict[str, Any]:
- return _build_prompt(
- args,
- text=spec.text,
- ref_audio=spec.ref_audio,
- ref_text=spec.ref_text,
- global_request_id=global_request_id,
- )
-
-
-def _count_voice_clone_prompts(prompt_specs: list[PromptSpec]) -> int:
- return sum(1 for spec in prompt_specs if spec.ref_audio is not None)
-
-
-def _get_warmup_specs(prompt_specs: list[PromptSpec]) -> list[PromptSpec]:
- return prompt_specs[:1]
-
-
-def _extract_stream_finished(stage_output: Any) -> bool:
- request_output = getattr(stage_output, "request_output", None)
- request_finished = getattr(request_output, "finished", None)
- if request_finished is not None:
- return bool(request_finished)
- return bool(getattr(stage_output, "finished", False))
-
-
-def _build_profiled_stage_config(
- stage_configs_path: str,
- profiler_dir: str,
-) -> str:
- stage_config_path = Path(stage_configs_path)
- yaml_text = stage_config_path.read_text(encoding="utf-8")
- injected_lines: list[str] = []
- injected_count = 0
-
- for line in yaml_text.splitlines():
- injected_lines.append(line)
- if line.strip() != "engine_args:":
- continue
- indent = line[: len(line) - len(line.lstrip())]
- child_indent = indent + " "
- grandchild_indent = child_indent + " "
- injected_lines.extend(
- [
- f"{child_indent}profiler_config:",
- f'{grandchild_indent}profiler: "torch"',
- f'{grandchild_indent}torch_profiler_dir: "{profiler_dir}"',
- f"{grandchild_indent}torch_profiler_with_stack: true",
- ]
- )
- injected_count += 1
-
- if injected_count == 0:
- raise ValueError(f"No engine_args block found in stage config: {stage_configs_path}")
-
- tmp = tempfile.NamedTemporaryFile(
- mode="w",
- encoding="utf-8",
- delete=False,
- suffix=".yaml",
- prefix=f"{stage_config_path.stem}_profile_",
- )
- tmp.write("\n".join(injected_lines) + "\n")
- tmp.close()
- return tmp.name
-
-
-def parse_args():
- parser = FlexibleArgumentParser(
- description="Offline split-stage VoxCPM inference with vLLM Omni (auto sync/streaming by stage config)"
- )
- parser.add_argument(
- "--model",
- type=str,
- default=os.environ.get("VOXCPM_MODEL"),
- help="Local VoxCPM model directory. Defaults to $VOXCPM_MODEL.",
- )
- parser.add_argument(
- "--text",
- type=str,
- default="This is a split-stage VoxCPM synthesis example running on vLLM Omni.",
- help="Text to synthesize. Ignored when --txt-prompts or --jsonl-prompts is used.",
- )
- parser.add_argument(
- "--txt-prompts",
- type=str,
- default=None,
- help="Path to a .txt file with one synthesis text per line.",
- )
- parser.add_argument(
- "--jsonl-prompts",
- type=str,
- default=None,
- help=(
- "Path to a .jsonl file. Each line must contain at least {'text': ...}; "
- "clone rows can also set ref_audio/ref_text, and ref_text must be the "
- "real transcript of ref_audio."
- ),
- )
- parser.add_argument(
- "--ref-audio",
- type=str,
- default=None,
- help=(
- "Optional reference audio path for voice cloning. With --txt-prompts, "
- "the same reference is applied to every line."
- ),
- )
- parser.add_argument(
- "--ref-text",
- type=str,
- default=None,
- help=(
- "Real transcript of the reference audio. Placeholder text or mismatched "
- "text will usually produce noisy/electronic clone audio."
- ),
- )
- parser.add_argument(
- "--stage-configs-path",
- type=str,
- default=str(DEFAULT_STAGE_SYNC),
- help="Stage config YAML path. Routing is selected only from this path.",
- )
- parser.add_argument(
- "--cfg-value",
- type=float,
- default=2.0,
- help="Classifier-free guidance value for VoxCPM.",
- )
- parser.add_argument(
- "--inference-timesteps",
- type=int,
- default=10,
- help="Number of inference timesteps.",
- )
- parser.add_argument(
- "--min-len",
- type=int,
- default=2,
- help="Minimum generated token length.",
- )
- parser.add_argument(
- "--max-new-tokens",
- type=int,
- default=4096,
- help="Maximum generated token length.",
- )
- parser.add_argument(
- "--streaming-prefix-len",
- type=int,
- default=None,
- help="VoxCPM streaming window (optional, streaming mode only).",
- )
- parser.add_argument(
- "--output-dir",
- type=str,
- default=None,
- help="Directory for output WAV files.",
- )
- parser.add_argument(
- "--stage-init-timeout",
- type=int,
- default=600,
- help="Stage initialization timeout in seconds.",
- )
- parser.add_argument(
- "--log-stats",
- dest="log_stats",
- action="store_true",
- help="Enable vLLM Omni stats logging.",
- )
- parser.add_argument(
- "--no-log-stats",
- dest="log_stats",
- action="store_false",
- help="Disable vLLM Omni stats logging.",
- )
- parser.set_defaults(log_stats=True)
- parser.add_argument(
- "--num-runs",
- type=int,
- default=1,
- help="Number of full inference runs (same prompt each time). Default 1.",
- )
- parser.add_argument(
- "--warmup-runs",
- type=int,
- default=0,
- help=(
- "Optional number of warmup passes before measured runs. Warmup uses only "
- "the first prompt and does not save outputs."
- ),
- )
- parser.add_argument(
- "--enable-profiler",
- action="store_true",
- help=(
- "Enable torch profiler for the configured stages. A temporary profiled "
- "stage config is generated automatically."
- ),
- )
- parser.add_argument(
- "--profiler-dir",
- type=str,
- default=None,
- help="Directory for profiler traces. Defaults to /profiler when profiling is enabled.",
- )
- parser.add_argument(
- "--profiler-stages",
- type=int,
- nargs="*",
- default=None,
- help="Optional stage ids to profile. Defaults to all stages that have profiler_config.",
- )
- parser.add_argument(
- "--profiler-wait-seconds",
- type=float,
- default=30.0,
- help="Seconds to wait after stop_profile for trace files to flush.",
- )
- args = parser.parse_args()
-
- if not args.model:
- parser.error("--model is required unless $VOXCPM_MODEL is set")
- if args.txt_prompts is not None and args.jsonl_prompts is not None:
- parser.error("--txt-prompts and --jsonl-prompts are mutually exclusive")
- if (args.ref_audio is None) != (args.ref_text is None):
- parser.error("--ref-audio and --ref-text must be provided together")
- if args.num_runs < 1:
- parser.error("--num-runs must be >= 1")
- if args.warmup_runs < 0:
- parser.error("--warmup-runs must be >= 0")
- if args.output_dir is None:
- args.output_dir = (
- "output_audio_streaming" if _is_streaming_stage_config(args.stage_configs_path) else "output_audio"
- )
- if args.enable_profiler and args.profiler_dir is None:
- args.profiler_dir = str(Path(args.output_dir) / "profiler")
- try:
- args.prompt_specs = _load_prompt_specs(args)
- except ValueError as exc:
- parser.error(str(exc))
-
- return args
-
-
-def _is_streaming_stage_config(stage_configs_path: str) -> bool:
- cfg_name = Path(stage_configs_path).name.lower()
- return "async_chunk" in cfg_name
-
-
-async def _collect_streaming_audio(
- omni: AsyncOmni,
- args: Any,
- spec: PromptSpec,
- request_id: str,
- *,
- phase_label: str,
- prompt_index: int,
- prompt_count: int,
- print_prompt: bool = False,
-) -> tuple[torch.Tensor, int, float, float | None]:
- prompt = _build_prompt_for_spec(args, spec, global_request_id=request_id)
- delta_chunks: list[torch.Tensor] = []
- sample_rate = 24000
- chunk_i = 0
- prev_total_samples = 0
- t_start = time.perf_counter()
- first_audio_elapsed: float | None = None
-
- if print_prompt:
- print(f"---prompt---:{prompt}")
-
- async for stage_output in omni.generate(prompt, request_id=request_id):
- mm = getattr(stage_output, "multimodal_output", None)
- if not isinstance(mm, dict):
- ro = getattr(stage_output, "request_output", None)
- if ro is None:
- continue
- mm = getattr(ro, "multimodal_output", None)
- if not isinstance(mm, dict) and getattr(ro, "outputs", None):
- seq = ro.outputs[0]
- mm = getattr(seq, "multimodal_output", None)
- if not isinstance(mm, dict):
- continue
- sample_rate = _extract_sample_rate(mm)
- try:
- w = _extract_audio_tensor(mm)
- n = int(w.numel())
- if n == 0:
- continue
- finished = _extract_stream_finished(stage_output)
- if n > prev_total_samples:
- delta = w.reshape(-1)[prev_total_samples:]
- prev_total_samples = n
- elif finished and n == prev_total_samples:
- delta = w.reshape(-1)[:0]
- else:
- delta = w.reshape(-1)
- prev_total_samples += int(delta.numel())
- if int(delta.numel()) > 0:
- delta_chunks.append(delta)
- if first_audio_elapsed is None and int(delta.numel()) > 0:
- first_audio_elapsed = time.perf_counter() - t_start
- logger.info(
- "%s prompt=%d/%d chunk=%d delta_samples=%d buf_len=%d finished=%s",
- phase_label,
- prompt_index + 1,
- prompt_count,
- chunk_i,
- int(delta.numel()),
- n,
- finished,
- )
- chunk_i += 1
- except ValueError:
- if not _extract_stream_finished(stage_output):
- logger.debug("skip non-audio partial output chunk=%d", chunk_i)
-
- if not delta_chunks:
- raise RuntimeError("No audio chunks received; check stage config and logs.")
-
- audio_cat = torch.cat([c.reshape(-1) for c in delta_chunks], dim=0)
- elapsed = time.perf_counter() - t_start
- return audio_cat, sample_rate, elapsed, first_audio_elapsed
-
-
-async def _abort_streaming_residual_work(
- omni: AsyncOmni,
- request_id: str,
- *,
- settle_seconds: float = 0.1,
-) -> None:
- """Stop any late stage-0 work once the final audio has been collected."""
- await omni.engine.abort_async([request_id])
- if settle_seconds > 0:
- await asyncio.sleep(settle_seconds)
-
-
-async def _run_streaming_single(
- omni: AsyncOmni,
- args: Any,
- spec: PromptSpec,
- output_dir: Path,
- request_id: str,
- *,
- run_index: int,
- num_runs: int,
- prompt_index: int,
- prompt_count: int,
-) -> Path:
- audio_cat, sample_rate, elapsed, first_audio_elapsed = await _collect_streaming_audio(
- omni,
- args,
- spec,
- request_id,
- phase_label=f"run={run_index + 1}/{num_runs}",
- prompt_index=prompt_index,
- prompt_count=prompt_count,
- print_prompt=(run_index == 0 and prompt_index == 0),
- )
- await _abort_streaming_residual_work(omni, request_id)
- output_path = output_dir / f"output_run{run_index + 1}_{spec.label}.wav"
- _write_audio_tensor(output_path, audio_cat, sample_rate)
- audio_duration_s = float(audio_cat.numel()) / float(sample_rate) if sample_rate > 0 else 0.0
- ttfp_text = f", ttfp={first_audio_elapsed:.2f}s" if first_audio_elapsed is not None else ""
- rtf_text = f", rtf={elapsed / audio_duration_s:.3f}" if audio_duration_s > 0 else ""
- print(
- f"Saved (streaming) run {run_index + 1}/{num_runs}, "
- f"prompt {prompt_index + 1}/{prompt_count}: {output_path} ({elapsed:.2f}s{ttfp_text}{rtf_text})"
- )
- _emit_offline_metrics(
- request_id=request_id,
- elapsed_s=elapsed,
- first_audio_elapsed=first_audio_elapsed,
- audio_duration_s=audio_duration_s,
- )
- return output_path
-
-
-async def _run_streaming_warmup(args, omni: AsyncOmni) -> None:
- if args.warmup_runs == 0:
- return
-
- warmup_specs = _get_warmup_specs(args.prompt_specs)
- print(
- f"Warmup: {args.warmup_runs} run(s) using the first prompt "
- f"({len(warmup_specs)} prompt(s)); outputs will be discarded."
- )
- for warmup_index in range(args.warmup_runs):
- t_warmup = time.perf_counter()
- tasks = []
- request_ids: list[str] = []
- for prompt_index, spec in enumerate(warmup_specs):
- request_id = f"warmup_stream_{warmup_index + 1}_{spec.label}_{uuid.uuid4().hex[:8]}"
- request_ids.append(request_id)
- tasks.append(
- _collect_streaming_audio(
- omni,
- args,
- spec,
- request_id,
- phase_label=f"warmup={warmup_index + 1}/{args.warmup_runs}",
- prompt_index=prompt_index,
- prompt_count=len(warmup_specs),
- )
- )
- results = await asyncio.gather(*tasks)
- for request_id in request_ids:
- await _abort_streaming_residual_work(omni, request_id)
- total_samples = sum(int(audio.numel()) for audio, _, _, _ in results)
- warmup_ttfps = [ttfp for _, _, _, ttfp in results if ttfp is not None]
- ttfp_text = f", ttfp={min(warmup_ttfps):.2f}s" if warmup_ttfps else ""
- print(
- f"Warmup (streaming) {warmup_index + 1}/{args.warmup_runs} finished: "
- f"{len(results)} prompt(s), {total_samples} sample(s) "
- f"({time.perf_counter() - t_warmup:.2f}s{ttfp_text})"
- )
-
-
-async def _run_streaming(args) -> list[Path]:
- output_dir = Path(args.output_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
-
- omni = AsyncOmni(
- model=args.model,
- stage_configs_path=args.stage_configs_path,
- log_stats=args.log_stats,
- stage_init_timeout=args.stage_init_timeout,
- )
-
- await _run_streaming_warmup(args, omni)
- profiler_started = False
- if args.enable_profiler:
- profile_prefix = f"voxcpm_streaming_{int(time.time())}"
- stages_text = args.profiler_stages if args.profiler_stages is not None else "all-configured"
- print(f"Starting profiler (streaming): stages={stages_text}, dir={args.profiler_dir}")
- await omni.start_profile(profile_prefix=profile_prefix, stages=args.profiler_stages)
- profiler_started = True
- t_total = time.perf_counter()
- total_elapsed = 0.0
- paths: list[Path] = []
- prompt_specs: list[PromptSpec] = args.prompt_specs
- try:
- for run in range(args.num_runs):
- for prompt_index, spec in enumerate(prompt_specs):
- request_id = f"stream_{run + 1}_{spec.label}_{uuid.uuid4().hex[:8]}"
- paths.append(
- await _run_streaming_single(
- omni,
- args,
- spec,
- output_dir,
- request_id,
- run_index=run,
- num_runs=args.num_runs,
- prompt_index=prompt_index,
- prompt_count=len(prompt_specs),
- )
- )
- total_elapsed = time.perf_counter() - t_total
- finally:
- if profiler_started:
- print("Stopping profiler (streaming)...")
- await omni.stop_profile(stages=args.profiler_stages)
- if args.profiler_wait_seconds > 0:
- print(f"Waiting {args.profiler_wait_seconds:.1f}s for profiler traces to flush...")
- await asyncio.sleep(args.profiler_wait_seconds)
-
- print(
- f"All streaming runs finished: {args.num_runs} run(s), "
- f"{len(prompt_specs)} prompt(s), {len(paths)} file(s) in {total_elapsed:.2f}s total"
- )
- return paths
-
-
-def _run_sync(args) -> list[Path]:
- output_dir = Path(args.output_dir)
-
- omni = Omni(
- model=args.model,
- stage_configs_path=args.stage_configs_path,
- log_stats=args.log_stats,
- stage_init_timeout=args.stage_init_timeout,
- )
-
- def _run_sync_single(
- spec: PromptSpec,
- *,
- request_prefix: str,
- save_outputs: bool,
- run_index: int | None = None,
- ) -> tuple[list[Path], int, float | None, float, float, str]:
- global_request_id = f"{request_prefix}_{spec.label}"
- prompt = _build_prompt_for_spec(args, spec, global_request_id=global_request_id)
- if save_outputs and run_index == 0 and spec.label == "item001":
- print(f"---prompt---:{prompt}")
-
- saved_paths: list[Path] = []
- output_count = 0
- first_audio_elapsed: float | None = None
- total_audio_duration_s = 0.0
- metrics_request_id = global_request_id
- t_start = time.perf_counter()
- for stage_outputs in omni.generate(prompt):
- request_output = stage_outputs.request_output
- if request_output is None:
- continue
- request_output_id = getattr(request_output, "request_id", None)
- if isinstance(request_output_id, str) and request_output_id:
- metrics_request_id = request_output_id
- for j, mm in enumerate(_iter_request_multimodal_outputs(request_output)):
- output_count += 1
- if first_audio_elapsed is None:
- try:
- audio_tensor = _extract_audio_tensor(mm)
- if int(audio_tensor.numel()) > 0:
- first_audio_elapsed = time.perf_counter() - t_start
- total_audio_duration_s += float(audio_tensor.numel()) / float(_extract_sample_rate(mm))
- except ValueError:
- pass
- else:
- try:
- audio_tensor = _extract_audio_tensor(mm)
- total_audio_duration_s += float(audio_tensor.numel()) / float(_extract_sample_rate(mm))
- except ValueError:
- pass
- if not save_outputs:
- continue
- save_stem = f"run{run_index + 1}_{spec.label}" if j == 0 else f"run{run_index + 1}_{spec.label}_{j}"
- saved_paths.append(_save_wav(mm, output_dir, save_stem))
-
- if output_count == 0:
- raise RuntimeError("No output from Omni.generate")
- elapsed_s = time.perf_counter() - t_start
- return saved_paths, output_count, first_audio_elapsed, elapsed_s, total_audio_duration_s, metrics_request_id
-
- if args.warmup_runs:
- warmup_specs = _get_warmup_specs(args.prompt_specs)
- print(
- f"Warmup: {args.warmup_runs} run(s) using the first prompt "
- f"({len(warmup_specs)} prompt(s)); outputs will be discarded."
- )
- for warmup_index in range(args.warmup_runs):
- t_warmup = time.perf_counter()
- _, output_count, first_audio_elapsed, elapsed_s, audio_duration_s, _ = _run_sync_single(
- warmup_specs[0],
- request_prefix=f"warmup_sync{warmup_index + 1}",
- save_outputs=False,
- )
- ttfp_text = f", ttfp={first_audio_elapsed:.2f}s" if first_audio_elapsed is not None else ""
- rtf_text = f", rtf={elapsed_s / audio_duration_s:.3f}" if audio_duration_s > 0 else ""
- print(
- f"Warmup (sync) {warmup_index + 1}/{args.warmup_runs} finished: "
- f"{output_count} output(s) ({time.perf_counter() - t_warmup:.2f}s{ttfp_text}{rtf_text})"
- )
-
- profiler_started = False
- if args.enable_profiler:
- profile_prefix = f"voxcpm_sync_{int(time.time())}"
- stages_text = args.profiler_stages if args.profiler_stages is not None else "all-configured"
- print(f"Starting profiler (sync): stages={stages_text}, dir={args.profiler_dir}")
- omni.start_profile(profile_prefix=profile_prefix, stages=args.profiler_stages)
- profiler_started = True
-
- t_total = time.perf_counter()
- total_elapsed = 0.0
- saved_paths: list[Path] = []
- prompt_specs: list[PromptSpec] = args.prompt_specs
- try:
- for run in range(args.num_runs):
- t_run = time.perf_counter()
- run_paths: list[Path] = []
- for prompt_index, spec in enumerate(prompt_specs):
- prompt_paths, _, first_audio_elapsed, elapsed_s, audio_duration_s, metrics_request_id = (
- _run_sync_single(
- spec,
- request_prefix=f"sync_run{run + 1}_{prompt_index + 1:03d}",
- save_outputs=True,
- run_index=run,
- )
- )
- run_paths.extend(prompt_paths)
- ttfp_text = f", ttfp={first_audio_elapsed:.2f}s" if first_audio_elapsed is not None else ""
- rtf_text = f", rtf={elapsed_s / audio_duration_s:.3f}" if audio_duration_s > 0 else ""
- print(
- f"Saved (sync) run {run + 1}/{args.num_runs}, "
- f"prompt {prompt_index + 1}/{len(prompt_specs)}: {len(prompt_paths)} file(s){ttfp_text}{rtf_text}"
- )
- _emit_offline_metrics(
- request_id=metrics_request_id,
- elapsed_s=elapsed_s,
- first_audio_elapsed=first_audio_elapsed,
- audio_duration_s=audio_duration_s,
- )
-
- saved_paths.extend(run_paths)
- print(
- f"Run {run + 1}/{args.num_runs} finished: {len(run_paths)} file(s) ({time.perf_counter() - t_run:.2f}s)"
- )
- for path in run_paths:
- print(f" {path}")
-
- total_elapsed = time.perf_counter() - t_total
- finally:
- if profiler_started:
- print("Stopping profiler (sync)...")
- omni.stop_profile(stages=args.profiler_stages)
- if args.profiler_wait_seconds > 0:
- print(f"Waiting {args.profiler_wait_seconds:.1f}s for profiler traces to flush...")
- time.sleep(args.profiler_wait_seconds)
-
- print(
- f"All sync runs finished: {args.num_runs} run(s), "
- f"{len(prompt_specs)} prompt(s), {len(saved_paths)} file(s) in {total_elapsed:.2f}s total"
- )
- return saved_paths
-
-
-def main(args) -> int:
- logging.basicConfig(level=logging.INFO)
- profiled_stage_config_path: str | None = None
- original_stage_config_path = args.stage_configs_path
- if args.enable_profiler:
- Path(args.profiler_dir).mkdir(parents=True, exist_ok=True)
- profiled_stage_config_path = _build_profiled_stage_config(
- args.stage_configs_path,
- str(Path(args.profiler_dir).resolve()),
- )
- args.stage_configs_path = profiled_stage_config_path
-
- is_streaming = _is_streaming_stage_config(args.stage_configs_path)
- voice_clone_count = _count_voice_clone_prompts(args.prompt_specs)
- print(f"Model: {args.model}")
- print(f"Stage config: {original_stage_config_path}")
- print(f"Route: {'streaming' if is_streaming else 'sync'} (from stage-configs-path)")
- print(f"Prompt count: {len(args.prompt_specs)}")
- print("Batch mode: sequential (aligned with native VoxCPM)")
- print(f"Warmup runs: {args.warmup_runs}")
- print(f"Voice cloning prompts: {voice_clone_count}/{len(args.prompt_specs)}")
- if args.enable_profiler:
- print(f"Profiler: enabled (dir={args.profiler_dir}, stages={args.profiler_stages or 'all-configured'})")
- print(f"Profiled stage config: {args.stage_configs_path}")
- if voice_clone_count:
- print("Voice cloning note: --ref-text/ref_text must match the spoken content of the reference audio.")
- print(f"Num runs: {args.num_runs}")
- try:
- if is_streaming:
- asyncio.run(_run_streaming(args))
- else:
- _run_sync(args)
- finally:
- if profiled_stage_config_path is not None and os.path.exists(profiled_stage_config_path):
- os.unlink(profiled_stage_config_path)
- return 0
-
-
-if __name__ == "__main__":
- os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
- raise SystemExit(main(parse_args()))
diff --git a/benchmarks/tts/model_configs.yaml b/benchmarks/tts/model_configs.yaml
deleted file mode 100644
index 83b25370538..00000000000
--- a/benchmarks/tts/model_configs.yaml
+++ /dev/null
@@ -1,39 +0,0 @@
-# Universal TTS benchmark model registry.
-# Maps HuggingFace model ID → supported tasks + per-task extra body fields.
-# To add a new TTS model: add an entry here. No code changes required.
-#
-# The server auto-loads its Deploy YAML from vllm_omni/deploy/.yaml via
-# the Pipeline + Deploy schema introduced in #2383, so no stage_config path
-# is tracked here.
-
-models:
- # -CustomVoice checkpoints lack speaker_encoder weights, so voice_clone is
- # NOT supported (an attempt raises ValueError from _extract_speaker_embedding
- # at model runtime). Use -Base for voice_clone.
- Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice:
- supported_tasks: [default_voice, voice_design]
- backend: openai-audio-speech
- endpoint: /v1/audio/speech
- task_extra_body:
- default_voice:
- voice: Vivian
- language: English
- task_type: CustomVoice
- voice_design:
- task_type: VoiceDesign
- language: English
-
- Qwen/Qwen3-TTS-12Hz-1.7B-Base:
- supported_tasks: [voice_clone]
- backend: openai-audio-speech
- endpoint: /v1/audio/speech
- task_extra_body:
- voice_clone:
- task_type: Base
-
- openbmb/VoxCPM2:
- supported_tasks: [voice_clone]
- backend: openai-audio-speech
- endpoint: /v1/audio/speech
- task_extra_body:
- voice_clone: {}
diff --git a/benchmarks/tts/plot_results.py b/benchmarks/tts/plot_results.py
deleted file mode 100644
index f19c613209a..00000000000
--- a/benchmarks/tts/plot_results.py
+++ /dev/null
@@ -1,324 +0,0 @@
-"""Plot universal TTS benchmark results.
-
-Reads JSON files saved by ``bench_tts.py`` (via ``vllm bench serve --omni``)
-and generates comparison bar charts grouped by task type.
-
-Metrics plotted:
-- AUDIO_TTFP (mean audio time-to-first-packet, ms)
-- E2EL (mean end-to-end latency, ms)
-- Audio RTF (mean real-time factor)
-- Audio throughput (audio-seconds / wall-second)
-
-Quality metrics (WER / SIM / UTMOS) are printed in a table when present.
-
-Usage::
-
- # Single run — one JSON per task, all in results/
- python benchmarks/tts/plot_results.py \\
- --results results/bench_tts_*.json \\
- --output results/tts_benchmark.png
-
- # Compare two runs (e.g. async_chunk on vs off)
- python benchmarks/tts/plot_results.py \\
- --results run_a/bench_tts_*.json \\
- --results run_b/bench_tts_*.json \\
- --labels "async_chunk_on" "async_chunk_off" \\
- --output results/comparison.png
-"""
-
-from __future__ import annotations
-
-import argparse
-import json
-import math
-from pathlib import Path
-
-import matplotlib.pyplot as plt
-import numpy as np
-
-# ---------------------------------------------------------------------------
-# JSON loading
-# ---------------------------------------------------------------------------
-
-
-def load_run(paths: list[str]) -> list[dict]:
- """Load and merge all JSON files for one run into a flat list of records.
-
- Each record is expected to have at least ``_concurrency`` (int) and
- ``_task`` (str) keys injected by ``bench_tts.py``. Records that come
- from a file that contains a list are flattened.
- """
- records: list[dict] = []
- for p in paths:
- raw = json.loads(Path(p).read_text(encoding="utf-8"))
- if isinstance(raw, list):
- records.extend(raw)
- elif isinstance(raw, dict):
- records.append(raw)
- return records
-
-
-def _get(record: dict, key: str) -> float:
- v = record.get(key, float("nan"))
- if v is None or (isinstance(v, float) and math.isnan(v)):
- return float("nan")
- try:
- return float(v)
- except (TypeError, ValueError):
- return float("nan")
-
-
-# ---------------------------------------------------------------------------
-# Plotting helpers
-# ---------------------------------------------------------------------------
-
-
-def _bar_group(
- ax: plt.Axes,
- x: np.ndarray,
- data_per_label: dict[str, list[float]],
- width: float,
- colors: list[str],
- ylabel: str,
- title: str,
- concurrency_labels: list[str],
- fmt: str = ".1f",
-) -> None:
- n = len(data_per_label)
- offsets = np.linspace(-(n - 1) * width / 2, (n - 1) * width / 2, n) if n > 1 else [0.0]
-
- for i, (label, values) in enumerate(data_per_label.items()):
- plot_vals = [0.0 if math.isnan(v) else v for v in values]
- bar = ax.bar(x + offsets[i], plot_vals, width, label=label, color=colors[i % len(colors)], alpha=0.85)
- max_val = max((v for v in values if not math.isnan(v)), default=1.0)
- for rect, val in zip(bar, values):
- if not math.isnan(val) and val > 0:
- ax.text(
- rect.get_x() + rect.get_width() / 2,
- rect.get_height() + max_val * 0.02,
- f"{val:{fmt}}",
- ha="center",
- va="bottom",
- fontsize=8,
- fontweight="bold",
- )
-
- ax.set_xlabel("Concurrency", fontsize=11)
- ax.set_ylabel(ylabel, fontsize=11)
- ax.set_title(title, fontsize=12, fontweight="bold")
- ax.set_xticks(x)
- ax.set_xticklabels(concurrency_labels)
- ax.legend(fontsize=9)
- ax.grid(axis="y", alpha=0.3)
- ax.set_axisbelow(True)
-
-
-COLORS = ["#2196F3", "#FF5722", "#4CAF50", "#FFC107", "#9C27B0"]
-
-
-# ---------------------------------------------------------------------------
-# Comparison plot (multiple labels / runs)
-# ---------------------------------------------------------------------------
-
-
-def plot_comparison(
- all_runs: list[list[dict]],
- labels: list[str],
- output_path: str,
- task_filter: str | None = None,
- title_prefix: str = "TTS",
-) -> None:
- """One 2×2 subplot per task found in the data."""
- # Determine tasks to plot
- tasks: list[str] = []
- for run in all_runs:
- for r in run:
- t = r.get("_task", "unknown")
- if t not in tasks:
- tasks.append(t)
- if task_filter:
- tasks = [t for t in tasks if t == task_filter]
-
- n_tasks = len(tasks)
- if n_tasks == 0:
- print("[plot_results] No tasks found in data.")
- return
-
- fig, axes_grid = plt.subplots(n_tasks, 4, figsize=(18, 4.5 * n_tasks))
- fig.suptitle(f"{title_prefix} Benchmark", fontsize=15, fontweight="bold")
-
- # Ensure axes_grid is always 2D
- if n_tasks == 1:
- axes_grid = [axes_grid]
-
- for row_idx, task in enumerate(tasks):
- # Collect concurrencies across all runs for this task
- all_concs: set[int] = set()
- for run in all_runs:
- for r in run:
- if r.get("_task") == task:
- c = r.get("_concurrency")
- if c is not None:
- all_concs.add(int(c))
- concurrencies = sorted(all_concs)
- x = np.arange(len(concurrencies))
- conc_labels = [str(c) for c in concurrencies]
-
- def _series(run: list[dict], metric_key: str) -> list[float]:
- conc_map = {int(r["_concurrency"]): r for r in run if r.get("_task") == task and "_concurrency" in r}
- return [_get(conc_map.get(c, {}), metric_key) for c in concurrencies]
-
- metrics = [
- ("mean_audio_ttfp_ms", "TTFP (ms)", "Time-to-First-Packet", ".0f"),
- ("mean_e2el_ms", "E2E Latency (ms)", "End-to-End Latency", ".0f"),
- ("mean_audio_rtf", "RTF", "Real-Time Factor (RTF)", ".3f"),
- ("audio_throughput", "audio-s / wall-s", "Audio Throughput", ".2f"),
- ]
-
- axes_row = axes_grid[row_idx]
- for col_idx, (key, ylabel, subtitle, fmt) in enumerate(metrics):
- data_per_label = {lbl: _series(run, key) for lbl, run in zip(labels, all_runs)}
- _bar_group(
- axes_row[col_idx],
- x,
- data_per_label,
- width=0.3 if len(labels) > 1 else 0.5,
- colors=COLORS,
- ylabel=ylabel,
- title=f"{task} — {subtitle}",
- concurrency_labels=conc_labels,
- fmt=fmt,
- )
-
- plt.tight_layout()
- Path(output_path).parent.mkdir(parents=True, exist_ok=True)
- plt.savefig(output_path, dpi=150, bbox_inches="tight")
- print(f"Plot saved to {output_path}")
- plt.close()
-
-
-# ---------------------------------------------------------------------------
-# Markdown comparison table
-# ---------------------------------------------------------------------------
-
-
-def print_comparison_table(all_runs: list[list[dict]], labels: list[str]) -> None:
- tasks: list[str] = []
- for run in all_runs:
- for r in run:
- t = r.get("_task", "unknown")
- if t not in tasks:
- tasks.append(t)
-
- perf_metrics = [
- ("TTFP (ms)", "mean_audio_ttfp_ms", ".1f"),
- ("E2E (ms)", "mean_e2el_ms", ".1f"),
- ("RTF", "mean_audio_rtf", ".3f"),
- ("Throughput (a-s/s)", "audio_throughput", ".2f"),
- ]
- quality_metrics = [
- ("WER (%)", "seed_tts_mean_wer", ".1f"),
- ("SIM", "seed_tts_mean_sim", ".3f"),
- ("UTMOS", "seed_tts_mean_utmos", ".2f"),
- ]
-
- for task in tasks:
- all_concs: set[int] = set()
- for run in all_runs:
- for r in run:
- if r.get("_task") == task:
- c = r.get("_concurrency")
- if c is not None:
- all_concs.add(int(c))
- concurrencies = sorted(all_concs)
-
- print(f"\n## {task}\n")
- col_header = "| Metric | Concurrency |" + "".join(f" {lbl} |" for lbl in labels)
- sep = "| --- | --- |" + " --- |" * len(labels)
- print(col_header)
- print(sep)
-
- for metric, key, fmt in perf_metrics + quality_metrics:
- for c in concurrencies:
- row = f"| {metric} | {c} |"
- for run in all_runs:
- conc_map = {
- int(r["_concurrency"]): r for r in run if r.get("_task") == task and "_concurrency" in r
- }
- val = _get(conc_map.get(c, {}), key)
- row += f" {val:{fmt}} |" if not math.isnan(val) else " n/a |"
- print(row)
-
- # Improvement column (2-run comparison only)
- if len(all_runs) == 2:
- print(f"\n### Improvement ({labels[0]} vs {labels[1]})\n")
- print("| Metric | Concurrency | Change |")
- print("| --- | --- | --- |")
- for metric, key, _ in perf_metrics:
- for c in concurrencies:
- conc_map0 = {
- int(r["_concurrency"]): r for r in all_runs[0] if r.get("_task") == task and "_concurrency" in r
- }
- conc_map1 = {
- int(r["_concurrency"]): r for r in all_runs[1] if r.get("_task") == task and "_concurrency" in r
- }
- v0 = _get(conc_map0.get(c, {}), key)
- v1 = _get(conc_map1.get(c, {}), key)
- if not math.isnan(v0) and not math.isnan(v1) and v1 > 0:
- pct = (v1 - v0) / v1 * 100
- print(f"| {metric} | {c} | {pct:+.1f}% |")
-
-
-# ---------------------------------------------------------------------------
-# CLI
-# ---------------------------------------------------------------------------
-
-
-def parse_args() -> argparse.Namespace:
- parser = argparse.ArgumentParser(
- description=__doc__,
- formatter_class=argparse.RawDescriptionHelpFormatter,
- )
- parser.add_argument(
- "--results",
- type=str,
- nargs="+",
- action="append",
- required=True,
- metavar="FILE",
- help="JSON result file(s) for one run. Repeat --results for multiple runs to compare.",
- )
- parser.add_argument(
- "--labels",
- type=str,
- nargs="+",
- default=None,
- help="Label for each --results group (must match the number of --results groups).",
- )
- parser.add_argument("--output", type=str, default="results/tts_benchmark.png", help="Output image path.")
- parser.add_argument("--title", type=str, default="TTS", help="Title prefix for the plot.")
- parser.add_argument("--task", type=str, default=None, help="Filter to a single task (e.g. voice_clone).")
- return parser.parse_args()
-
-
-def main() -> None:
- args = parse_args()
-
- # args.results is a list-of-lists due to action="append"
- all_runs: list[list[dict]] = [load_run(group) for group in args.results]
- n_runs = len(all_runs)
-
- labels: list[str]
- if args.labels:
- if len(args.labels) != n_runs:
- raise SystemExit(f"--labels count ({len(args.labels)}) must match --results groups ({n_runs})")
- labels = args.labels
- else:
- labels = [f"run{i + 1}" for i in range(n_runs)]
-
- print_comparison_table(all_runs, labels)
- plot_comparison(all_runs, labels, args.output, task_filter=args.task, title_prefix=args.title)
-
-
-if __name__ == "__main__":
- main()
diff --git a/docker/Dockerfile.ci b/docker/Dockerfile.ci
index 9cbf89d0b79..24ce39bafd7 100644
--- a/docker/Dockerfile.ci
+++ b/docker/Dockerfile.ci
@@ -7,7 +7,7 @@ COPY . .
# Install system dependencies
RUN apt-get update && \
- apt-get install -y espeak-ng git jq && \
+ apt-get install -y espeak-ng ffmpeg git sox libsox-fmt-all jq && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
diff --git a/docker/Dockerfile.cuda b/docker/Dockerfile.cuda
deleted file mode 100644
index 28e10f4fb85..00000000000
--- a/docker/Dockerfile.cuda
+++ /dev/null
@@ -1,22 +0,0 @@
-ARG BASE_IMAGE=vllm/vllm-openai:v0.19.0
-FROM ${BASE_IMAGE}
-
-ARG COMMON_WORKDIR=/app
-
-WORKDIR ${COMMON_WORKDIR}
-
-# Step 1: Setup - Install system dependencies
-RUN apt-get update && \
- apt-get install -y git jq && \
- apt-get clean && \
- rm -rf /var/lib/apt/lists/*
-
-RUN mkdir -p ${COMMON_WORKDIR}/vllm-omni
-
-# Step 2: Copy vllm-omni code and install
-COPY . ${COMMON_WORKDIR}/vllm-omni
-RUN cd ${COMMON_WORKDIR}/vllm-omni && uv pip install --python "$(python3 -c 'import sys; print(sys.executable)')" --no-cache-dir "."
-
-RUN ln -sf /usr/bin/python3 /usr/bin/python
-
-ENTRYPOINT []
diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm
index a54aa3b7933..bfbb060bcb5 100644
--- a/docker/Dockerfile.rocm
+++ b/docker/Dockerfile.rocm
@@ -18,10 +18,8 @@ ARG COMMON_WORKDIR=/app
WORKDIR ${COMMON_WORKDIR}
# Step 1: Setup - Install system dependencies
-# Need to include ffmpeg because vllm rocm upstream docker image
-# does not include it.
RUN apt-get update && \
- apt-get install -y espeak-ng ffmpeg git jq && \
+ apt-get install -y espeak-ng ffmpeg git sox libsox-fmt-all jq && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
@@ -41,24 +39,6 @@ RUN if [ "${USE_NIGHTLY_BUILD}" = "1" ]; then \
# Step 3: Copy vllm-omni code and install without uv
RUN mkdir -p ${COMMON_WORKDIR}/vllm-omni
COPY . ${COMMON_WORKDIR}/vllm-omni
-
-# This is a workaround to ensure pytest exits with the correct status code in CI tests.
-RUN printf '%s\n' \
- 'import os' \
- '' \
- '_exit_code = 1' \
- '' \
- 'def pytest_sessionfinish(session, exitstatus):' \
- ' global _exit_code' \
- ' _exit_code = int(exitstatus)' \
- '' \
- 'def pytest_unconfigure(config):' \
- ' import sys' \
- ' sys.stdout.flush()' \
- ' sys.stderr.flush()' \
- ' os._exit(_exit_code)' \
- > ${COMMON_WORKDIR}/vllm-omni/conftest.py
-
RUN cd ${COMMON_WORKDIR}/vllm-omni && uv pip install --python "$(python3 -c 'import sys; print(sys.executable)')" --no-cache-dir ".[dev]" --no-build-isolation
RUN ln -sf /usr/bin/python3 /usr/bin/python
diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu
index 25d5d0c800e..17f1aebf0d0 100644
--- a/docker/Dockerfile.xpu
+++ b/docker/Dockerfile.xpu
@@ -15,7 +15,9 @@ RUN apt clean && apt-get update -y && \
apt-get install -y --no-install-recommends --fix-missing \
curl \
espeak-ng \
+ ffmpeg \
git \
+ libsndfile1 \
libsm6 \
libxext6 \
libgl1 \
diff --git a/docs/.nav.yml b/docs/.nav.yml
index b44e8e6b5a8..a4939961e89 100644
--- a/docs/.nav.yml
+++ b/docs/.nav.yml
@@ -10,7 +10,6 @@ nav:
- Image Generation: serving/image_generation_api.md
- Image Edit: serving/image_edit_api.md
- Text to Speech: serving/speech_api.md
- - Streaming Video Input: serving/video_stream_api.md
- Examples:
- examples/README.md
- Offline Inference:
@@ -35,7 +34,6 @@ nav:
- Online Serving:
- BAGEL-7B-MoT: user_guide/examples/online_serving/bagel.md
- vLLM-Omni Helm Chart: user_guide/examples/online_serving/chart-helm.md
- - Diffusers Backend Adapter Example: user_guide/examples/online_serving/diffusers_pipeline_adapter.md
- Fish Speech S2 Pro: user_guide/examples/online_serving/fish_speech.md
- GLM-Image Online Serving: user_guide/examples/online_serving/glm_image.md
- Image-To-Image: user_guide/examples/online_serving/image_to_image.md
@@ -66,8 +64,6 @@ nav:
- FP8: user_guide/diffusion/quantization/fp8.md
- Int8: user_guide/diffusion/quantization/int8.md
- GGUF: user_guide/diffusion/quantization/gguf.md
- - Attention Backends: user_guide/diffusion/attention_backends.md
- - Frame Interpolation: user_guide/diffusion/frame_interpolation.md
- Parallelism:
- Overview: user_guide/diffusion/parallelism/overview.md
- CFG Parallel: user_guide/diffusion/parallelism/cfg_parallel.md
@@ -84,6 +80,7 @@ nav:
- Developer Guide:
- General:
- contributing/README.md
+ - pr_reviewer.md
- glob: contributing/*
flatten_single_child_sections: true
- Model Implementation:
@@ -100,16 +97,14 @@ nav:
- design/feature/disaggregated_inference.md
- design/feature/ray_based_execution.md
- design/feature/omni_connectors/
- - design/feature/prefix_caching.md
- design/feature/cfg_parallel.md
- - design/feature/expert_parallel.md
- design/feature/sequence_parallel.md
- design/feature/tensor_parallel.md
- design/feature/vae_parallel.md
- design/feature/hsdp.md
- design/feature/cache_dit.md
- design/feature/teacache.md
- - design/feature/async_chunk.md
+ - design/feature/async_chunk_design.md
- design/feature/vae_parallel.md
- design/feature/diffusion_step_execution.md
- Module Design:
diff --git a/docs/api/README.md b/docs/api/README.md
index 0147f19e126..f65cbb525d9 100644
--- a/docs/api/README.md
+++ b/docs/api/README.md
@@ -5,7 +5,7 @@
Main entry points for vLLM-Omni inference and serving.
- [vllm_omni.entrypoints.async_omni.AsyncOmni][]
-- [vllm_omni.engine.cfg_companion_tracker.CfgCompanionTracker][]
+- [vllm_omni.entrypoints.cfg_companion_tracker.CfgCompanionTracker][]
- [vllm_omni.entrypoints.cli.benchmark.base.OmniBenchmarkSubcommandBase][]
- [vllm_omni.entrypoints.cli.benchmark.main.OmniBenchmarkSubcommand][]
- [vllm_omni.entrypoints.cli.benchmark.serve.OmniBenchmarkServingSubcommand][]
diff --git a/docs/assets/WeChat.jpg b/docs/assets/WeChat.jpg
index 83252b7569d..28956a12099 100644
Binary files a/docs/assets/WeChat.jpg and b/docs/assets/WeChat.jpg differ
diff --git a/docs/cli/serve.md b/docs/cli/serve.md
index 035fa056731..47a873b7211 100644
--- a/docs/cli/serve.md
+++ b/docs/cli/serve.md
@@ -1,59 +1,5 @@
# vllm-omni serve
-## Stage-based CLI quickstart
-
-The stage-based CLI is designed for deployments that require launching each pipeline stage in an isolated process
-(e.g., across separate operating system processes, distinct GPUs, or distributed hosts).
-
-- For **migrated models** that utilize the bundled deployment YAML configurations located in
- `vllm_omni/deploy/`, the `--deploy-config` flag is only required to override the default configuration. By default, executing `vllm serve MODEL --omni ...`
- automatically loads the bundled deployment configuration.
-- For **legacy models** utilizing configuration files located in
- `vllm_omni/model_executor/stage_configs/`, the `--stage-configs-path` parameter remains mandatory.
-
-Example: Initializing Stage 0 (Orchestrator and API Server):
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni \
- --port 8091 \
- --stage-id 0 \
- --omni-master-address 127.0.0.1 \
- --omni-master-port 26000
-```
-
-Example: Initializing a Headless Worker Stage (Stage 1):
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni \
- --stage-id 1 \
- --headless \
- --omni-master-address 127.0.0.1 \
- --omni-master-port 26000
-```
-
-When utilizing a custom deployment YAML based on the new schema, append `--deploy-config /path/to/override.yaml` to each command execution. Conversely, for legacy models, substitute this parameter with `--stage-configs-path /path/to/stage_configs.yaml`.
-
-In the standard execution paradigm, the `--stage-overrides` argument is utilized to apply stage-specific configurations from a single CLI command.
-However, under the **stage-based CLI** paradigm, where each process strictly encapsulates a single stage, it is recommended to specify tuning parameters directly via discrete command-line flags for the respective stage, rather than constructing a composite `--stage-overrides` JSON string.
-
-For example, as an alternative to the following composite configuration:
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
- --stage-overrides '{"1": {"gpu_memory_utilization": 0.5}}'
-```
-
-the stage-based CLI permits the direct initialization of Stage 1 with explicit parameters:
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni \
- --stage-id 1 \
- --headless \
- --gpu-memory-utilization 0.5 \
- --omni-master-address 127.0.0.1 \
- --omni-master-port 26000
-```
-
## JSON CLI Arguments
--8<-- "docs/cli/json_tip.inc.md"
diff --git a/docs/configuration/README.md b/docs/configuration/README.md
index 390176e9cea..b5761a7f1bc 100644
--- a/docs/configuration/README.md
+++ b/docs/configuration/README.md
@@ -6,7 +6,7 @@ For options within a vLLM Engine. Please refer to [vLLM Configuration](https://d
Currently, the main options are maintained by stage configs for each model.
-For a specific example, see the [Qwen2.5-Omni deploy config](gh-file:vllm_omni/deploy/qwen2_5_omni.yaml). The matching frozen pipeline topology lives at [vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py](gh-file:vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py).
+For specific example, please refer to [Qwen2.5-omni stage config](stage_configs/qwen2_5_omni.yaml)
For introduction, please check [Introduction for stage config](./stage_configs.md)
diff --git a/docs/configuration/pd_disaggregation.md b/docs/configuration/pd_disaggregation.md
index 9196bdb0240..1cf6189e603 100644
--- a/docs/configuration/pd_disaggregation.md
+++ b/docs/configuration/pd_disaggregation.md
@@ -11,7 +11,7 @@ deployment-specific values usually change per environment:
- connector backend and connector ports
- connector IPs or bootstrap addresses
-Start from the [default Qwen3-Omni stage config](gh-file:vllm_omni/deploy/qwen3_omni_moe.yaml)
+Start from the [default Qwen3-Omni stage config](gh-file:vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml)
and copy it to your own file, for example `qwen3_omni_pd.yaml`. Then apply the
changes below.
@@ -145,13 +145,19 @@ Compared with the default Qwen3-Omni config:
```yaml
runtime:
enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
edges:
- from: 0
to: 1
+ window_size: -1
- from: 1
to: 2
+ window_size: -1
- from: 2
to: 3
+ window_size: -1
```
## 4. Launch with your custom config
diff --git a/docs/configuration/stage_configs.md b/docs/configuration/stage_configs.md
index 45bacfb7893..95c42afcc70 100644
--- a/docs/configuration/stage_configs.md
+++ b/docs/configuration/stage_configs.md
@@ -3,210 +3,7 @@
In vLLM-Omni, the target model is separated into multiple stages, which are processed by different LLMEngines, DiffusionEngines or other types of engines. Depending on different types of stages, such as Autoregressive (AR) stage or Diffusion transformer (DiT) stage, each can choose corresponding schedulers, model workers to load with the Engines in a plug-in fashion.
!!! note
- Default deploy config YAMLs (for example, `vllm_omni/deploy/qwen2_5_omni.yaml`, `vllm_omni/deploy/qwen3_omni_moe.yaml`, and `vllm_omni/deploy/qwen3_tts.yaml`) are bundled and loaded automatically when neither `--stage-configs-path` nor `--deploy-config` is provided — the model registry resolves the right pipeline + deploy YAML by `model_type`. The bundled defaults have been verified on 1xH100 for Qwen2.5-Omni and 2xH100 for Qwen3-Omni. Models that have not yet migrated to the new schema continue to use the legacy `vllm_omni/model_executor/stage_configs/.yaml` files via `--stage-configs-path`.
-
-## New deploy schema reference
-
-The new deploy schema lives under `vllm_omni/deploy/` and is paired with a frozen `PipelineConfig` registered by the model's `pipeline.py`. Each deploy YAML has these top-level fields:
-
-| Field | Type | Required | Default | Description |
-|-------|------|----------|---------|-------------|
-| `base_config` | str (path) | optional | — | Overlay parent (relative or absolute). `stages:` / `platforms:` deep-merged by stage_id; other scalars overlay-wins. Intended for user-authored overlays; prod yamls stay flat. |
-| `async_chunk` | bool | optional | `true` | Enable chunked streaming between stages. Pin to `false` if the pipeline runs end-to-end. |
-| `connectors` | dict | optional | `null` | Named connector specs (`{name, extra}`). Referenced by each stage's `input_connectors` / `output_connectors`. See [Connector schema](#connector-schema). |
-| `edges` | list | optional | `null` | Explicit edge list for the KV transfer graph. Auto-derived from stage inputs if omitted. |
-| `stages` | list | required | — | Per-stage engine args + wiring (see [Stage fields](#stage-fields)). |
-| `platforms` | dict | optional | `null` | Keyed by `npu` / `rocm` / `xpu`, each contains a `stages:` list with per-platform overrides applied on top of the CUDA defaults. |
-| `pipeline` | str | optional | `null` | Override the auto-detected pipeline registry key (used for structural variants like `qwen2_5_omni_thinker_only`). |
-| `trust_remote_code` | bool | optional | `true` | **Pipeline-wide.** Trust HF remote code on model load; applies to every stage. |
-| `distributed_executor_backend` | str \| null | optional | `null` | **Pipeline-wide.** Distributed executor backend forwarded to vLLM (`"mp"`, `"ray"`, `"external_launcher"`). If omitted, vLLM auto-selects backend from runtime topology. |
-| `dtype` | str \| null | optional | `null` | **Pipeline-wide.** Model dtype for every stage. |
-| `quantization` | str \| null | optional | `null` | **Pipeline-wide.** Quantization method for every stage. |
-| `enable_prefix_caching` | bool | optional | `false` | **Pipeline-wide.** Prefix cache toggle applied to every stage. |
-| `enable_chunked_prefill` | bool \| null | optional | `null` | **Pipeline-wide.** Chunked prefill toggle applied to every stage. |
-| `data_parallel_size` | int | optional | `1` | **Pipeline-wide.** DP degree for every stage. |
-| `pipeline_parallel_size` | int | optional | `1` | **Pipeline-wide.** PP degree for every stage. |
-
-Note: for diffusion path, `distributed_executor_backend` currently defaults to
-`mp`, and `ray` / `external_launcher` are not fully supported yet.
-
-### Stage fields
-
-Each entry under `stages:` accepts any `StageDeployConfig` field directly (no nested `engine_args:`). Only fields whose value legitimately varies across stages live here; pipeline-wide settings (trust_remote_code, distributed_executor_backend, dtype, quantization, prefix/chunked prefill, DP/PP sizes) are declared at the top level and applied to every stage. Unknown keys fall through to `engine_extras:` and are forwarded to the engine.
-
-| Field | Type | Required | Default | Description |
-|-------|------|----------|---------|-------------|
-| `stage_id` | int | required | — | Stage identity; matched against `PipelineConfig.stages[*].stage_id`. |
-| `max_num_seqs` | int | optional | `64` | Max concurrent sequences per stage. |
-| `gpu_memory_utilization` | float | optional | `0.9` | Per-stage memory budget. |
-| `tensor_parallel_size` | int | optional | `1` | TP degree for this stage. |
-| `enforce_eager` | bool | optional | `false` | Disable CUDA graphs. |
-| `max_num_batched_tokens` | int | optional | `32768` | Prefill budget. |
-| `max_model_len` | int \| null | optional | `null` | Per-stage context length (auto-sets `VLLM_ALLOW_LONG_MAX_MODEL_LEN=1` when larger than HF default). |
-| `async_scheduling` | bool \| null | optional | `null` | Per-stage async scheduling toggle. |
-| `devices` | str | optional | `"0"` | `CUDA_VISIBLE_DEVICES`-style device list. |
-| `output_connectors` | dict \| null | optional | `null` | Keyed by `to_stage_`; values are names registered under top-level `connectors:`. |
-| `input_connectors` | dict \| null | optional | `null` | Keyed by `from_stage_`; values are names registered under top-level `connectors:`. |
-| `default_sampling_params` | dict \| null | optional | `null` | Baseline sampling params. Deep-merged with pipeline `sampling_constraints` (pipeline wins). |
-| `engine_extras` | dict | optional | `{}` | Catch-all for keys not listed above; deep-merged across overlays. Also carries per-stage overrides of pipeline-wide settings (e.g. stage-specific `dtype`). |
-
-### Connector schema
-
-Each entry under top-level `connectors:` follows this shape:
-
-```yaml
-connectors:
- :
- name: # required — class registered in vllm_omni.distributed
- extra: # optional — forwarded to the connector's __init__
- :
- ...
-```
-
-| Connector class | Use case | `extra` keys |
-|-----------------|----------|--------------|
-| `SharedMemoryConnector` | Same-host KV transfer between stages (default for bundled YAMLs). | `shm_threshold_bytes` (int, default `65536`). |
-| `MooncakeStoreConnector` | Cross-host KV transfer over TCP. Required for multi-node deployments. | `host`, `metadata_server`, `master`, `segment` (int bytes), `localbuf` (int bytes), `proto` (`"tcp"` / `"rdma"`). |
-
-A stage references a connector by name in its `input_connectors` / `output_connectors`:
-
-```yaml
-connectors:
- shm:
- name: SharedMemoryConnector
-
-stages:
- - stage_id: 0
- output_connectors: {to_stage_1: shm}
- - stage_id: 1
- input_connectors: {from_stage_0: shm}
-```
-
-### CLI flags introduced in this refactor
-
-| Flag | Description |
-|------|-------------|
-| `--deploy-config PATH` | Load a new-schema deploy YAML. Takes precedence over `--stage-configs-path`. **Optional** — when omitted, the bundled `vllm_omni/deploy/.yaml` is auto-loaded by the model registry. |
-| `--stage-overrides JSON` | Per-stage JSON overrides, e.g. `'{"0":{"gpu_memory_utilization":0.5}}'`. Per-stage values always win over global flags. |
-| `--async-chunk` / `--no-async-chunk` | Flip the deploy YAML's `async_chunk:` bool. Unset (default) leaves the YAML value in force. |
-| `--stage-configs-path` | **Deprecated.** Accepts legacy `stage_args` yamls and (auto-detected) new deploy yamls; emits a deprecation warning. Migrate to `--deploy-config`. To be removed in a follow-up PR. |
-
-### Stage-Based CLI Paradigm
-
-The stage-based CLI paradigm facilitates the execution of discrete pipeline stages within isolated processes:
-
-- **Stage 0** typically encapsulates the orchestrator and the primary API server. Invocation requires `--stage-id 0`,
- `--omni-master-address`, `--omni-master-port`, and standard port declarations (e.g., `--port`).
-- **Worker Stages** operate without a distinct API server (i.e., using `--headless`), are assigned sequential `--stage-id` identifiers, and must reference the corresponding
- `--omni-master-address` and `--omni-master-port` parameters to successfully register with Stage 0.
-
-For migrated architectures, the system automatically resolves and loads the bundled deployment YAML. Consequently, the primary execution path
-does **not** necessitate the explicit definition of `--deploy-config`:
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni \
- --port 8091 \
- --stage-id 0 \
- --omni-master-address 127.0.0.1 \
- --omni-master-port 26000
-
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni \
- --stage-id 1 \
- --headless \
- --omni-master-address 127.0.0.1 \
- --omni-master-port 26000
-```
-
-When instantiating a custom deployment YAML conforming to the updated schema, append the `--deploy-config /path/to/override.yaml` directive
-to all node invocations. For legacy architectures (e.g., BAGEL) configured via deprecated `stage_args:` schemas, continue to specify the relevant configuration via `--stage-configs-path /path/to/config.yaml`.
-
-In the context of standard initialization architectures, utilizing the `--stage-overrides` parameter operates as the optimal methodology
-for delineating stage-specific tuning from the CLI interface:
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
- --stage-overrides '{"1": {"gpu_memory_utilization": 0.5}}'
-```
-
-Conversely, in the context of the **stage-based CLI** paradigm, given that each execution process exclusively instantiates a single pipeline stage, configuration override attributes
-can be defined uniformly via explicit CLI flags on the corresponding instantiation command, rendering composite `--stage-overrides` JSON strings unnecessary:
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni \
- --stage-id 1 \
- --headless \
- --gpu-memory-utilization 0.5 \
- --omni-master-address 127.0.0.1 \
- --omni-master-port 26000
-```
-
-### Precedence
-
-From highest to lowest:
-
-1. Per-stage flags (`--stage-overrides` JSON, `--stage--` if registered)
-2. Explicit global CLI flags (`--gpu-memory-utilization 0.85`, etc.)
-3. Platform section (`platforms.npu.stages`, etc.) on top of the base `stages:`
-4. Overlay YAML (via `base_config:`) on top of the base YAML
-5. Parser defaults
-
-### Worked override example
-
-Starting from the bundled `vllm_omni/deploy/qwen3_omni_moe.yaml`:
-
-```yaml
-# vllm_omni/deploy/qwen3_omni_moe.yaml (excerpt)
-async_chunk: true
-stages:
- - stage_id: 0
- gpu_memory_utilization: 0.9
- max_num_seqs: 32
- - stage_id: 1
- gpu_memory_utilization: 0.7
- max_num_seqs: 16
-```
-
-A user-authored overlay that inherits the base and overrides only stage 1:
-
-```yaml
-# my_overrides.yaml
-base_config: /path/to/vllm_omni/deploy/qwen3_omni_moe.yaml
-stages:
- - stage_id: 1
- gpu_memory_utilization: 0.5 # smaller GPU
-```
-
-Launched with both an explicit global flag and a per-stage override:
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
- --deploy-config my_overrides.yaml \
- --max-model-len 16384 \
- --stage-overrides '{"0": {"max_num_seqs": 8}}'
-```
-
-Within the stage-based CLI paradigm, equivalent configuration parameters can inherently be passed directly
-as command-line arguments to the designated single-stage process instantiation:
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni \
- --stage-id 0 \
- --max-num-seqs 8 \
- --omni-master-address 127.0.0.1 \
- --omni-master-port 26000
-```
-
-Effective config per stage after the merge:
-
-| Stage | Field | Final value | Source |
-|-------|-------|-------------|--------|
-| 0 | `gpu_memory_utilization` | `0.9` | base YAML (overlay didn't touch stage 0) |
-| 0 | `max_num_seqs` | `8` | per-stage CLI (`--stage-overrides`) — wins over base `32` |
-| 0 | `max_model_len` | `16384` | global CLI |
-| 1 | `gpu_memory_utilization` | `0.5` | overlay YAML — wins over base `0.7` |
-| 1 | `max_num_seqs` | `16` | base YAML (overlay didn't touch this field) |
-| 1 | `max_model_len` | `16384` | global CLI |
-| 2 | (all defaults) | — | base YAML (no overrides apply) |
+ Default stage config YAMLs (for example, `vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml` and `vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml`) are bundled and loaded automatically when `stage_configs_path` is not provided. They have been verified to work on 1xH100 for Qwen2.5-Omni and 2xH100 for Qwen3-Omni.
Therefore, as a core part of vLLM-Omni, the stage configs for a model have several main functions:
@@ -216,14 +13,9 @@ Therefore, as a core part of vLLM-Omni, the stage configs for a model have sever
- Input and output dependencies for each stage.
- Default input parameters.
-To override specific parameters, explicitly inject the customized configuration schema
-in both online and offline instantiation flows. Prioritize the `--deploy-config` flag
-when loading the new-schema deploy YAML schemas, reserving the `--stage-configs-path` parameter
-exclusively to maintain compatibility with legacy `stage_args` YAML constructs.
-
-Examples:
+If users want to modify some part of it. The custom stage_configs file can be input as input argument in both online and offline. Just like examples below:
-For offline (Assume necessary dependencies have been imported):
+For offline (Assume necessary dependencies have ben imported):
```python
model_name = "Qwen/Qwen2.5-Omni-7B"
omni = Omni(model=model_name, stage_configs_path="/path/to/custom_stage_configs.yaml")
@@ -231,13 +23,7 @@ omni = Omni(model=model_name, stage_configs_path="/path/to/custom_stage_configs.
For online serving:
```bash
-vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8091 --deploy-config /path/to/deploy_config.yaml
-```
-
-Legacy online serving:
-
-```bash
-vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
+vllm serve Qwen/Qwen2.5-Omni-7B --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
```
!!! important
We are actively iterating on the definition of stage configs, and we welcome all feedbacks from both community users and developers to help us shape the development!
@@ -249,7 +35,7 @@ stage_args:
- stage_id: 0 # mark the unique id for each stage
runtime: # The disaggregated configuration
process: true # Run this stage in a separate process
- devices: "0" # Logical device index for this stage (mapped through CUDA_VISIBLE_DEVICES / ASCEND_RT_VISIBLE_DEVICES if set)
+ devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
engine_args: # Engine arguments for a certain engine
model_stage: thinker
max_num_seqs: 1
@@ -328,12 +114,16 @@ stage_args:
# Top-level runtime config (concise): default windows and stage edges
runtime:
enabled: true
-
+ defaults:
+ window_size: -1 # Simplified: trigger downstream only after full upstream completion
+ max_inflight: 1 # Simplified: process serially within each stage
edges:
- from: 0 # thinker → talker: trigger only after receiving full input (-1)
to: 1
+ window_size: -1
- from: 1 # talker → code2wav: trigger only after receiving full input (-1)
to: 2
+ window_size: -1
```
@@ -365,9 +155,7 @@ Default: `true`
#### `runtime.devices`
-Logical device indices for this stage, specified as a string. Values are **logical indices** (`0`, `1`, `2`, ...) — not physical GPU IDs — and are mapped through the platform's visibility env var (`CUDA_VISIBLE_DEVICES` on CUDA, `ASCEND_RT_VISIBLE_DEVICES` on NPU) before being applied via `torch.cuda.set_device()` (or the equivalent).
-
-Example: if `CUDA_VISIBLE_DEVICES=0,2,4` is set in the environment, then `devices: "0"` selects physical GPU 0 (the first visible), `devices: "1"` selects physical GPU 2, and `devices: "0,1"` makes physical GPUs 0 and 2 available to the stage. If no visibility env var is set, logical and physical IDs coincide.
+Visible devices for this stage, specified as a string. This controls which GPU devices are available to the stage process, similar to setting `CUDA_VISIBLE_DEVICES` or using `torch.cuda.set_device()`. For example, `"0"` uses GPU 0, `"1"` uses GPU 1, and `"0,1"` makes both GPUs 0 and 1 visible.
Default: `"0"`
diff --git a/docs/configuration/stage_configs/qwen2_5_omni.yaml b/docs/configuration/stage_configs/qwen2_5_omni.yaml
new file mode 100644
index 00000000000..690577b84a8
--- /dev/null
+++ b/docs/configuration/stage_configs/qwen2_5_omni.yaml
@@ -0,0 +1,94 @@
+# stage config for running qwen2.5-omni with AsyncOmniEngine + Orchestrator runtime.
+stage_args:
+ - stage_id: 0
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.8
+ enforce_eager: true # Now we only support eager mode
+ trust_remote_code: true
+ engine_output_type: latent
+ enable_prefix_caching: false
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+ - stage_id: 1
+ runtime:
+ process: true
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.8
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: latent
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
+ default_sampling_params:
+ temperature: 0.9
+ top_p: 0.8
+ top_k: 40
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+ stop_token_ids: [8294]
+ - stage_id: 2
+ runtime:
+ process: true
+ devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ gpu_memory_utilization: 0.15
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: audio
+ engine_input_source: [1]
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+
+# Top-level runtime config (concise): default windows and stage edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1 # Simplified: trigger downstream only after full upstream completion
+ max_inflight: 1 # Simplified: process serially within each stage
+ edges:
+ - from: 0 # thinker → talker: trigger only after receiving full input (-1)
+ to: 1
+ window_size: -1
+ - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
+ to: 2
+ window_size: -1
diff --git a/docs/contributing/ci/CI_5levels.md b/docs/contributing/ci/CI_5levels.md
index 3baa7ff8828..967d0cc6d72 100644
--- a/docs/contributing/ci/CI_5levels.md
+++ b/docs/contributing/ci/CI_5levels.md
@@ -86,8 +86,7 @@ Through five levels (L1-L5) and common (Common) specifications, the system clari
/tests/e2e/online_serving/test_{model_name}_expansion.py
/tests/e2e/offline_inference/test_{model_name}_expansion.py
Performance:
- /tests/dfx/perf/tests/test_qwen_omni.json (Omni), test_tts.json (TTS),
- and /tests/dfx/perf/tests/test_{diffusion_model}_vllm_omni.json (Diffusion)
+ /tests/dfx/perf/tests/test.json
Doc Test:
tests/example/online_serving/test_{model_name}.py
tests/example/offline_inference/test_{model_name}.py
@@ -105,8 +104,7 @@ Through five levels (L1-L5) and common (Common) specifications, the system clari
Depends on reality
Stability:
- /tests/dfx/stability/tests/test_qwen3_omni.json
- /tests/dfx/stability/tests/test_wan22.json
+ /tests/dfx/stability/tests/test.json
Reliability:
tests/e2e/reliability/test_{model_name}.py
@@ -232,7 +230,8 @@ vllm_omni/ tests/
│ ├── test_qwen3_omni_expansion.py
│ ├── test_mimo_audio.py
│ ├── test_image_gen_edit.py
- │ └── test_images_generations_lora.py
+ │ ├── test_images_generations_lora.py
+ │ └── stage_configs/
└── offline_inference/ ✅
├── test_qwen2_5_omni.py
├── test_qwen3_omni.py
@@ -243,17 +242,16 @@ vllm_omni/ tests/
├── test_zimage_tensor_parallel.py
├── test_cache_dit.py
├── test_teacache.py
- ├── test_stable_audio_expansion.py
+ ├── test_stable_audio_model.py
├── test_diffusion_cpu_offload.py
├── test_diffusion_layerwise_offload.py
├── test_diffusion_lora.py
├── test_sequence_parallel.py
- └── stage_configs/ (legacy schema, still
- ├── bagel_*.yaml present for unmigrated
- └── npu/, rocm/, etc. models)
-
-# Migrated models (qwen3_omni_moe, qwen2_5_omni, qwen3_tts) live under
-# vllm_omni/deploy/ instead — see docs/configuration/stage_configs.md.
+ └── stage_configs/
+ ├── qwen2_5_omni_ci.yaml
+ ├── qwen3_omni_ci.yaml
+ ├── bagel_*.yaml
+ └── npu/, rocm/, etc.
```
@@ -273,7 +271,7 @@ Before entering specific testing levels, the project establishes two common spec
L1 and L2 level testing form the foundation of the quality assurance system. L1 level testing focuses on verifying the internal logic correctness of code units (e.g., functions, classes), ensuring each independent component behaves as designed.
-L2 level testing builds upon L1 by introducing GPU resources and verifying that the end-to-end (E2E) process of the model in basic deployment scenarios is smooth. For example, it uses dummy models to confirm that core interfaces like the inference pipeline, output format, and streaming response work properly. The common goal of these two levels is to provide developers with rapid feedback, discovering and fixing issues early in the development cycle.
+L2 level testing builds upon L1 by introducing GPU resources and verifying that the end-to-end (E2E) process of the model in basic deployment scenarios is smooth. For example, it uses dummy models to confirm that core interfaces like the inference pipeline, output format, and streaming response work properly. The common goal of these two levels is to provide developers with rapid feedback, discovering and fixing issues early in the development cycle .
@@ -419,13 +417,13 @@ L3 level testing executes after code is merged into the main branch. Its core pu
**Explanation**:
- @pytest.mark.advanced_model: Marks the test as L3 merge level, indicating deep validation with real models. @pytest.mark.full_model: Marks L4 nightly-only suites (e.g. `test_*_expansion.py`, doc examples).
+ @pytest.mark.advanced_model: Marks the test as L3 or L4 level, indicating that this test case performs deep validation, using real models for performance, integration, and accuracy testing. This forms a "basic-advanced" correspondence with the core_model mark at the L2 level.
@pytest.mark.core_model: Marks the test as L1 or L2 level, indicating that this test case validates the basic functionality of the core model. It uses mock weights and only checks if the relevant interface functions correctly.
@pytest.mark.parametrize: A parameterization decorator that allows abstracting test data into parameters, enabling reuse of the same test logic across different data configurations. indirect=True indicates that parameters will be passed to the fixture for processing.
- **Notes**: If you believe the test case only needs to execute basic run logic at the PR-level CI, you can mark it only with @pytest.mark.core_model. If you believe it only needs to execute deep validation at merge (L3), use @pytest.mark.advanced_model. For L4 nightly-only expansion and doc-example tests, use @pytest.mark.full_model with `--run-level full_model`. If the test case needs both basic run and deep validation, mark with @pytest.mark.core_model and the appropriate L3/L4 marker (`advanced_model` and/or `full_model`).
+ **Notes**: If you believe the test case only needs to execute basic run logic at the PR-level CI, you can mark it only with @pytest.mark.core_model. If you believe it only needs to execute deep validation run logic at the merge or nightly level, you can mark it only with @pytest.mark.advanced_model. If you believe the test case needs to accommodate both basic run and deep validation test logic, you should mark it with both @pytest.mark.core_model and @pytest.mark.advanced_model.
**2.4.2 Test Function Definition and Documentation**
@@ -517,11 +515,9 @@ L3 level testing executes after code is merged into the main branch. Its core pu
**Single Request**: The comment clearly states this is a single-request completion test. For concurrent testing, it can be extended to multiple requests using request_num = n.
- **Implicit Validation**: The `send_omni_request` and `send_diffusion_request` methods internally includes validation logic dynamically selected based on the --run-level parameter: core_model performs basic validation, while advanced_model and full_model perform deep validation.
-
-- ***Run Command (L3 merge)***: `pytest -s -v /tests/e2e/online_serving/test_{model_name}.py -m advanced_model --run-level=advanced_model`
+ **Implicit Validation**: The `send_omni_request` and `send_diffusion_request` methods internally includes validation logic dynamically selected based on the --run-level parameter: core_model performs basic validation, while advanced_model performs deep validation.
-- ***Run Command (L4 nightly expansion)***: `pytest -s -v /tests/e2e/online_serving/test_{model_name}_expansion.py -m full_model --run-level=full_model`
+- ***Run Command***: `pytest -s -v /tests/e2e/online_serving/test_{model_name}.py -m advanced_model --run-level=advanced_model`
## Chapter 3: L4 Level Testing - Full Functionality, Performance, and Documentation Testing
@@ -534,13 +530,13 @@ L4 level testing is a comprehensive quality audit before a version release. It e
### 3.2 Testing Content and Scope
- ***Full Functionality Testing***: Executes all test cases defined in `test_{model_name}_expansion.py`, covering all implemented features, positive flows, boundary conditions, and exception handling.
-- ***Performance Testing***: Uses `tests/dfx/perf/tests/test_qwen_omni.json`, `tests/dfx/perf/tests/test_tts.json`, and diffusion configs in the form `tests/dfx/perf/tests/test_*_vllm_omni.json` (passed to `run_benchmark.py` via `--test-config-file`) to drive performance testing tools for stress, load, and endurance tests, collecting metrics like throughput, response time, and resource utilization.
+- ***Performance Testing***: Uses the `tests/dfx/perf/tests/test.json` configuration file to drive performance testing tools for stress, load, and endurance tests, collecting metrics like throughput, response time, and resource utilization.
- ***Documentation Testing***: Verifies whether the example code provided to users is runnable and its results match the description.
### 3.3 Test Directory and Execution Files
- ***Functional Testing***: Same directories as L3.
-- ***Performance Test Configuration***: `tests/dfx/perf/tests/test_qwen_omni.json`, `tests/dfx/perf/tests/test_tts.json`, and diffusion configs `tests/dfx/perf/tests/test_*_vllm_omni.json` (e.g. `test_qwen_image_vllm_omni.json`)
+- ***Performance Test Configuration***: `tests/dfx/perf/tests/test.json`
- ***Documentation Example Tests***:
- - `tests/example/online_serving/test_{model_name}.py`
- `tests/example/offline_inference/test_{model_name}.py`
@@ -575,12 +571,12 @@ L5 level testing focuses on the performance of model services under ***long-runn
### 4.2 Testing Content and Scope
-- ***Long-term Stability (Stability) Testing***: Uses JSON under `tests/dfx/stability/tests/` (for example `test_qwen3_omni.json` and `test_wan22.json`) to run the service under moderate load for an extended period (e.g., over 12 hours), monitoring whether metrics like memory/VRAM usage, response time, and throughput degrade over time, and whether the service process remains stable.
+- ***Long-term Stability (Stability) Testing***: Uses the `tests/dfx/stability/tests/test.json` configuration to run the service under moderate load for an extended period (e.g., over 12 hours), monitoring whether metrics like memory/VRAM usage, response time, and throughput degrade over time, and whether the service process remains stable.
- ***Reliability Testing***: Uses `tests/e2e/reliability/test_{model_name}.py` to actively simulate various fault and abnormal scenarios, such as: dependent service interruption, abnormal input data, network flicker, hardware resource preemption, etc., to verify the system's fault tolerance, self-healing, and graceful degradation capabilities.
### 4.3 Test Directory and Execution Files
-- ***Stability Test Configuration***: `tests/dfx/stability/tests/test_qwen3_omni.json`, `tests/dfx/stability/tests/test_wan22.json` (one JSON per model / runner family)
+- ***Stability Test Configuration***: `tests/dfx/stability/tests/test.json`
- ***Reliability Test Suite***: `tests/e2e/reliability/test_{model_name}.py`
### 4.4 Execution Method and Example
@@ -591,7 +587,7 @@ L5 level testing focuses on the performance of model services under ***long-runn
Test Examples
-When you want to add L5-level stability test cases, add or extend the appropriate JSON file under `tests/dfx/stability/tests/` (for example `test_qwen3_omni.json` for Omni bench traffic, or `test_wan22.json` for diffusion `/v1/videos` workloads). The following illustrates the Qwen3-Omni shape:
+When you want to add L5-level stability test cases, you can refer to the following format for case addition in `tests/dfx/stability/tests/test.json`:
```json
{
@@ -662,7 +658,7 @@ All other optional parameters follow the same rules as the in Chapter 3.4.
-- - ***Stability***: `pytest -s -v tests/dfx/stability/scripts/test_stability_qwen3_omni.py` or `pytest -s -v tests/dfx/stability/scripts/test_stability_wan22.py` (or add `test_stability_.py` alongside a matching JSON config)
+- - ***Stability***: `pytest -s -v tests/dfx/stability/scripts/test_{model_name}.py`
- ***Reliability***: `pytest -s -v tests/e2e/reliability/test_{model_name}.py`
## Summary
diff --git a/docs/contributing/ci/test_examples/l4_functionality_tests.inc.md b/docs/contributing/ci/test_examples/l4_functionality_tests.inc.md
index e1309b1adeb..69d6ad82871 100644
--- a/docs/contributing/ci/test_examples/l4_functionality_tests.inc.md
+++ b/docs/contributing/ci/test_examples/l4_functionality_tests.inc.md
@@ -37,10 +37,10 @@ Currently all the features are available in online serving mode. Hence, only nee
**Code Style**
- Validation: test that the multimodal output files of your model have the correct shapes. `OpenAIClientHandler.send_diffusion_request` should have taken care of this.
-- Test marks: always add `full_model` and `diffusion` for L4 nightly `test_*_expansion.py` cases. Add GPU-related marks if needed. Ref: [Markers for Tests](https://docs.vllm.ai/projects/vllm-omni/en/latest/contributing/ci/tests_markers/).
+- Test marks: always add `advanced_model` and `diffusion`. Add GPU-related marks if needed. Ref: [Markers for Tests](https://docs.vllm.ai/projects/vllm-omni/en/latest/contributing/ci/tests_markers/).
- To maximize code reuse, you may refer to
- `tests/conftest.py` for `omni_server` (running server in subprocess) and `openai_client` fixtures (sending requests and validating output), `generate_synthetic_image` and `assert_XXX_valid` helper.
- - `tests/helpers/mark.py` for `@hardware_test(...)` and `hardware_marks`.
+ - `tests/utils.py` for `@hardware_test(...)` and `hardware_marks`.
- [Parametrizing tests (pytest doc)](https://docs.pytest.org/en/stable/example/parametrize.html) to reuse test function implementation for different cases.
- Doc: add a concise docstring for each test function.
- Reference L4 test implementation: [tests/e2e/online_serving/test_qwen_image_edit_expansion.py](https://github.com/vllm-project/vllm-omni/blob/main/tests/e2e/online_serving/test_qwen_image_edit_expansion.py).
diff --git a/docs/contributing/ci/test_examples/l4_performance_tests.inc.md b/docs/contributing/ci/test_examples/l4_performance_tests.inc.md
index f1f3073dc52..8093e1459f5 100644
--- a/docs/contributing/ci/test_examples/l4_performance_tests.inc.md
+++ b/docs/contributing/ci/test_examples/l4_performance_tests.inc.md
@@ -1,4 +1,4 @@
-When you want to add L4-level ***performance test*** cases, you can refer to the following format for case addition in `tests/dfx/perf/tests/test_qwen_omni.json`, `tests/dfx/perf/tests/test_tts.json`, or diffusion configs such as `tests/dfx/perf/tests/test_*_vllm_omni.json` (selected via `pytest ... run_benchmark.py --test-config-file `):
+When you want to add L4-level ***performance test*** cases, you can refer to the following format for case addition in tests/dfx/perf/tests/test.json:
```JSON
{
diff --git a/docs/contributing/ci/test_guide.md b/docs/contributing/ci/test_guide.md
index 018c47b053f..425f24332c2 100644
--- a/docs/contributing/ci/test_guide.md
+++ b/docs/contributing/ci/test_guide.md
@@ -42,63 +42,32 @@ Our test scripts use the pytest framework. First, please use `git clone https://
```
The latest test commands for various test suites can be found in the [pipeline](https://github.com/vllm-project/vllm-omni/blob/main/.buildkite/test-ready.yml).
-=== "L3 level"
+=== "L3 level & L4 level"
```bash
+ cd tests
pytest -s -v -m "advanced_model" --run-level=advanced_model
```
- If you only want to run a specific test case, you can use:
- ```bash
- pytest -s -v test_xxxx.py --run-level=advanced_model
- ```
- If you only want to run specific test cases on a particular platform, you can use:
- ```bash
- pytest -s -v -m "advanced_model and distributed_cuda and L4" --run-level=advanced_model
- ```
- The latest L3 test commands for various test suites can be found in the [pipeline](https://github.com/vllm-project/vllm-omni/blob/main/.buildkite/test-merge.yml).
-
-
-=== "L4 level"
-
+ If you only want to run L3 test case, you can use:
```bash
- cd tests
- pytest -s -v -m "full_model" --run-level=full_model
+ pytest -s -v e2e/ --ignore-glob='*expansion.py' -m "advanced_model" --run-level=advanced_model
```
If you only want to run a specific test case, you can use:
```bash
- pytest -s -v test_xxxx.py --run-level=full_model
+ pytest -s -v test_xxxx.py --run-level=advanced_model
```
If you only want to run specific test cases on a particular platform, you can use:
```bash
- pytest -s -v -m "full_model and distributed_cuda and L4" --run-level=full_model
+ pytest -s -v -m "core_model and distributed_cuda and L4" --run-level=core_model
```
- Note: To run performance tests (defaults to ``test_qwen_omni.json``; use ``--test-config-file tests/dfx/perf/tests/test_tts.json`` for TTS):
+ Note: To run performance tests, use:
```bash
- pytest -s -v tests/dfx/perf/scripts/run_benchmark.py
+ pytest -s -v perf/scripts/run_benchmark.py
```
- The latest L4 (nightly) test commands use the `full_model` marker and `--run-level full_model` (see [test-nightly.yml](https://github.com/vllm-project/vllm-omni/blob/main/.buildkite/test-nightly.yml) and [test-nightly-diffusion.yml](https://github.com/vllm-project/vllm-omni/blob/main/.buildkite/test-nightly-diffusion.yml)). Example:
-
- ```bash
- cd tests
- pytest -s -v -m "full_model and omni and H100" --run-level=full_model
- ```
-
-=== "L5 level"
-
- L5 includes stability and reliability testing. Typical commands:
- ```bash
- cd tests
-
- # Stability: Qwen3-Omni
- pytest -s -v dfx/stability/scripts/test_stability_qwen3_omni.py
-
- # Stability: Wan2.2 (v1/videos diffusion benchmark loop)
- pytest -s -v dfx/stability/scripts/test_stability_wan22.py
-
- ```
+ The latest L3 test commands for various test suites can be found in the [pipeline](https://github.com/vllm-project/vllm-omni/blob/main/.buildkite/test-merge.yml).
- The latest L5 commands for CI can be found in the [pipeline](https://github.com/vllm-project/vllm-omni/blob/main/.buildkite/test-ready.yml).
+ The latest L4 test commands for various test suites can be found in the [pipeline](https://github.com/vllm-project/vllm-omni/blob/main/.buildkite/test-nightly.yml).
You can find more information about markers in the documentation: [marker doc](./tests_markers.md)
diff --git a/docs/contributing/ci/tests_markers.md b/docs/contributing/ci/tests_markers.md
index 6130541a617..7c1ba1c73bd 100644
--- a/docs/contributing/ci/tests_markers.md
+++ b/docs/contributing/ci/tests_markers.md
@@ -8,8 +8,7 @@ Defined in `pyproject.toml`:
| Marker | Description |
| ------------------ | --------------------------------------------------------- |
| `core_model` | L1&L2 tests (run in each PR) |
-| `advanced_model` | L3 tests (run on each merge to main) |
-| `full_model` | L4 tests (run nightly) |
+| `advanced_model` | L3&L4 level tests (run in each merge or nightly) |
| `diffusion` | Diffusion model tests |
| `omni` | Omni model tests |
| `cache` | Cache backend tests |
@@ -39,7 +38,7 @@ Defined in `pyproject.toml`:
### Example usage for markers
```python
-from tests.helpers.mark import hardware_test
+from tests.utils import hardware_test
@pytest.mark.core_model
@pytest.mark.omni
@@ -54,7 +53,7 @@ def test_video_to_audio()
### Decorator: `@hardware_test`
-This decorator is intended to make hardware-aware, cross-platform test authoring easier and more robust for CI/CD environments. The `hardware_test` decorator in `vllm-omni/tests/helpers/mark.py` performs the following actions:
+This decorator is intended to make hardware-aware, cross-platform test authoring easier and more robust for CI/CD environments. The `hardware_test` decorator in `vllm-omni/tests/utils.py` performs the following actions:
1. **Applies platform and resource markers**
Adds the appropriate pytest markers for each specified hardware platform (e.g., `cuda`, `rocm`, `xpu`, `npu`) and resource type (e.g., `L4`, `H100`, `MI325`, `B60`, `A2`, `A3`).
@@ -106,7 +105,7 @@ This decorator is intended to make hardware-aware, cross-platform test authoring
`hardware_marks` returns a list of pytest mark objects with the same signature as `@hardware_test`. Use it when you need more flexibility, such as attaching hardware marks to individual `pytest.param` entries rather than an entire test function.
```python
-from tests.helpers.mark import hardware_marks
+from tests.utils import hardware_marks
MULTI_CARD_MARKS = hardware_marks(
res={"cuda": "H100", "rocm": "MI325", "npu": "A2"}, num_cards=2
@@ -134,9 +133,9 @@ If you want to add support for a new platform (e.g., "tpu" for a new accelerator
"distributed_tpu: Tests that require multiple TPU devices",
]
```
-2. **Implement a marker construction function for your platform** in `vllm-omni/tests/helpers/mark.py`:
+2. **Implement a marker construction function for your platform** in `vllm-omni/tests/utils.py`:
```python
- # In vllm-omni/tests/helpers/mark.py
+ # In vllm-omni/tests/utils.py
def tpu_marks(*, res: str, num_cards: int):
test_platform = pytest.mark.tpu
@@ -176,4 +175,4 @@ If you want to add support for a new platform (e.g., "tpu" for a new accelerator
- Plug into `hardware_marks`
- You're done: tests using `@hardware_test` or `hardware_marks` with your platform now automatically get the correct markers, distribution, and isolation!
-See code in `vllm-omni/tests/helpers/mark.py` for existing examples (`cuda_marks`, `rocm_marks`, `npu_marks`).
+See code in `vllm-omni/tests/utils.py` for existing examples (`cuda_marks`, `rocm_marks`, `npu_marks`).
diff --git a/docs/contributing/ci/tests_style.md b/docs/contributing/ci/tests_style.md
index 3a8cb0f127c..8b10cf4cc1c 100644
--- a/docs/contributing/ci/tests_style.md
+++ b/docs/contributing/ci/tests_style.md
@@ -135,7 +135,8 @@ vllm_omni/ tests/
│ ├── test_qwen3_omni_expansion.py
│ ├── test_mimo_audio.py
│ ├── test_image_gen_edit.py
- │ └── test_images_generations_lora.py
+ │ ├── test_images_generations_lora.py
+ │ └── stage_configs/
└── offline_inference/ ✅
├── test_qwen2_5_omni.py
├── test_qwen3_omni.py
@@ -146,18 +147,17 @@ vllm_omni/ tests/
├── test_zimage_tensor_parallel.py
├── test_cache_dit.py
├── test_teacache.py
- ├── test_stable_audio_expansion.py
+ ├── test_stable_audio_model.py
├── test_diffusion_cpu_offload.py
├── test_diffusion_layerwise_offload.py
├── test_diffusion_lora.py
├── test_sequence_parallel.py
├── test_qwen_image_edit_expansion.py
- └── stage_configs/ (legacy schema, still present
- ├── bagel_*.yaml for unmigrated models)
+ └── stage_configs/
+ ├── qwen2_5_omni_ci.yaml
+ ├── qwen3_omni_ci.yaml
+ ├── bagel_*.yaml
└── npu/, rocm/, etc.
-
-# Migrated models (qwen3_omni_moe, qwen2_5_omni, qwen3_tts) live under
-# vllm_omni/deploy/ instead — see docs/configuration/stage_configs.md.
examples/ tests
│ └── examples
├── online_serving/ → ├── online_serving/
@@ -221,13 +221,14 @@ from pathlib import Path
import openai
import pytest
-from tests.helpers.media import (
- convert_audio_bytes_to_text,
+from tests.conftest import (
+ OmniServer,
+ convert_audio_to_text,
cosine_similarity_text,
+ dummy_messages_from_mix_data,
generate_synthetic_video,
+ merge_base64_and_convert_to_text,
)
-from tests.helpers.runtime import OmniServer, dummy_messages_from_mix_data
-from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
from vllm_omni.platforms import current_omni_platform
# Edit: model name and stage config path
@@ -235,7 +236,7 @@ models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
#If you use the default configuration file, you can directly use the following address.
def get_default_config():
- return get_deploy_config_path("ci/qwen3_omni_moe.yaml")
+ return str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml")
#If you need to modify the configuration file, you can use modify_stage_config.
def get_chunk_config():
@@ -404,7 +405,7 @@ def test_mix_to_text_audio_001(client: openai.OpenAI, omni_server, request) -> N
# PURPOSE: Verify text and audio outputs convey the same information
# CUSTOMIZATION: Adjust similarity threshold (0.9) based on accuracy requirements
assert audio_data is not None, "No audio output is generated"
- audio_content = convert_audio_bytes_to_text(audio_data)
+ audio_content = merge_base64_and_convert_to_text(audio_data)
print(f"text content is: {text_content}")
print(f"audio content is: {audio_content}")
similarity = cosine_similarity_text(audio_content.lower(), text_content.lower())
@@ -427,7 +428,7 @@ from pathlib import Path
import pytest
from vllm.assets.video import VideoAsset
-from tests.helpers.mark import hardware_test
+from tests.utils import hardware_test
from ..multi_stages.conftest import OmniRunner
# Optional: set process start method for workers
diff --git a/docs/contributing/model/adding_diffusion_model.md b/docs/contributing/model/adding_diffusion_model.md
index 6d5782a6e3c..dfa550173cf 100644
--- a/docs/contributing/model/adding_diffusion_model.md
+++ b/docs/contributing/model/adding_diffusion_model.md
@@ -802,7 +802,7 @@ omni = Omni(model="your-model", enable_layerwise_offload=True)
```python
class WanTransformer3DModel(nn.Module):
- _layerwise_offload_blocks_attrs = ["blocks"] # Attribute name containing transformer blocks
+ _layerwise_offload_blocks_attr = "blocks" # Attribute name containing transformer blocks
def __init__(self):
self.blocks = nn.ModuleList([...]) # Transformer blocks
@@ -813,16 +813,16 @@ class WanTransformer3DModel(nn.Module):
---
-### Diffusion Pipeline Profiler (Performance Profiling)
+### Diffusion Timing (Performance Profiling)
When adapting a new diffusion model, it is often useful to analyze the latency of key components such as text encoding, diffusion denoising, and VAE decoding.
vLLM-Omni provides a timing utility via `DiffusionPipelineProfilerMixin` to help developers quickly identify performance bottlenecks.
!!! info
- `DiffusionPipelineProfilerMixin` is different from using `torch.profiler` for diffusion models, as introduced in this [tutorial](https://github.com/vllm-project/vllm-omni/blob/main/docs/contributing/profiling.md). `DiffusionPipelineProfilerMixin` only prints the timing information of multiple functions (such as `vae.decode`), while `torch.profiler` saves detailed GPU/CPU computation time, call/execution steps.
+ `DiffusionPipelineProfilerMixin` is different from using `torch.profiler` for diffusion models, as introduced in this [tutorial](https://github.com/vllm-project/vllm-omni/blob/main/docs/contributing/profiling.md#3-profiling-diffusion-models). `DiffusionPipelineProfilerMixin` only prints the timing information of multiple functions (such as `vae.decode`), while `torch.profiler` saves detailed GPU/CPU computation time, call/execution steps.
This tool automatically measures the execution time of selected pipeline modules and prints the results in the logs.
-**Enabling Diffusion Pipeline Profiler**
+**Enabling Diffusion Timing**
Enable timing by setting:
@@ -843,7 +843,7 @@ If not specified, the default targets are used:
**Adding DiffusionPipelineProfilerMixin to a Pipeline**
To enable timing support in your pipeline, inherit from DiffusionPipelineProfilerMixin.
```python
-from vllm_omni.diffusion.profiler import DiffusionPipelineProfilerMixin
+from vllm_omni.diffusion.utils.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
class YourModelPipeline(nn.Module, DiffusionPipelineProfilerMixin):
# Optional: Specify custom timing targets
@@ -862,9 +862,7 @@ class YourModelPipeline(nn.Module, DiffusionPipelineProfilerMixin):
...
# initialize timing profiler
- self.setup_diffusion_pipeline_profiler(
- enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler
- )
+ self.setup_diffusion_pipeline_profiler(enable_diffusion_pipeline_profiler)
```
The mixin dynamically wraps selected methods and records their execution time during inference.
@@ -908,9 +906,9 @@ tokenizer.forward
When enabled, timing logs appear like this:
```
-[DiffusionPipelineProfiler] text_encoder.forward took 0.018s
-[DiffusionPipelineProfiler] diffuse took 2.412s
-[DiffusionPipelineProfiler] vae.decode took 0.063s
+[DiffusionTiming] text_encoder.forward took 0.018s
+[DiffusionTiming] diffuse took 2.412s
+[DiffusionTiming] vae.decode took 0.063s
```
These measurements help identify bottlenecks during model adaptation and optimization
diff --git a/docs/contributing/model/adding_omni_model.md b/docs/contributing/model/adding_omni_model.md
index 1eaff10596c..a0619e33811 100644
--- a/docs/contributing/model/adding_omni_model.md
+++ b/docs/contributing/model/adding_omni_model.md
@@ -313,7 +313,7 @@ The registry uses lazy loading, so the model class is imported only when needed.
## Stage Configuration
-Create a YAML configuration file in `vllm_omni/deploy/`. For a complete example, see the [Qwen3-Omni configuration file](gh-file:vllm_omni/deploy/qwen3_omni_moe.yaml).
+Create a YAML configuration file in `vllm_omni/model_executor/stage_configs/`. For a complete example, see the [Qwen3-Omni configuration file](gh-file:vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml).
### Key Configuration Fields
@@ -408,17 +408,18 @@ Understanding the data structures is crucial for implementing stage transitions:
**Input to your function:**
- `stage_list[source_stage_id].engine_outputs`: List of `EngineCoreOutput` objects
-- - Each contains `outputs`: List of `RequestOutput` objects
- - Each `RequestOutput` has:
-- - - `token_ids`: Generated token IDs
- - `multimodal_output`: Dict with keys like `"code_predictor_codes"`, etc.These are the hidden states or intermediate outputs from the model's forward pass
- - `prompt_token_ids`: Original prompt token IDs
+ - Each contains `outputs`: List of `RequestOutput` objects
+ - Each `RequestOutput` has:
+ - `token_ids`: Generated token IDs
+ - `multimodal_output`: Dict with keys like `"code_predictor_codes"`, etc.
+ - These are the hidden states or intermediate outputs from the model's forward pass
+ - `prompt_token_ids`: Original prompt token IDs
**Output from your function:**
- Must return `list[OmniTokensPrompt]` where each `OmniTokensPrompt` contains:
-- - `prompt_token_ids`: List[int] - Token IDs for the next stage
- - `additional_information`: Dict[str, Any] - Optional metadata (e.g., embeddings, hidden states)
- - `multi_modal_data`: Optional multimodal data if needed
+ - `prompt_token_ids`: List[int] - Token IDs for the next stage
+ - `additional_information`: Dict[str, Any] - Optional metadata (e.g., embeddings, hidden states)
+ - `multi_modal_data`: Optional multimodal data if needed
### How Model Outputs Are Stored
@@ -613,7 +614,7 @@ For a complete reference implementation, see:
- **Thinker**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py`
- **Talker**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py`
- **Code2Wav**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py`
-- **Stage config**: `vllm_omni/deploy/qwen3_omni_moe.yaml`
+- **Stage config**: `vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml`
- **Input processors**: `vllm_omni/model_executor/stage_input_processors/qwen3_omni.py`
- **Registry**: `vllm_omni/model_executor/models/registry.py`
- **Testing**: `vllm_omni/tests/e2e/offline_inference/test_qwen3_omni.py`
diff --git a/docs/contributing/model/adding_tts_model.md b/docs/contributing/model/adding_tts_model.md
index 34fd2dbb503..e48ae5049ff 100644
--- a/docs/contributing/model/adding_tts_model.md
+++ b/docs/contributing/model/adding_tts_model.md
@@ -1,93 +1,20 @@
# Adding a TTS Model
-This guide walks through adding a new TTS model to vLLM-Omni. Two patterns are
-supported:
-
-- **Two-stage pipeline** (e.g. Qwen3-TTS, Fish Speech): an AR code-predictor stage
- feeds an audio decoder stage via the `async_chunk` framework. This is the standard
- pattern for maximum streaming performance.
-- **Single-stage AR model** (e.g. MOSS-TTS-Nano): the model runs entirely inside one
- AR worker and streams audio chunks directly from its own `inference_stream()` generator.
-
-Qwen3-TTS is used as the reference for the two-stage pattern. For the single-stage
-pattern, refer to MOSS-TTS-Nano.
+This guide walks through adding a new TTS model to vLLM-Omni, using **Qwen3-TTS**
+as a reference. Qwen3-TTS demonstrates the standard two-stage TTS pipeline and the
+key optimizations all TTS models in this repo should follow.
## Table of Contents
1. [Overview](#overview)
-2. [Cross-Cutting Invariants](#cross-cutting-invariants)
-3. [Directory Structure](#directory-structure)
-4. [Step-by-Step Implementation](#step-by-step-implementation)
-5. [Key Components](#key-components)
-6. [Model Registration](#model-registration)
-7. [Stage Configuration](#stage-configuration)
-8. [Stage Input Processors](#stage-input-processors)
-9. [Online Serving Integration](#online-serving-integration)
-10. [Single-Stage Models](#single-stage-models)
-11. [Testing](#testing)
-12. [Pre-commit and DCO](#pre-commit-and-dco)
-13. [Summary](#summary)
-
-## Cross-Cutting Invariants
-
-These rules apply to every TTS model regardless of architecture (AR vs AR+diffusion,
-single-stage vs two-stage, codec-based vs VAE-based). Each has surfaced as a silent
-bug in a shipped PR — check them at the end of every phase, not just at the start.
-
-**I1. Streaming output contract.** Pick one per-step semantics for `forward()` and
-document it in the docstring:
-
-- *Delta*: yield only new audio samples produced this step. Preferred — linear cost.
-- *Cumulative*: re-decode from step 0 every call. O(N²); only acceptable when the
- codec exposes no streaming decode.
-
-If you choose delta, audit the full chain: `forward()` returns the new chunk →
-`_consolidate_multimodal_tensors()` in `vllm_omni/engine/output_processor.py`
-concatenates the audio key into a single tensor at finish → streaming consumers
-receive per-step chunks, offline consumers receive the concatenated tensor. A
-mismatch (consolidator skips the key with `continue`, or consumers expect a list
-but receive a tensor) is invisible in offline RTF benchmarks — users hear replays
-or truncation only under live playback.
-
-**I2. Multimodal output consumer hygiene.** `outputs[0].outputs[0].multimodal_output[key]`
-can be `Tensor`, `list[Tensor]` (pre-consolidation snapshot), `np.ndarray`, or
-scalar. In every test, example, and benchmark:
-
-- Never write `dict.get("a") or dict.get("b")` on tensor values — Python evaluates
- the tensor's truthiness and raises `Boolean value of Tensor with more than one
- value is ambiguous`. Use explicit `if x is None` chains.
-- Defensively handle the list form:
- `if isinstance(x, list): x = torch.cat([t.reshape(-1) for t in x], dim=0)`.
-- Assert `shape` / `dtype` / `duration` explicitly — do not rely on truthiness for
- presence checks.
-
-**I3. Hot-loop GPU discipline.** Inside any per-step model loop (AR decode,
-diffusion solver, CFM Euler step, per-frame vocoder):
-
-- No `tensor.item()`, `.cpu()`, or `.tolist()` — each triggers a GPU→CPU sync; a
- 10-step × 60-frame × 4-op loop creates 2400 syncs per request.
-- Prefer `dst.copy_(src)` over `dst.fill_(src.item())` for scalar-into-buffer writes.
-- Whole-model `torch.compile(Model.forward, fullgraph=False)` usually outperforms
- per-submodule compile — fewer dispatch boundaries, larger fusion regions. Measure
- before choosing granularity.
-- No Python control flow that depends on tensor values; use `torch.where` or masking.
-
-Profile before optimizing.
-
-**I4. Validation pyramid.** Offline RTF alone is necessary but not sufficient. A
-new TTS model must pass all three levels:
-
-| Layer | Catches | Tool |
-|-------|---------|------|
-| Offline RTF / duration | Throughput regressions, missing audio, wrong sample rate | `end2end.py`, pytest e2e |
-| Browser streaming playback | Delta-vs-cumulative bugs, chunk boundary glitches, TTFP regressions | Gradio demo over `/v1/audio/speech?stream=true` |
-| Concurrent requests | Per-request state leaks, codec window round-robin gaps | `max_num_seqs>1` smoke with 4+ parallel prompts |
-
-**I5. Per-request state belongs to the request.** If the model caches anything
-across `forward()` calls (streaming generators, codec buffers, sliding-window pads,
-CUDA graph state), key it by `info.get("_omni_req_id")` and free the entry on
-request finish. A shared buffer silently corrupts audio across concurrent requests —
-the symptom is crosstalk or truncation under load, nothing in single-request tests.
+2. [Directory Structure](#directory-structure)
+3. [Step-by-Step Implementation](#step-by-step-implementation)
+4. [Key Components](#key-components)
+5. [Model Registration](#model-registration)
+6. [Stage Configuration](#stage-configuration)
+7. [Stage Input Processors](#stage-input-processors)
+8. [Testing](#testing)
+9. [Summary](#summary)
## Overview
@@ -101,7 +28,7 @@ and can be placed on different devices. Qwen3-TTS has two stages:
Each stage is a separate model class configured independently via YAML. The two stages
are connected by the `async_chunk` framework, which enables inter-stage streaming for
-low first-packet latency (see [Async Chunk Design](../../design/feature/async_chunk.md)).
+low first-packet latency (see [Async Chunk Design](../../design/feature/async_chunk_design.md)).
### Without async_chunk (batch mode)
@@ -193,18 +120,8 @@ vllm_omni/model_executor/stage_configs/
| `models/qwen3_tts/qwen3_tts.py` | Unified model class |
| `models/qwen3_tts/qwen3_tts_code_predictor_vllm.py` | Stage 0 - optimized AR |
| `models/qwen3_tts/qwen3_tts_code2wav.py` | Stage 1 - decoder |
-| `deploy/qwen3_tts.yaml` (new schema) | Deploy config (async_chunk enabled) — paired with `models/qwen3_tts/pipeline.py` for the frozen topology |
-
-> **Chunked vs end-to-end modes**: `qwen3_tts` registers a single
-> pipeline whose stage 1 declares alternate processor functions — an
-> `async_chunk_process_next_stage_input_func` (per-chunk streaming, used
-> when `deploy.async_chunk=True`) and a `sync_process_input_func`
-> (batch-end, used when `deploy.async_chunk=False`). The loader selects
-> one at merge time based on the bool, so `--no-async-chunk` alone
-> switches modes — no variant yaml or variant pipeline registration is
-> needed. Pipelines that only make sense in one mode (e.g.
-> `qwen3_omni_moe` is always chunked) can keep using the unconditional
-> `custom_process_*` fields.
+| `stage_configs/qwen3_tts.yaml` | Stage config (async_chunk enabled) |
+| `stage_configs/qwen3_tts_batch.yaml` | Batch mode config |
| `stage_input_processors/qwen3_tts.py` | Stage transition processors |
## Step-by-Step Implementation
@@ -629,302 +546,6 @@ Recommended test cases for a new TTS model:
Reference test: `tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py`
-### E2E Online Serving Tests (`tests/e2e/online_serving/test_.py`)
-
-The `omni_server` fixture in `tests/conftest.py` is **module-scoped**. Each distinct
-`OmniServerParams` id in the same test file forces the fixture to tear the server
-down and spawn a new one mid-module. A few rules that save real CI debugging time:
-
-- **Prefer a single `OmniServerParams` set per file.** If you need to exercise two
- deploy variants (e.g. `model.yaml` and `model_async_chunk.yaml`), either use one
- variant and exercise streaming via request args, or split into two test files so
- each file does exactly one server lifecycle. Mid-module teardown/restart is the
- fragile path and surfaces startup races first.
-- **Never depend on server-side fetching of external URLs** for reference audio or
- other fixture data. CI runners (and China-hosted dev boxes) routinely fail on
- SSL/DNS for public URLs. Inline the payload as a `data:audio/wav;base64,...`
- ref_audio value — the serving layer accepts both forms.
-- **Don't roll your own readiness probe.** The harness already waits for HTTP 200
- on `/health` before releasing the server to the test. If your model needs extra
- warmup signals, expose them through `/health` rather than adding `time.sleep(...)`
- inside the test. (Bare TCP `connect_ex` probes were insufficient; see
- `tests/conftest.py` `OmniServer.wait_for_ready`.)
-- **Use `core_model` marker + H100 hardware_test** to match the `test-ready.yml`
- pipeline so your test is picked up by the `ready` label, not only nightly.
-
-## Online Serving Integration
-
-To expose your model through the `/v1/audio/speech` OpenAI-compatible endpoint, add
-**all five** of the following integration points to
-`vllm_omni/entrypoints/openai/serving_speech.py` in a **single commit**. Adding them
-piecemeal causes partial-integration failures that are hard to debug.
-
-### 1. Stage constant
-
-Near the top of the file, alongside the other `_*_TTS_MODEL_STAGES` constants:
-
-```python
-_YOUR_MODEL_TTS_MODEL_STAGES = {"your_model_stage_key"}
-```
-
-### 2. Union into `_TTS_MODEL_STAGES`
-
-Add to the `_TTS_MODEL_STAGES` set union:
-
-```python
-_TTS_MODEL_STAGES: set[str] = (
- ...
- | _YOUR_MODEL_TTS_MODEL_STAGES
-)
-```
-
-### 3. Model type detection
-
-In `_detect_tts_model_type()`, add before the final `return None`:
-
-```python
-if model_stage in _YOUR_MODEL_TTS_MODEL_STAGES:
- return "your_model"
-```
-
-### 4. Request validation dispatch
-
-In `_validate_tts_request()`, add before the fallback `return`:
-
-```python
-if self._tts_model_type == "your_model":
- return self._validate_your_model_request(request)
-```
-
-### 5. Validation and parameter-builder methods
-
-Add two new methods:
-
-```python
-def _validate_your_model_request(
- self, request: OpenAICreateSpeechRequest
-) -> str | None:
- """Validate YourModel request. Returns an error string or None."""
- if not request.input or not request.input.strip():
- return "Input text cannot be empty"
- return None
-
-def _build_your_model_params(
- self, request: OpenAICreateSpeechRequest
-) -> dict[str, Any]:
- """Build additional_information dict for YourModel."""
- params: dict[str, Any] = {"text": [request.input]}
- if request.voice is not None:
- params["voice"] = [request.voice]
- # Add any other model-specific fields here
- return params
-```
-
-Then wire `_build_your_model_params` into the request-dispatch block in
-`_create_tts_request()` (search for the equivalent `_build_*_params` call for an
-existing model to find the right location). If the model supports voice cloning
-(`ref_audio` → `prompt_audio_path`, `ref_text` → `prompt_text`), add those mappings
-here too — follow any existing `_build__params` in `serving_speech.py` (e.g.
-`_build_moss_tts_params` for the voice-cloning variant) for the pattern.
-
-> **Two dispatch patterns coexist:** Fish Speech uses a `self._is_fish_speech` boolean
-> checked *before* `elif self._is_tts`. All newer models use the `_tts_model_type`
-> string pattern shown above. For new models, always use the string pattern — do not
-> add new `_is_*` boolean flags.
-
-> **Note on unused variables:** Only extract parameters in `_build_your_model_params`
-> that you actually pass to the model's generate / `inference_stream` call. Extracting
-> a variable without forwarding it will trigger a `ruff F841` pre-commit failure.
-
-### Merge conflicts
-
-`serving_speech.py` is modified by every new model PR and is the most common source of
-rebase conflicts. When rebasing onto `main` and a conflict appears here, the resolution
-is always to **keep both** the upstream model's additions and your own — never discard
-either side. After resolving:
-
-```bash
-git add vllm_omni/entrypoints/openai/serving_speech.py
-git rebase --continue
-```
-
-## Single-Stage Models
-
-Some TTS models (e.g. MOSS-TTS-Nano) do not use a two-stage pipeline. Instead the
-entire AR LM and audio decoder run inside a single AR worker, streaming audio chunks
-directly from the model's own generator.
-
-### Directory structure
-
-```
-vllm_omni/model_executor/models/your_model_name/
- __init__.py
- modeling_your_model_name.py # unified class: load_weights + forward + streaming
-
-vllm_omni/model_executor/stage_configs/your_model_name.yaml
-```
-
-No stage input processor is needed.
-
-### Stage config
-
-Use a single stage with `worker_type: ar`. The `is_comprehension: true` field and the
-top-level `async_chunk: false` are required — omitting them causes silent
-misclassification in the serving layer. Set `max_num_seqs` to at least 4 for
-concurrent production use.
-
-```yaml
-# stage_configs/your_model_name.yaml
-async_chunk: false
-
-stage_args:
- - stage_id: 0
- stage_type: llm
- is_comprehension: true # required for serving_speech.py dispatch
- runtime:
- devices: "0"
- engine_args:
- model_stage: your_model_stage_key
- model_arch: YourModelForCausalLM
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- engine_output_type: audio
- max_num_seqs: 4 # min 4 for concurrent requests; default 1 causes gaps
- final_output: true
- final_output_type: audio
-```
-
-### Generator-based streaming pattern
-
-This is the MOSS-TTS-Nano pattern, distinct from VoxCPM2's vLLM-native AR pattern
-(see `plan/voxcpm2_native_ar_design.md` for that variant). Load model weights in
-`load_weights()` (not `__init__`) so vLLM finishes distributed initialisation before
-any CUDA allocations. Stream via a per-request generator stored in an instance dict:
-
-```python
-class YourModelForCausalLM(nn.Module):
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
- self._lm: nn.Module | None = None # populated in load_weights()
- self._stream_gens: dict[str, Any] = {} # request_key → generator
-
- def load_weights(self, weights):
- # Load self._lm here, after vLLM distributed init
- ...
-
- def forward(
- self,
- input_ids,
- positions,
- intermediate_tensors=None,
- inputs_embeds=None,
- runtime_additional_information: list[dict] | None = None, # one dict per request
- **kwargs,
- ) -> OmniOutput:
- infos = runtime_additional_information or [{}]
- # Return empty output during dummy/profiling calls
- if not runtime_additional_information or all(i.get("_is_dummy") for i in infos):
- self._ar_emit_stop_token = True
- return OmniOutput(...)
-
- outputs, last_flags = [], []
- for info in infos:
- request_key = str(info.get("_omni_req_id", "0")) # set by vLLM, not user code
- if request_key not in self._stream_gens:
- self._stream_gens[request_key] = self._create_stream_gen(info)
- try:
- chunk, is_last = next(self._stream_gens[request_key])
- except StopIteration:
- chunk, is_last = torch.zeros(0), True
- if is_last:
- del self._stream_gens[request_key]
- outputs.append(chunk)
- last_flags.append(is_last)
-
- self._ar_emit_stop_token = all(last_flags)
- return OmniOutput(multimodal_outputs={"model_outputs": outputs, "is_last": last_flags})
-
- def _create_stream_gen(self, info: dict):
- """Yield (waveform_tensor, is_last) from the model's inference_stream().
-
- Handle both incremental ("audio" events) and batch ("result" event) models:
- some upstream implementations emit one "result" event with the full waveform
- instead of incremental "audio" events. Both paths must be covered.
- """
- for event in self._lm.inference_stream(...):
- if event["type"] == "audio":
- yield event["waveform"], False
- elif event["type"] == "result":
- # Fallback for models that don't emit incremental audio events
- yield event.get("waveform", torch.zeros(0)), True
- return
- yield torch.zeros(0), True
-
- def compute_logits(self, hidden_states, sampling_metadata):
- # Emit EOS only when the last chunk has been yielded so the AR
- # scheduler ends the request at the right time.
- ...
-```
-
-For an in-tree reference, look for any single-stage AR model under
-`vllm_omni/model_executor/models/` (for example
-`moss_tts_nano/modeling_moss_tts_nano.py` once its integration has landed).
-
-## Pre-commit and DCO
-
-All contributions must pass the pre-commit checks and the Developer Certificate of
-Origin (DCO) sign-off before merging.
-
-### Running pre-commit
-
-Install the hooks once with `pre-commit install`. Then run before committing:
-
-```bash
-pre-commit run --files \
- vllm_omni/model_executor/models/your_model_name/*.py \
- vllm_omni/entrypoints/openai/serving_speech.py \
- vllm_omni/model_executor/models/registry.py \
- tests/e2e/offline_inference/test_your_model_name.py \
- tests/e2e/online_serving/test_your_model_name.py
-```
-
-When pre-commit **modifies files**, it exits with a non-zero code but the reformatting
-is correct. Stage the modified files and commit again — do not revert the changes.
-
-Common failures and fixes:
-
-| Check | Cause | Fix |
-|-------|-------|-----|
-| `ruff F841` | Local variable assigned but never used | Remove the extraction or forward it to the model call |
-| `ruff E402` | Module-level import not at top of file | Move import to the top-level import block |
-| `ruff format` | Line length, spacing, or quote style | Accept the auto-fix, stage, and re-commit |
-
-### DCO sign-off
-
-Every commit must carry a `Signed-off-by` trailer. Use the `-s` flag when committing:
-
-```bash
-git commit -s -m "feat(your-model): add YourModel TTS support"
-```
-
-Or configure git to add it automatically:
-
-```bash
-git config format.signOff true
-```
-
-To fix a missing sign-off on the most recent commit:
-
-```bash
-git commit --amend -s --no-edit
-git push origin your-branch --force-with-lease
-```
-
-> The DCO check verifies that the commit author email matches the `Signed-off-by` email.
-> Make sure `git config user.email` is set to the address associated with your GitHub
-> account before committing.
-
## Adding a Model Recipe
After implementing and testing your model, add a model recipe to the
@@ -936,19 +557,15 @@ for the expected format.
Adding a TTS model to vLLM-Omni involves:
-1. **Create model directory** with AR stage, decoder stage, and unified class (two-stage)
- or a single unified class with generator-based streaming (single-stage)
+1. **Create model directory** with AR stage, decoder stage, and unified class
2. **AR stage** - use vLLM's native decoder layers with fused QKV; do not wrap HF directly
3. **Decoder stage** - thin wrapper around your audio decoder; implement `chunked_decode_streaming()`
4. **Unified class** - dispatches on `model_stage`; same structure as `Qwen3TTSModelForGeneration`
5. **Register** all stage classes in `registry.py`
-6. **YAML configs** - provide both batch and `async_chunk` variants (two-stage), or a single-stage AR config
-7. **Stage input processor** - buffer Stage 0 outputs and forward in chunks of 25 (two-stage only)
-8. **Online serving** - add all 5 integration points to `serving_speech.py` in one commit
-9. **Tests** - cover single request, batching, and streaming
-10. **Pre-commit + DCO** - run `pre-commit` before pushing; sign every commit with `git commit -s`
-11. **Model recipe** - add to [vllm-project/recipes](https://github.com/vllm-project/recipes)
-12. **Invariants** - re-check I1–I5 (streaming contract, consumer hygiene, hot-loop discipline, validation pyramid, per-request state) at the end of every phase
+6. **YAML configs** - provide both batch and `async_chunk` variants
+7. **Stage input processor** - buffer Stage 0 outputs and forward in chunks of 25
+8. **Tests** - cover single request, batching, and async_chunk streaming
+9. **Model recipe** - add to [vllm-project/recipes](https://github.com/vllm-project/recipes)
### Qwen3-TTS Reference Files
@@ -957,12 +574,11 @@ Adding a TTS model to vLLM-Omni involves:
| `models/qwen3_tts/qwen3_tts.py` | Unified model class |
| `models/qwen3_tts/qwen3_tts_code_predictor_vllm.py` | AR stage with vLLM fused ops |
| `models/qwen3_tts/qwen3_tts_code2wav.py` | Decoder stage with `chunked_decode_streaming()` |
-| `models/qwen3_tts/pipeline.py` | Frozen pipeline topology (registered at import time) |
-| `deploy/qwen3_tts.yaml` | Deploy config (user-editable, async_chunk + SharedMemoryConnector) |
+| `stage_configs/qwen3_tts.yaml` | Stage configuration |
| `stage_input_processors/qwen3_tts.py` | Stage transition processors |
For more information, see:
- [Architecture Overview](../../design/architecture_overview.md)
-- [Async Chunk Design](../../design/feature/async_chunk.md)
+- [Async Chunk Design](../../design/feature/async_chunk_design.md)
- [Stage Configuration Guide](../../configuration/stage_configs.md)
diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md
index e1dbc8234b0..7a2e64f1312 100644
--- a/docs/contributing/profiling.md
+++ b/docs/contributing/profiling.md
@@ -1,286 +1,216 @@
-# Profiling Diffusion Models
+# Profiling vLLM-Omni
-> **Warning:** Profiling is for development and debugging only. It adds significant overhead and should not be enabled in production.
+> **Warning:** Profiling incurs significant overhead. Use only for development and debugging, never in production.
-Diffusion profiling supports two backends through `profiler_config`:
+vLLM-Omni uses the PyTorch Profiler to analyze performance across both **multi-stage omni-modality models** and **diffusion models**.
-- `torch`: detailed CPU/CUDA traces, operator tables, and optional memory snapshots
-- `cuda`: low-overhead CUDA range control for NVIDIA Nsight Systems (`nsys`)
+### 1. Configure Profiling in the Stage YAML
-## 1. Configure `profiler_config`
-
-Use `profiler_config` to enable profiling for a diffusion model. For diffusion usage, pass it directly to `Omni(...)` or `vllm serve`.
-
-Minimal torch-profiler config:
+Enable profiling by adding `profiler_config` under `engine_args` for the stage(s) you want to profile in your stage config YAML:
```yaml
-profiler_config:
- profiler: torch
- torch_profiler_dir: ./perf
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ engine_args:
+ # ... other engine args ...
+ profiler_config:
+ profiler: torch
+ torch_profiler_dir: ./perf
```
-Supported fields:
-
| Field | Description |
|---|---|
-| `profiler` | Profiler backend. Supported values: `torch`, `cuda`. Use `torch` for `trace.json`, Excel operator tables, and optional memory snapshots. Use `cuda` for Nsight Systems only. |
-| `torch_profiler_dir` | Output directory for torch-profiler artifacts. Required when `profiler: torch`. |
-| `torch_profiler_use_gzip` | Compress `trace_rank*.json` into `trace_rank*.json.gz`. |
-| `torch_profiler_record_shapes` | Record input shapes and add a `by_shape` sheet to `ops_rank*.xlsx`. |
-| `torch_profiler_with_stack` | Record call stacks, add a `by_stack` sheet to `ops_rank*.xlsx`, and export `stacks_cpu_rank*.txt` and `stacks_cuda_rank*.txt`. |
-| `torch_profiler_with_memory` | Enable memory profiling and attempt to dump `memory_snapshot_rank*.pickle`. The pickle is only generated when the current backend supports memory history and snapshot APIs. |
-| `torch_profiler_with_flops` | Enable FLOPs collection in `torch.profiler`. This does not add a separate output file. |
-| `torch_profiler_dump_cuda_time_total` | Export an additional text summary `profiler_out_.txt` sorted by `self_cuda_time_total`. |
-| `delay_iterations` | Number of worker iterations to skip before profiling starts. |
-| `max_iterations` | Maximum number of worker iterations to capture before auto-stop. |
-| `wait_iterations` | Torch-profiler wait iterations before warmup. |
-| `warmup_iterations` | Torch-profiler warmup iterations. |
-| `active_iterations` | Torch-profiler active iterations. |
-
-### Minimal configurations by output
-
-Only collect trace output:
+| `profiler` | Profiler backend to use. Currently supports `torch`. |
+| `torch_profiler_dir` | Directory where trace files are saved. Created automatically if it doesn't exist. |
-```python
-profiler_config = {
- "profiler": "torch",
- "torch_profiler_dir": "./perf",
-}
-```
+> **Tip:** Only enable `profiler_config` on stages you actually need to profile. Stages without it will not start a profiler, keeping overhead minimal.
-Outputs:
+### 2. Profiling Omni-Modality Models
-- `trace_rank*.json`
-- `ops_rank*.xlsx` with a `summary` sheet
+**Selective Stage Profiling**
-Collect compressed trace output:
+It is highly recommended to profile specific stages to prevent producing overly large trace files:
```python
-profiler_config = {
- "profiler": "torch",
- "torch_profiler_dir": "./perf",
- "torch_profiler_use_gzip": True,
-}
-```
-
-Outputs:
+# Profile all stages
+omni_llm.start_profile()
-- `trace_rank*.json.gz`
-- `ops_rank*.xlsx` with a `summary` sheet
+# Only profile Stage 1
+omni_llm.start_profile(stages=[1])
-Collect trace and full operator tables:
-
-```python
-profiler_config = {
- "profiler": "torch",
- "torch_profiler_dir": "./perf",
- "torch_profiler_record_shapes": True,
- "torch_profiler_with_stack": True,
-}
+# Stage 0 (Thinker) and Stage 2 (Audio Decoder) for qwen omni
+omni_llm.start_profile(stages=[0, 2])
```
-Outputs:
+> **Important:** Always pass the same `stages` list to both `start_profile()` and `stop_profile()`. If you omit `stages` from `stop_profile()`, it defaults to stopping all stages — including ones that were never started — which will produce errors.
-- `trace_rank*.json`
-- `ops_rank*.xlsx` with `summary`, `by_shape`, and `by_stack`
-- `stacks_cpu_rank*.txt`
-- `stacks_cuda_rank*.txt`
-
-Collect trace, operator tables, and memory snapshots:
+**Python Usage**: Wrap your generation logic with `start_profile()` and `stop_profile()`.
```python
-profiler_config = {
- "profiler": "torch",
- "torch_profiler_dir": "./perf",
- "torch_profiler_record_shapes": True,
- "torch_profiler_with_stack": True,
- "torch_profiler_with_memory": True,
-}
-```
+profiler_stages = [0] # Only profile the stages you need
-Outputs:
+# 1. Start profiling
+omni.start_profile(stages=profiler_stages)
-- `trace_rank*.json`
-- `ops_rank*.xlsx` with `summary`, `by_shape`, and `by_stack`
-- `stacks_cpu_rank*.txt`
-- `stacks_cuda_rank*.txt`
-- `memory_snapshot_rank*.pickle` when supported by the current backend
+# Initialize generator
+omni_generator = omni.generate(prompts, sampling_params_list, py_generator=args.py_generator)
-### Full torch-profiler configuration
+total_requests = len(prompts)
+processed_count = 0
-If you want to enable the commonly used torch-profiler options together:
+# Main Processing Loop
+for stage_outputs in omni_generator:
-```python
-profiler_config = {
- "profiler": "torch",
- "torch_profiler_dir": "./perf",
- "torch_profiler_use_gzip": False,
- "torch_profiler_record_shapes": True,
- "torch_profiler_with_stack": True,
- "torch_profiler_with_memory": True,
- "torch_profiler_with_flops": False,
- "torch_profiler_dump_cuda_time_total": False,
- "delay_iterations": 0,
- "max_iterations": 0,
- "wait_iterations": 0,
- "warmup_iterations": 0,
- "active_iterations": 0,
-}
-```
+ # ... [Output processing logic for text/audio would go here] ...
-## 2. Profiling Diffusion with PyTorch Profiler
+ # Update count to track when to stop profiling
+ processed_count += len(stage_outputs.request_output)
-Single-stage diffusion models use `start_profile()` / `stop_profile()` controls. The profiler only writes artifacts after profiling has been started and then stopped.
+ # 2. Check if all requests are done to stop the profiler safely
+ if profiler_enabled and processed_count >= total_requests:
+ print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...")
-```python
-from vllm_omni import Omni
-
-omni = Omni(
- model="Wan-AI/Wan2.2-I2V-A14B-Diffusers",
- profiler_config={
- "profiler": "torch",
- "torch_profiler_dir": "./perf",
- },
-)
-
-omni.start_profile()
-...
-omni.stop_profile()
-```
+ # Stop the profiler while workers are still active
+ # Pass the same stages list used in start_profile()
+ omni_llm.stop_profile(stages=profiler_stages)
-For diffusion offline example scripts under `examples/offline_inference/`, pass `--profiler-config` as a JSON object. The script enables profiling when this argument is set and wraps generation with `start_profile()` / `stop_profile()`.
+ # Wait for traces to flush to disk
+ print("[Info] Waiting 30s for workers to write trace files to disk...")
+ time.sleep(30)
+ print("[Info] Trace export wait time finished.")
+
+omni_llm.close()
+```
-Example:
+**CLI Usage** (using `end2end.py`):
```bash
-python examples/offline_inference/image_to_video/image_to_video.py \
- --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \
- --image input.jpg \
- --prompt "A cat playing with yarn" \
- --profiler-config '{
- "profiler": "torch",
- "torch_profiler_dir": "./perf",
- "torch_profiler_record_shapes": true,
- "torch_profiler_with_stack": true
- }'
-```
+# Profile only Stage 0 (Thinker)
+python end2end.py --output-wav output_audio \
+ --query-type text --enable-profiler --profiler-stages 0
-Examples:
+# Profile Stage 0 and Stage 2
+python end2end.py --output-wav output_audio \
+ --query-type text --enable-profiler --profiler-stages 0 2
-1. [Image edit example](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py)
-2. [Image to video example](https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video)
+# Profile all stages (omit --profiler-stages)
+python end2end.py --output-wav output_audio \
+ --query-type text --enable-profiler
+```
-## 3. Profiling Diffusion with Nsight Systems (`nsys`)
+**Examples**:
-For Nsight Systems, use `profiler: cuda` and wrap the process with `nsys profile`.
+1. **Qwen2.5-Omni**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen2_5_omni/end2end.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen2_5_omni/end2end.py)
-```bash
-nsys profile \
- --trace-fork-before-exec=true \
- --cuda-graph-trace=node \
- --capture-range=cudaProfilerApi \
- --capture-range-end=repeat \
- -o diffusion_trace \
- python image_to_video.py ...
-```
+2. **Qwen3-Omni**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/qwen3_omni/end2end.py)
-The Python process being profiled must create the diffusion engine with:
+### 3. Profiling diffusion models
-```python
-profiler_config = {"profiler": "cuda"}
+Diffusion profiling is End-to-End, capturing encoding, denoising loops, and decoding. Standalone diffusion scripts use `--profiler-dir` to enable profiling.
+
+**CLI Usage:**
+```bash
+python image_to_video.py \
+ --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \
+ --image qwen-bear.png \
+ --prompt "A cat playing with yarn, smooth motion" \
+ --profiler-dir \
+ \
+ # Minimize Spatial Dimensions (Optional but helpful):
+ # Drastically reduces memory usage so the profiler doesn't
+ # crash due to overhead, though for accurate performance
+ # tuning you often want target resolutions.
+ --height 48 \
+ --width 64 \
+ \
+ # Minimize Temporal Dimension (Frames):
+ # Video models process 3D tensors (Time, Height, Width).
+ # Reducing frames to the absolute minimum (2) keeps the
+ # tensor size small, ensuring the trace file doesn't become
+ # multi-gigabytes in size.
+ --num-frames 2 \
+ \
+ # Minimize Iteration Loop (Steps):
+ # This is the most critical setting for profiling.
+ # Diffusion models run the same loop X times.
+ # Profiling 2 steps gives you the exact same performance
+ # data as 50 steps, but saves minutes of runtime and
+ # prevents the trace viewer from freezing.
+ --num-inference-steps 2 \
+ \
+ --guidance-scale 5.0 \
+ --guidance-scale-high 6.0 \
+ --boundary-ratio 0.875 \
+ --flow-shift 12.0 \
+ --fps 16 \
+ --output i2v_output.mp4
```
-Then call `start_profile()` before the requests you want to capture and `stop_profile()` after them. The diffusion worker processes open and close the CUDA capture range themselves, so `nsys` sees the actual GPU work instead of only the parent process.
+> **Note:** For diffusion stages within a multi-stage omni pipeline, use `profiler_config` in the stage YAML instead (see Section 1).
-## 4. Profiling Online Serving
+**Examples**:
-When `profiler_config.profiler` is set for a diffusion model, the server exposes:
+1. **Qwen image edit**: [https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py](https://github.com/vllm-project/vllm-omni/blob/main/examples/offline_inference/image_to_image/image_edit.py)
-- `POST /start_profile`
-- `POST /stop_profile`
+2. **Wan-AI/Wan2.2-I2V-A14B-Diffusers**: [https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video](https://github.com/vllm-project/vllm-omni/tree/main/examples/offline_inference/image_to_video)
-### Start the server
+### 4. Profiling Online Serving
-Single-stage diffusion serving with torch profiler:
+When `profiler_config` is set in the stage YAML, the server automatically exposes `/start_profile` and `/stop_profile` HTTP endpoints.
+**1. Start the server** with a stage YAML that has `profiler_config` enabled:
```bash
-vllm serve Wan-AI/Wan2.2-I2V-A14B-Diffusers \
- --omni \
- --port 8091 \
- --profiler-config '{
- "profiler": "torch",
- "torch_profiler_dir": "/tmp/vllm_profile_wan22_i2v",
- "torch_profiler_with_stack": true,
- "torch_profiler_with_flops": false,
- "torch_profiler_use_gzip": true,
- "torch_profiler_dump_cuda_time_total": false,
- "torch_profiler_record_shapes": true,
- "torch_profiler_with_memory": true,
- "delay_iterations": 0,
- "max_iterations": 0,
- "wait_iterations": 0,
- "warmup_iterations": 0,
- "active_iterations": 0
- }'
+vllm serve Qwen/Qwen2.5-Omni-7B \
+ --omni \
+ --stage-configs-path qwen2_5_omni.yaml \
+ --port 8091
```
-Single-stage diffusion serving with Nsight Systems:
+Or for one stage diffusion models:
```bash
-nsys profile \
- --trace-fork-before-exec=true \
- --cuda-graph-trace=node \
- --capture-range=cudaProfilerApi \
- --capture-range-end=repeat \
- -o serving_trace \
- vllm serve Wan-AI/Wan2.2-I2V-A14B-Diffusers \
- --omni \
- --port 8091 \
- --profiler-config '{"profiler": "cuda"}'
+vllm serve Wan-AI/Wan2.2-I2V-A14B-Diffusers --omni --port 8091 --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}'
```
-### Control capture
-
-Example profiling flow for an online Qwen-Image request:
-
+**2. Start profiling** by sending a POST request:
```bash
-# Start profiling.
+# Profile all stages that have profiler_config set
curl -X POST http://localhost:8091/start_profile
-# Send a Qwen-Image generation request while profiling is active.
-curl http://localhost:8091/v1/images/generations \
- -H "Content-Type: application/json" \
- -d '{
- "model": "Qwen/Qwen-Image",
- "prompt": "A red vintage bicycle parked beside a quiet canal at sunset"
- }'
-
-# Stop profiling and flush profiler artifacts.
-curl -X POST http://localhost:8091/stop_profile
+# Profile specific stages only
+curl -X POST http://localhost:8091/start_profile \
+ -H "Content-Type: application/json" \
+ -d '{"stages": [0]}'
```
-## 5. Diffusion Pipeline Profiler
+**3. Send your inference requests** as normal while the profiler is running.
-For lightweight per-stage pipeline timing such as `vae.decode` or `diffuse`, see [Diffusion Pipeline Profiler](model/adding_diffusion_model.md#diffusion-pipeline-profiler-performance-profiling). That utility logs stage durations only and does not generate torch-profiler artifacts such as `trace.json`, Excel tables, or memory snapshots.
+**4. Stop profiling** and collect traces:
+```bash
+# Stop all stages
+curl -X POST http://localhost:8091/stop_profile
+
+# Stop specific stages (must match the stages you started)
+curl -X POST http://localhost:8091/stop_profile \
+ -H "Content-Type: application/json" \
+ -d '{"stages": [0]}'
+```
-## 6. Analyze Results
+Trace files are written to the `torch_profiler_dir` specified in your stage YAML.
-Torch-profiler output:
+> **Important:** Always stop the same stages you started. Stopping a stage that was never started will produce errors.
-- Chrome/Perfetto trace: `trace_rank*.json` or `trace_rank*.json.gz`
-- Excel workbook: `ops_rank*.xlsx` with `summary`, and optional `by_shape` / `by_stack` sheets
-- Stack exports: `stacks_cpu_rank*.txt` and `stacks_cuda_rank*.txt` when stack capture is enabled
-- Memory snapshot: `memory_snapshot_rank*.pickle` when memory capture is enabled and supported by the backend
-- Optional CUDA-time text summary: `profiler_out_.txt` when `torch_profiler_dump_cuda_time_total` is enabled
+### 5. Analyzing Traces
-CUDA profiler / Nsight Systems output:
+Output files are saved to the `torch_profiler_dir` specified in your stage YAML config.
-- `.nsys-rep` report files written by `nsys -o ...`
+**Output**
+**Chrome Trace** (`.json.gz`): Visual timeline of kernels and stages. Open in Perfetto UI.
-Recommended viewers:
+**Viewing Tools:**
-- [Perfetto](https://ui.perfetto.dev/) for torch traces
-- `nsys stats .nsys-rep` for CLI summaries
-- Nsight Systems GUI for CUDA kernel timelines
+- [Perfetto](https://ui.perfetto.dev/) (recommended)
+- `chrome://tracing` (Chrome only)
-For upstream background on the underlying vLLM profiling infrastructure, see the [vLLM profiling guide](https://docs.vllm.ai/en/stable/contributing/profiling/).
+**Note**: vLLM-Omni reuses the PyTorch Profiler infrastructure from vLLM. See the official vLLM profiler documentation: [vLLM Profiling Guide](https://docs.vllm.ai/en/stable/contributing/profiling/)
diff --git a/docs/design/feature/async_chunk.md b/docs/design/feature/async_chunk_design.md
similarity index 80%
rename from docs/design/feature/async_chunk.md
rename to docs/design/feature/async_chunk_design.md
index 57b4209b8df..202ef0e18e8 100644
--- a/docs/design/feature/async_chunk.md
+++ b/docs/design/feature/async_chunk_design.md
@@ -1,4 +1,4 @@
-# Async Chunk
+# Async Chunk Design
## Table of Contents
@@ -19,7 +19,7 @@ The `async_chunk` feature enables asynchronous, chunked processing of data acros
For qwen3-omni:
- **Thinker → Talker**: Per decode step (typically chunk_size=1)
-- **Talker → Code2Wav**: Accumulated to `codec_chunk_frames` (default=25) before sending. During the initial phase, a dynamic initial chunk size (IC) is automatically selected based on server load to reduce TTFP. Use the per-request `initial_codec_chunk_frames` API field to override.
+- **Talker → Code2Wav**: Accumulated to `codec_chunk_frames` (default=25) before sending. During the initial phase, a dynamic initial chunk size (IC) is automatically selected based on server load to reduce TTFA. Use the per-request `initial_codec_chunk_frames` API field to override.
- **Code2Wav**: Streaming decode with code2wav chunk_size
With `async_chunk`:
@@ -75,85 +75,26 @@ Enabling **async_chunk** (False→True) sharply reduces time-to-first-audio (TTF
## Architecture
+### Data Flow
-### Async Chunk Pipeline Overview
-
-The following diagram illustrates the **Async Chunk Architecture** for multi-stage models (e.g., Qwen3-Omni with Thinker → Talker → Code2Wav), showing how data flows through the 4-stage pipeline with parallel processing and dual-stream output:
-
+#### Sequential Flow
-
-
+
+
-**Diagram Legend:**
-
-| Step | Stage Type | Description |
-|------|-----------|------------|
-| `prefill` | Initialization | Context processing, KV cache initialization |
-| `decode` | Autoregressive | Token-by-token generation in AR stages |
-| `codes` | Audio Encoding | RVQ codec codes from Talker stage |
-| `output` | Final Output | Text chunks or audio waveforms |
-
-### Data Flow
-
-#### Stage 0: Thinker (Multimodal Understanding + Text Generation)
-- **Prefill**: Processes multimodal input (text/image/audio/video), initializes KV cache
-- **Decode Loop**: Generates text tokens autoregressively
-- **Chunk Triggers**: Each decode step (typically `chunk_size=1`) can trigger downstream processing
-- **Dual Output**:
- - **Text Stream**: `text_0`, `text_1`, `text_2`... `text_n` streamed to output
- - **Hidden States**: Passed to Talker stage for audio synthesis
-
-#### Stage 1: Talker (Text → RVQ Audio Codes)
-- **Prefill**: Receives hidden states from Thinker as semantic condition
-- **Decode Loop**: Generates RVQ codec codes autoregressively
-- **Accumulation**: Codes accumulate to `codec_chunk_frames` (default=25) before forwarding
-- **Dynamic IC**: Initial chunk size auto-selected based on server load to optimize TTFP
-- **Output**: `codes` blocks (chunk 0, 1, ... n) sent to Code2Wav
-
-#### Stage 2: Code2Wav (Vocoder Decoder)
-- **Non-Autoregressive**: Processes RVQ codes in parallel batches
-- **Streaming Decode**: Converts codes to audio waveforms chunk-by-chunk
-- **Batching**: Supports batched inference for multiple concurrent requests
-- **Output**: Audio segments `audio_0`, `audio_1`, ... `audio_n`
-
-#### Stage 3: Output (Dual Stream)
-- **Text Streaming**: `text_0` → `text_1` → `text_2` → ... (user sees response in real-time)
-- **Audio Streaming**: `audio_0` → `audio_1` → ... (user hears audio progressively)
-
-### Execution Timeline
+#### Async Chunk Flow
-```
-Timeline: Parallel vs Sequential
-
-Sequential (async_chunk=false):
-[Thinker: ████████████████████] (2.0s)
- [Talker: ████████████████████] (3.0s)
- [Code2Wav: ████] (1.0s)
-Total: 6.0s, TTFP: 6.0s
-
-Async Chunk (async_chunk=true):
-[Thinker: ████░░░░████░░░░████] (2.0s, streaming)
- [Talker: ░░████░░░░████░░] (3.0s, parallel)
- [Code2Wav: ░░░░████░░] (1.0s, batched)
-Total: ~3.5s, TTFP: ~0.5s
-
-█ = Active computation ░ = Waiting/idle
-```
-
-#### Sequential Flow (for comparison)
-
-
+
+
-In sequential mode, each stage must wait for the previous stage to complete entirely before starting.
-
-### Async Chunk System Architecture
+### Async Chunk architecture
diff --git a/docs/design/feature/cfg_parallel.md b/docs/design/feature/cfg_parallel.md
index c73a87749f5..64decbe9560 100644
--- a/docs/design/feature/cfg_parallel.md
+++ b/docs/design/feature/cfg_parallel.md
@@ -25,9 +25,7 @@ In standard Classifier-Free Guidance, each diffusion step requires two forward p
1. **Positive/Conditional**: Guided by the text prompt
2. **Negative/Unconditional**: Typically using empty or negative prompt
-Some models require 3 or more CFG branches (see [N-Branch CFG](#n-branch-cfg-3-branches)).
-
-CFG-Parallel eliminates this bottleneck by distributing the forward passes across different GPU ranks, allowing them to execute simultaneously rather than sequentially.
+CFG-Parallel eliminates this bottleneck by distributing the two forward passes across different GPU ranks, allowing them to execute simultaneously rather than sequentially.
### Architecture
@@ -35,11 +33,9 @@ vLLM-omni provides `CFGParallelMixin` that encapsulates all CFG parallel logic.
| Method | Purpose | Automatic Behavior |
|--------|---------|-------------------|
-| [`predict_noise_maybe_with_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Predict noise with 2-branch CFG | Detects parallel mode, distributes computation, gathers results |
-| [`predict_noise_with_multi_branch_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Predict noise with N-branch CFG | Round-robin dispatches N branches across M GPUs |
+| [`predict_noise_maybe_with_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Predict noise with CFG | Detects parallel mode, distributes computation, gathers results |
| [`scheduler_step_maybe_with_cfg()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Step scheduler | All ranks step locally (no broadcast needed) |
-| [`combine_cfg_noise()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Combine 2-branch predictions | Applies CFG formula with optional normalization |
-| [`combine_multi_branch_cfg_noise()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Combine N-branch predictions | Override for custom multi-branch combine logic |
+| [`combine_cfg_noise()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Combine positive/negative | Applies CFG formula with optional normalization |
| [`predict_noise()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Forward pass wrapper | Override for custom transformer calls |
| [`cfg_normalize_function()`](https://docs.vllm.ai/projects/vllm-omni/en/latest/api/vllm_omni/diffusion/distributed/cfg_parallel/) | Normalize CFG output | Override for custom normalization |
@@ -61,22 +57,6 @@ vLLM-omni provides `CFGParallelMixin` that encapsulates all CFG parallel logic.
- All ranks compute the scheduler step locally — no broadcast needed because `predict_noise_maybe_with_cfg` already ensures all ranks have identical noise predictions after `all_gather` + local combine.
-### N-Branch CFG (3+ branches)
-
-Some models require more than 2 CFG branches. For example, Bagel and OmniGen2 use 3 branches, DreamID Omni uses 4 branches.
-
-`predict_noise_with_multi_branch_cfg()` handles these by automatically dispatching N branches across M GPUs using round-robin (rule: branch `i` → rank `i % M`):
-
-| Branches (N) | GPUs (M) | Dispatch |
-|:---:|:---:|:---|
-| 3 | 2 | `[[0, 2], [1]]` |
-| 3 | 3 | `[[0], [1], [2]]` |
-| 4 | 2 | `[[0, 2], [1, 3]]` |
-| 4 | 3 | `[[0, 3], [1], [2]]` |
-| 4 | 4 | `[[0], [1], [2], [3]]` |
-
-When a rank handles multiple branches, it runs them sequentially. After `all_gather`, all ranks execute `combine_multi_branch_cfg_noise()` locally, producing identical results.
-
---
## Step-by-Step Implementation
@@ -118,7 +98,6 @@ class YourModelPipeline(nn.Module, CFGParallelMixin):
- `positive_kwargs`: transformer arguments for conditional (text-guided) prediction
- `negative_kwargs`: transformer arguments for unconditional prediction (set to `None` if CFG disabled)
- For image editing pipelines, add `output_slice=image_seq_len` to extract the generative image portion
-- For models with 3+ CFG branches, see [Multi-Branch CFG](#multi-branch-cfg-3-branches) in the Customization section
### Step 2: Call `diffuse`
@@ -192,42 +171,20 @@ class LongCatImagePipeline(nn.Module, CFGParallelMixin):
```
-### Multi-Branch CFG (3+ branches)
-
-For models with 3 or more CFG branches, use `predict_noise_with_multi_branch_cfg()` instead of `predict_noise_maybe_with_cfg()`, and override `combine_multi_branch_cfg_noise()` for custom combine logic. This interface also works for standard 2-branch CFG — just pass 2 branches in `branches_kwargs`.
+### Override `combine_cfg_noise()` for Multi-Output Models
-**Example (3-branch with dual guidance scale):**
+When `predict_noise()` returns a tuple (e.g., video + audio), the default `combine_cfg_noise()` applies CFG to every element. Override it to apply different logic per element — for example, CFG on video but positive-only on audio:
```python
-class YourMultiBranchPipeline(nn.Module, CFGParallelMixin):
- def combine_multi_branch_cfg_noise(self, predictions, true_cfg_scale, cfg_normalize=False):
- text_scale = true_cfg_scale["text"]
- image_scale = true_cfg_scale["image"]
- pos, ref, uncond = predictions
- return uncond + image_scale * (ref - uncond) + text_scale * (pos - ref)
-
- def diffuse(self, ...):
- for i, t in enumerate(timesteps):
- positive_kwargs = {...} # conditional prompt
- ref_neg_kwargs = {...} # negative prompt + reference
- uncond_kwargs = {...} # unconditional
-
- noise_pred = self.predict_noise_with_multi_branch_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale={"text": text_guidance_scale, "image": image_guidance_scale},
- branches_kwargs=[positive_kwargs, ref_neg_kwargs, uncond_kwargs],
- )
- latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
-
- return latents
+class MyVideoAudioPipeline(nn.Module, CFGParallelMixin):
+ def combine_cfg_noise(self, positive_noise_pred, negative_noise_pred, scale, normalize):
+ (video_pos, audio_pos) = positive_noise_pred
+ (video_neg, audio_neg) = negative_noise_pred
+ video_combined = super().combine_cfg_noise(video_pos, video_neg, scale, normalize)
+ return (video_combined, audio_pos) # audio: positive only, no CFG
```
-### Override Combine Functions
-
-There are two combine functions for different scenarios:
-
-- **`combine_cfg_noise()`** — Used by `predict_noise_maybe_with_cfg()`. Override when `predict_noise()` returns a tuple (e.g., video + audio) and you need per-element CFG logic.
-- **`combine_multi_branch_cfg_noise()`** — Used by `predict_noise_with_multi_branch_cfg()`. Override to implement custom multi-branch combine formulas (see [Multi-Branch CFG](#multi-branch-cfg-3-branches) above).
+This also requires `predict_noise()` to return a tuple (see [Override predict_noise](#override-predict_noise-for-custom-transformer-calls) above).
### Implement a Composite Scheduler for Multi-Output Models
@@ -346,5 +303,4 @@ Adding CFG-Parallel support:
1. ✅ **Create mixin** - Inherit from `CFGParallelMixin` and implement `diffuse()` method
2. ✅ **(Optional) Customize** - Override `predict_noise()` or `cfg_normalize_function()` for custom behavior
-3. ✅ **(Optional) Multi-branch** - For 3+ branch models, use `predict_noise_with_multi_branch_cfg()` and override `combine_multi_branch_cfg_noise()`
-4. ✅ **Test** - Verify with `--cfg-parallel-size 2` (or 3/4 for multi-branch) and compare performance
+3. ✅ **Test** - Verify with `--cfg-parallel-size 2` and compare performance
diff --git a/docs/design/feature/expert_parallel.md b/docs/design/feature/expert_parallel.md
deleted file mode 100644
index e05eec33613..00000000000
--- a/docs/design/feature/expert_parallel.md
+++ /dev/null
@@ -1,221 +0,0 @@
-# Expert Parallel
-
-This section describes how to add Expert Parallel (EP) to a diffusion transformer that uses Mixture-of-Experts (MoE) layers.
-We use **HunyuanImage3.0** as the reference implementation.
-
----
-
-## Table of Contents
-
-- [Overview](#overview)
-- [Step-by-Step Implementation](#step-by-step-implementation)
-- [Testing](#testing)
-- [Reference Implementations](#reference-implementations)
-- [Summary](#summary)
-
----
-
-## Overview
-
-### What is Expert Parallel?
-
-**Expert Parallel** is a parallelism strategy in Mixture-of-Experts (MoE) models that distributes different expert networks across distinct computational devices. Each device holds and computes only a subset of experts (local experts), with tokens dispatched to and gathered from remote devices via collective communication operations (e.g., All-to-All, All-Gather).
-
-| Backend | Description |
-|---------|-------------|
-| `allgather_reducescatter` | Default backend based on allgather/reducescatter primitives, suitable for general EP+DP deployments.|
-
-## Configuration
-
-Enable EP by setting the `--enable-expert-parallel` flag. The EP size is automatically calculated as:
-
-```text
-EP_SIZE = TP_SIZE × SP_SIZE × CFG_SIZE × DP_SIZE
-```
-
-
-Where:
-
-- `TP_SIZE`: Tensor parallel size
-- `SP_SIZE`: Sequence parallel size
-- `CFG_SIZE`: Classifier-free guidance parallel size
-- `DP_SIZE`: Data parallel size
-- `EP_SIZE`: Expert parallel size (computed automatically)
-
-Note:
-- Expert parallelism is only applicable to Mixture-of-Experts (MoE) models.
-- The EP group is created **per pipeline stage**, meaning it includes all ranks that participate in model parallelism except pipeline parallelism.
-- The underlying communication pattern for expert parallelism is **All-to-All** among the ranks in the EP group.
-
-For example, consider a configuration with `TP=2`, `SP=1`, `CFG=2`, and `DP=4` (total 2×1×2×4 = 16 GPUs).
-
-- Expert layers are handled by an EP group of size 16.
-
-- Attention layers use tensor parallelism of size 2 within each of the 8 DP groups (because `DP×CFG×SP = 4×2×1 = 8` groups, each containing the 2 TP ranks). Inside each such group, the attention weights are sharded across the 2 GPUs.
-
-
-## Step-by-Step Implementation
-
-### Step 1: Configure Expert Parallelism Settings
-
-Calculate local experts per rank:
-
-```
-ep_size = 8 # Expert Parallel size (typically equals TP size)
-num_experts = 64
-num_local_experts = num_experts // ep_size # 8 experts per card
-
-# Check divisibility
-assert num_experts % ep_size == 0, "Experts must be divisible by EP size"
-```
-
-### Step 2: Use Sparse MoE Block to enable EP routing.
-
-Example:
-```
-from vllm.model_executor.layers.linear import ReplicatedLinear
-class HunYuanSparseMoeBlock(nn.Module):
- def __init__(
- self,
- config: PretrainedConfig,
- layer_id: int = -1,
- prefix: str = "",
- ):
- super().__init__()
- self.tp_size = get_tensor_model_parallel_world_size()
- self.n_routed_experts = config.num_experts # 64
-
- # Calculate local experts per rank (key for EP)
- if self.tp_size > self.n_routed_experts:
- raise ValueError(f"TP size {self.tp_size} > experts {self.n_routed_experts}")
-
- # Routing gate (replicated on all ranks, computes scores for all tokens to all experts)
- self.gate = ReplicatedLinear(
- config.hidden_size,
- config.num_experts,
- bias=False,
- quant_config=None,
- prefix=f"{prefix}.gate",
- )
-
- # EP expert layer (factory loads platform-specific implementation)
- self.experts = HunyuanFusedMoE(...)
-```
-**Key Points:**
-- gate is **ReplicatedLinear** (replicated on all ranks)
-- experts is created via **HunyuanFusedMoE factory**, which automatically handles EP dispatch
-
-### Step 3: Initialize EP Runtime
-
-Initialize the EP communication context before model loading.
-```
-from vllm.utils.import_utils import resolve_obj_by_qualname
-# Call during __init__ or model loading
-op_name = "hunyuan_fused_moe"
-
-# Prepare EP runtime: establish communication groups, assign local expert indices, init _expert_map
-current_omni_platform.prepare_diffusion_op_runtime(op_name)
-
-# Factory automatically resolves platform implementation (GPU: FusedMoE / NPU: AscendFusedMoE)
-impl = resolve_obj_by_qualname(
- current_omni_platform.get_diffusion_model_impl_qualname(op_name)
-)
-```
-
-### Step 4: Expert Weight Mapping & Loading
-
-Each rank loads only the expert weights assigned to its local allocation.
-```
-# Get expert parameter mapping (different per rank)
-expert_mapping = HunyuanFusedMoE.make_expert_params_mapping(
- model=self,
- ckpt_gate_proj_name="gate_proj",
- ckpt_down_proj_name="down_proj",
- ckpt_up_proj_name="up_proj",
- num_experts=64,
- num_redundant_experts=0,
-)
-# Returns: [(param_name, weight_name, expert_id, shard_id), ...]
-# Note: Each rank only contains mappings for its local expert_ids
-
-# Filter non-local experts during loading
-for name, loaded_weight in weights:
- if "mlp.experts" in name:
- # Parse expert_id from weight name (implementation needed)
- expert_id = parse_expert_id_from_name(name)
- local_expert_start = (ep_rank) * num_local_experts
- local_expert_end = (ep_rank + 1) * num_local_experts
-
- if not (local_expert_start <= expert_id < local_expert_end):
- continue # Skip non-local expert weights
-```
-### Step 5: Forward Pass with EP
-
-Example (MoE Forward):
-```
-def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- orig_shape = hidden_states.shape
- hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
-
- # 1. Global routing computation (all tokens, all expert scores)
- # hidden_states: [num_tokens, hidden_dim] (full tensor)
- router_logits, _ = self.gate(hidden_states) # [num_tokens, num_experts]
-
- # 2. EP dispatch and compute (HunyuanFusedMoE handles all_to_all internally)
- # - Dispatch: Send tokens to target ranks based on router_logits
- # - Local Compute: Each rank processes only its num_local_experts
- # - Combine: Results returned to original token positions
- final_hidden_states = self.experts(
- hidden_states=hidden_states,
- router_logits=router_logits,
- )
-
- # 3. Add shared expert output (not EP, computed on all ranks)
- if self.shared_mlp is not None:
- shared_out = self.shared_mlp(hidden_states)
- final_hidden_states = final_hidden_states + shared_out
-
- # 4. Tensor Parallel All-Reduce (synchronize across TP group)
- if self.tp_size > 1:
- final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
- final_hidden_states
- )
-
- return final_hidden_states.view(orig_shape)
-```
-
-## Testing
-After adding Expert Parallel support, test via command line:
-```bash
-cd examples/offline_inference/text_to_image
-python text_to_image.py \
- --model Your-org/your-model \
- --prompt "a cup of coffee on the table" \
- --output "ep_enabled.png" \
- --num-inference-steps 50 \
- --guidance-scale 5.0 \
- --tensor-parallel-size 8 \
- --seed 1234 \
- --enable-expert-parallel
-```
-
-vLLM‑Omni currently focuses on core diffusion model inference acceleration, so the Expert Parallel implementation includes only the basic multi‑GPU expert sharding functionality (enabled via --enable-expert-parallel). Advanced features such as communication backend selection (--all2all-backend), load balancing (--enable-eplb and its configuration), and multi‑node deployment belong to the extended capabilities of the main vLLM project and have not yet been integrated into Omni.
-
-## Reference Implementations
-
-Complete examples in the codebase:
-
-| Model | Path | Pattern | Notes |
-|-------|------|---------|-------|
-| **HunyuanImage3.0** | `vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py` | Standard EP | Full implementation with validation |
-| **EP Tests** | `vllm-omni/tests/e2e/offline_inference/test_expert_parallel.py` | E2E testing | EP correctness and performance |
-| **Constraint Tests** | `vllm-omni/tests/diffusion/models/hunyuan_image3/test_hunyuan_fused_moe.py` | Unit testing | Validation logic |
-
----
-## Summary
-
-Adding Expert Parallel support to diffusion model:
-
-1. **Identify MoE layers** - Locate the router and expert networks in each transformer block.
-2. **Validate EP constraints** – Ensure num_experts is divisible by expert_parallel_size.
-3. **Test** - Run with enable-expert-parallel, check memory reduction, speedup, and output quality against single‑GPU baseline.
diff --git a/docs/design/feature/omni_connectors/mooncake_transfer_engine_connector.md b/docs/design/feature/omni_connectors/mooncake_transfer_engine_connector.md
index 306a0620b4b..798644b96ff 100644
--- a/docs/design/feature/omni_connectors/mooncake_transfer_engine_connector.md
+++ b/docs/design/feature/omni_connectors/mooncake_transfer_engine_connector.md
@@ -33,8 +33,8 @@ runtime:
zmq_port: 50051 # ZMQ base port (see "Port Offset Scheme" below)
protocol: "rdma" # "rdma" or "tcp"
device_name: "" # RDMA device (e.g., "mlx5_0"), empty for auto-detect
- memory_pool_size: 4294967296 # 4 GB (CPU); use 2147483648 (2 GB) for GPU
- memory_pool_device: "cpu" # "cpu" for pinned memory (recommended), "cuda" for GPUDirect RDMA
+ memory_pool_size: 2147483648 # 2GB memory pool
+ memory_pool_device: "cpu" # "cpu" for pinned memory, "cuda" for GPUDirect RDMA
```
Wire stages to the connector:
@@ -64,8 +64,8 @@ stage_args:
| Parameter | Default | Description |
|---|---|---|
-| `memory_pool_size` | 4 GB (CPU) / 2 GB (GPU) | Total size of the RDMA-registered memory pool in bytes. Recommended 4 GB for CPU pinned memory; 2 GB for GPU VRAM to conserve device memory. |
-| `memory_pool_device` | `"cpu"` | `"cpu"`: pinned host memory (recommended, works on all topologies). `"cuda"`: GPU VRAM for GPUDirect RDMA (requires NIC-GPU direct PCIe connectivity, PIX topology). |
+| `memory_pool_size` | 1 GB | Total size of the RDMA-registered memory pool in bytes. |
+| `memory_pool_device` | `"cpu"` | `"cpu"`: pinned host memory (recommended). `"cuda"`: GPU VRAM for GPUDirect RDMA (requires NIC-GPU direct PCIe connectivity). |
### Networking
@@ -107,10 +107,10 @@ receiver_connect = remote_side_channel_port + tp_rank
## Memory Pool Modes
-| Mode | Config | Recommended Pool Size | Data Flow | Best For |
-|---|---|---|---|---|
-| CPU Pinned | `memory_pool_device: "cpu"` | 4 GB | GPU → CPU pool → RDMA → CPU pool → GPU | Most hardware topologies (recommended) |
-| GPUDirect | `memory_pool_device: "cuda"` | 2 GB | GPU → GPU pool → RDMA (NIC reads GPU BAR1) → GPU pool | NIC-GPU direct PCIe (PIX topology) |
+| Mode | Config | Data Flow | Best For |
+|---|---|---|---|
+| CPU Pinned | `memory_pool_device: "cpu"` | GPU → CPU pool → RDMA → CPU pool → GPU | Most hardware topologies (recommended) |
+| GPUDirect | `memory_pool_device: "cuda"` | GPU → GPU pool → RDMA (NIC reads GPU BAR1) → GPU pool | NIC-GPU direct PCIe (PIX topology) |
> **Note**: GPUDirect RDMA requires the NIC and GPU to share a direct PCIe
> switch (PIX topology). On systems where they are connected via PXB or NODE,
diff --git a/docs/design/feature/prefix_caching.md b/docs/design/feature/prefix_caching.md
deleted file mode 100644
index ebad8b69106..00000000000
--- a/docs/design/feature/prefix_caching.md
+++ /dev/null
@@ -1,164 +0,0 @@
-# Automatic Prefix Caching in Omni Models
-
-
----
-
-## Table of Contents
-
-- [Overview](#overview)
-- [High-Level Approach](#high-level-approach)
-- [Example](#example)
-- [What About Multimodal Inputs?](#what-about-multimodal-inputs)
-
----
-
-### Overview
-
-Prefix caching in the context of kv-cache management is a useful optimization for avoiding redundant computations. The main idea is that we store portions of the kv-cache from processed requests, so that we can reuse them if incoming requests have the same prefix as previous requests.
-
-vLLM manages the kv-cache as blocks, which represent a span of tokens of a fixed length. Blocks are hashable by the content that they contain, which typically means the tokens within the span, but also could be influenced by other factors, e.g., LoRA and multimodal data.
-
-vLLM implements automatic prefix caching for managing its kv-cache, which is best understood by reading the design document [here](https://docs.vllm.ai/en/latest/design/prefix_caching/). vLLM-Omni builds on top of the prefix caching mechanism in a noninvasive way to allow caching between stages in Omni pipelines. This typically means for a given stage we aim to support caching for the following:
-
-- The last hidden states produced by the stage
-- Model / stage specific multimodal data
-
-!!! note "Note 1"
- This document describes vLLM-Omni's mechanism for caching tensor outputs that are meant to be passed between stages, when requests have common prefixes, similar to the way in which vLLM has prefix caching for the kv-cache. This works in conjunction with vLLM's multimodal encoder caching, but is distinct. See the final section for a concrete example for how they tie together in practice.
-
-### High-Level Approach
-!!! note "Note 2"
- Prior to reading this section, it's recommended to take a look at the design documents in vLLM for [Automatic Prefix Caching](https://docs.vllm.ai/en/latest/features/automatic_prefix_caching/), which will make some of the concepts more clear.
-
-The main focus of vLLM-Omni's approach to prefix caching stage outputs is to build on vLLM's prefix caching in the least invasive way possible while minimizing impact for cache misses, and consuming a minimal amount of GPU memory. To understand the implementation, there are a few important things to note:
-
-- Between stages, device tensors are generally moved to CPU; this is important since we're just caching the outputs of stages, so it is okay to keep the entire cache on the CPU.
-
-- For a tensor to be considered cacheable, the first dimension (currently) needs to be the same as the token count, as it allows us to reuse block/slot mappings for our externally maintained tensor caches. This allows us to dynamically discover the tensors to be marked as cacheable outputs in each Omni model without having to explicitly specify cacheable output field names in every model.
-
-With this in mind, consider the set of blocks in a 2D layout, where the row represents the index of blocks being considered, and the columns represent the slots corresponding to tokens within each block. Since we know the `num_blocks` and `block_size` from our kv cache config, if we want to cache a tensor with feature size `D`, we can preallocate a CPU tensor of size `(num_blocks, block_size, D)`, and use the same block index and slot mapping to retrieve the corresponding feature vector.
-
-
-### Example
-!!! note "Note 3"
- Prefix caching in vLLM-Omni currently is only supported on AutoRegressive stages with one kv-cache group. It can be enabled/disabled per-stage via the `enable_prefix_caching` parameter in the model's stage config.
-
-The way in which vLLM-Omni ties into vLLM's prefix caching is best understood by example. Say that we have the following:
-
-- `num_blocks=8`
-- `block_size=4`
-- `hidden_size=2`
-- A stage specific multimodal output tensor named `mm_feature` with feature dimension `16`
-
-The prefix cache flow is then outlined below.
-
-1. When the model is initialized, we can determine the `hidden_size` from the `ModelConfig`, and allocate a cache of size `(num_blocks, block_size, hidden_size)`.
-
-2. Say we process the request `The quick brown fox was tired and slept beneath the shady tree`, which is 12 tokens and evenly divides into 3 blocks as shown below.
-
-```
- [ The quick brown fox ] [ was tired and slept ] [beneath the shady tree ]
-Block 1: |<--- block tokens ---->|
-Block 2: |<------- prefix ------>| |<--- block tokens --->|
-Block 3: |<------------------ prefix -------------------->| |<--- block tokens ---->|
-```
-
-When the request processes, we inspect the multimodal outputs and identify the `mm_feature` tensor, which will be of shape `(seq_len, feature_dim)`, i.e., `(12, 16)` in this example. We note that the first axis is dependent on the `seq_len` and add a new cache_tensor of shape `(num_blocks, block_size, feature_dim)` to our multimodal cache for tensors.
-
-
-3. If we lay out the cache as a 2D tensor of shape (`num_blocks`, `block_size`), we'll have something like the following:
-
-```
-0: [ The quick brown fox ]
-1: [ was tired and slept ]
-2: [beneath the shady tree ]
-3: [EMPTY]
-...
-7: [EMPTY]
-```
-
-Or, if we flatten it down to 1D,
-```
-0: The
-1: quick
-2: brown
-3: fox
-...
-11: tree
-12: [EMPTY]
-...
-```
-
-which we can think of as row indices into the hidden states tensor if we view it as the 2D shape `(num_blocks x block_size, feature_dim)`. That is, the analogous flattened (from 3D -> 2D) mapping of the cache for hidden states becomes the following.
-```
-0:
-1:
-2:
-3:
-...
-11:
-12: [EMPTY]
-...
-```
-
-Similarly, for the multimodal outputs cache, the flattened coordinates are the same, but the `mm_feature` maps to vectors of length `16` instead of the hidden size of `2`. Note that in practice, we may have multiple multimodal output tensors per forward pass, which may have different names and different feature dimensions.
-
-
-4. Now, say that we receive a new request `The quick brown fox jumped over the dog`.
-
-```
- [ The quick brown fox ] [ jumped over the dog ]
-Block 1: |<--- block tokens ---->|
-Block 2: |<------- prefix ------>| |<--- block tokens --->|
-```
-
-Here, we will have a cache hit for `Block 1` which will be detected by vLLM based on the hash of the first block when it's handling the prefix caching on the kv-cache. As a result, when we get the output from the scheduler, we will see that `num_computed_tokens=4` (corresponding to the cached first block), and we only need to process the remaining 4 new tokens in the new prefill.
-
-Since we have the block indices / slot mappings from the kv cache manager, we can simply mirror the mappings and leverage the same indices for the cached hidden states and multimodal outputs. This allows us to look up the correct tensors from our externally maintained 3D caches.
-
-```
-0: [ The quick brown fox ] < already in the cache
-1: [ was tired and slept ]
-2: [beneath the shady tree ]
-3: [ jumped over the dog ] < added on the second request
-4: [EMPTY]
-...
-7: [EMPTY]
-...
-```
-
-Finally, to pass the full hidden states and multimodal outputs to the next stage, we simply concatenate the cached contents with the corresponding new tensors computed from the current forward call.
-
-
-### What About Multimodal Inputs?
-It's also useful to consider the case about how Omni prefix caching is handled when we have multimodal inputs that don't cleanly end on block boundaries, as well as how this works with multimodal encoder caching in vLLM. For example:
-
-```
- [ Im0 Im1 Im2 Im3 ] [ Im4 Im5 foo ]
-Block 1: |<--- block tokens ---->|
-Block 2: |<------- prefix ------>| |<--- block tokens --->|
-```
-
-In this case, only `Block 1` will have outputs stored in the prefix tensor cache, because vLLM does not store partial blocks. This may appear to be a problem at first glance, because the multimodal input is fragmented across a new block that wasn't cached.
-
-In reality, this isn't a big problem for correctness, because vLLM also maintains an encoder cache for multimodal inputs. In other words, after the first pass, we'll have the following:
-
-- The Block 1 hash, which is used for prefix caching
-- The hash describing the image data starting at position 0 and with length 6
-- In vLLM's encoder cache, a mapping from the image hash above to the encoder output
-
-
-To understand what happens, say we get the following input as a second request:
-```
- [ Im0 Im1 Im2 Im3 ] [ Im4 Im5 bar baz ]
-Block 1: |<--- block tokens ---->|
-Block 2: |<------- prefix ------>| |<--- block tokens --->|
-```
-
-First, the scheduler will check for a prefix cache hit, which we will see on `Block 1`. As a result, we will have 4 tokens marked as precomputed, and only see the remaining 4 tokens in the following prefill.
-
-Because we have multimodal data in a scheduled span that isn't fully precomputed, we still need to call the visual encoder. However, since we have the image hash and encoder cache, we will retrieve the encoder outputs for `Im4` and `Im5` as we create the multimodal embeddings.
-
-When we pass our multimodal tensors to the language model component in the same stage, we'll then expect the same outputs, because the prefix caching behaviors in vLLM-Omni / vLLM match, so the LLM will use vLLM's KV cache manager's prefix caching to correctly handle the attention information for `Block 1` while calculating the outputs for `Block 2`, giving us the correct results for processing `Block 2` with the context of `Block 1`.
-
-Finally, we look up the output hidden states/multimodal tensors corresponding to the prefix cache hit `Block 1` and concatenate it with the forward pass result to get the final result, which is expected to be identical to the full hidden states when prefix caching is disabled.
diff --git a/docs/design/feature/teacache.md b/docs/design/feature/teacache.md
index 8577cff1f05..9fa315cee77 100644
--- a/docs/design/feature/teacache.md
+++ b/docs/design/feature/teacache.md
@@ -326,41 +326,9 @@ for prompt in tqdm(prompts, desc="Collecting data"):
# Estimate coefficients
coeffs = estimator.estimate(poly_order=4)
-print(f"Estimated coefficients: {coeffs}")
+print(f"Estimated coefficients: {coeffs.tolist()}")
```
-Note: some models may require the vLLM context and config to be initialized to initialize vLLM modules. To this end, you may need a workaround like the following to be able to run coefficient estimation.
-```python
-from vllm_omni.diffusion.forward_context import set_forward_context
-from vllm_omni.diffusion.distributed.parallel_state import (
- init_distributed_environment,
- initialize_model_parallel,
-)
-from vllm.config import VllmConfig
-...
-
-if __name__ == "__main__":
- os.environ["MASTER_ADDR"] = "localhost"
- os.environ["MASTER_PORT"] = "8192"
- os.environ["LOCAL_RANK"] = "0"
- os.environ["RANK"] = "0"
- os.environ["WORLD_SIZE"] = "1"
-
- vllm_config = VllmConfig()
- init_distributed_environment()
- initialize_model_parallel()
-
- # NOTE: you may have to pass an initialized OmniDiffusionConfig as a kwarg
- # here to make current sp checks happy; if this is the case, just create one
- # .from_kwargs() with the model name to get around this check for now,
- # since your estimator subclass should handle the actual model configuration.
- #
- # This will be cleaned up in the future
- with set_forward_context(vllm_config):
-
-```
-
-
**Data Statistics Guide:**
| Metric | Good Range | Warning Signs |
diff --git a/docs/design/feature/vae_parallel.md b/docs/design/feature/vae_parallel.md
index e330b41a68f..9009ece72a5 100644
--- a/docs/design/feature/vae_parallel.md
+++ b/docs/design/feature/vae_parallel.md
@@ -1,15 +1,14 @@
# VAE Patch Parallelism
This document describes how to add **VAE Patch Parallelism** support to a diffusion model.
-We use **Qwen-Image** as the reference implementation for decode parallel, and **Wan2.2** for encode parallel.
+We use **Qwen-Image** as the reference implementation.
---
## Table of Contents
- [Overview](#overview)
-- [Step-by-Step Implementation (Decode)](#step-by-step-implementation-decode)
-- [Encode Parallel Implementation](#encode-parallel-implementation)
+- [Step-by-Step Implementation](#step-by-step-implementation)
- [Testing](#testing)
- [Reference Implementations](#reference-implementations)
- [Summary](#summary)
@@ -20,13 +19,13 @@ We use **Qwen-Image** as the reference implementation for decode parallel, and *
### What is Vae Patch parallel?
-**VAE Patch Parallelism** is an acceleration technique for both **encoding** and **decoding**. Instead of processing the entire tensor at once, the tensor is:
+**VAE Patch Parallelism** is a decoding acceleration technique. Instead of decoding the entire latent tensor at once, the latent tensor is:
+ Split into multiple spatial tiles
+ Distributed across multiple ranks
-+ Encoded/Decoded in parallel
++ Decoded in parallel
+ Merged to reconstruct the final output
@@ -36,17 +35,10 @@ This approach:
+ Reduces peak memory usage per device
-+ Accelerates encoding/decoding latency
-
-### When to Use Encode vs Decode Parallel
-
-| Operation | Use Case | Example |
-|-----------|----------|---------|
-| **Decode Parallel** | Text-to-Image, Text-to-Video | Latent → Image/Video |
-| **Encode Parallel** | Image-to-Video (I2V) | Image → Latent (for conditioning) |
++ Accelerates decoding latency
### Architecture
-We introduce **DistributedVaeExecutor** as the core component responsible for distributed VAE encoding/decoding.
+We introduce **DistributedVaeExecutor** as the core component responsible for distributed VAE decoding.
The executor is model-agnostic and accepts three function parameters:
@@ -92,7 +84,7 @@ Therefore:
+ Merge must perform blending to avoid seams
-## Step-by-Step Implementation (Decode)
+## Step-by-Step Implementation
### Step 1: Implement DistributedAutoencoderKLQwenImage
`QwenImagePipeline` use `AutoencoderKLQwenImage` for vae, so implement a distributed version:
@@ -213,14 +205,14 @@ def tile_merge(self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid
We need to override tiled_decode, the main logic is:
+ check distributed is enabled
+ select split/exec/merge
-+ Invoke self.distributed_executor.execute to decode
++ Invoke self.distributed_decoder.execute to decode
```
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True):
if not self.is_distributed_enabled():
return super().tiled_decode(z, return_dict=return_dict)
logger.info("Decode run with distributed executor")
- result = self.distributed_executor.execute(
+ result = self.distributed_decoder.execute(
z,
DistributedOperator(split=self.tile_split, exec=self.tile_exec, merge=self.tile_merge),
broadcast_result=True,
@@ -251,166 +243,6 @@ class YourModelPipeline(nn.Module):
+ ).to(self.device)
```
-## Encode Parallel Implementation
-
-For models that require VAE encoding (e.g., Image-to-Video), you can also parallelize the encode operation. We use **Wan2.2** as the reference implementation.
-
-### Step 1: Implement encode_tile_split
-
-Similar to decode, split the input tensor into tiles. Key considerations:
-
-+ **Patchify handling**: If the model uses `patch_size`, scale tile parameters accordingly
-+ **Temporal chunking**: Video VAEs may have temporal compression (e.g., 4x)
-
-```python
-def encode_tile_split(self, x: torch.Tensor) -> tuple[list[TileTask], GridSpec]:
- _, _, num_frames, height, width = x.shape
- encode_spatial_compression_ratio = self.spatial_compression_ratio
-
- # Scale tile parameters for patchified coordinate system
- tile_sample_min_height = self.tile_sample_min_height
- tile_sample_min_width = self.tile_sample_min_width
- tile_sample_stride_height = self.tile_sample_stride_height
- tile_sample_stride_width = self.tile_sample_stride_width
-
- if self.config.patch_size is not None:
- # When input is patchified, scale tile parameters accordingly
- encode_spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size
- tile_sample_min_height = tile_sample_min_height // self.config.patch_size
- tile_sample_min_width = tile_sample_min_width // self.config.patch_size
- tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
- tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
-
- latent_height = height // encode_spatial_compression_ratio
- latent_width = width // encode_spatial_compression_ratio
-
- tile_latent_min_height = tile_sample_min_height // encode_spatial_compression_ratio
- tile_latent_min_width = tile_sample_min_width // encode_spatial_compression_ratio
- tile_latent_stride_height = tile_sample_stride_height // encode_spatial_compression_ratio
- tile_latent_stride_width = tile_sample_stride_width // encode_spatial_compression_ratio
-
- blend_height = tile_latent_min_height - tile_latent_stride_height
- blend_width = tile_latent_min_width - tile_latent_stride_width
-
- tiletask_list = []
- # Use temporal compression ratio from config instead of hardcoding
- temporal_compression = self.config.scale_factor_temporal
-
- for i in range(0, height, tile_sample_stride_height):
- for j in range(0, width, tile_sample_stride_width):
- time_list = []
- frame_range = 1 + (num_frames - 1) // temporal_compression
- for k in range(frame_range):
- if k == 0:
- tile = x[:, :, :1, i : i + tile_sample_min_height, j : j + tile_sample_min_width]
- else:
- tile = x[
- :, :,
- 1 + temporal_compression * (k - 1) : 1 + temporal_compression * k,
- i : i + tile_sample_min_height,
- j : j + tile_sample_min_width,
- ]
- time_list.append(tile)
- tiletask_list.append(
- TileTask(len(tiletask_list), (i // tile_sample_stride_height, j // tile_sample_stride_width),
- time_list, workload=time_list[0].shape[3] * time_list[0].shape[4])
- )
-
- grid_spec = GridSpec(
- split_dims=(3, 4),
- grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1),
- tile_spec={
- "latent_height": latent_height, "latent_width": latent_width,
- "blend_height": blend_height, "blend_width": blend_width,
- "tile_latent_stride_height": tile_latent_stride_height,
- "tile_latent_stride_width": tile_latent_stride_width,
- },
- output_dtype=self.dtype,
- )
- return tiletask_list, grid_spec
-```
-
-### Step 2: Implement encode_tile_exec
-
-```python
-def encode_tile_exec(self, task: TileTask) -> torch.Tensor:
- """Encode a single sample tile into latent space."""
- self.clear_cache()
- time = []
- for k, tile in enumerate(task.tensor):
- self._enc_conv_idx = [0]
- encoded = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
- encoded = self.quant_conv(encoded)
- time.append(encoded)
- result = torch.cat(time, dim=2)
- self.clear_cache()
- return result
-```
-
-### Step 3: Implement encode_tile_merge
-
-```python
-def encode_tile_merge(
- self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec
-) -> torch.Tensor:
- """Merge encoded tiles into a full latent tensor."""
- grid_h, grid_w = grid_spec.grid_shape
- result_rows = []
- for i in range(grid_h):
- result_row = []
- for j in range(grid_w):
- tile = coord_tensor_map[(i, j)]
- if i > 0:
- tile = self.blend_v(coord_tensor_map[(i - 1, j)], tile, grid_spec.tile_spec["blend_height"])
- if j > 0:
- tile = self.blend_h(coord_tensor_map[(i, j - 1)], tile, grid_spec.tile_spec["blend_width"])
- result_row.append(tile[:, :, :,
- : grid_spec.tile_spec["tile_latent_stride_height"],
- : grid_spec.tile_spec["tile_latent_stride_width"]])
- result_rows.append(torch.cat(result_row, dim=-1))
-
- enc = torch.cat(result_rows, dim=3)[
- :, :, :, : grid_spec.tile_spec["latent_height"], : grid_spec.tile_spec["latent_width"]
- ]
- return enc
-```
-
-### Step 4: Override tiled_encode method
-
-Override `tiled_encode` instead of `encode`. The parent's `_encode()` handles patchify before calling `tiled_encode()`, so input `x` is already patchified.
-
-```python
-def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
- """
- Encode using distributed VAE executor.
-
- Note: x is already patchified by parent's _encode() before calling this method.
- """
- if not self.is_distributed_enabled():
- return super().tiled_encode(x)
-
- self.clear_cache()
- result = self.distributed_executor.execute(
- x,
- DistributedOperator(
- split=self.encode_tile_split,
- exec=self.encode_tile_exec,
- merge=self.encode_tile_merge,
- ),
- broadcast_result=True, # Latents needed by all ranks for diffusion
- )
- self.clear_cache()
- return result
-```
-
-**Key differences from decode parallel:**
-
-| Aspect | Decode Parallel | Encode Parallel |
-|--------|-----------------|-----------------|
-| `broadcast_result` | Often `False` (only rank 0 needs output) | `True` (all ranks need latents for diffusion) |
-| Patchify | Applied in merge (unpatchify) | Handled by parent `_encode()` before `tiled_encode()` |
-| Temporal chunking | Frame-by-frame | Chunk-based (e.g., 1 + 4n frames) |
-
## Testing
Verify numerical consistency between:
+ vae_patch_parallel_size = 1
@@ -440,20 +272,18 @@ When vae_patch_parallel_size is larger than the DiT world size, it will automati
Complete examples in the codebase:
-| Model | Path | Decode Parallel | Encode Parallel |
-|-------|------|-----------------|-----------------|
-| **Z-Image** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py` | ✅ | ❌ |
-| **Wan2.2** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py` | ✅ | ✅ |
-| **Qwen-Image** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_qwenimage.py` | ✅ | ❌ |
+| Model | Path | Notes |
+|-------|------|-------|
+| **Z-Image** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py` | Distributed AutoencoderKL |
+| **Wan2.2** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py` | Distributed AutoencoderKLWan |
+| **Qwen-Image** | `vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_qwenimage.py` | Distributed AutoencoderKLQwenImage |
---
## Summary
-Adding VAE Patch Parallel support to diffusion model:
+Adding Vae Patch Parallel support to diffusion model:
-1. **Implement Distributed VAE** - Inherit from base VAE class and `DistributedVaeMixin`
-2. **Decode Parallel** - Refactor `tiled_decode` into `tile_split`/`tile_exec`/`tile_merge`
-3. **Encode Parallel** (optional) - Implement `encode_tile_split`/`encode_tile_exec`/`encode_tile_merge` for I2V models
-4. **Change VAE model in pipeline** - Use the distributed version
-5. **Test** - Verify numerical consistency with `vae_patch_parallel_size=1` vs `N`
+1. **Implement Distributed Vae** - mainly copy from `diffusers` tiled_decode, and refactor into split/exec/merge
+2. **Change vae model in pipeline to Distributed Vae**
+3. **Test** - Verify with `tensor_parallel_size=N` quality
diff --git a/docs/design/figures/omni/E2EL_s_vllm_omni_vs_transformers.png b/docs/design/figures/omni/E2EL_s_vllm_omni_vs_transformers.png
deleted file mode 100644
index 15112d5862a..00000000000
Binary files a/docs/design/figures/omni/E2EL_s_vllm_omni_vs_transformers.png and /dev/null differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_RTF_Baseline_vs_Batch.png b/docs/design/figures/omni/Mean_AUDIO_RTF_Baseline_vs_Batch.png
deleted file mode 100644
index 2f0615f77bb..00000000000
Binary files a/docs/design/figures/omni/Mean_AUDIO_RTF_Baseline_vs_Batch.png and /dev/null differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_CUDA_Graph_vs_Async_Chunk.png b/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_CUDA_Graph_vs_Async_Chunk.png
deleted file mode 100644
index 62d8bc79b6b..00000000000
Binary files a/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_CUDA_Graph_vs_Async_Chunk.png and /dev/null differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_vs_Batch_CUDA_Graph.png b/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_vs_Batch_CUDA_Graph.png
deleted file mode 100644
index 5838b45319e..00000000000
Binary files a/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_vs_Batch_CUDA_Graph.png and /dev/null differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Baseline_vs_Batch.png b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Baseline_vs_Batch.png
deleted file mode 100644
index 24be814b7e9..00000000000
Binary files a/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Baseline_vs_Batch.png and /dev/null differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_CUDA_Graph_vs_Async_Chunk.png b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_CUDA_Graph_vs_Async_Chunk.png
deleted file mode 100644
index c8df58ebcdf..00000000000
Binary files a/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_CUDA_Graph_vs_Async_Chunk.png and /dev/null differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_vs_Batch_CUDA_Graph.png b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_vs_Batch_CUDA_Graph.png
deleted file mode 100644
index 2d1a04e9c2c..00000000000
Binary files a/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_vs_Batch_CUDA_Graph.png and /dev/null differ
diff --git a/docs/design/figures/omni/Mean_E2EL_ms_Baseline_vs_Batch.png b/docs/design/figures/omni/Mean_E2EL_ms_Baseline_vs_Batch.png
deleted file mode 100644
index e598b543431..00000000000
Binary files a/docs/design/figures/omni/Mean_E2EL_ms_Baseline_vs_Batch.png and /dev/null differ
diff --git a/docs/design/figures/omni/Mean_E2EL_ms_Batch_CUDA_Graph_vs_Async_Chunk.png b/docs/design/figures/omni/Mean_E2EL_ms_Batch_CUDA_Graph_vs_Async_Chunk.png
deleted file mode 100644
index 54452013eb4..00000000000
Binary files a/docs/design/figures/omni/Mean_E2EL_ms_Batch_CUDA_Graph_vs_Async_Chunk.png and /dev/null differ
diff --git a/docs/design/figures/omni/Mean_E2EL_ms_Batch_vs_Batch_CUDA_Graph.png b/docs/design/figures/omni/Mean_E2EL_ms_Batch_vs_Batch_CUDA_Graph.png
deleted file mode 100644
index 04c5ad7396a..00000000000
Binary files a/docs/design/figures/omni/Mean_E2EL_ms_Batch_vs_Batch_CUDA_Graph.png and /dev/null differ
diff --git a/docs/design/figures/omni/RTF_vllm_omni_vs_transformers.png b/docs/design/figures/omni/RTF_vllm_omni_vs_transformers.png
deleted file mode 100644
index d93ba0b2af5..00000000000
Binary files a/docs/design/figures/omni/RTF_vllm_omni_vs_transformers.png and /dev/null differ
diff --git a/docs/design/figures/omni/Summary_E2EL_ms_vs_features.png b/docs/design/figures/omni/Summary_E2EL_ms_vs_features.png
deleted file mode 100644
index 04087b5910f..00000000000
Binary files a/docs/design/figures/omni/Summary_E2EL_ms_vs_features.png and /dev/null differ
diff --git a/docs/design/figures/omni/Summary_RTF_vs_features.png b/docs/design/figures/omni/Summary_RTF_vs_features.png
deleted file mode 100644
index c2c8ad40834..00000000000
Binary files a/docs/design/figures/omni/Summary_RTF_vs_features.png and /dev/null differ
diff --git a/docs/design/figures/omni/Summary_TTFP_ms_vs_features.png b/docs/design/figures/omni/Summary_TTFP_ms_vs_features.png
deleted file mode 100644
index 3dcc1c55379..00000000000
Binary files a/docs/design/figures/omni/Summary_TTFP_ms_vs_features.png and /dev/null differ
diff --git a/docs/design/figures/omni/TTFP_s_vllm_omni_vs_transformers.png b/docs/design/figures/omni/TTFP_s_vllm_omni_vs_transformers.png
deleted file mode 100644
index 9a5b6c9bdaf..00000000000
Binary files a/docs/design/figures/omni/TTFP_s_vllm_omni_vs_transformers.png and /dev/null differ
diff --git a/docs/design/figures/tts/Mean_AUDIO_RTF_vllm_omni_vs_transformers.png b/docs/design/figures/tts/Mean_AUDIO_RTF_vllm_omni_vs_transformers.png
deleted file mode 100644
index 68f0ef17e88..00000000000
Binary files a/docs/design/figures/tts/Mean_AUDIO_RTF_vllm_omni_vs_transformers.png and /dev/null differ
diff --git a/docs/design/figures/tts/Mean_AUDIO_TTFP_(ms)_vllm_omni_vs_transformers.png b/docs/design/figures/tts/Mean_AUDIO_TTFP_(ms)_vllm_omni_vs_transformers.png
deleted file mode 100644
index 44be96e96da..00000000000
Binary files a/docs/design/figures/tts/Mean_AUDIO_TTFP_(ms)_vllm_omni_vs_transformers.png and /dev/null differ
diff --git a/docs/design/figures/tts/Mean_E2EL_(ms)_vllm_omni_vs_transformers.png b/docs/design/figures/tts/Mean_E2EL_(ms)_vllm_omni_vs_transformers.png
deleted file mode 100644
index 2e5d1482bd7..00000000000
Binary files a/docs/design/figures/tts/Mean_E2EL_(ms)_vllm_omni_vs_transformers.png and /dev/null differ
diff --git a/docs/design/figures/tts/Mean_mean_e2e_ms_baseline_vs_batch.png b/docs/design/figures/tts/Mean_mean_e2e_ms_baseline_vs_batch.png
deleted file mode 100644
index 04d8f0bac53..00000000000
Binary files a/docs/design/figures/tts/Mean_mean_e2e_ms_baseline_vs_batch.png and /dev/null differ
diff --git a/docs/design/figures/tts/Mean_mean_e2e_ms_batch_vs_cuda_graph.png b/docs/design/figures/tts/Mean_mean_e2e_ms_batch_vs_cuda_graph.png
deleted file mode 100644
index eb85ec0dd4f..00000000000
Binary files a/docs/design/figures/tts/Mean_mean_e2e_ms_batch_vs_cuda_graph.png and /dev/null differ
diff --git a/docs/design/figures/tts/Mean_mean_e2e_ms_cuda_graph_vs_async_chunk.png b/docs/design/figures/tts/Mean_mean_e2e_ms_cuda_graph_vs_async_chunk.png
deleted file mode 100644
index 6f0e0e2529d..00000000000
Binary files a/docs/design/figures/tts/Mean_mean_e2e_ms_cuda_graph_vs_async_chunk.png and /dev/null differ
diff --git a/docs/design/figures/tts/Mean_mean_rtf_baseline_vs_batch.png b/docs/design/figures/tts/Mean_mean_rtf_baseline_vs_batch.png
deleted file mode 100644
index 89ea30a8643..00000000000
Binary files a/docs/design/figures/tts/Mean_mean_rtf_baseline_vs_batch.png and /dev/null differ
diff --git a/docs/design/figures/tts/Mean_mean_rtf_batch_vs_cuda_graph.png b/docs/design/figures/tts/Mean_mean_rtf_batch_vs_cuda_graph.png
deleted file mode 100644
index 2b207b88987..00000000000
Binary files a/docs/design/figures/tts/Mean_mean_rtf_batch_vs_cuda_graph.png and /dev/null differ
diff --git a/docs/design/figures/tts/Mean_mean_rtf_cuda_graph_vs_async_chunk.png b/docs/design/figures/tts/Mean_mean_rtf_cuda_graph_vs_async_chunk.png
deleted file mode 100644
index f5f7ad72c8f..00000000000
Binary files a/docs/design/figures/tts/Mean_mean_rtf_cuda_graph_vs_async_chunk.png and /dev/null differ
diff --git a/docs/design/figures/tts/Mean_mean_ttfp_ms_baseline_vs_batch.png b/docs/design/figures/tts/Mean_mean_ttfp_ms_baseline_vs_batch.png
deleted file mode 100644
index 6f8c1da4a5b..00000000000
Binary files a/docs/design/figures/tts/Mean_mean_ttfp_ms_baseline_vs_batch.png and /dev/null differ
diff --git a/docs/design/figures/tts/Mean_mean_ttfp_ms_batch_vs_cuda_graph.png b/docs/design/figures/tts/Mean_mean_ttfp_ms_batch_vs_cuda_graph.png
deleted file mode 100644
index b0fe1d02a9d..00000000000
Binary files a/docs/design/figures/tts/Mean_mean_ttfp_ms_batch_vs_cuda_graph.png and /dev/null differ
diff --git a/docs/design/figures/tts/Mean_mean_ttfp_ms_cuda_graph_vs_async_chunk.png b/docs/design/figures/tts/Mean_mean_ttfp_ms_cuda_graph_vs_async_chunk.png
deleted file mode 100644
index 008ba9bf78f..00000000000
Binary files a/docs/design/figures/tts/Mean_mean_ttfp_ms_cuda_graph_vs_async_chunk.png and /dev/null differ
diff --git a/docs/design/figures/tts/Summary_mean_e2e_ms_vs_features.png b/docs/design/figures/tts/Summary_mean_e2e_ms_vs_features.png
deleted file mode 100644
index 7c65aa11770..00000000000
Binary files a/docs/design/figures/tts/Summary_mean_e2e_ms_vs_features.png and /dev/null differ
diff --git a/docs/design/figures/tts/Summary_mean_rtf_vs_features.png b/docs/design/figures/tts/Summary_mean_rtf_vs_features.png
deleted file mode 100644
index 71bb2c54680..00000000000
Binary files a/docs/design/figures/tts/Summary_mean_rtf_vs_features.png and /dev/null differ
diff --git a/docs/design/figures/tts/Summary_mean_ttfp_ms_vs_features.png b/docs/design/figures/tts/Summary_mean_ttfp_ms_vs_features.png
deleted file mode 100644
index cef2546d6fe..00000000000
Binary files a/docs/design/figures/tts/Summary_mean_ttfp_ms_vs_features.png and /dev/null differ
diff --git a/docs/design/qwen3_omni_tts_performance_optimization.md b/docs/design/qwen3_omni_tts_performance_optimization.md
deleted file mode 100644
index 2f18a1b1bc0..00000000000
--- a/docs/design/qwen3_omni_tts_performance_optimization.md
+++ /dev/null
@@ -1,539 +0,0 @@
-# Speech Generation on vLLM-Omni: Performance Optimizations for Qwen3-Omni and Qwen3-TTS
-
-## Summary
-
-vLLM-Omni supports end-to-end serving for speech-generating models, including both **Qwen3-Omni** (multimodal understanding + speech) and **Qwen3-TTS** (text-to-speech). Despite their different architectures, both models share the same multi-stage pipeline design and benefit from the same set of stacked optimizations:
-
-1. **Batching** improves GPU utilization stage by stage and increases overall throughput.
-2. **CUDA Graph** reduces CPU launch overhead and decode-time jitter on stable shapes.
-3. **Async Chunk and Streaming Output** overlap compute and communication across stages and emit audio incrementally, improving both TTFP and E2E.
-
-### Model architectures
-
-**Qwen3-Omni** is a native multimodal model that understands text, audio, image, and video inputs, and generates both text and speech outputs. Its pipeline has three stages:
-
-- **Thinker**: multimodal understanding and text generation
-- **Talker (+ Talker-MTP / code predictor path)**: converts semantic/text representations into codec tokens
-- **Code2Wav**: decodes codec tokens into waveform audio
-
-**Qwen3-TTS** is a lightweight, high-quality text-to-speech model. Its pipeline has two stages:
-
-- **Talker (AR decoder)**: auto-regressively generates codec tokens from text input
-- **Code2Wav (vocoder)**: decodes codec tokens into waveform audio
-
-The optimizations described in this post apply to both models. We present results for each side by side.
-
-### vLLM-Omni vs HF Transformers
-
-Compared with **HF Transformers** (offline, single request), vLLM-Omni with the full optimization stack delivers dramatically lower latency and higher efficiency for both models.
-
-**Qwen3-Omni** (A100):
-
-
-
-| Metric | vLLM-Omni | HF Transformers | Improvement |
-| --- | --- | --- | --- |
-| E2E latency (s) | 23.78 | 336.10 | ~93% reduction |
-| TTFP (s) | 0.934 | 336.10 | ~99.7% reduction |
-| RTF | 0.32 | 3.776 | ~91% reduction (~12× faster) |
-
-- **E2E latency**: 23.78 s vs 336.10 s - **~93%** reduction
-- **TTFP**: 0.934 s vs 336.10 s - **~99.7%** reduction
-- **RTF**: 0.32 vs 3.776 - **~91%** reduction (~12x faster)
-
-**Qwen3-TTS** (H200, concurrency 1):
-
-
-
-| Metric | vLLM-Omni | HF Transformers | Improvement |
-| --- | --- | --- | --- |
-| E2E latency (ms) | 941 | 15,513 | ~94% reduction |
-| TTFP (ms) | 64 | 15,513 | ~99.6% reduction (242× faster) |
-| RTF | 0.16 | 2.64 | ~94% reduction (~16.5× faster) |
-
-- **E2E latency**: 941 ms vs 15,513 ms - **~94%** reduction
-- **TTFP**: 64 ms vs 15,513 ms - **~99.6%** reduction (242x faster)
-- **RTF**: 0.16 vs 2.64 - **~94%** reduction (~16.5x faster)
-
-### Stacked optimization summary
-
-Each optimization stacks on the previous one. The summary plots below show the cumulative effect at each step, with one line per concurrency level (1, 4, 10).
-
-**Qwen3-Omni** (A100):
-
-
-
-- **E2EL reduction**: ~74% at concurrency 10 (410,054 ms -> 104,901 ms); ~90% at concurrency 1 (426,529 ms -> 41,216 ms)
-- **TTFP reduction**: ~96% at concurrency 10 (409,705 ms -> 16,482 ms); ~99.7% at concurrency 1 (426,078 ms -> 1,164 ms)
-- **RTF reduction**: ~74% at concurrency 10 (2.83 -> 0.74); ~90% at concurrency 1 (2.08 -> 0.21)
-
-**Qwen3-TTS** (H200):
-
-
-
-- **E2EL reduction**: ~85% at concurrency 10 (12,141 ms -> 1,767 ms); ~29% at concurrency 1 (1,323 ms -> 941 ms)
-- **TTFP reduction**: ~96.5% at concurrency 10 (12,141 ms -> 425 ms); ~95% at concurrency 1 (1,323 ms -> 64 ms)
-- **RTF reduction**: ~86% at concurrency 10 (2.19 -> 0.31); ~30% at concurrency 1 (0.23 -> 0.16)
-
-**Benchmark environment:**
-
-| | Qwen3-Omni | Qwen3-TTS |
-| --- |-----------------------------| --- |
-| **GPU** | A100 | H200 |
-| **Model** | Qwen3-Omni-30B-A3B-Instruct | Qwen3-TTS-12Hz-1.7B-CustomVoice |
-| **vLLM** | v0.17.0 | v0.18.0 |
-| **vllm-omni** | commit 199f7832 | v0.18.0rc2 |
-| **CUDA** | 12.9 | 12.8 |
-
-This post walks through each optimization in the same order they are typically enabled in practice, then ends with deployment playbooks for both models.
-
----
-
-## Pipeline Batching
-
-### How stage-wise batching works
-
-For both Qwen3-Omni and Qwen3-TTS, batching is a pipeline-level optimization:
-
-- Requests are grouped per stage using `runtime.max_batch_size`
-- Each stage executes batch inference with its own scheduler/worker
-- Stage outputs are routed to downstream stages with per-request mapping preserved
-
-**Batching strategy by stage:** The understanding and decode stages (Thinker for Omni, Talker for both) use **continuous batching**: requests can join and leave the batch over time. Code2Wav uses **static batching**: once a batch is formed, the stage runs the whole batch before starting the next. This matches the decode pattern of Code2Wav and keeps implementation simple while still improving throughput.
-
-### Batching results (Baseline vs. Batch)
-
-Batching alone greatly reduces E2EL and RTF across all concurrencies. The biggest gains appear at high concurrency where requests share GPU resources.
-
-**Qwen3-Omni** (A100):
-
-
-
-| Metric | Concurrency | Baseline | + Batch | Improvement |
-| --- | --- | --- | --- | --- |
-| E2EL (ms) | 1 | 426,529 | 307,719 | 1.4× |
-| E2EL (ms) | 4 | 407,213 | 376,934 | 1.1× |
-| E2EL (ms) | 10 | 410,054 | 234,844 | 1.7× |
-| TTFP (ms) | 1 | 426,078 | 307,262 | 1.4× |
-| TTFP (ms) | 4 | 406,843 | 376,466 | 1.1× |
-| TTFP (ms) | 10 | 409,705 | 234,557 | 1.7× |
-| RTF | 1 | 2.08 | 1.51 | 1.4× |
-| RTF | 4 | 2.55 | 1.83 | 1.4× |
-| RTF | 10 | 2.83 | 2.28 | 1.2× |
-
-At concurrency 10, E2EL drops from ~410 s to ~235 s; at concurrency 1, from ~427 s to ~308 s.
-
-**Qwen3-TTS** (H200):
-
-
-
-| Metric | Concurrency | Baseline | + Batch | Improvement |
-| --- | --- | --- | --- | --- |
-| E2EL (ms) | 1 | 1,323 | 1,339 | 1.0× |
-| E2EL (ms) | 4 | 5,171 | 1,471 | 3.5× |
-| E2EL (ms) | 10 | 12,141 | 1,705 | 7.1× |
-| RTF | 1 | 0.230 | 0.234 | 1.0× |
-| RTF | 4 | 0.908 | 0.255 | 3.6× |
-| RTF | 10 | 2.186 | 0.292 | 7.5× |
-| Throughput (audio-s/wall-s) | 10 | 3.99 | 33.53 | 8.4× |
-
-At concurrency 10, batching alone brings Qwen3-TTS RTF from 2.19 (slower than realtime) down to 0.29 (faster than realtime), and throughput from 4.0 to 33.5 audio-sec/wall-sec.
-
----
-
-## CUDA Graph on the Critical Decode Path
-
-### Why CUDA Graph helps here
-
-In decode-heavy serving, repeatedly launching many small kernels from CPU can become a visible overhead. CUDA Graph reduces this overhead by capturing and replaying stable execution graphs.
-
-In stage configs, this is represented by `enforce_eager: false` for stages where graph capture is desired (Thinker/Talker), while Code2Wav keeps eager mode depending on stage behavior.
-
-### CUDA Graph results on top of batching
-
-**Qwen3-Omni** (A100):
-
-
-
-| Metric | Concurrency | Batch | + CUDA Graph | Improvement |
-| --- | --- | --- | --- | --- |
-| E2EL (ms) | 1 | 307,719 | 61,613 | 5.0× |
-| E2EL (ms) | 4 | 376,934 | 79,019 | 4.8× |
-| E2EL (ms) | 10 | 234,844 | 126,867 | 1.9× |
-| TTFP (ms) | 1 | 307,262 | 61,257 | 5.0× |
-| TTFP (ms) | 4 | 376,466 | 78,634 | 4.8× |
-| TTFP (ms) | 10 | 234,557 | 126,534 | 1.9× |
-| RTF | 1 | 1.51 | 0.32 | 4.7× |
-| RTF | 4 | 1.83 | 0.43 | 4.3× |
-| RTF | 10 | 2.28 | 0.90 | 2.5× |
-
-For the larger Qwen3-Omni model (30B-A3B), CUDA Graph provides a significant improvement. At concurrency 1, E2EL drops from ~308 s to ~62 s; at concurrency 10, from ~235 s to ~127 s.
-
-**Qwen3-TTS** (H200):
-
-
-
-| Metric | Concurrency | Batch | + CUDA Graph | Improvement |
-| --- | --- | --- | --- | --- |
-| E2EL (ms) | 1 | 1,339 | 733 | 1.8× |
-| E2EL (ms) | 4 | 1,471 | 987 | 1.5× |
-| E2EL (ms) | 10 | 1,705 | 1,197 | 1.4× |
-| RTF | 1 | 0.234 | 0.124 | 1.9× |
-| RTF | 10 | 0.292 | 0.203 | 1.4× |
-| Throughput (audio-s/wall-s) | 10 | 33.53 | 47.15 | 1.4× |
-
-At concurrency 1, CUDA Graph reduces E2EL from 1,339 ms to 733 ms and RTF from 0.234 to 0.124 - nearly a 2x improvement. The benefit is consistent across all concurrency levels.
-
----
-
-## Async Chunk and Streaming Output: Earlier Audio and Cross-Stage Overlap
-
-### Why this step matters for first-packet latency
-
-Two mechanisms work together to improve user-visible latency:
-
-- **Streaming output**: audio streaming emits audio chunks as soon as they are decoded (lower **TTFP**). Without streaming, the client waits for larger buffers or end-of-sequence.
-- **Async chunk** is the main enabler for *earlier* audio: instead of handing off whole-request results between stages, each stage forwards **chunks** so the next stage can start as soon as the first chunk is ready. For Omni: Thinker -> Talker forwards hidden-state chunks; for both: Talker -> Code2Wav forwards codec chunks; Code2Wav decodes and emits packets incrementally. This **overlaps compute and communication** across stages and directly reduces time-to-first-audio-packet (TTFP) and end-to-end latency (E2EL).
-
-So in practice: streaming output defines *how* bytes are sent to the client; async chunk defines *when* the pipeline can produce the first bytes.
-
-**Dependency between the two:** Async chunk and audio streaming output are mutually dependent. Without async chunk, **audio streaming output cannot truly take effect**. Without audio streaming output, async chunk's **TTFP advantage is not fully realized**: the client would still wait for larger buffers or end-of-sequence instead of hearing the first packet as soon as it is ready. We therefore recommend enabling **both** on top of batching + CUDA Graph; the benchmarks in this post use both.
-
-### Results: Batch + CUDA Graph vs. Batch + CUDA Graph + Async Chunk + Streaming Output
-
-**Qwen3-Omni** (A100):
-
-
-
-| Metric | Concurrency | Batch + CG | + Async Chunk | Improvement |
-| --- | --- | --- | --- | --- |
-| E2EL (ms) | 1 | 61,613 | 41,216 | 1.5× |
-| E2EL (ms) | 4 | 79,019 | 67,584 | 1.2× |
-| E2EL (ms) | 10 | 126,867 | 104,901 | 1.2× |
-| TTFP (ms) | 1 | 61,257 | 1,164 | 53× |
-| TTFP (ms) | 4 | 78,634 | 3,152 | 24.9× |
-| TTFP (ms) | 10 | 126,534 | 16,482 | 7.7× |
-| RTF | 1 | 0.32 | 0.21 | 1.5× |
-| RTF | 4 | 0.43 | 0.34 | 1.3× |
-| RTF | 10 | 0.90 | 0.74 | 1.2× |
-
-Enabling both brings TTFP down sharply (concurrency 1: 61,257 ms -> 1,164 ms, **~98% reduction**; concurrency 4: 78,634 ms -> 3,152 ms, **~96% reduction**). E2EL and RTF also improve at every concurrency.
-
-**Qwen3-TTS** (H200):
-
-
-
-| Metric | Concurrency | Batch + CG | + Async Chunk | Improvement |
-| --- | --- | --- | --- | --- |
-| TTFP (ms) | 1 | 733 | **64** | **11.5×** |
-| TTFP (ms) | 4 | 987 | **119** | **8.3×** |
-| TTFP (ms) | 10 | 1,197 | **425** | **2.8×** |
-| E2EL (ms) | 1 | 733 | 941 | 0.8× |
-| E2EL (ms) | 10 | 1,197 | 1,767 | 0.7× |
-| RTF | 1 | 0.124 | 0.160 | 0.8× |
-| RTF | 10 | 0.203 | 0.314 | 0.6× |
-
-The TTFP improvement is the headline result for both models. For Qwen3-TTS at concurrency 1, users hear the first audio in **64 ms** instead of 733 ms - an **11.5x reduction**. For Qwen3-Omni at concurrency 1, TTFP drops from 61 s to 1.2 s - a **53x reduction**.
-
-### Why E2EL and RTF are higher with async chunk (TTS)
-
-The table above shows that enabling async chunk + streaming *increases* E2EL and RTF for TTS compared to CUDA Graph alone. This is expected - the two configurations optimize for fundamentally different metrics:
-
-- **CUDA Graph (no async chunk)** generates the entire audio end-to-end before returning. No chunking overhead, so total compute is minimized.
-- **Async Chunk + Streaming** splits the pipeline into incremental chunks, adding overhead from chunked transport, context overlap in Code2Wav (`codec_left_context_frames=25`), and smaller effective batch sizes per chunk.
-
-**The tradeoff is intentional.** Async chunk trades ~30% higher total compute for **11x faster time-to-first-audio**. For interactive applications (voice assistants, chatbots), TTFP determines perceived responsiveness. For offline batch processing, CUDA Graph without async chunk is the better choice.
-
----
-
-## TTS-Specific: Code Predictor Re-prefill + `torch.compile`
-
-Qwen3-TTS has a **code predictor** - a small 5-layer transformer that generates residual codebook tokens (groups 1 through Q-1) autoregressively. Each AR step operates on very short sequences (2 to ~16 tokens).
-
-The naive approach uses a KV cache for this small transformer, similar to the main Talker. But the KV cache machinery (block tables, slot mappings, paged attention) introduces significant overhead relative to the tiny model. Two optimizations replace that:
-
-### Re-prefill (stateless forward, no KV cache)
-
-Instead of maintaining a KV cache across steps, the code predictor **re-feeds the full growing sequence** at each AR step using `F.scaled_dot_product_attention`. With sequences of at most ~16 tokens through 5 layers, the O(T^2) attention cost is negligible - and removing the KV cache machinery (block table management, `set_forward_context`, slot mapping) saves far more time than it costs.
-
-### `torch.compile` on the code predictor forward
-
-The 5-layer transformer forward pass launches ~60 small CUDA kernels per step. `torch.compile(mode="default", dynamic=True)` fuses these into fewer kernels via Inductor:
-
-```python
-self._compiled_model_fwd = torch.compile(
- self.model.forward,
- mode="default", # no Inductor CUDA graphs, avoids conflict with vLLM's CUDAGraphWrapper
- dynamic=True, # sequence length grows each step (2, 3, ..., num_groups+1)
-)
-```
-
-`mode="default"` is used instead of `mode="reduce-overhead"` to avoid conflicts with vLLM's own CUDA graph capture on the main Talker model. `dynamic=True` handles the growing sequence length without recompilation.
-
-These optimizations are always-on in the current codebase - all Qwen3-TTS benchmark results in this post include them.
-
----
-
-## TTS-Specific: Dynamic Initial Chunk for Faster First Audio
-
-In the async chunk pipeline, the standard `codec_chunk_frames` is 25 (each chunk = ~2 seconds of audio at 12 Hz). Waiting for 25 frames before forwarding the first chunk to Code2Wav adds unnecessary TTFP. The **initial codec chunk** optimization sends a smaller first chunk so Code2Wav can start decoding earlier.
-
-**Dynamic initial chunk sizing (default behavior):**
-
-Rather than using a fixed initial chunk size, vLLM-Omni dynamically selects it based on current server load. The initial chunk size is chosen from power-of-2 steps [2, 4, 8, 16] based on load factor (`active_requests / max_batch_size`):
-
-| Server load | Initial chunk frames | Rationale |
-| --- | --- | --- |
-| Low (e.g. 1/10 active) | **2** (~167 ms of audio) | Minimize TTFP when there's headroom |
-| Medium (e.g. 5/10 active) | **4-8** | Balance TTFP vs decode efficiency |
-| High (e.g. 10/10 active) | **16** | Larger first chunk to amortize decode cost |
-
-After the initial chunk, all subsequent chunks use the standard `codec_chunk_frames` (25) size.
-
-**How it works in the pipeline:**
-
-1. Talker generates codec tokens auto-regressively
-2. The stage input processor checks current load and picks an initial chunk size (e.g. **2 frames** at low load)
-3. After that many frames, the first chunk is forwarded to Code2Wav
-4. Code2Wav decodes this small chunk and emits the first audio packet
-5. Subsequent chunks use the standard 25-frame size for efficient batch decoding
-
-**Per-request override:** Clients can also set a fixed initial chunk size via the API:
-
-```json
-{"initial_codec_chunk_frames": 2}
-```
-
-This overrides the dynamic calculation for that request.
-
-**Config (server-side):**
-
-```yaml
-runtime:
- connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- codec_streaming: true
- codec_chunk_frames: 25 # standard chunk size (~2s of audio)
- codec_left_context_frames: 25
- # initial chunk is computed dynamically by default
- # set initial_codec_chunk_frames: 2 to force a fixed value
-```
-
-The 64 ms TTFP result reported above for Qwen3-TTS at concurrency 1 uses the dynamic initial chunk, which picks `initial_codec_chunk_frames=2` at low load. At higher concurrency the dynamic sizing increases the initial chunk to maintain decode efficiency.
-
----
-
-## Live Demo: Streaming TTS over WebSocket
-
-vLLM-Omni supports real-time streaming audio output for Qwen3-TTS over WebSocket ([PR #1719](https://github.com/vllm-project/vllm-omni/pull/1719)). With `stream_audio: true`, the server sends chunked PCM audio frames as they are generated, so clients can start playback before full sentence synthesis completes.
-
-The WebSocket protocol uses `audio.start` / binary PCM chunks / `audio.done` framing per sentence:
-
-```json
-// Client sends:
-{"type":"session.config","voice":"Vivian","response_format":"pcm","stream_audio":true}
-{"type":"input.text","text":"Hello world. This is a streaming demo."}
-{"type":"input.done"}
-
-// Server streams back per sentence:
-{"type":"audio.start","sentence_index":0,"sentence_text":"Hello world.","format":"pcm","sample_rate":24000}
-
-
-...
-{"type":"audio.done","sentence_index":0,"total_bytes":96000,"error":false}
-{"type":"audio.start","sentence_index":1,"sentence_text":"This is a streaming demo.","format":"pcm","sample_rate":24000}
-
-...
-{"type":"audio.done","sentence_index":1,"total_bytes":72000,"error":false}
-{"type":"session.done","total_sentences":2}
-```
-
-VIDEO
-
----
-
-## Deployment Playbook
-
-### Qwen3-Omni
-
-#### 1) Serve with the default 3-stage config
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct \
- --omni \
- --port 8091
-```
-
-Notes:
-
-- `runtime.max_batch_size` controls stage-level batching.
-- Thinker/Talker commonly use `enforce_eager: false` for CUDA Graph paths.
-- Code2Wav often remains eager (`enforce_eager: true`) depending on runtime behavior.
-
-#### 2) Enable async chunk
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct \
- --omni \
- --port 8091 \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml
-```
-
-#### 3) Key config knobs
-
-```yaml
-async_chunk: true
-stage_args:
- - stage_id: 0 # thinker
- runtime:
- max_batch_size: 64
- engine_args:
- enforce_eager: false
- max_num_batched_tokens: 32768
- custom_process_next_stage_input_func: >-
- vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk
-
- - stage_id: 1 # talker
- runtime:
- max_batch_size: 64
- engine_args:
- enforce_eager: false
- max_num_batched_tokens: 32768
- custom_process_next_stage_input_func: >-
- vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk
-
- - stage_id: 2 # code2wav
- runtime:
- max_batch_size: 64
- engine_args:
- enforce_eager: true
- max_num_batched_tokens: 51200
-```
-
-#### Reproduce Qwen3-Omni benchmarks
-
-```bash
-vllm bench serve \
- --dataset-name random \
- --port ${PORT} \
- --model ${MODEL_PATH} \
- --endpoint /v1/chat/completions \
- --backend openai-chat-omni \
- --max-concurrency ${MAX_CONCURRENCY} \
- --num-prompts ${NUM_PROMPTS} \
- --random-input-len 2500 \
- --ignore-eos \
- --percentile-metrics ttft,tpot,itl,e2el,audio_ttfp,audio_rtf \
- --random-output-len 900 \
- --extra_body '{"modalities": ["text","audio"]}'
-```
-
-### Qwen3-TTS
-
-#### 1) Serve with async chunk (recommended)
-
-```bash
-vllm-omni serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --omni \
- --port 8000
-```
-
-The default config (`qwen3_tts.yaml`) enables the full optimization stack:
-
-- Batching with `max_batch_size: 10` on the Talker stage
-- CUDA Graph on the Talker (`enforce_eager: false`)
-- Async chunk with streaming transport
-
-#### 2) Serve without async chunk (for comparison)
-
-```bash
-vllm-omni serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --omni \
- --port 8000 \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml
-```
-
-#### 3) Key config knobs
-
-```yaml
-async_chunk: true
-stage_args:
- - stage_id: 0 # Talker (AR decoder)
- runtime:
- max_batch_size: 10
- engine_args:
- enforce_eager: false
- max_num_batched_tokens: 512
- custom_process_next_stage_input_func: >-
- vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
-
- - stage_id: 1 # Code2Wav (vocoder)
- runtime:
- max_batch_size: 1
- engine_args:
- enforce_eager: true
- max_num_batched_tokens: 8192
-
-runtime:
- connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- codec_streaming: true
- codec_chunk_frames: 25
- codec_left_context_frames: 25
-```
-
-#### Reproduce Qwen3-TTS benchmarks
-
-```bash
-GPU_DEVICE=0 \
-MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
-NUM_PROMPTS=50 \
-CONCURRENCY="1 4 10" \
-bash benchmarks/qwen3-tts/vllm_omni/run_stacked_benchmark.sh
-```
-
-This cycles through four configs (Baseline -> + Batch -> + CUDA Graph -> + Async Chunk + Streaming), benchmarks each at the specified concurrency levels, and generates all comparison figures automatically.
diff --git a/docs/features/sleep_mode.md b/docs/features/sleep_mode.md
index b4eec162d31..41aa48c1735 100644
--- a/docs/features/sleep_mode.md
+++ b/docs/features/sleep_mode.md
@@ -1,243 +1,39 @@
-# Sleep Mode & ACK Protocol
+# Sleep Mode
-vLLM-Omni’s **Sleep Mode** allows you to temporarily release most GPU memory used by a model—such as model weights and key-value (KV) caches—**without stopping the server or unloading the Docker container**.
+vLLM-Omni’s **Sleep Mode** allows you to temporarily release most GPU memory used by a model—such as model weights and key-value (KV) caches (for autoregressive models)—**without stopping the server or unloading the Docker container**.
-This feature is inherited from [vLLM’s Sleep Mode](https://blog.vllm.ai/2025/10/26/sleep-mode.html) and extended with the **Omni ACK Protocol** to support multi-stage pipelines and heterogeneous hardware backends (NVIDIA, AMD, Intel, Huawei). It is especially useful in **RLHF**, **dynamic model switching**, or **cost-saving scenarios**.
+This feature is inherited from [vLLM’s Sleep Mode](https://blog.vllm.ai/2025/10/26/sleep-mode.html), which provides zero-reload model switching for multi-model serving.
+
+It is especially useful in **RLHF**, **training**, or **cost-saving scenarios**, where GPU resources must be freed between inference workloads.
---
-## 1. Feature Documentation
+## Omni Model
-### Overview
-Omni Sleep Mode provides a mechanism to "sleep" specific model stages. When a stage enters sleep, its physical VRAM is reclaimed by the system, while the process state is preserved for rapid "wake-up" without full re-initialization.
+Omni model inherit the feature from vLLM' Sleep Mode
-### Sleep Levels
-We support two levels of hibernation to balance recovery speed and memory efficiency:
+This means:
-| Level | Name | Mechanism | Recovery Speed | Memory Freed |
-| :--- | :--- | :--- | :--- | :--- |
-| **Level 1** | **Weight Offloading** | Offloads weights to Host CPU RAM. | **Fast** (DMA) | Substantial |
-| **Level 2** | **Full De-mapping** | Physically releases memory pages via VRAM scavenging. | **Moderate** | **Maximum** (up to 95%+) |
+- Support both Level 1 and Level 2 sleep, allow to release and reset both model weights and KV Cache
-### Supported Platforms
+## Diffusion Model Extension
-Omni Sleep Mode is optimized for high-performance computing backends:
+We added Sleep Mode support for **diffusion models**, which previously lacked this functionality.
+In diffusion pipelines, this currently only offloads **model weight memory**, as these models typically do not use KV caches.
-* **NVIDIA**: Supported via Virtual Memory Management (VMM).
-* **AMD (ROCm)**: Fully supported with physical page de-mapping.
-* **Intel XPU**: Supported via Level Zero memory management.
-* **Huawei NPU**: Supported via Ascend memory scavenging.
+This means:
-### Hardware Requirements
-* **Memory Considerations**: System RAM must be sufficient to hold offloaded weights during sleep.
-* **TP Support**: Tensor Parallel groups synchronize sleep/wake transitions across all workers.
+- Diffusion models can now enter Level 1 sleep.
+- Pipeline states (e.g., noise schedulers, buffers) remain intact after waking.
+- Useful for releasing VRAM between image generation or training cycles.
---
+## Enable sleep mode
+To enable sleep mode, set the `enable_sleep_mode` in `engine_args` to `True`
-## 2. Usage Examples
-
-### Python API Example
-You can programmatically control the lifecycle of stages using the `AsyncOmni` engine.
+Example:
```python
-
-import asyncio
-from vllm_omni.entrypoints.async_omni import AsyncOmni
-
-async def run_sleep_demo():
- # 1. initialization
- engine = AsyncOmni(
- model="ByteDance-Seed/BAGEL-7B-MoT",
- enable_sleep_mode=True
- )
-
- # 2. sleep mode level2
- acks = await engine.sleep(stage_ids=[0], level=2)
- print(f"Freed {acks[0].freed_bytes / 1024**3:.2f} GiB on Stage 0")
-
- # 3. wake up
- await engine.wake_up(stage_ids=[0])
-
-if __name__ == "__main__":
- asyncio.run(run_sleep_demo())
-
-```
-
-### server command Example
-Start the server with sleep mode enabled:
-
-The first method
-
-```
-
-vllm serve ByteDance-Seed/BAGEL-7B-MoT \
---omni \
---enable-sleep-mode \
---trust-remote-code \
---gpu-memory-utilization 0.7
-
-```
-
-The second method
-
-```
-
-python3 -m vllm_omni.entrypoints.openai.api_server \
- --model ByteDance-Seed/BAGEL-7B-MoT \
- --omni \
- --enable-sleep-mode \
- --trust-remote-code \
---gpu-memory-utilization 0.7
-
-```
-
-
-
-
-### Test Scenarios & Commands
-
-#### Scenario 1: LLM Engine Sleep
-
-Objective: Verify VRAM reclamation for Stage 0 (Thinker).
-
-Trigger sleep (Level 1 or Level 2) via client:
-
+omni = Omni(model=...,enable_sleep_mode=True)
```
-
-curl -X POST http://localhost:8000/v1/omni/sleep \
- -H "Content-Type: application/json" \
- -d '{"stage_ids": [0], "level": 2}'
-
-```
-
-Tip: Open a new terminal and run rocm-smi or nvidia-smi or to observe the immediate drop in VRAM usage.
-
-
-
-#### Scenario 2: Diffusion Sleep
-Objective: Verify VRAM reclamation for Stage 1 (Diffusion).
-
-Trigger sleep (Level 1 or Level 2) via client:
-
-```
-
-curl -X POST http://localhost:8000/v1/omni/sleep \
- -H "Content-Type: application/json" \
- -d '{"stage_ids": [1], "level": 2}'
-
-```
-
-
-
-#### Scenario 3: Multi-Stage Coordinated Stress Test
-Objective: Test concurrent sleep and rapid wake-up across multiple stages.
-
-Concurrent Sleep (Stage 0 & 1):
-
-```
-
-curl -X POST http://localhost:8000/v1/omni/sleep \
- -H "Content-Type: application/json" \
- -d '{"stage_ids": [0, 1], "level": 2}'
-
-```
-
-
-Rapid Wake-up:
-
-```
-
-curl -X POST http://localhost:8000/v1/omni/wakeup \
- -H "Content-Type: application/json" \
- -d '{"stage_ids": [0, 1]}'
-
-```
-
-
-#### Scenario 4: Full Lifecycle Memory Audit & Functional Integrity
-Objective: Audit the complete flow from Sleep to Wake-up followed by an Inference validation.
-
-Check Initial State: Observe baseline VRAM usage.
-
-Trigger Deep Sleep (Level 2):
-
-```
-
-curl -X POST http://localhost:8000/v1/omni/sleep \
- -H "Content-Type: application/json" \
- -d '{"stage_ids": [0], "level": 2}'
-
-```
-
-Wake-up Model:
-
-```
-
-curl -X POST http://localhost:8000/v1/omni/wakeup \
- -H "Content-Type: application/json" \
- -d '{"stage_ids": [0]}'
-
-```
-
-Verify Functional Integrity (Inference):
-Ensure the model still generates valid output after reloading weights.
-
-```
-
-curl -X POST http://localhost:8000/v1/images/generations \
- -H "Content-Type: application/json" \
- -d '{
- "prompt": "A huge swimming pool, with many people swimming.",
- "model": "ByteDance-Seed/BAGEL-7B-MoT",
- "response_format": "b64_json",
- "extra_body": {"sampling_params": {"num_inference_steps": 4, "seed": 42}}
- }' > post.json
-
-```
-
-
-
-
-## 3. API Reference
-
-
-### Methods
-
-| Method | Arguments | Return Type | Description |
-| :--- | :--- | :--- | :--- |
-| **sleep** | `stage_ids: List[int], level: int` | `List[OmniACK]` | Triggers hibernation for specified stages. |
-| **wake_up** | `stage_ids: List[int]` | `List[OmniACK]` | Reloads weights and re-maps memory. |
-
-
-
-### OmniACK Dataclass Fields
-
-| Field | Type | Description |
-| :--- | :--- | :--- |
-| **task_id** | `str` | Unique identifier for the operation. |
-| **status** | `str` | `SUCCESS` or `ERROR`. |
-| **stage_id** | `int` | The ID of the stage that responded. |
-| **rank** | `int` | The rank ID within the Tensor Parallel group. |
-| **freed_bytes** | `int` | Actual amount of physical VRAM reclaimed. |
-| **metadata** | `dict` | Additional platform-specific metrics. |
-
-Metadata Field Analysis
-The metadata field is a dynamic dictionary containing hardware-specific telemetry and audit data, primarily used for verifying memory reclamation on various backends (e.g., AMD ROCm, NVIDIA CUDA).
-
-```
-"metadata": {
- "source": "Platform_AMD_Instinct_MI300X",
- "total_freed_gib": "78.57",
- "rank_residual_gib": "2.07"
-}
-```
-
-#### Core Utility:
-**VRAM Reclamation Audit (total_freed_gib)**: Converts raw freed_bytes into human-readable GiB. It serves as the primary metric to verify that Level 2 sleep has successfully purged model weights from VRAM.
-
-**Residual & Fragmentation Monitoring (rank_residual_gib)**: Reports the remaining VRAM footprint after memory de-mapping. A low residual value (e.g., 2.07 GiB) confirms a successful "clean" state, ensuring the device is ready for high-memory co-located tasks like training or diffusion pipelines.
-
-**Backend Traceability (source)**: Identifies the underlying hardware driver or audit source. This is critical for debugging synchronization issues in multi-stage, distributed environments.
-
-**Performance Analytics (Roadmap)**: Future updates will include latency_ms (context-switch overhead) and cuda_graph_recalled (graph engine status) to optimize performance in high-frequency sleep/wake scenarios.
diff --git a/docs/getting_started/installation/README.md b/docs/getting_started/installation/README.md
index 89562c53c51..353fbe1c073 100644
--- a/docs/getting_started/installation/README.md
+++ b/docs/getting_started/installation/README.md
@@ -6,5 +6,4 @@ vLLM-Omni supports the following hardware platforms:
- [NVIDIA CUDA](gpu.md)
- [AMD ROCm](gpu.md)
- [Intel XPU](gpu.md)
- - [MThreads MUSA](gpu.md)
- [NPU](npu.md)
diff --git a/docs/getting_started/installation/gpu.md b/docs/getting_started/installation/gpu.md
index d08f134b5d6..297c3666169 100644
--- a/docs/getting_started/installation/gpu.md
+++ b/docs/getting_started/installation/gpu.md
@@ -22,10 +22,6 @@ vLLM-Omni is a Python library that supports the following GPU variants. The libr
--8<-- "docs/getting_started/installation/gpu/xpu.inc.md:requirements"
-=== "MThreads MUSA"
-
- --8<-- "docs/getting_started/installation/gpu/musa.inc.md:requirements"
-
## Set up using Python
### Create a new Python environment
@@ -48,10 +44,6 @@ Note: Pre-built wheels are currently available for vLLM-Omni 0.11.0rc1, 0.12.0rc
--8<-- "docs/getting_started/installation/gpu/xpu.inc.md:pre-built-wheels"
-=== "MThreads MUSA"
-
- --8<-- "docs/getting_started/installation/gpu/musa.inc.md:pre-built-wheels"
-
[](){ #build-from-source }
### Build wheel from source
@@ -68,10 +60,6 @@ Note: Pre-built wheels are currently available for vLLM-Omni 0.11.0rc1, 0.12.0rc
--8<-- "docs/getting_started/installation/gpu/xpu.inc.md:build-wheel-from-source"
-=== "MThreads MUSA"
-
- --8<-- "docs/getting_started/installation/gpu/musa.inc.md:build-wheel-from-source"
-
## Set up using Docker
### Pre-built images
@@ -88,10 +76,6 @@ Note: Pre-built wheels are currently available for vLLM-Omni 0.11.0rc1, 0.12.0rc
--8<-- "docs/getting_started/installation/gpu/xpu.inc.md:pre-built-images"
-=== "MThreads MUSA"
-
- --8<-- "docs/getting_started/installation/gpu/musa.inc.md:pre-built-images"
-
### Build your own docker image
=== "AMD ROCm"
@@ -101,7 +85,3 @@ Note: Pre-built wheels are currently available for vLLM-Omni 0.11.0rc1, 0.12.0rc
=== "Intel XPU"
--8<-- "docs/getting_started/installation/gpu/xpu.inc.md:build-docker"
-
-=== "MThreads MUSA"
-
- --8<-- "docs/getting_started/installation/gpu/musa.inc.md:build-docker"
diff --git a/docs/getting_started/installation/gpu/musa.inc.md b/docs/getting_started/installation/gpu/musa.inc.md
deleted file mode 100644
index a7cbc848f58..00000000000
--- a/docs/getting_started/installation/gpu/musa.inc.md
+++ /dev/null
@@ -1,65 +0,0 @@
-# --8<-- [start:requirements]
-
-- GPU: Moore Threads GPU with MUSA SDK installed (validated on MTT S5000)
-
-# --8<-- [end:requirements]
-# --8<-- [start:set-up-using-python]
-
-vLLM-Omni for MUSA requires building from source. Pre-built wheels are not currently available.
-
-!!! note
- MUSA platform requires vLLM-MUSA to be installed first.
-
-# --8<-- [start:pre-built-wheels]
-
-# --8<-- [end:pre-built-wheels]
-
-# --8<-- [start:build-wheel-from-source]
-
-#### Prerequisites
-
-- **MUSA SDK**: Download from [MUSA SDK Download](https://developer.mthreads.com/sdk/download/musa)
-- **torchada**: CUDA→MUSA compatibility layer for PyTorch (`pip install torchada`)
-- **mthreads-ml-py**: MTML Python bindings (`pip install mthreads-ml-py`)
-- **MATE**: MUSA AI Tensor Engine ([GitHub](https://github.com/MooreThreads/mate))
-
-#### Installation of vLLM-MUSA
-
-```bash
-git clone https://github.com/MooreThreads/vllm-musa.git
-cd vllm-musa
-git checkout v0.18.0-dev
-pip install . --no-build-isolation -v
-```
-
-#### Installation of vLLM-Omni
-
-```bash
-git clone https://github.com/vllm-project/vllm-omni.git
-cd vllm-omni
-VLLM_OMNI_TARGET_DEVICE=musa pip install -e . --no-build-isolation
-```
-
-For Gradio demos:
-
-```bash
-pip install -e '.[demo]' --no-build-isolation
-```
-
-#### Environment Variables
-
-```bash
-export MUSA_VISIBLE_DEVICES=0,1
-export VLLM_WORKER_MULTIPROC_METHOD=spawn
-export VLLM_MUSA_CUSTOM_OP_USE_NATIVE=false
-```
-
-# --8<-- [end:build-wheel-from-source]
-
-# --8<-- [start:build-docker]
-
-# --8<-- [end:build-docker]
-
-# --8<-- [start:pre-built-images]
-
-# --8<-- [end:pre-built-images]
diff --git a/docs/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu/rocm.inc.md
index 5dfea8d2ffe..1a683d174f7 100644
--- a/docs/getting_started/installation/gpu/rocm.inc.md
+++ b/docs/getting_started/installation/gpu/rocm.inc.md
@@ -26,7 +26,7 @@ uv pip install vllm-omni
# Optional if want to run Qwen3 TTS
uv pip uninstall onnxruntime # should be removed before we can install onnxruntime-rocm
-uv pip install onnxruntime-rocm
+uv pip install onnxruntime-rocm sox
```
# --8<-- [end:pre-built-wheels]
diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py
index e3cfb1b6a86..0aed44a0c65 100644
--- a/docs/mkdocs/hooks/generate_argparse.py
+++ b/docs/mkdocs/hooks/generate_argparse.py
@@ -121,7 +121,6 @@ def add_parser(self, name, **kwargs):
"FlexibleArgumentParser": _FlexibleArgumentParser,
"make_arg_parser": lambda parser: parser, # no-op for doc
"_ensure_vllm_platform": lambda: None, # no-op for doc
- "nullify_stage_engine_defaults": lambda parser: None, # no-op for doc
"VLLM_SUBCMD_PARSER_EPILOG": "",
"logger": logger,
"DummySubparsers": DummySubparsers,
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index feea969e51f..d611c0311c5 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -36,10 +36,7 @@ th {
| `LTX2ImageToVideoPipeline` | LTX-2-I2V | `Lightricks/LTX-2` | ✅︎ | ✅︎ | | |
| `LTX2TwoStagesPipeline` | LTX-2-T2V | `rootonchair/LTX-2-19b-distilled` | ✅︎ | ✅︎ | | |
| `LTX2ImageToVideoTwoStagesPipeline` | LTX-2-I2V | `rootonchair/LTX-2-19b-distilled` | ✅︎ | ✅︎ | | |
-| `LTX23Pipeline` | LTX-2.3-T2V | `dg845/LTX-2.3-Diffusers` | ✅︎ | ✅︎ | | |
-| `LTX23ImageToVideoPipeline` | LTX-2.3-I2V | `dg845/LTX-2.3-Diffusers` | ✅︎ | ✅︎ | | |
| `HeliosPipeline`, `HeliosPyramidPipeline` | Helios | `BestWishYsh/Helios-Base`, `BestWishYsh/Helios-Mid`, `BestWishYsh/Helios-Distilled` | ✅︎ | ✅︎ | ✅︎ | |
-| `MagiHumanPipeline` | MagiHuman | `SII-GAIR/daVinci-MagiHuman-Base-1080p` | ✅︎ | ✅︎ | | |
| `OvisImagePipeline` | Ovis-Image | `OvisAI/Ovis-Image` | ✅︎ | ✅︎ | | ✅︎ |
| `LongcatImagePipeline` | LongCat-Image | `meituan-longcat/LongCat-Image` | ✅︎ | ✅︎ | ✅︎ | ✅︎ |
| `LongCatImageEditPipeline` | LongCat-Image-Edit | `meituan-longcat/LongCat-Image-Edit` | ✅︎ | ✅︎ | ✅︎ | ✅︎ |
@@ -49,21 +46,18 @@ th {
| `Flux2KleinPipeline` | FLUX.2-klein | `black-forest-labs/FLUX.2-klein-4B`, `black-forest-labs/FLUX.2-klein-9B` | ✅︎ | ✅︎ | ✅︎ | ✅︎ |
| `FluxKontextPipeline` | FLUX.1-Kontext-dev | `black-forest-labs/FLUX.1-Kontext-dev` | ✅︎ | ✅︎ | | |
| `FluxPipeline` | FLUX.1-dev | `black-forest-labs/FLUX.1-dev` | ✅︎ | ✅︎ | | ✅︎ |
-| `FluxPipeline` | FLUX.1-schnell | `black-forest-labs/FLUX.1-schnell` | ✅︎ | ✅︎ | | ✅︎ |
| `OmniGen2Pipeline` | OmniGen2 | `OmniGen2/OmniGen2` | ✅︎ | ✅︎ | | ✅︎ |
| `StableAudioPipeline` | Stable-Audio-Open | `stabilityai/stable-audio-open-1.0` | ✅︎ | ✅︎ | | ✅︎ |
| `Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-CustomVoice | `Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice` | ✅︎ | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-VoiceDesign | `Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign` | ✅︎ | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3TTSForConditionalGeneration` | Qwen3-TTS-12Hz-1.7B-Base | `Qwen/Qwen3-TTS-12Hz-0.6B-Base` | ✅︎ | ✅︎ | ✅︎ | ✅︎ |
| `NextStep11Pipeline` | NextStep-1.1 | `stepfun-ai/NextStep-1.1` | ✅︎ | ✅︎ | | ✅︎ |
-| `MiMoAudioModel` | MiMo-Audio-7B-Instruct | `XiaomiMiMo/MiMo-Audio-7B-Instruct` | ✅︎ | ✅︎ | | |
-| `MiMoV2ASRForCausalLM` | MiMo-V2.5-ASR | `XiaomiMiMo/MiMo-V2.5-ASR` | ✅︎ | ✅︎ | | |
+| `MiMoAudioForConditionalGeneration` | MiMo-Audio-7B-Instruct | `XiaomiMiMo/MiMo-Audio-7B-Instruct` | ✅︎ | ✅︎ | | |
| `Flux2Pipeline` | FLUX.2-dev | `black-forest-labs/FLUX.2-dev` | ✅︎ | ✅︎ | | |
| `FishSpeechSlowARForConditionalGeneration` | Fish Speech S2 Pro | `fishaudio/s2-pro` | ✅︎ | ✅︎ | | |
| `DreamIDOmniPipeline` | DreamID-Omni | `XuGuo699/DreamID-Omni` | ✅︎ | ✅︎ | | |
| `HunyuanVideo15Pipeline` | HunyuanVideo-1.5-T2V | `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v`, `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_t2v` | ✅︎ | ✅︎ | | |
| `HunyuanVideo15ImageToVideoPipeline` | HunyuanVideo-1.5-I2V | `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_i2v`, `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_i2v` | ✅︎ | ✅︎ | | |
| `VoxtralTTSForConditionalGeneration` | Voxtral TTS | `mistralai/Voxtral-4B-TTS-2603` | ✅︎ | ✅︎ | | |
-|`DyninOmniForConditionalGeneration` | Dynin-Omni | `snu-aidas/Dynin-Omni` | ✅︎ | | | |
✅︎ indicates the model is supported on that backend. Empty cells mean not listed as supported on that backend.
diff --git a/docs/pr_reviewer.md b/docs/pr_reviewer.md
new file mode 100644
index 00000000000..ad32355328b
--- /dev/null
+++ b/docs/pr_reviewer.md
@@ -0,0 +1,255 @@
+# VLLM-Omni PR Reviewer
+
+## Overview
+
+The VLLM-Omni PR Reviewer is an automated code review bot powered by GLM-4.7 AI model. It helps maintain code quality by providing intelligent feedback on pull requests.
+
+## Features
+
+- **Intelligent Code Analysis**: Leverages GLM-4.7 for understanding code context and providing meaningful feedback
+- **Comprehensive Reviews**: Covers code quality, architecture, security, testing, and documentation
+- **Structured Output**: Provides well-formatted reviews with clear sections and actionable suggestions
+- **Rate Limiting**: Built-in cooldown mechanism to prevent excessive API usage
+- **Retry Logic**: Automatic retries with exponential backoff for transient API failures
+- **Defensive Parsing**: Robust validation of API responses to handle malformed data
+- **Cost Control**: Only repository members/collaborators/owners can trigger reviews
+
+## How to Use
+
+### Triggering a Review
+
+To trigger an automated PR review, mention the bot in a PR comment:
+
+```
+@vllm-omni-reviewer please review
+```
+
+Or include in your PR description:
+
+```
+@vllm-omni-reviewer
+```
+
+The bot will automatically review your changes and post a detailed comment.
+
+## What Gets Reviewed
+
+- **vLLM Architecture Compatibility**: Ensures changes align with vLLM's design patterns
+- **Multi-modal Integration**: Reviews audio, vision, and text processing implementations
+- **Performance Implications**: Analyzes impact on inference latency and throughput
+- **Code Quality**: Checks Python best practices, type hints, and documentation
+- **Security Considerations**: Identifies potential security vulnerabilities
+- **Testing Coverage**: Recommends additional test cases when needed
+
+## Review Output
+
+The bot posts a structured review comment with:
+
+- **Overview**: Brief summary of the PR's purpose
+- **Critical Issues (Must Fix)**: Blocking issues that need to be addressed
+- **Important Issues (Should Fix)**: Significant concerns that should be resolved
+- **Minor Issues & Suggestions**: Small improvements and optional suggestions
+- **Positive Aspects**: Highlights well-implemented features
+- **Performance Considerations**: Analysis of performance impact
+- **Testing Recommendations**: Suggestions for additional tests
+- **Overall Assessment**: Final recommendation (Approve/Request Changes/Needs Major Work)
+
+## Rate Limiting and Cooldown
+
+The bot includes a cooldown mechanism to prevent excessive API usage:
+
+- **Default cooldown**: 5 minutes between reviews per PR
+- **Configurable**: Can be adjusted via `PR_REVIEWER_COOLDOWN_MINUTES` environment variable
+- **Smart detection**: Checks for previous bot comments before starting a review
+
+If you trigger a review within the cooldown period, the bot will log a message and skip the review.
+
+## Architecture
+
+```
+┌─────────────────┐
+│ PR Comment │
+│ @vllm-omni- │
+│ reviewer │
+└────────┬────────┘
+ │
+ ▼
+┌─────────────────────────────────┐
+│ GitHub Actions Workflow │
+│ (.github/workflows/ │
+│ pr-reviewer.yml) │
+│ │
+│ - Python 3.11 │
+│ - requests==2.31.0 │
+│ - pyyaml==6.0.1 │
+└────────┬────────────────────────┘
+ │
+ ▼
+┌─────────────────────────────────┐
+│ PR Reviewer Script │
+│ (.github/scripts/ │
+│ pr_reviewer.py) │
+│ │
+│ 1. Check cooldown │
+│ 2. Fetch PR details & diff │
+│ 3. Build review prompt │
+│ 4. Call GLM-4.7 API │
+│ (with retry logic) │
+│ 5. Validate response │
+│ 6. Post review comment │
+└────────┬────────────────────────┘
+ │
+ ▼
+┌─────────────────────────────────┐
+│ GLM-4.7 API │
+│ (open.bigmodel.cn) │
+└─────────────────────────────────┘
+```
+
+1. **GitHub Actions Workflow** (`.github/workflows/pr-reviewer.yml`): Triggers on @mention
+2. **Python Script** (`.github/scripts/pr_reviewer.py`): Fetches PR data and calls GLM-4.7 API
+3. **GLM-4.7 API**: Provides intelligent code analysis
+
+## Testing
+
+### Testing the PR Reviewer Bot
+
+To test the PR reviewer bot before deploying to production:
+
+1. **Create a test PR** - Make a small, safe change (e.g., documentation update)
+2. **Open the PR** - Create a pull request with a descriptive title
+3. **Trigger the review** - Comment `@vllm-omni-reviewer` on the PR
+4. **Monitor results** - Check the Actions tab for workflow execution logs
+
+### Running Unit Tests
+
+The bot includes comprehensive unit tests that can be run locally:
+
+```bash
+# Run all tests
+pytest .github/tests/test_pr_reviewer.py -v
+
+# Run specific test
+pytest .github/tests/test_pr_reviewer.py::TestCheckTrigger -v
+
+# Run with coverage
+pytest .github/tests/test_pr_reviewer.py --cov=.github/scripts/pr_reviewer.py --cov-report=term-missing
+```
+
+### What to Look For
+
+When testing, verify that:
+- [ ] The workflow triggers on the `@vllm-omni-reviewer` comment
+- [ ] The cooldown mechanism works correctly
+- [ ] The GLM API call completes without errors (with retry if needed)
+- [ ] A review comment is posted to the PR
+- [ ] The review content is meaningful and well-structured
+- [ ] The cost is within the expected range (0.50-5 CNY)
+
+### Safe Test Changes
+
+For testing, consider making these types of safe changes:
+- Documentation updates (like adding this Testing section)
+- Comment improvements
+- README enhancements
+- Non-functional file additions
+
+### Example Test PR
+
+A good test PR might:
+- Update a documentation file
+- Add explanatory comments
+- Improve code formatting
+- Fix a minor typo
+
+These changes are safe to merge if the test is successful and won't affect functionality.
+
+## Troubleshooting
+
+### Bot Doesn't Respond
+
+1. **Check permissions** - Verify you have Owner/Member/Collaborator access
+2. **Check Actions tab** - Look for workflow execution and view logs
+3. **Check cooldown** - If another review was posted recently, wait for the cooldown period
+4. **Check API key** - Ensure `GLM_API_KEY` is configured in repository secrets
+
+### API Errors
+
+If the GLM API call fails:
+- Check the Actions tab for detailed error logs
+- Verify the `GLM_API_KEY` secret is correctly configured
+- Ensure sufficient API quota is available
+- The bot will automatically retry up to 3 times with exponential backoff
+
+### Review Seems Truncated
+
+If the review appears incomplete:
+- Large diffs may be truncated at 100,000 characters
+- Check the logs for truncation warnings
+- Consider breaking large PRs into smaller chunks
+
+## Configuration
+
+### Required Secrets
+
+The following secret must be configured in the repository settings:
+
+- `GLM_API_KEY` - Your GLM (BigModel) API key for accessing the GLM-4.7 API
+
+To add the secret:
+1. Go to repository Settings → Secrets and variables → Actions
+2. Click "New repository secret"
+3. Name: `GLM_API_KEY`
+4. Value: Your GLM API key
+
+### Optional Configuration
+
+The following optional environment variables can be set in the workflow file:
+
+| Variable | Default | Description |
+|----------|---------|-------------|
+| `GLM_API_URL` | `https://open.bigmodel.cn/api/paas/v4/chat/completions` | GLM API endpoint |
+| `GLM_MODEL` | `glm-4.7` | Model to use for reviews |
+| `PR_REVIEWER_COOLDOWN_MINUTES` | `5` | Cooldown period between reviews |
+| `PR_REVIEWER_MAX_RETRIES` | `3` | Maximum API retry attempts |
+| `PR_REVIEWER_RETRY_DELAY` | `1.0` | Base delay for retry backoff (seconds) |
+| `PR_REVIEWER_MAX_DIFF_SIZE` | `100000` | Maximum diff size before truncation |
+
+### Workflow Customization
+
+The workflow can be customized in `.github/workflows/pr-reviewer.yml`:
+- Change Python version (default: 3.11)
+- Adjust timeout value (default: 10 minutes)
+- Modify trigger conditions
+- Add additional dependencies
+
+## Code Quality
+
+The PR reviewer script follows vllm-omni coding standards:
+
+- **Type hints**: All functions have complete type hints following mypy strict mode
+- **Logging**: Uses Python's logging module for structured logging
+- **Testing**: Comprehensive unit tests with pytest
+- **Pre-commit**: Script is checked by pre-commit hooks (flake8)
+
+## Cost Estimate
+
+| Component | Cost |
+|-----------|------|
+| GitHub Actions (public repo) | Free |
+| GLM API (glm-4.7) | ~0.50-5 CNY per PR (varies by size) |
+| Total (20 PRs/month) | ~10-100 CNY/month (~$2-15 USD) |
+
+## Contributing
+
+To improve the PR reviewer bot:
+
+1. Edit `.github/scripts/pr_reviewer.py` for logic changes
+2. Edit `.github/workflows/pr-reviewer.yml` for workflow changes
+3. Add tests to `.github/tests/test_pr_reviewer.py`
+4. Run `pre-commit run --files .github/scripts/pr_reviewer.py` to check code quality
+5. Test thoroughly with a test PR before deploying to production
+
+## License
+
+This bot is part of the VLLM-Omni project and follows the same license terms.
diff --git a/docs/serving/image_edit_api.md b/docs/serving/image_edit_api.md
index 79303e1a690..d254ac06ad7 100644
--- a/docs/serving/image_edit_api.md
+++ b/docs/serving/image_edit_api.md
@@ -104,8 +104,6 @@ Content-Type: multipart/form-data
| `guidance_scale` | float | model defaults | Classifier-free guidance scale (typically 0.0-20.0) |
| `true_cfg_scale` | float | model defaults | True CFG scale (model-specific parameter, may be ignored if not supported) |
| `seed` | integer | null | Random seed for reproducibility |
-| `reference_image` | string or array | null | Reference image for inpainting |
-| `mask_image` | string or array | null | Mask for inpainting (white areas will be inpainted) |
### Response Format
diff --git a/docs/serving/speech_api.md b/docs/serving/speech_api.md
index 8f78d6a2001..ecbe8d9ac98 100644
--- a/docs/serving/speech_api.md
+++ b/docs/serving/speech_api.md
@@ -15,17 +15,28 @@ Each server instance runs a single model (specified at startup via `vllm serve <
```bash
# Qwen3-TTS: CustomVoice model (predefined speakers)
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
--enforce-eager
# Fish Speech S2 Pro
-vllm serve fishaudio/s2-pro --omni --port 8091
+vllm-omni serve fishaudio/s2-pro \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml \
+ --omni \
+ --port 8091 \
+ --trust-remote-code \
+ --enforce-eager \
+ --gpu-memory-utilization 0.9
# Voxtral TTS
-vllm serve mistralai/Voxtral-4B-TTS-2603 --omni --port 8091
+vllm serve mistralai/Voxtral-4B-TTS-2603 \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxtral_tts.yaml \
+ --omni \
+ --port 8091 \
+ --trust-remote-code \
+ --enforce-eager
```
### Generate Speech
@@ -289,7 +300,7 @@ curl -X POST http://localhost:8091/v1/audio/speech \
```bash
# Start server with VoiceDesign model first
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign \
- --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
@@ -311,7 +322,7 @@ curl -X POST http://localhost:8091/v1/audio/speech \
```bash
# Start server with Base model first
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-Base \
- --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
@@ -506,16 +517,15 @@ for result in response.json()["results"]:
All items are fanned out to `generate()` concurrently. The engine's stage worker automatically batches them up to the configured `max_batch_size` and queues the rest — no client-side throttling needed.
-For best throughput, set both stages' `max_num_seqs` to ≥4 via `--stage-overrides`:
+For best throughput, use a batch-optimized stage config with `max_batch_size > 1`:
```bash
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --omni --port 8091 --trust-remote-code --enforce-eager \
- --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2},
- "1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}'
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml \
+ --omni --port 8091 --trust-remote-code --enforce-eager
```
-The bundled `qwen3_tts.yaml` uses `max_num_seqs: 1` (single request) on both stages. Bumping to 4 yields roughly 4× throughput on the talker and lets stage 1 batch chunks across in-flight requests.
+The default `qwen3_tts.yaml` uses `max_batch_size: 1` (single request). The `qwen3_tts_batch.yaml` config sets `max_batch_size: 4` for ~4x throughput.
## Supported Models
@@ -607,7 +617,7 @@ Enable debug logging:
```bash
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
diff --git a/docs/serving/video_stream_api.md b/docs/serving/video_stream_api.md
deleted file mode 100644
index 88f74affca9..00000000000
--- a/docs/serving/video_stream_api.md
+++ /dev/null
@@ -1,93 +0,0 @@
-# Streaming Video Input API
-
-vLLM-Omni provides a WebSocket API for streaming video frames and optional audio chunks into Qwen3-Omni, then asking questions over the buffered session context.
-
-Each server instance runs a single model specified at startup with `vllm serve --omni`.
-
-## Quick Start
-
-### Start the Server
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct \
- --deploy-config vllm_omni/deploy/qwen3_omni.yaml \
- --omni \
- --port 8000 \
- --trust-remote-code
-```
-
-### Run the Example Client
-
-```bash
-python examples/online_serving/qwen3_omni/streaming_video_client.py \
- --url ws://localhost:8000/v1/video/chat/stream \
- --video /path/to/video.mp4 \
- --query "Describe what is happening in the video."
-```
-
-## API Reference
-
-### Endpoint
-
-```text
-WebSocket /v1/video/chat/stream
-```
-
-### Protocol
-
-| Direction | Type | Required fields | Description |
-|-----------|------|-----------------|-------------|
-| Client -> Server | `session.config` | none | First message. Configures output modalities, frame sampling, EVS, and prompts. |
-| Client -> Server | `video.frame` | `data` | Base64 JPEG/PNG frame. |
-| Client -> Server | `audio.chunk` | `data` | Base64 PCM16 16 kHz mono audio bytes. |
-| Client -> Server | `video.query` | `text` | Ask a question over the buffered frames and audio. |
-| Client -> Server | `video.done` | none | End the WebSocket session. |
-| Server -> Client | `response.start` | none | Query generation started. |
-| Server -> Client | `response.text.delta` | `delta` | Incremental text output. |
-| Server -> Client | `response.text.done` | `text` | Final text output for the query. |
-| Server -> Client | `response.audio.delta` | `data`, `format` | Incremental generated audio, base64 WAV. |
-| Server -> Client | `response.audio.done` | none | Audio output finished. |
-| Server -> Client | `session.done` | none | Session closed. |
-| Server -> Client | `error` | `message` | Recoverable protocol or generation error. |
-
-### `session.config` Fields
-
-| Field | Type | Default | Description |
-|-------|------|---------|-------------|
-| `model` | string or null | null | Optional model name. Usually omitted because the server hosts one model. |
-| `modalities` | list[string] | `["text", "audio"]` | Output modalities. Use `["text"]`, `["audio"]`, or both. |
-| `num_frames` | integer, 1-128 | `4` | Number of buffered frames sampled for each query. |
-| `max_frames` | integer, 1-256 | `50` | Maximum retained frame buffer size. Oldest frames are evicted first. |
-| `system_prompt` | string or null | null | Optional custom system prompt. |
-| `use_audio_in_video` | bool | `true` | Include streamed audio chunks in multimodal video understanding when audio is present. |
-| `sampling_params_list` | list or null | null | Optional per-stage sampling parameter overrides. |
-| `enable_frame_filter` | bool | `true` | Enable EVS near-duplicate frame filtering. |
-| `frame_filter_threshold` | float, 0.0-1.0 | `0.95` | EVS similarity threshold. Higher keeps more frames; lower drops more near-duplicates. |
-
-### Legacy Aliases
-
-The server accepts these legacy field names and rewrites them before validation. New clients should send the canonical names above.
-
-| Legacy field | Canonical field |
-|--------------|-----------------|
-| `num_sample_frames` | `num_frames` |
-| `evs_enabled` | `enable_frame_filter` |
-| `evs_threshold` | `frame_filter_threshold` |
-
-### Environment Variables
-
-| Variable | Values | Default | Description |
-|----------|--------|---------|-------------|
-| `VLLM_VIDEO_ASYNC_CHUNK` | `on`, `off` | `on` | Wire-level streaming switch. `off` buffers server-side deltas and emits coalesced outputs at the end of a query. |
-| `VLLM_VIDEO_AUDIO_DELTA_MODE` | `fast`, `slow` | `fast` | Audio delta extraction strategy. `fast` emits only newly produced chunks; `slow` recomputes from accumulated audio and exists for A/B verification. |
-
-## EVS Semantics
-
-EVS compares downsampled frames and drops near-duplicate frames before they enter the session frame buffer. `frame_filter_threshold` controls retention: higher values are more permissive and keep more frames; lower values are more aggressive and drop more similar frames.
-
-## Known Limitations
-
-- Session KV reuse and incremental prefill are not implemented in this PR. Each `video.query` rebuilds the model prompt from the retained frame and audio buffers.
-- Back-to-back short replies can still expose an engine-layer scheduler race. The PR notes an observed workaround of at least 200 ms idle between turns when clients repeatedly see idle timeouts.
-- If the audio buffer exceeds the server limit, the server emits `Audio buffer overflow` and clears the currently buffered audio for the session.
-- The API is intended for Qwen3-Omni streaming video understanding; other models may not support the same multimodal processor arguments.
diff --git a/docs/source/architecture/async-chunk-architecture.png b/docs/source/architecture/async-chunk-architecture.png
index 7b3e95e4df9..249de53bfe3 100644
Binary files a/docs/source/architecture/async-chunk-architecture.png and b/docs/source/architecture/async-chunk-architecture.png differ
diff --git a/docs/source/architecture/qwen3-omni-async-chunk.png b/docs/source/architecture/qwen3-omni-async-chunk.png
index e73ca84b283..b2d98b80f33 100644
Binary files a/docs/source/architecture/qwen3-omni-async-chunk.png and b/docs/source/architecture/qwen3-omni-async-chunk.png differ
diff --git a/docs/source/architecture/qwen3-omni-non-async-chunk.png b/docs/source/architecture/qwen3-omni-non-async-chunk.png
index 47a9ba66a5e..da5610a11bb 100644
Binary files a/docs/source/architecture/qwen3-omni-non-async-chunk.png and b/docs/source/architecture/qwen3-omni-non-async-chunk.png differ
diff --git a/docs/source/architecture/vllm-omni-dataflow-between-stages.png b/docs/source/architecture/vllm-omni-dataflow-between-stages.png
index 74abc81ff07..cdbc9a8b7b3 100644
Binary files a/docs/source/architecture/vllm-omni-dataflow-between-stages.png and b/docs/source/architecture/vllm-omni-dataflow-between-stages.png differ
diff --git a/docs/usage/faq.md b/docs/usage/faq.md
index 0539e158b01..c080eae4023 100644
--- a/docs/usage/faq.md
+++ b/docs/usage/faq.md
@@ -4,6 +4,14 @@
A: Now, we support natively disaggregated deployment for different model stages within a model. There is a restriction that one chip can only have one AutoRegressive model stage. This is because the unified KV cache management of vLLM. Stages of other types can coexist within a chip. The restriction will be resolved in later version.
+> Q: When trying to run examples, I encounter error about backend of librosa or soundfile. How to solve it?
+
+A: If you encounter error about backend of librosa, try to install ffmpeg with command below.
+```
+sudo apt update
+sudo apt install ffmpeg
+```
+
> Q: I see GPU OOM or "free memory is less than desired GPU memory utilization" errors. How can I fix it?
A: Refer to [GPU memory calculation and configuration](../configuration/gpu_memory_utilization.md) for guidance on tuning `gpu_memory_utilization` and related settings.
diff --git a/docs/user_guide/diffusion/attention_backends.md b/docs/user_guide/diffusion/attention_backends.md
deleted file mode 100644
index 692bcc0f9d0..00000000000
--- a/docs/user_guide/diffusion/attention_backends.md
+++ /dev/null
@@ -1,120 +0,0 @@
-# Diffusion Attention Backends
-
-This document describes the diffusion attention backends available in vLLM-Omni, how to select them, and how to use SageAttention.
-
-## Overview
-
-Diffusion attention backend selection is controlled by the `DIFFUSION_ATTENTION_BACKEND` environment variable and resolved in `vllm_omni.diffusion.attention.selector`.
-
-This backend is used by diffusion attention layers such as the DiT attention in video and image generation models.
-
-On CUDA, the practical choices today are:
-
-- `FLASH_ATTN`: FlashAttention backend. This is the default on supported CUDA systems when FlashAttention is installed.
-- `TORCH_SDPA`: PyTorch `scaled_dot_product_attention`.
-- `SAGE_ATTN`: SageAttention backend, if `sageattention` is installed.
-
-If `DIFFUSION_ATTENTION_BACKEND` is unset, vLLM-Omni asks the current platform to choose the default backend. On CUDA, that normally means `FLASH_ATTN` when available, otherwise `TORCH_SDPA`.
-
-## Backend Options
-
-| Value | Notes |
-|---|---|
-| `FLASH_ATTN` | Default on CUDA when FlashAttention is available. Good default for most diffusion workloads. |
-| `TORCH_SDPA` | Most conservative fallback. Useful for debugging or compatibility. |
-| `SAGE_ATTN` | Requires `sageattention`. Can improve performance on some workloads, but output quality must be validated model-by-model. |
-
-## Selection Priority
-
-Diffusion attention backend selection follows this order:
-
-1. `DIFFUSION_ATTENTION_BACKEND`
-2. Platform default
-
-Example:
-
-```bash
-export DIFFUSION_ATTENTION_BACKEND=SAGE_ATTN
-```
-
-## SageAttention Installation
-
-vLLM-Omni expects SageAttention to be installed into the same Python environment as vLLM-Omni.
-
-Build from source:
-
-```bash
-git clone https://github.com/thu-ml/SageAttention.git
-cd SageAttention
-
-export EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32
-pip install . --no-build-isolation
-```
-
-Quick check:
-
-```bash
-python -c "import sageattention; print(sageattention.__file__)"
-```
-
-## Usage
-
-### Enable SageAttention
-
-Example: HunyuanVideo-1.5 text-to-video
-
-```bash
-DIFFUSION_ATTENTION_BACKEND=SAGE_ATTN python examples/offline_inference/text_to_video/text_to_video.py \
- --model hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v \
- --prompt "A dog running across a field of golden wheat." \
- --height 480 --width 832 --num-frames 33 \
- --num-inference-steps 30 --seed 42 --guidance-scale 6.0 \
- --tensor-parallel-size 2 \
- --output ../tmp/hv15_modelopt_sage.mp4
-```
-
-Example: Wan2.2 TI2V 5B
-
-```bash
-DIFFUSION_ATTENTION_BACKEND=SAGE_ATTN python examples/offline_inference/text_to_video/text_to_video.py \
- --model Wan-AI/Wan2.2-TI2V-5B-Diffusers \
- --prompt "A dog running across a field of golden wheat." \
- --height 704 --width 1280 --num-frames 49 \
- --num-inference-steps 30 --seed 42 --guidance-scale 5.0 \
- --tensor-parallel-size 2 \
- --output outputs/wan22_sage.mp4
-```
-
-### Compare Against FlashAttention
-
-Unset the backend override, or explicitly use `FLASH_ATTN`:
-
-```bash
-python examples/offline_inference/text_to_video/text_to_video.py \
- --model Wan-AI/Wan2.2-TI2V-5B-Diffusers \
- --prompt "A dog running across a field of golden wheat." \
- --height 704 --width 1280 --num-frames 49 \
- --num-inference-steps 30 --seed 42 --guidance-scale 5.0 \
- --tensor-parallel-size 2 \
- --output outputs/wan22_fa3.mp4
-```
-
-## Validation Guidance
-
-Do not assume that a faster attention backend is numerically interchangeable with `FLASH_ATTN`.
-
-Always compare:
-
-- End-to-end runtime
-- DiT / diffusion stage runtime
-- Output quality against a known-good baseline
-
-At minimum, keep the same:
-
-- model
-- prompt
-- seed
-- resolution
-- frame count
-- inference steps
-- parallel config
diff --git a/docs/user_guide/diffusion/cache_acceleration/cache_dit.md b/docs/user_guide/diffusion/cache_acceleration/cache_dit.md
index eaaca84ad6d..824e8c93051 100644
--- a/docs/user_guide/diffusion/cache_acceleration/cache_dit.md
+++ b/docs/user_guide/diffusion/cache_acceleration/cache_dit.md
@@ -283,10 +283,3 @@ Using Cache-DiT acceleration:
1. ✅ **Enable Cache-DiT** - Set `cache_backend="cache_dit"` to get 1.5x-3x speedup with optimized defaults
2. ✅ **(Optional) Customize** - Adjust `cache_config` parameters for specific speed/quality trade-offs
-
----
-
-## Additional Resources
-
-- [Cache-DiT documentation](https://cache-dit.readthedocs.io/en/latest/)
-- [Cache-DiT API reference](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/)
diff --git a/docs/user_guide/diffusion/cpu_offload_diffusion.md b/docs/user_guide/diffusion/cpu_offload_diffusion.md
index 1d3f1811aed..8786ae9649a 100644
--- a/docs/user_guide/diffusion/cpu_offload_diffusion.md
+++ b/docs/user_guide/diffusion/cpu_offload_diffusion.md
@@ -36,45 +36,6 @@ m = Omni(model="Wan-AI/Wan2.2-T2V-A14B-Diffusers", enable_cpu_offload=True)
vllm-omni serve diffusion Wan-AI/Wan2.2-T2V-A14B-Diffusers --enable-cpu-offload
```
-### To Support a Model
-
-Implement the `SupportsModuleOffload` protocol to declare which
-submodules participate in offloading:
-
-```python
-from typing import ClassVar
-from vllm_omni.diffusion.models.interface import SupportsModuleOffload
-
-class MyPipeline(nn.Module, SupportsModuleOffload):
- _dit_modules: ClassVar[list[str]] = ["transformer"]
- _encoder_modules: ClassVar[list[str]] = ["text_encoder", "vision_model"]
- _vae_modules: ClassVar[list[str]] = ["vae"]
- _resident_modules: ClassVar[list[str]] = [] # optional
-
- def __init__(self):
- super().__init__()
- self.transformer = ... # DiT — stays on GPU during denoising
- self.text_encoder = ... # Encoder — offloaded to CPU during denoising
- self.vision_model = ... # Encoder — offloaded to CPU during denoising
- self.vae = ... # VAE — always on GPU
-```
-
-- `_dit_modules`: attribute names of denoising submodules (kept on GPU
- during the diffusion loop).
-- `_encoder_modules`: attribute names of encoder/vision submodules
- (offloaded to CPU during the diffusion loop).
-- `_vae_modules`: attribute names of VAE(s) (always kept on GPU, not
- part of the mutual exclusion hooks).
-- `_resident_modules`: attribute names of small submodules that must
- stay on GPU during layerwise offloading (e.g. embedders, connectors).
- Optional — defaults to `[]`.
-
-All attribute names support dotted paths for nested submodules
-(e.g. `"pipe.transformer"`, `"bagel.time_embedder"`).
-
-Both DiT and encoder lists are needed because the offload hooks use
-mutual exclusion: when one group runs, the other moves to CPU.
-
### Limitations
- Cold start latency increases
- Adds overhead from CPU-GPU transfers between encoder and denoising phases
@@ -130,19 +91,12 @@ Models must define the blocks attribute name for layerwise offloading:
```python
class WanTransformer3DModel(nn.Module):
- _layerwise_offload_blocks_attrs = ["blocks"] # Attribute names containing transformer blocks
+ _layerwise_offload_blocks_attr = "blocks" # Attribute name containing transformer blocks
def __init__(self):
self.blocks = nn.ModuleList([...]) # Transformer blocks
```
-For models with multiple block types:
-
-```python
-class Flux2Transformer2DModel(nn.Module):
- _layerwise_offload_blocks_attrs = ["transformer_blocks", "single_transformer_blocks"]
-```
-
### Limitations
- Cold start latency increases because of
1) components are loaded to CPU first at the very first during initialization,
@@ -155,19 +109,11 @@ class Flux2Transformer2DModel(nn.Module):
**Module Discovery**
-The offloader discovers pipeline components in two ways:
+The offloader automatically discovers pipeline components:
-1. **Protocol-based** (preferred): If the pipeline implements
- `SupportsModuleOffload`, its `_dit_modules`, `_encoder_modules`,
- `_vae_modules`, and `_resident_modules` class variables are used
- directly. All attribute names support dotted paths (e.g.
- `"pipe.transformer"`, `"bagel.time_embedder"`) for nested submodules.
-
-2. **Fallback attribute scan**: Otherwise, the offloader scans for
- well-known attribute names:
- - **DiT modules**: `transformer`, `transformer_2`, `dit`, `sr_dit`, `language_model`, `transformer_blocks`, `model`
- - **Encoders**: `text_encoder`, `text_encoder_2`, `text_encoder_3`, `image_encoder`
- - **VAE**: `vae`, `audio_vae`
+- **DiT modules**: `transformer`, `transformer_2`, `dit`
+- **Encoders**: `text_encoder`, `text_encoder_2`, `text_encoder_3`, `image_encoder`
+- **VAE**: `vae`
**Hook System**
@@ -186,17 +132,12 @@ Factory function `get_offload_backend()` selects the appropriate backend based o
## Supported Models
-| Architecture | Example Models | DiT Class | Model-Level Offload | Layerwise Offload | Blocks Attrs (Layerwise specific) |
-|--------------|----------------|-----------|---------------------|-------------------|-----------------------------------|
-| LongCatImagePipeline | `meituan-longcat/LongCat-Image` | `LongCatImageTransformer2DModel` | - | ✓ | `"transformer_blocks"`, `"single_transformer_blocks"` |
-| NextStep11Pipeline | `stepfun-ai/NextStep-1.1` | `NextStepModel` | - | ✓ | `"layers"` |
-| OvisImagePipeline | `AIDC-AI/Ovis-Image-7B` | `OvisImageTransformer2DModel` | - | ✓ | `"transformer"` |
-| QwenImagePipeline | `Qwen/Qwen-Image` | `QwenImageTransformer2DModel` | ✓ | ✓ | `"transformer_blocks"` |
-| StableDiffusion3Pipeline | `stabilityai/stable-diffusion-3.5-medium` | `SD3Transformer2DModel` | - | ✓ | `"transformer_blocks"` |
-| Wan22I2VPipeline | `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | `WanTransformer3DModel` | ✓ | ✓ | `"blocks"` |
+| Architecture | Example Models | DiT Class | Model-Level Offload | Layerwise Offload | Blocks Attr (Layerwise specific) |
+|--------------|----------------|-----------|---------------------|-------------------|-------------|
| Wan22Pipeline | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | `WanTransformer3DModel` | ✓ | ✓ | `"blocks"` |
-| BagelPipeline | `ByteDance-Seed/BAGEL-7B-MoT` | `Qwen2MoTModel` | - | ✓ | `"layers"`, `"customized modules"` |
+| Wan22I2VPipeline | `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | `WanTransformer3DModel` | ✓ | ✓ | `"blocks"` |
+| QwenImagePipeline | `Qwen/Qwen-Image` | `QwenImageTransformer2DModel` | ✓ | ✓ | `"transformer_blocks"` |
**Notes:**
- Model-Level Offloading is expected to be supported by all common diffusion models (DiT and encoders) naturally
-- Layerwise Offloading requires DiT class to define `_layerwise_offload_blocks_attrs` pointing to transformer blocks
+- Layerwise Offloading requires DiT class to define `_layerwise_offload_blocks_attr` pointing to transformer blocks
diff --git a/docs/user_guide/diffusion/frame_interpolation.md b/docs/user_guide/diffusion/frame_interpolation.md
deleted file mode 100644
index 349af50c51c..00000000000
--- a/docs/user_guide/diffusion/frame_interpolation.md
+++ /dev/null
@@ -1,92 +0,0 @@
-# Frame Interpolation
-
-## Overview
-
-vLLM-Omni supports post-generation frame interpolation for supported video
-diffusion pipelines. This feature inserts synthesized intermediate frames
-between adjacent generated frames to improve temporal smoothness without
-rerunning the diffusion denoising loop.
-
-Frame interpolation runs in the diffusion worker post-processing path instead
-of the API server encoding path. This allows the interpolation step to reuse
-the worker's current accelerator device and keeps the FastAPI event loop free
-from heavy synchronous PyTorch work.
-
-For an input video with `N` generated frames and interpolation exponent `exp`,
-the output frame count is:
-
-```text
-(N - 1) * 2**exp + 1
-```
-
-The output FPS is multiplied by `2**exp` so the clip duration remains close to
-the original generated video.
-
-## Supported Pipelines
-
-Frame interpolation is currently supported for:
-
-- `WanPipeline` (Wan2.2 text-to-video)
-- `WanImageToVideoPipeline`
-- `Wan22TI2VPipeline`
-
-## Request Parameters
-
-The video APIs `/v1/videos` and `/v1/videos/sync` accept:
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `enable_frame_interpolation` | bool | `false` | Enable post-generation frame interpolation |
-| `frame_interpolation_exp` | int | `1` | Interpolation exponent. `1=2x`, `2=4x`, etc. |
-| `frame_interpolation_scale` | float | `1.0` | RIFE inference scale |
-| `frame_interpolation_model_path` | str | `None` | Local directory or Hugging Face repo ID containing `flownet.pkl` |
-
-## Execution Flow
-
-For supported Wan2.2 pipelines, the execution order is:
-
-1. Diffusion worker finishes denoising and decodes the raw video tensor.
-2. Worker-side model-specific post-processing runs.
-3. If frame interpolation is enabled, RIFE interpolates the decoded video
- tensor on the worker side and records a FPS multiplier in `custom_output`.
-4. The API server receives the already-interpolated video and only performs
- MP4 export.
-
-This design keeps interpolation close to the generated tensor and avoids
-introducing another heavyweight GPU context in the API server process.
-
-## Example
-
-Start the server:
-
-```bash
-vllm serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --omni --port 8091
-```
-
-Run a sync request with interpolation enabled:
-
-```bash
-curl -X POST http://localhost:8091/v1/videos/sync \
- -F "prompt=A dog running through a park" \
- -F "num_frames=81" \
- -F "width=832" \
- -F "height=480" \
- -F "fps=16" \
- -F "num_inference_steps=40" \
- -F "guidance_scale=1.0" \
- -F "guidance_scale_2=1.0" \
- -F "enable_frame_interpolation=true" \
- -F "frame_interpolation_exp=1" \
- -F "frame_interpolation_scale=1.0" \
- -F "seed=42" \
- -o sync_t2v_interpolated.mp4
-```
-
-## Notes
-
-- This is a post-processing feature. It does not modify the diffusion denoising
- schedule.
-- Higher interpolation exponents increase post-processing time and memory usage.
-- If the interpolation model weights are not available locally,
- `frame_interpolation_model_path` may point to a Hugging Face repo containing
- `flownet.pkl`.
diff --git a/docs/user_guide/diffusion/lora.md b/docs/user_guide/diffusion/lora.md
index 256698752a1..e45c033b848 100644
--- a/docs/user_guide/diffusion/lora.md
+++ b/docs/user_guide/diffusion/lora.md
@@ -56,92 +56,6 @@ outputs = omni.generate(
!!! note "Server-side Path Requirement"
The LoRA adapter path (`local_path`) must be readable on the **server** machine. If your client and server are on different machines, ensure the LoRA adapter is accessible via a shared mount or copied to the server.
-## Wan2.2 LightX2V Offline Assembly
-
-This workflow is LoRA-adjacent: it uses external LightX2V conversion plus
-`Wan2.2-Distill-Loras` to bake converted Wan2.2 I2V checkpoints into a local
-Diffusers directory, instead of loading LoRA adapters at runtime.
-
-### Required assets
-
-- Base model: `Wan-AI/Wan2.2-I2V-A14B`
-- Diffusers skeleton: `Wan-AI/Wan2.2-I2V-A14B-Diffusers`
-- Optional external converter from the LightX2V project (not shipped in this repository)
-- Optional LoRA weights: `lightx2v/Wan2.2-Distill-Loras`
-
-### Step 1: Optional - convert high/low-noise DiT weights with LightX2V
-
-Install or clone LightX2V from the upstream repository
-(`https://github.com/ModelTC/LightX2V`). After cloning, the converter used
-below is available at `/tools/convert/converter.py`.
-
-```bash
-python /path/to/lightx2v/tools/convert/converter.py \
- --source /path/to/Wan2.2-I2V-A14B/high_noise_model \
- --output /tmp/wan22_lightx2v/high_noise_out \
- --output_ext .safetensors \
- --output_name diffusion_pytorch_model \
- --model_type wan_dit \
- --direction forward \
- --lora_path /path/to/wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors \
- --lora_key_convert auto \
- --single_file
-
-python /path/to/lightx2v/tools/convert/converter.py \
- --source /path/to/Wan2.2-I2V-A14B/low_noise_model \
- --output /tmp/wan22_lightx2v/low_noise_out \
- --output_ext .safetensors \
- --output_name diffusion_pytorch_model \
- --model_type wan_dit \
- --direction forward \
- --lora_path /path/to/wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors \
- --lora_key_convert auto \
- --single_file
-```
-
-If you are not using LightX2V, skip this step and either keep the original
-Diffusers weights from the skeleton or point Step 2 at any other converted
-`transformer/` and `transformer_2/` checkpoints.
-
-### Step 2: Assemble a final Diffusers-style directory
-
-```bash
-python tools/wan22/assemble_wan22_i2v_diffusers.py \
- --diffusers-skeleton /path/to/Wan2.2-I2V-A14B-Diffusers \
- --transformer-weight /tmp/wan22_lightx2v/high_noise_out \
- --transformer-2-weight /tmp/wan22_lightx2v/low_noise_out \
- --output-dir /path/to/Wan2.2-I2V-A14B-Custom-Diffusers \
- --asset-mode symlink \
- --overwrite
-```
-
-`--transformer-weight` and `--transformer-2-weight` are optional. If you omit
-them, the tool keeps the original weights from the Diffusers skeleton.
-
-### Step 3: Run offline inference
-
-```bash
-python examples/offline_inference/image_to_video/image_to_video.py \
- --model /path/to/Wan2.2-I2V-A14B-Custom-Diffusers \
- --image /path/to/input.jpg \
- --prompt "A cat playing with yarn" \
- --num-frames 81 \
- --num-inference-steps 4 \
- --tensor-parallel-size 4 \
- --height 480 \
- --width 832 \
- --flow-shift 12 \
- --sample-solver euler \
- --guidance-scale 1.0 \
- --guidance-scale-high 1.0 \
- --boundary-ratio 0.875
-```
-
-Notes:
-
-- This route avoids runtime LoRA loading changes in vLLM-Omni when you choose to bake converted weights into a local Diffusers directory.
-- Output quality and speed depend on the replacement checkpoints and sampling params you choose.
-
## See Also
diff --git a/docs/user_guide/diffusion/quantization/autoround.md b/docs/user_guide/diffusion/quantization/autoround.md
index d06627d40a3..48df176b037 100644
--- a/docs/user_guide/diffusion/quantization/autoround.md
+++ b/docs/user_guide/diffusion/quantization/autoround.md
@@ -72,8 +72,6 @@ At load time:
| Model | HF Checkpoint | Scheme | Group Size | Backend |
|-------|--------------|--------|------------|---------|
| FLUX.1-dev | `vllm-project-org/FLUX.1-dev-AutoRound-w4a16` | W4A16 | 128 | GPTQ-Marlin |
-| Qwen2.5-Omni-7B | `Intel/Qwen2.5-Omni-7B-int4-AutoRound` | W4A16 | 128 | GPTQ-Marlin |
-| Qwen3-Omni-30B-A3B-Instruct | `Intel/Qwen3-Omni-30B-A3B-Instruct-int4-AutoRound` | W4A16 | 128 | GPTQ-Marlin |
## Creating a Quantized Checkpoint
diff --git a/docs/user_guide/diffusion/quantization/fp8.md b/docs/user_guide/diffusion/quantization/fp8.md
index ceb3d006c2e..9906631b625 100644
--- a/docs/user_guide/diffusion/quantization/fp8.md
+++ b/docs/user_guide/diffusion/quantization/fp8.md
@@ -65,7 +65,6 @@ The available `ignored_layers` names depend on the model architecture (e.g., `to
| Flux | `black-forest-labs/FLUX.1-dev` | All layers | None |
| HunyuanImage-3 | `tencent/HunyuanImage3` | All layers | None |
| HunyuanVideo-1.5 | `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v`, `720p_t2v`, `480p_i2v` | All layers | None |
-| GLM-Image | `zai-org/GLM-Image` | All layers | None |
| Helios | `BestWishYsh/Helios-Base`, `BestWishYsh/Helios-Mid`, `BestWishYsh/Helios-Distilled` | All layers | None |
## Combining with Other Features
diff --git a/docs/user_guide/diffusion/quantization/msmodelslim.md b/docs/user_guide/diffusion/quantization/msmodelslim.md
deleted file mode 100644
index 5492cd9272b..00000000000
--- a/docs/user_guide/diffusion/quantization/msmodelslim.md
+++ /dev/null
@@ -1,56 +0,0 @@
-# msModelSlim Quantization
-
-## Overview
-
-[msModelSlim](https://github.com/Ascend/msmodelslim) is an Ascend-friendly compression tool focused on acceleration, using compression techniques, and built for Ascend hardware. It includes a series of inference optimization technologies such as quantization and compression, aiming to accelerate large language dense models, MoE models, multimodal understanding models, multimodal generation models, etc.
-
-Once you have a quantized model which is generated by **msModelSlim**, you can use vLLM Omni for inference by specifying the --quantization ascend parameter to enable quantization features.
-
-### Supported Schemes
-
-| Scheme | Bits | Status |
-|--------|------|--------|
-| W8A8 | 8 | ✅ Supported |
-| W4A4 | 4 | Planned |
-
-W8A8 is the first supported scheme. Additional schemes will be added in future releases.
-
-## Model Quantization
-
-The following example shows how to generate W8A8 quantized weights for the [Wan2_2 model](https://gitcode.com/Ascend/msmodelslim/blob/master/example/multimodal_sd/Wan2_2/README.md).
-
-**Quantization Script:**
-
-```bash
-msmodelslim quant \
- --model_path /path/to/wan2_2_t2v_float_weights \
- --save_path /path/to/wan2_2_t2v_quantized_weights \
- --device npu \
- --model_type Wan2_2 \
- --config_path /lab_practice/wan2_2/wan2_2_w8a8f8_mxfp_t2v.yaml \
- --trust_remote_code True
-```
-
-After quantization completes, the output directory will contain the quantized model files.
-
-For more examples, refer to the [official examples](https://gitcode.com/Ascend/msit/tree/master/msmodelslim/example).
-
-## Configuration
-
-1. **CLI**: pass `--quantization ascend`.
-
-```bash
-# Offline inference
-python text_to_image.py --model --quantization ascend
-
-# Online serving
-vllm serve --omni --quantization ascend
-```
-
-## Supported Models
-
-| Model | HF Models | Recommendation | `ignored_layers` |
-|-------|-----------|---------------|------------------|
-| HunyuanImage-3.0 | - | All layers | None |
-
-Currently, quantized HunyuanImage-3.0 weights have not been uploaded to public model platforms such as Hugging Face. You can use a [HunyuanImage-3.0-adapted msModelSlim version](https://gitcode.com/betta18/msmodelslim/tree/hyimage3_mxfp8) to generate the quantized weights manually. We will upload the quantized weights as soon as possible.
diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md
index be1602788b7..9cd407d377a 100644
--- a/docs/user_guide/diffusion_features.md
+++ b/docs/user_guide/diffusion_features.md
@@ -12,9 +12,9 @@
vLLM-Omni supports various advanced features for diffusion models:
-- Acceleration: **cache methods**, **parallelism methods**, **startup optimizations**
+- Acceleration: **cache methods**, **parallelism methods**
- Memory optimization: **cpu offloading**, **quantization**
-- Extensions: **LoRA inference**, **frame interpolation**
+- Extensions: **LoRA inference**
- Execution modes: **step execution**
## Supported Features
@@ -44,12 +44,6 @@ Parallelism methods distribute computation across GPUs without quality loss (mat
| **[HSDP](diffusion/parallelism/hsdp.md)** | Weight sharding via FSDP2, redistributed on-demand at runtime | Very large models (14B+) on limited VRAM, combinable with SP |
| **[Expert Parallelism](diffusion/parallelism/expert_parallel.md)** | Shards MoE expert MLP blocks across devices | MoE diffusion models (e.g., HunyuanImage3.0) |
-#### Startup Optimization
-
-| Method | Description | Best For |
-|--------|-------------|----------|
-| **[Multi-Thread Weight Loading](#multi-thread-weight-loading)** | Loads safetensors shards in parallel using a thread pool | All diffusion models; reduces startup from minutes to seconds |
-
**Note:** Some acceleration methods can be combined together for optimized performance. See [Feature Compatibility Table](#feature-compatibility) and [Feature Compatibility Tutorial](feature_compatibility.md) for detailed configuration examples.
### Memory Optimization
@@ -69,7 +63,6 @@ Extension methods add specialized capabilities to diffusion models beyond standa
| Method | Description | Best For |
|--------|-------------|----------|
| **[LoRA Inference](diffusion/lora.md)** | Enables inference with Low-Rank Adaptation (LoRA) adapters weights | Reinforcement learning extensions |
-| **[Frame Interpolation](diffusion/frame_interpolation.md)** | Inserts intermediate video frames after generation for smoother motion | Video generation pipelines that need higher temporal smoothness |
### Execution Modes
@@ -107,55 +100,47 @@ The following tables show which models support each feature:
| Model | ⚡TeaCache | ⚡Cache-DiT | 🔀SP (Ulysses & Ring) | 🔀CFG-Parallel | 🔀Tensor-Parallel | 🔀HSDP | 💾CPU Offload (Layerwise) | 💾VAE-Patch-Parallel | 💾Quantization | 🔄Step Execution |
|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:|
-| **Bagel** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
+| **Bagel** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **FLUX.1-dev** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
-| **FLUX.1-schnell** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
| **FLUX.2-klein** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
| **FLUX.1-Kontext-dev** | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
-| **FLUX.2-dev** | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
+| **FLUX.2-dev** | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **GLM-Image** | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **HunyuanImage3** | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
-| **LongCat-Image** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
-| **LongCat-Image-Edit** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
-| **MagiHuman** | ❌ | ❌ | ❌ | ❓ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
+| **LongCat-Image** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
+| **LongCat-Image-Edit** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **MammothModa2(T2I)** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
-| **Nextstep_1(T2I)** | ❓ | ❓ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
-| **OmniGen2** | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
-| **Ovis-Image** | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
-| **Qwen-Image** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ✅ |
-| **Qwen-Image-2512** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ✅ |
-| **Qwen-Image-Edit** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ |
-| **Qwen-Image-Edit-2509** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ | ❌ |
-| **Qwen-Image-Layered** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ |
-| **Stable-Diffusion3.5** | ❌ | ✅ | ❌ | ✅ | ✅ | ❌ | ✅ | ✅ (decode) | ❌ | ❌ |
-| **Z-Image** | ✅ | ✅ | ✅ | ❓ | ✅ (TP=2 only) | ✅ | ❌ | ✅ (decode) | ✅ | ❌ |
+| **Nextstep_1(T2I)** | ❓ | ❓ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
+| **OmniGen2** | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
+| **Ovis-Image** | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
+| **Qwen-Image** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| **Qwen-Image-2512** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+| **Qwen-Image-Edit** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
+| **Qwen-Image-Edit-2509** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
+| **Qwen-Image-Layered** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
+| **Stable-Diffusion3.5** | ❌ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
+| **Z-Image** | ✅ | ✅ | ✅ | ❓ | ✅ (TP=2 only) | ✅ | ❌ | ✅ | ✅ | ❌ |
> Notes:
> 1. Nextstep_1(T2I) does not support cache acceleration methods such as TeaCache or Cache-DiT.
-> 2. `Tongyi-MAI/Z-Image-Turbo` and `SII-GAIR/daVinci-MagiHuman-Base-1080p` are distilled models with minimal NFEs; CFG-Parallel is not necessary.
+> 2. `Tongyi-MAI/Z-Image-Turbo` is a distilled model with minimal NFEs; CFG-Parallel is not necessary.
### VideoGen
| Model | ⚡TeaCache | ⚡Cache-DiT | 🔀SP (Ulysses & Ring) | 🔀CFG-Parallel | 🔀Tensor-Parallel | 🔀HSDP | 💾CPU Offload (Layerwise) | 💾VAE-Patch-Parallel | 💾Quantization | 🔄Step Execution |
|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:|
-| **Wan2.2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (encode/decode) | ❌ | ❌ |
-| **Wan2.1-VACE** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ |
-| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
-| **LTX-2.3** | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
+| **Wan2.2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
+| **Wan2.1-VACE** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
+| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **Helios** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
-| **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ |
-| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
-
-**Frame Interpolation Support**
-
-- **Supported**: Wan2.2 text-to-video, image-to-video, and TI2V pipelines
-- **Not supported**: Wan2.1-VACE, LTX-2, LTX-2.3, Helios, HunyuanVideo-1.5, DreamID-Omni
+| **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
+| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
### AudioGen
| Model | ⚡TeaCache | ⚡Cache-DiT | 🔀SP (Ulysses & Ring) | 🔀CFG-Parallel | 🔀Tensor-Parallel | 🔀HSDP | 💾CPU Offload (Layerwise) | 💾VAE-Patch-Parallel | 💾Quantization | 🔄Step Execution |
|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:|
-| **Stable-Audio-Open** | ✅ | ❌ | ❓ | ❓ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ |
+| **Stable-Audio-Open** | ❌ | ❌ | ❓ | ❓ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
## Feature Compatibility
@@ -193,59 +178,6 @@ The following tables show which models support each feature:
6. Step Execution is not compatible with cache backends (TeaCache, Cache-DiT) or LoRA.
-## Multi-Thread Weight Loading
-
-Large diffusion models can take several minutes to load weights at startup (e.g., ~3 min for Qwen-Image, ~5 min for Wan2.2 I2V 14B). Multi-thread weight loading speeds up this process by loading safetensors shards in parallel using a thread pool instead of sequentially.
-
-This optimization is **enabled by default** with 4 threads. No configuration is needed for the default behavior.
-
-### Configuration
-
-| Parameter | CLI Flag | Default | Description |
-|-----------|----------|---------|-------------|
-| `enable_multithread_weight_load` | `--disable-multithread-weight-load` | `True` (enabled) | Pass the flag to disable multi-thread loading |
-| `num_weight_load_threads` | `--num-weight-load-threads` | `4` | Number of threads for parallel weight loading |
-
-!!! tip
- The default of 4 threads balances speed and disk I/O contention. On fast NVMe storage you may benefit from more threads (e.g., 8). On HDD or network storage, the default of 4 avoids saturating I/O bandwidth.
-
-### Online Serving
-
-```bash
-# Default (multi-thread enabled, 4 threads)
-vllm serve Qwen/Qwen-Image --omni --port 8091
-
-# Custom thread count
-vllm serve Wan-AI/Wan2.2-I2V-A14B-Diffusers --omni --num-weight-load-threads 8
-
-# Disable multi-thread loading
-vllm serve Qwen/Qwen-Image --omni --disable-multithread-weight-load
-```
-
-### Offline Inference
-
-```python
-from vllm_omni import Omni
-
-# Default (multi-thread enabled, 4 threads)
-omni = Omni(model="Qwen/Qwen-Image")
-
-# Custom thread count
-omni = Omni(
- model="Wan-AI/Wan2.2-I2V-A14B-Diffusers",
- num_weight_load_threads=8,
-)
-```
-
-### Benchmarks
-
-Measured on NVIDIA H800:
-
-| Model | Before | After | Speedup |
-|-------|--------|-------|---------|
-| **Qwen/Qwen-Image** (53.7 GiB) | 168s | 27s | **6.2x** |
-| **Wan-AI/Wan2.2-I2V-A14B-Diffusers** (64.5 GiB) | 283s | 56s | **5.1x** |
-
## Learn More
**Cache Acceleration:**
@@ -266,16 +198,11 @@ Measured on NVIDIA H800:
**Extensions:**
- **[LoRA Inference Guide](diffusion/lora.md)** - Low-Rank Adaptation for style customization and fine-tuning
-- **[Frame Interpolation Guide](diffusion/frame_interpolation.md)** - Worker-side post-generation video frame interpolation for smoother motion
**Execution Modes:**
- **[Step Execution Guide](diffusion/step_execution.md)** - Per-step denoise execution with mid-request abort support
-**Startup Optimization:**
-
-- **[Multi-Thread Weight Loading](#multi-thread-weight-loading)** - Speed up model startup by loading safetensors shards in parallel
-
**Advanced Topics:**
- **[Feature Compatibility](feature_compatibility.md)** - How to combine multiple features for maximum performance
diff --git a/docs/user_guide/examples/offline_inference/bagel.md b/docs/user_guide/examples/offline_inference/bagel.md
index 0d3498b28d9..5f458750b44 100644
--- a/docs/user_guide/examples/offline_inference/bagel.md
+++ b/docs/user_guide/examples/offline_inference/bagel.md
@@ -2,61 +2,46 @@
Source .
-## Setup
-Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup.
-
-## Architecture
+## Set up
-BAGEL-7B-MoT is a Mixture-of-Transformers (MoT) model supporting both image generation and understanding. It offers two deployment topologies:
+Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup.
-| Topology | Stages | Description |
-| :------- | :----- | :---------- |
-| **Two-stage** (default) | Stage 0 (Thinker, AR) + Stage 1 (DiT, Diffusion) | Thinker handles text/understanding via vLLM AR engine; DiT handles image generation. KV cache is transferred between stages. |
-| **Single-stage** | Stage 0 (DiT, Diffusion) only | The DiT stage contains a full LLM, ViT, VAE, and tokenizer internally. All modalities are handled within a single diffusion process. |
+## Run examples
-Both topologies support all four modalities: `text2img`, `img2img`, `img2text`, `text2text`.
+**Note**: These examples work with the default configuration on an **NVIDIA A100 (80GB)**. We also tested on dual **NVIDIA RTX 5000 Ada (32GB each)**. For dual-GPU setups, please modify the stage configuration to distribute the model across devices.
-## Quick Start
+Get into the bagel folder
```bash
cd examples/offline_inference/bagel
-
-# Default two-stage mode (auto-detected)
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality text2img \
- --prompts "A cute cat"
-
-# Single-stage mode
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality text2img \
- --prompts "A cute cat" \
- --deploy-config vllm_omni/deploy/bagel_single_stage.yaml
```
-> **Note**: These examples work with the default configuration on an **NVIDIA A100 (80GB)**. For dual-GPU setups, modify the deploy YAML to distribute stages across devices.
+### Modality Control
-## Modality Control
+BAGEL-7B-MoT supports multiple modality modes. You can control the mode using the `--modality` argument:
-Control the mode using the `--modality` argument:
+#### Text to Image (text2img)
-| Modality | Input | Output | Description |
-| :------- | :---- | :----- | :---------- |
-| `text2img` | Text | Image | Generate images from text prompts |
-| `img2img` | Image + Text | Image | Transform images using text guidance |
-| `img2text` | Image + Text | Text | Generate text descriptions from images |
-| `text2text` | Text | Text | Pure text generation (language model mode) |
+- **Pipeline**: Text → Thinker → DiT → VAE Decode → Image
+- **Stages Used**: Stage 0 (Thinker) + Stage 1 (DiT)
+- **KV Transfer**: Thinker sends KV cache to DiT for conditioned generation
-### Text to Image (text2img)
+Generate images from text prompts:
```bash
python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
--modality text2img \
- --prompts "A cute cat" \
- --steps 50
+ --prompts "A cute cat"
```
-### Image to Image (img2img)
+#### Image to Image (img2img)
+
+- **Pipeline**: Image → VAE Encode → DiT → VAE Decode → New Image
+- **Stages Used**: Stage 1 (DiT) only
+- **Special**: Bypasses the Thinker stage, direct image-to-image transformation
+
+Transform images based on text prompts:
```bash
python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
@@ -65,7 +50,13 @@ python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
--prompts "Let the woman wear a blue dress"
```
-### Image to Text (img2text)
+#### Image to Text (img2text)
+
+- **Pipeline**: Image → ViT + VAE Encode → Thinker → Text Output
+- **Stages Used**: Stage 0 (Thinker) only
+- **Special**: Uses both VAE latent encoding AND ViT semantic encoding for comprehensive image understanding
+
+Generate text descriptions from images:
```bash
python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
@@ -74,210 +65,205 @@ python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
--prompts "Describe this image in detail"
```
-### Text to Text (text2text)
+#### Text to Text (text2text)
+
+- **Pipeline**: Text → Thinker → Text Output
+- **Stages Used**: Stage 0 (Thinker) only
+- **Special**: No visual components involved, operates as pure language model
+
+Pure text generation:
```bash
python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
--modality text2text \
--prompts "What is the capital of France?"
-# Load prompts from a text file (one prompt per line):
+# You can load prompts from a text file (one prompt per line):
python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
--modality text2text \
--txt-prompts /path/to/prompts.txt
```
-## Think Mode
-
-Think mode enables the model to generate `... ` planning/reasoning tokens before producing the final output. This improves generation quality for complex prompts.
+### Inference Steps
-- **Two-stage**: The Thinker (AR) stage decodes think tokens, then transfers the augmented KV cache to the DiT stage for image generation.
-- **Single-stage**: The DiT's internal LLM generates think tokens in-place before proceeding to denoise.
+Control the number of inference steps for image generation:
```bash
-# Think + text2img: plan before generating
+# You can adjust steps to 100 to improve image quality
python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
--modality text2img \
- --prompts "A futuristic city with flying cars" \
- --think \
- --max-think-tokens 1000
+ --steps 50 \
+ --prompts "A cute cat"
+```
-# Think + img2img: reason about the edit
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality img2img \
- --image-path /path/to/image.jpg \
- --prompts "Make it look like a watercolor painting" \
- --think
+### Key arguments
-# Think + img2text: reason before describing
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality img2text \
- --image-path /path/to/image.jpg \
- --prompts "What is happening in this image?" \
- --think
+BAGEL-7B-MoT supports **multiple modality modes** for different use cases.
-# Think + text2text: chain-of-thought reasoning
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality text2text \
- --prompts "Solve: 23 * 47" \
- --think
-```
+The default yaml configuration deploys Thinker and DiT on the same GPU. You can use the default configuration file: [`bagel.yaml`](https://github.com/vllm-project/vllm-omni/tree/main/vllm_omni/model_executor/stage_configs/bagel.yaml)
-Think mode parameters:
+#### 📌 Command Line Arguments (end2end.py)
-| Argument | Default | Description |
-| :------- | :------ | :---------- |
-| `--think` | `False` | Enable thinking mode |
-| `--max-think-tokens` | `1000` | Maximum tokens for think generation |
-| `--do-sample` | `False` | Enable sampling (vs. greedy) for text generation |
-| `--text-temperature` | `0.3` | Temperature for text generation sampling |
+| Argument | Type | Default | Description |
+| :--------------------- | :----- | :---------------------------- | :----------------------------------------------------------- |
+| `--model` | string | `ByteDance-Seed/BAGEL-7B-MoT` | Model path or name |
+| `--modality` | choice | `text2img` | Modality mode: `text2img`, `img2img`, `img2text`, `text2text` |
+| `--prompts` | list | `None` | Input text prompts directly |
+| `--txt-prompts` | string | `None` | Path to txt file with one prompt per line |
+| `--image-path` | string | `None` | Input image path (for `img2img`/`img2text`) |
+| `--steps` | int | `50` | Number of inference steps |
+| `--stage-configs-path` | string | `None` | Custom stage config file path |
+| `--worker-backend` | choice | `process` | Worker backend: `process` or `ray` |
+| `--ray-address` | string | `None` | Ray cluster address |
+| `--enable-stats` | flag | `False` | Enable statistics logging |
+| `--init-sleep-seconds` | int | `20` | Initialization sleep time |
+| `--batch-timeout` | int | `5` | Batch timeout |
+| `--init-timeout` | int | `300` | Initialization timeout |
-## Classifier-Free Guidance (CFG)
+------
-CFG controls the trade-off between prompt fidelity and diversity. These parameters apply to image generation modalities (`text2img`, `img2img`).
+#### ⚙️ Stage Configuration Parameters (bagel.yaml)
-```bash
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality text2img \
- --prompts "A photorealistic portrait" \
- --cfg-text-scale 6.0 \
- --cfg-img-scale 2.0 \
- --negative-prompt "blurry, low quality, distorted" \
- --cfg-interval 0.4 1.0 \
- --cfg-renorm-type global \
- --cfg-renorm-min 0.0
-```
+ **Stage 0 - Thinker (LLM Stage)**
-| Argument | Default | Description |
-| :------- | :------ | :---------- |
-| `--cfg-text-scale` | `4.0` | Text CFG scale (higher = more prompt-adherent) |
-| `--cfg-img-scale` | `1.5` | Image CFG scale (for img2img) |
-| `--negative-prompt` | `None` | Negative prompt for CFG conditioning |
-| `--cfg-interval` | pipeline default | CFG active interval `[start, end]` as fractions of total timesteps |
-| `--cfg-renorm-type` | `None` | Renormalization type: `global`, `text_channel`, `channel` |
-| `--cfg-renorm-min` | `None` | Minimum renormalization value |
-| `--cfg-parallel-size` | `1` | CFG parallel size: `1` = batched (single GPU), `2` = 2-branch parallel, `3` = full 3-GPU parallel |
+| Parameter | Value | Description |
+| :------------------------------- | :------------------------------ | :----------------------- |
+| `stage_type` | `llm` | Stage type |
+| `devices` | `"0"` | GPU device ID |
+| `max_num_seqs` | `1` | Maximum batch size |
+| `model_stage` | `thinker` | Model stage identifier |
+| `model_arch` | `BagelForConditionalGeneration` | Model architecture |
+| `gpu_memory_utilization` | `0.4` | GPU memory utilization |
+| `tensor_parallel_size` | `1` | Tensor parallel size |
+| `max_num_batched_tokens` | `32768` | Maximum batched tokens |
+| `omni_kv_config.need_send_cache` | `true` | Whether to send KV cache |
-## Deployment Topologies
+------
-### Two-Stage (Default)
+**Stage 1 - DiT (Diffusion Stage)**
-The default topology auto-detected from the model. No extra flags needed.
+| Parameter | Value | Description |
+| :------------------------------- | :---------- | :-------------------------- |
+| `stage_type` | `diffusion` | Stage type |
+| `devices` | `"0"` | GPU device ID |
+| `max_num_seqs` | `1` | Maximum batch size |
+| `model_stage` | `dit` | Model stage identifier |
+| `gpu_memory_utilization` | `0.4` | GPU memory utilization |
+| `omni_kv_config.need_recv_cache` | `true` | Whether to receive KV cache |
+| `engine_input_source` | `[0]` | Input source from Stage 0 |
-```bash
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality text2img \
- --prompts "A cute cat"
-```
+------
-The pipeline is defined in [`bagel.yaml`](https://github.com/vllm-project/vllm-omni/tree/main/vllm_omni/deploy/bagel.yaml). Stage 0 (Thinker) and Stage 1 (DiT) share GPU 0 by default. For dual-GPU setups, customize the deploy YAML and set `devices: "1"` for stage 1.
+#### Tensor Parallelism (TP)
-### Single-Stage
+For larger models or multi-GPU environments, you can enable Tensor Parallelism (TP) by modifying the stage configuration (e.g., [`bagel.yaml`](https://github.com/vllm-project/vllm-omni/tree/main/vllm_omni/model_executor/stage_configs/bagel.yaml)).
-Pass the single-stage deploy config via `--deploy-config`:
+1. **Set `tensor_parallel_size`**: Increase this value (e.g., to `2` or `4`).
+2. **Set `devices`**: Specify the comma-separated GPU IDs to be used for the stage (e.g., `"0,1"`).
-```bash
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality text2img \
- --prompts "A cute cat" \
- --deploy-config vllm_omni/deploy/bagel_single_stage.yaml
+Example configuration for TP=2 on GPUs 0 and 1:
+```yaml
+ engine_args:
+ tensor_parallel_size: 2
+ ...
+ runtime:
+ devices: "0,1"
```
-See [`bagel_single_stage.yaml`](https://github.com/vllm-project/vllm-omni/tree/main/vllm_omni/deploy/bagel_single_stage.yaml) for configuration details. The `pipeline: bagel_single_stage` field selects the single-stage topology from the pipeline registry.
+------
-### Tensor Parallelism (TP)
+#### 🔗 Runtime Configuration
-For larger models or multi-GPU environments, customize the deploy YAML (see [`bagel.yaml`](https://github.com/vllm-project/vllm-omni/tree/main/vllm_omni/deploy/bagel.yaml)) and set per-stage `tensor_parallel_size` and `devices`:
+| Parameter | Value | Description |
+| :-------------------- | :------ | :------------------------------- |
+| `window_size` | `-1` | Window size (-1 means unlimited) |
+| `max_inflight` | `1` | Maximum inflight requests |
+| `shm_threshold_bytes` | `65536` | Shared memory threshold (64KB) |
-```yaml
-# Example: TP=2 on GPUs 0,1 for the Thinker stage
-stages:
- - stage_id: 0
- tensor_parallel_size: 2
- devices: "0,1"
+## Using Mooncake Connector
+
+[Mooncake](https://github.com/kvcache-ai/Mooncake) is a high-performance distributed KV cache transfer engine that enables efficient cross-node data movement via TCP or RDMA, making it ideal for multi-node disaggregated inference.
+
+By default, BAGEL uses `SharedMemoryConnector` for inter-stage communication. You can switch to the Mooncake connector for better performance on multi-GPU setups and to enable multi-node deployment.
+
+### Prerequisites
+
+Install the Mooncake transfer engine:
+
+```bash
+# For CUDA-enabled systems (recommended)
+pip install mooncake-transfer-engine
+
+# For non-CUDA systems
+pip install mooncake-transfer-engine-non-cuda
```
-Then pass the custom deploy YAML:
+### Step 1: Start the Mooncake Master
+
+On the **primary node**, start the Mooncake master service (run in a separate terminal or background with `&`):
```bash
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality text2img \
- --prompts "A cute cat" \
- --deploy-config /path/to/custom_bagel.yaml
+# Optional: enable disk-backed storage by creating a directory and passing --root_fs_dir.
+# Without it, Mooncake runs in memory-only mode, which is sufficient for KV cache transfer.
+mkdir -p ./mc_storage
+
+mooncake_master \
+ --rpc_port=50051 \
+ --enable_http_metadata_server=true \
+ --http_metadata_server_host=0.0.0.0 \
+ --http_metadata_server_port=8080 \
+ --metrics_port=9003 \
+ --root_fs_dir=./mc_storage/ \
+ --cluster_id=mc-local-1 &
```
-### FP8 Quantization
+### Step 2: Run Offline Inference with Mooncake
+
+Use the provided Mooncake stage config [`bagel_multiconnector.yaml`](https://github.com/vllm-project/vllm-omni/tree/main/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml). Before launching, update the `metadata_server` and `master` addresses in the YAML to match your Mooncake master node's IP (use `127.0.0.1` for single-node testing).
```bash
+cd examples/offline_inference/bagel
+
+# Text to Image with Mooncake
python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
--modality text2img \
--prompts "A cute cat" \
- --quantization fp8
+ --stage-configs-path ../../../vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml
+
+# Image to Text with Mooncake
+python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
+ --modality img2text \
+ --image-path /path/to/image.jpg \
+ --prompts "Describe this image" \
+ --stage-configs-path ../../../vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml
+
+# Text to Text with Mooncake
+python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
+ --modality text2text \
+ --prompts "What is the capital of France?" \
+ --stage-configs-path ../../../vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml
```
-## Command Line Reference
-
-### Core Arguments
-
-| Argument | Type | Default | Description |
-| :------- | :--- | :------ | :---------- |
-| `--model` | string | `ByteDance-Seed/BAGEL-7B-MoT` | Model path or HuggingFace name |
-| `--modality` | choice | `text2img` | `text2img`, `img2img`, `img2text`, `text2text` |
-| `--prompts` | list | `None` | Input text prompts |
-| `--txt-prompts` | string | `None` | Path to text file with one prompt per line |
-| `--image-path` | string | `None` | Input image path (required for `img2img`/`img2text`) |
-| `--output` | string | `.` | Output directory for saved images |
-| `--steps` | int | `50` | Number of diffusion inference steps |
-| `--seed` | int | `None` | Random seed for reproducibility |
-
-### Think Mode Arguments
-
-| Argument | Type | Default | Description |
-| :------- | :--- | :------ | :---------- |
-| `--think` | flag | `False` | Enable `... ` planning/reasoning |
-| `--max-think-tokens` | int | `1000` | Maximum tokens for think generation |
-| `--do-sample` | flag | `False` | Use sampling instead of greedy decoding |
-| `--text-temperature` | float | `0.3` | Sampling temperature for text generation |
-
-### CFG Arguments
-
-| Argument | Type | Default | Description |
-| :------- | :--- | :------ | :---------- |
-| `--cfg-text-scale` | float | `4.0` | Text CFG guidance scale |
-| `--cfg-img-scale` | float | `1.5` | Image CFG guidance scale |
-| `--negative-prompt` | string | `None` | Negative prompt for CFG |
-| `--cfg-parallel-size` | int | `1` | CFG parallel GPU count (1, 2, or 3) |
-| `--cfg-interval` | float[2] | pipeline default | CFG active window `[start, end]` |
-| `--cfg-renorm-type` | string | `None` | `global`, `text_channel`, or `channel` |
-| `--cfg-renorm-min` | float | `None` | Minimum renormalization value |
-
-### Engine Arguments
-
-| Argument | Type | Default | Description |
-| :------- | :--- | :------ | :---------- |
-| `--deploy-config` | string | `None` | Path to deploy YAML (auto-detected if omitted) |
-| `--stage-configs-path` | string | `None` | [Deprecated] Legacy path to `stage_args` YAML; prefer `--deploy-config` |
-| `--worker-backend` | choice | `process` | `process` or `ray` |
-| `--ray-address` | string | `None` | Ray cluster address |
-| `--quantization` | string | `None` | Quantization method (e.g. `fp8`) |
-| `--log-stats` | flag | `False` | Enable statistics logging |
-| `--init-timeout` | int | `300` | Initialization timeout (seconds) |
-| `--batch-timeout` | int | `5` | Batch timeout (seconds) |
-| `--enable-diffusion-pipeline-profiler` | flag | `False` | Profile diffusion stage durations |
+For more details on the Mooncake connector and multi-node setup, see the [Mooncake Store Connector documentation](https://github.com/vllm-project/vllm-omni/tree/main/docs/design/feature/omni_connectors/mooncake_store_connector.md).
+
+------
## FAQ
-- If you encounter OOM errors, try decreasing `max_model_len` or `gpu_memory_utilization` in the deploy YAML.
+- If you encounter an error about the backend of librosa, try to install ffmpeg with the command below.
-**Two-stage VRAM usage:**
+```bash
+sudo apt update
+sudo apt install ffmpeg
+```
-| Stage | VRAM |
-| :---- | :--- |
-| Stage 0 (Thinker) | **15.04 GiB + KV Cache** |
-| Stage 1 (DiT) | **26.50 GiB** |
-| Total | **~42 GiB + KV Cache** |
+- If you don’t know how much VRAM is needed for the model or encounter the OOM error, you can try to decrease the max_model_len.
-**Single-stage VRAM usage:** The DiT loads the full model (~42 GiB) in one process.
+| Stage | VRAM |
+| :------------------ | :--------------------------- |
+| Stage-0 (Thinker) | **15.04 GiB** **+ KV Cache** |
+| Stage-1 (DiT) | **26.50 GiB** |
+| Total | **~42 GiB + KV Cache** |
## Example materials
diff --git a/docs/user_guide/examples/offline_inference/cosyvoice3.md b/docs/user_guide/examples/offline_inference/cosyvoice3.md
index d0638f4140f..d912f1c62eb 100644
--- a/docs/user_guide/examples/offline_inference/cosyvoice3.md
+++ b/docs/user_guide/examples/offline_inference/cosyvoice3.md
@@ -10,7 +10,7 @@ Install dependencies:
uv pip install -e .
```
-> **Note:** This includes required libraries such as `soundfile`,
+> **Note:** This includes required libraries such as `librosa`, `soundfile`,
> `onnxruntime`, `x-transformers`, and `einops` via
> `requirements/common.txt` and platform-specific requirements files.
@@ -61,17 +61,10 @@ Key components live in `vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py
- Stage 0 uses `CosyVoice3LM` and outputs speech tokens + conditioning features.
- Stage 1 runs the flow model (DiT-based CFM) and HiFiGAN to synthesize waveform.
-Pipeline topology lives in `vllm_omni/model_executor/models/cosyvoice3/pipeline.py`;
-runtime tunables (batch size, memory limits, sampling) live in
-`vllm_omni/deploy/cosyvoice3.yaml`. The deploy config auto-loads by
-HF `model_type` and defaults to `async_chunk: true` (shared-memory
-streaming). Pass `--no-async-chunk` on `vllm serve` to switch to the
-legacy sync path where stage 1 runs `text2flow` over the full
-speech-token sequence.
+Stage wiring is configured in `vllm_omni/model_executor/stage_configs/cosyvoice3.yaml`.
- Stage 0 emits latent speech tokens.
-- Stage 1 consumes them via `sync_process_input_func` (sync mode) or the
- shared-memory connector (async-chunk mode) and outputs audio.
+- Stage 1 consumes them via `custom_process_input_func` and outputs audio.
## Example materials
diff --git a/docs/user_guide/examples/offline_inference/glm_image.md b/docs/user_guide/examples/offline_inference/glm_image.md
index c6ac6e33ffd..4519e26fda6 100644
--- a/docs/user_guide/examples/offline_inference/glm_image.md
+++ b/docs/user_guide/examples/offline_inference/glm_image.md
@@ -1,87 +1,154 @@
-# GLM-Image Offline Inference
+# GLM-Image Multistage End-to-End Inference
-GLM-Image is a 2-stage image generation model (AR + Diffusion) supported by vLLM-Omni's
-declarative config system. The pipeline topology and stage structure are declared in
-`vllm_omni/model_executor/models/glm_image/pipeline.py`; deployment knobs live in
-`vllm_omni/deploy/glm_image.yaml`.
+Source .
+
+
+This example demonstrates how to run GLM-Image with the vLLM-Omni multistage architecture.
## Architecture
+GLM-Image uses a 2-stage pipeline:
+
+```
+┌─────────────────────────────────────────────────────────────┐
+│ GLM-Image Pipeline │
+├─────────────────────────────────────────────────────────────┤
+│ │
+│ Stage 0 (AR Model) Stage 1 (Diffusion) │
+│ ┌─────────────────┐ ┌─────────────────────┐ │
+│ │ vLLM-optimized │ │ GlmImagePipeline │ │
+│ │ GlmImageFor │ prior │ ┌───────────────┐ │ │
+│ │ Conditional │──tokens───►│ │ DiT Denoiser │ │ │
+│ │ Generation │ │ └───────────────┘ │ │
+│ │ (9B AR model) │ │ │ │ │
+│ └─────────────────┘ │ ▼ │ │
+│ ▲ │ ┌───────────────┐ │ │
+│ │ │ │ VAE Decode │──┼──► Image
+│ Text/Image │ └───────────────┘ │ │
+│ Input └─────────────────────┘ │
+│ │
+└─────────────────────────────────────────────────────────────┘
+```
+
+## Features
+
+- **vLLM-optimized AR**: Uses PagedAttention and tensor parallelism for faster prior token generation
+- **Flexible deployment**: AR and Diffusion stages can run on different GPUs
+- **Text-to-Image**: Generate images from text descriptions
+- **Image-to-Image**: Edit existing images with text prompts
+
+## Usage
+
+### Text-to-Image
+
+```bash
+python end2end.py \
+ --config-path ../../../vllm_omni/model_executor/stage_configs/glm_image.yaml \
+ --prompt "A beautiful sunset over the ocean with sailing boats" \
+ --height 1024 \
+ --width 1024 \
+ --output output_t2i.png
```
-Stage 0 (AR Model) Stage 1 (Diffusion)
-┌───────────────────┐ ┌─────────────────────┐
-│ vLLM-optimized │ prior │ GlmImagePipeline │
-│ GlmImageFor │──tokens──►│ ┌───────────────┐ │
-│ Conditional │ │ │ DiT Denoiser │ │
-│ Generation │ │ └───────┬───────┘ │
-│ (9B AR model) │ │ ▼ │
-└───────────────────┘ │ ┌───────────────┐ │
- ▲ │ │ VAE Decode │──┼──► Image
- │ │ └───────────────┘ │
- Text / Image └─────────────────────┘
- Input
+
+### Image-to-Image (Image Editing)
+
+```bash
+python end2end.py \
+ --config-path ../../../vllm_omni/model_executor/stage_configs/glm_image.yaml \
+ --prompt "Transform this scene into a winter wonderland" \
+ --image input.png \
+ --output output_i2i.png
```
-## Text-to-Image
-
-```python
-from vllm_omni.entrypoints.omni import Omni
-
-if __name__ == "__main__":
- omni = Omni(model="zai-org/GLM-Image")
- outputs = omni.generate(
- "A photorealistic mountain landscape at sunset",
- sampling_params={
- "height": 1024,
- "width": 1024,
- "num_inference_steps": 50,
- "guidance_scale": 1.5,
- "seed": 42,
- },
- )
- outputs[0].request_output.images[0].save("output.png")
+### With Custom Parameters
+
+```bash
+python end2end.py \
+ --model-path /path/to/glm-image \
+ --config-path ../../../vllm_omni/model_executor/stage_configs/glm_image.yaml \
+ --prompt "A photorealistic cat sitting on a window sill" \
+ --height 1024 \
+ --width 1024 \
+ --num-inference-steps 50 \
+ --guidance-scale 1.5 \
+ --seed 42 \
+ --output output.png
```
-## Image-to-Image (Image Editing)
-
-```python
-from vllm_omni.entrypoints.omni import Omni
-
-if __name__ == "__main__":
- omni = Omni(model="zai-org/GLM-Image")
- outputs = omni.generate(
- {
- "prompt": "Convert this image to watercolor style",
- "multi_modal_data": {
- "image": "input.png",
- },
- },
- sampling_params={
- "height": 1024,
- "width": 1024,
- "num_inference_steps": 50,
- "guidance_scale": 1.5,
- "seed": 42,
- },
- )
- outputs[0].request_output.images[0].save("output.png")
+## Shell Scripts
+
+### Run Text-to-Image
+
+```bash
+./run_t2i.sh
+```
+
+### Run Image-to-Image
+
+```bash
+./run_i2i.sh --image /path/to/input.png
+```
+
+## Stage Configuration
+
+The stage config (`glm_image.yaml`) defines:
+
+- **Stage 0 (AR)**: Uses `GPUARWorker` with vLLM engine
+
+ - Model: `GlmImageForConditionalGeneration`
+ - Output: `token_ids` (prior tokens)
+
+- **Stage 1 (Diffusion)**: Uses diffusion engine
+ - Model: `GlmImagePipeline`
+ - Output: Generated image
+
+See `vllm_omni/model_executor/stage_configs/glm_image.yaml` for full configuration.
+
+## Comparison with Single-Stage
+
+| Aspect | Single-Stage (transformers) | Multistage (vLLM) |
+| ----------- | --------------------------- | ------------------- |
+| AR Model | transformers native | vLLM PagedAttention |
+| Memory | Higher (no KV cache opt) | Lower (optimized) |
+| Throughput | Lower | Higher |
+| Flexibility | Single GPU | Multi-GPU support |
+
+## Troubleshooting
+
+### OOM Error
+
+Try reducing memory usage:
+
+```bash
+# In glm_image.yaml, adjust:
+gpu_memory_utilization: 0.5 # Reduce from 0.6
+```
+
+### Slow Initialization
+
+The first run loads model weights. Subsequent runs are faster:
+
+```bash
+--stage-init-timeout 900 # Increase timeout for slow storage
```
-## Generation Parameters
+## Requirements
-| Parameter | Type | Default | Description |
-| --------------------- | ----- | ------- | ----------------------------------- |
-| `height` | int | 1024 | Image height in pixels |
-| `width` | int | 1024 | Image width in pixels |
-| `num_inference_steps` | int | 50 | Number of diffusion denoising steps |
-| `guidance_scale` | float | 1.5 | Classifier-free guidance scale |
-| `seed` | int | None | Optional random seed |
-| `negative_prompt` | str | None | Negative prompt |
+- vLLM-Omni with GLM-Image support
+- CUDA-capable GPU (recommended: H100/A100 with 80GB)
+- GLM-Image model weights
-## VRAM Requirements
+## Example materials
-| Stage | VRAM |
-| :---------------- | :--------------------- |
-| Stage-0 (AR) | **~18 GiB + KV Cache** |
-| Stage-1 (DiT+VAE) | **~20 GiB** |
-| Total | **~38 GiB + KV Cache** |
+??? abstract "end2end.py"
+ ``````py
+ --8<-- "examples/offline_inference/glm_image/end2end.py"
+ ``````
+??? abstract "run_i2i.sh"
+ ``````sh
+ --8<-- "examples/offline_inference/glm_image/run_i2i.sh"
+ ``````
+??? abstract "run_t2i.sh"
+ ``````sh
+ --8<-- "examples/offline_inference/glm_image/run_t2i.sh"
+ ``````
diff --git a/docs/user_guide/examples/offline_inference/image_to_video.md b/docs/user_guide/examples/offline_inference/image_to_video.md
index 6e105741a7e..7a750aeff3b 100644
--- a/docs/user_guide/examples/offline_inference/image_to_video.md
+++ b/docs/user_guide/examples/offline_inference/image_to_video.md
@@ -62,13 +62,12 @@ Key arguments:
- `--negative-prompt`: Optional list of artifacts to suppress.
- `--boundary-ratio`: Boundary split ratio for two-stage MoE models.
- `--flow-shift`: Scheduler flow shift (5.0 for 720p, 12.0 for 480p).
-- `--sample-solver`: Wan2.2 sampling solver. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints.
- `--num-inference-steps`: Number of denoising steps (default 50).
- `--fps`: Frames per second for the saved MP4 (requires `diffusers` export_to_video).
- `--output`: Path to save the generated video.
- `--vae-use-slicing`: Enable VAE slicing for memory optimization.
- `--vae-use-tiling`: Enable VAE tiling for memory optimization.
-- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](https://github.com/vllm-project/vllm-omni/tree/main/docs/user_guide/diffusion/parallelism/cfg_parallel.md).
+- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](https://github.com/vllm-project/vllm-omni/tree/main/docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel).
- `--tensor-parallel-size`: tensor parallel size (effective for models that support TP, e.g. LTX2).
- `--enable-cpu-offload`: enable CPU offloading for diffusion models.
- `--use-hsdp`: Enable Hybrid Sharded Data Parallel to shard model weights across GPUs.
@@ -79,9 +78,6 @@ Key arguments:
> ℹ️ If you encounter OOM errors, try using `--vae-use-slicing` and `--vae-use-tiling` to reduce memory usage.
-For Wan2.2 LightX2V-converted local Diffusers directories and related LoRA
-assets, see the [LoRA guide](../../diffusion/lora.md#wan22-lightx2v-offline-assembly).
-
## Example materials
??? abstract "image_to_video.py"
diff --git a/docs/user_guide/examples/offline_inference/mimo_audio.md b/docs/user_guide/examples/offline_inference/mimo_audio.md
index 1cba2f77dcf..1a3be15d69a 100644
--- a/docs/user_guide/examples/offline_inference/mimo_audio.md
+++ b/docs/user_guide/examples/offline_inference/mimo_audio.md
@@ -38,6 +38,7 @@ Run a single sample for basic TTS:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type tts_sft
```
@@ -46,6 +47,7 @@ Run batch samples for basic TTS:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type tts_sft \
--num-prompts {batch_size}
@@ -63,6 +65,7 @@ Generate speech from text input:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type tts_sft \
--text "The weather is so nice today."
@@ -74,6 +77,7 @@ Generate speech with explicit voice style instructions:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type tts_sft_with_instruct \
--text "The weather is so nice today." \
@@ -86,6 +90,7 @@ Generate speech using an audio reference for voice cloning:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type tts_sft_with_audio \
--text "The weather is so nice today." \
@@ -98,6 +103,7 @@ Generate speech from text containing natural voice descriptions:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type tts_sft_with_natural_instruction \
--text "In a panting young male voice, he said: I can't run anymore, wait for me!"
@@ -109,6 +115,7 @@ Transcribe audio to text:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type audio_trancribing_sft \
--audio-path "./spoken_dialogue_assistant_turn_1.wav"
@@ -120,6 +127,7 @@ Understand and analyze audio content with text queries:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type audio_understanding_sft \
--text "Summarize the audio." \
@@ -132,6 +140,7 @@ Audio understanding with reasoning chain:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type audio_understanding_sft_with_thinking \
--text "Summarize the audio." \
@@ -144,6 +153,7 @@ Multi-turn dialogue with audio input and output:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type spoken_dialogue_sft_multiturn \
--audio-path "./prompt_speech_zh_m.wav"
@@ -157,6 +167,7 @@ Multi-turn dialogue converting speech to text:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type speech2text_dialogue_sft_multiturn
```
@@ -169,6 +180,7 @@ Multi-turn text-only dialogue:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type text_dialogue_sft_multiturn
```
@@ -177,6 +189,29 @@ Note: This task uses hardcoded message lists in the script.
## Troubleshooting
+### Audio dependencies (soundfile, librosa)
+
+This example depends on **soundfile** (read/write WAV) and **librosa** (load audio including MP3). Install the project requirements first:
+
+```bash
+pip install -r requirements/common.txt
+# or at least: pip install soundfile>=0.13.1 librosa>=0.11.0
+```
+
+- **`soundfile` / libsndfile not found**
+ `soundfile` uses the C library **libsndfile**. On Linux, install the system package before pip:
+ - Debian/Ubuntu: `sudo apt-get install libsndfile1`
+ - For development builds: `sudo apt-get install libsndfile1-dev`
+ - Then: `pip install soundfile`
+
+- **`librosa` fails to load MP3 or reports "No backend available"**
+ Loading MP3 (e.g. in `spoken_dialogue_sft_multiturn` with `.mp3` files) uses **ffmpeg** as the backend. Install ffmpeg:
+ - Debian/Ubuntu: `sudo apt-get install ffmpeg`
+ - macOS: `brew install ffmpeg`
+
+- **`ImportError: No module named 'soundfile'` or `ModuleNotFoundError: ... librosa`**
+ Ensure you are in the same Python environment where vLLM Omni and the example dependencies are installed, and that `requirements/common.txt` (or the packages above) are installed.
+
### Tokenizer path
- **`MIMO_AUDIO_TOKENIZER_PATH` not set or model fails to find tokenizer**
diff --git a/docs/user_guide/examples/offline_inference/qwen2_5_omni.md b/docs/user_guide/examples/offline_inference/qwen2_5_omni.md
index c54976b540d..07a56cf9a06 100644
--- a/docs/user_guide/examples/offline_inference/qwen2_5_omni.md
+++ b/docs/user_guide/examples/offline_inference/qwen2_5_omni.md
@@ -64,6 +64,14 @@ If media file paths are not provided, the script will use default assets. Suppor
- `use_audio_in_video`: Extract audio from video
- `text`: Text-only query
+### FAQ
+
+If you encounter error about backend of librosa, try to install ffmpeg with command below.
+```
+sudo apt update
+sudo apt install ffmpeg
+```
+
## Example materials
??? abstract "end2end.py"
diff --git a/docs/user_guide/examples/offline_inference/qwen3_omni.md b/docs/user_guide/examples/offline_inference/qwen3_omni.md
index 2d856f7380a..6577092bbfe 100644
--- a/docs/user_guide/examples/offline_inference/qwen3_omni.md
+++ b/docs/user_guide/examples/offline_inference/qwen3_omni.md
@@ -112,6 +112,14 @@ python end2end_async_chunk.py \
> async_chunk example when you need the stage-level concurrency semantics
> described in PR #962 / #1151.
+### FAQ
+
+If you encounter error about backend of librosa, try to install ffmpeg with command below.
+```
+sudo apt update
+sudo apt install ffmpeg
+```
+
## Example materials
??? abstract "end2end.py"
diff --git a/docs/user_guide/examples/offline_inference/qwen3_tts.md b/docs/user_guide/examples/offline_inference/qwen3_tts.md
index 7226ac1fe4b..19fea4132ce 100644
--- a/docs/user_guide/examples/offline_inference/qwen3_tts.md
+++ b/docs/user_guide/examples/offline_inference/qwen3_tts.md
@@ -18,11 +18,11 @@ Please refer to the [stage configuration documentation](https://docs.vllm.ai/pro
### ROCm Dependencies
-You will need to install the dependency `onnxruntime-rocm`.
+You will need to install these two dependencies `onnxruntime-rocm` and `sox`.
```
pip uninstall onnxruntime # should be removed before we can install onnxruntime-rocm
-pip install onnxruntime-rocm
+pip install onnxruntime-rocm sox
```
## Quick Start
@@ -144,13 +144,13 @@ completes. This demonstrates that audio data is available progressively rather t
## Batched Decoding
-The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, set `max_num_seqs > 1` on both stages via `--stage-overrides` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`.
+The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, provide a stage config with `max_num_seqs > 1` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`.
```
python end2end.py --query-type CustomVoice \
--txt-prompts benchmark_prompts.txt \
--batch-size 4 \
- --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2},"1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}'
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml
```
**Important:** `--batch-size` must match a CUDA graph capture size (1, 2, 4, 8, 16...) because the Talker's code predictor KV cache is sized to `max_num_seqs`, and CUDA graphs pad the batch to the next capture size. Both stages need `max_num_seqs >= batch_size` in the stage config for batching to take effect. If only stage 1 has a higher `max_num_seqs`, it won't help — stage 1 can only batch chunks from requests that are in-flight simultaneously, which requires stage 0 to also process multiple requests concurrently.
diff --git a/docs/user_guide/examples/offline_inference/text_to_audio.md b/docs/user_guide/examples/offline_inference/text_to_audio.md
index a31a4d7a4d5..62a70e5254d 100644
--- a/docs/user_guide/examples/offline_inference/text_to_audio.md
+++ b/docs/user_guide/examples/offline_inference/text_to_audio.md
@@ -29,22 +29,6 @@ python text_to_audio.py \
--output stable_audio_output.wav
```
-To reduce per-GPU memory for multi-GPU inference, launch with HSDP:
-
-```bash
-python text_to_audio.py \
- --model stabilityai/stable-audio-open-1.0 \
- --prompt "The sound of a hammer hitting a wooden surface" \
- --negative-prompt "Low quality" \
- --seed 42 \
- --guidance-scale 7.0 \
- --audio-length 10.0 \
- --num-inference-steps 100 \
- --use-hsdp \
- --hsdp-shard-size 2 \
- --output stable_audio_output.wav
-```
-
Key arguments:
- `--prompt`: text description (string).
@@ -53,9 +37,6 @@ Key arguments:
- `--guidance-scale`: classifier-free guidance scale.
- `--audio-length`: audio duration in seconds.
- `--num-inference-steps`: diffusion sampling steps.(more steps = higher quality, slower).
-- `--use-hsdp`: enable HSDP weight sharding for the Stable Audio DiT.
-- `--hsdp-shard-size`: number of GPUs used for HSDP sharding.
-- `--hsdp-replicate-size`: number of HSDP replica groups.
- `--output`: path to save the generated WAV file.
## Example materials
diff --git a/docs/user_guide/examples/offline_inference/text_to_video.md b/docs/user_guide/examples/offline_inference/text_to_video.md
index a09dbfc979f..4288c089c60 100644
--- a/docs/user_guide/examples/offline_inference/text_to_video.md
+++ b/docs/user_guide/examples/offline_inference/text_to_video.md
@@ -5,8 +5,6 @@ Source : In the frame, a woman with black long hair is identified as .\n**Overall Environment/Scene**: A lively open-kitchen café at night; stove flames flare, steam rises, and warm pendant lights swing slightly as staff move behind her. The shot is an upper-body close-up.\n**Main Characters/Subjects Appearance**: is a young woman with thick dark wavy hair and a side part. She wears a fitted black top under a light apron, a thin gold chain necklace, and small stud earrings.\n**Main Characters/Subjects Actions**: tastes the sauce with a spoon, then turns her face toward the camera while still holding the spoon, her expression shifting from focused to conflicted.\n maintains eye contact, swallows as if choosing her words, and says, I keep telling myself I’m fine,but some nights it feels like I’m just performing calm." \
- --image-path 9.png \
- --audio-path 9.wav \
- --video-negative-prompt "jitter, bad hands, blur, distortion" \
- --audio-negative-prompt "robotic, muffled, echo, distorted" \
- --cfg-parallel-size 2 \
- --num-inference-steps 45 \
- --height 704 \
- --width 1280 \
- --output out_dreamid_omni_oneip.mp4
-```
-
-
Key arguments:
- `--prompt`: text description (string).
- `--model`: path to the model local directory.
diff --git a/docs/user_guide/examples/online_serving/bagel.md b/docs/user_guide/examples/online_serving/bagel.md
index aa8b33de802..4a6094c0894 100644
--- a/docs/user_guide/examples/online_serving/bagel.md
+++ b/docs/user_guide/examples/online_serving/bagel.md
@@ -2,112 +2,147 @@
Source .
-## Installation
-Please refer to [README.md](https://github.com/vllm-project/vllm-omni/tree/main/README.md)
-
-## Architecture
-
-BAGEL-7B-MoT is a Mixture-of-Transformers (MoT) model supporting both image generation and understanding. It offers two deployment topologies:
+## 🛠️ Installation
-| Topology | Stages | Description |
-| :------- | :----- | :---------- |
-| **Two-stage** (default) | Stage 0 (Thinker, AR) + Stage 1 (DiT, Diffusion) | Thinker handles text/understanding via vLLM AR engine; DiT handles image generation. KV cache is transferred between stages. |
-| **Single-stage** | Stage 0 (DiT, Diffusion) only | The DiT stage contains a full LLM, ViT, VAE, and tokenizer internally. All modalities are handled within a single diffusion process. |
-
-Both topologies support all four modalities: `text2img`, `img2img`, `img2text`, `text2text`.
-
-> **Note**: These examples work with the default configuration on an **NVIDIA A100 (80GB)**. We also tested on dual **NVIDIA RTX 5000 Ada (32GB each)**. For dual-GPU setups, modify the deploy YAML to distribute stages across devices.
+Please refer to [README.md](https://github.com/vllm-project/vllm-omni/tree/main/README.md)
-## Launch the Server
+## Run examples (BAGEL-7B-MoT)
-### Two-Stage (Default)
+**Note**: These examples work with the default configuration on an **NVIDIA A100 (80GB)**. We also tested on dual **NVIDIA RTX 5000 Ada (32GB each)**. For dual-GPU setups, please modify the stage configuration to distribute the model across devices.
-The default pipeline is auto-detected from the model. No extra flags needed:
+### Launch the Server
```bash
+# Use default configuration
vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091
```
Or use the convenience script:
```bash
-cd examples/online_serving/bagel
+cd /workspace/vllm-omni/examples/online_serving/bagel
bash run_server.sh
+```
-# Launch a single stage per terminal
-bash run_server_stage_cli.sh --stage 0
-bash run_server_stage_cli.sh --stage 1
+```bash
+vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
```
-To use a custom deploy YAML (note: `--stage-configs-path` is deprecated in favor of `--deploy-config`):
+#### 🚀 Tensor Parallelism (TP)
+
+For larger models or multi-GPU environments, you can enable Tensor Parallelism (TP) for the server.
+
+1. **Modify Stage Config**: Create or modify a stage configuration yaml (e.g., [`bagel.yaml`](https://github.com/vllm-project/vllm-omni/tree/main/vllm_omni/model_executor/stage_configs/bagel.yaml)). Set `tensor_parallel_size` to `2` (or more) and update `devices` to include multiple GPU IDs (e.g., `"0,1"`).
+```yaml
+ engine_args:
+ tensor_parallel_size: 2
+ ...
+ runtime:
+ devices: "0,1"
+```
+
+2. **Launch Server**:
```bash
-vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 \
- --deploy-config /path/to/deploy_config.yaml
+vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 --stage-configs-path /path/to/your/custom_bagel.yaml
```
-See [`bagel.yaml`](https://github.com/vllm-project/vllm-omni/tree/main/vllm_omni/deploy/bagel.yaml) for the default two-stage deploy configuration.
+#### Using Mooncake Connector
-### Single-Stage
+By default, BAGEL uses `SharedMemoryConnector` for inter-stage communication. You can use the [Mooncake](https://github.com/kvcache-ai/Mooncake) connector to transfer KV cache between stages, which also enables multi-node deployment.
-The DiT stage contains a full LLM, ViT, VAE, and tokenizer, so it can handle all modalities (text2img, img2img, img2text, text2text, think) without a separate Thinker stage:
+**1. Install Mooncake**
```bash
-vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 \
- --deploy-config vllm_omni/deploy/bagel_single_stage.yaml
+# For CUDA-enabled systems (recommended)
+pip install mooncake-transfer-engine
+
+# For non-CUDA systems
+pip install mooncake-transfer-engine-non-cuda
```
-See [`bagel_single_stage.yaml`](https://github.com/vllm-project/vllm-omni/tree/main/vllm_omni/deploy/bagel_single_stage.yaml) for configuration. The `pipeline: bagel_single_stage` field selects the single-stage topology from the pipeline registry.
+**2. Start Mooncake Master** on the primary node:
-### Tensor Parallelism (TP)
+```bash
+# Optional: enable disk-backed storage by creating a directory and passing --root_fs_dir.
+# Without it, Mooncake runs in memory-only mode, which is sufficient for KV cache transfer.
+mkdir -p ./mc_storage
+
+mooncake_master \
+ --rpc_port=50051 \
+ --enable_http_metadata_server=true \
+ --http_metadata_server_host=0.0.0.0 \
+ --http_metadata_server_port=8080 \
+ --metrics_port=9003 \
+ --root_fs_dir=./mc_storage/ \
+ --cluster_id=mc-local-1 &
+```
-For larger models or multi-GPU environments, enable TP via CLI:
+**3. Launch the server** with the Mooncake stage config:
```bash
-vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 --tensor-parallel-size 2
+vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml
```
-Or set `tensor_parallel_size` per stage in a custom deploy YAML.
+> **Note**: Before launching, edit [`bagel_multiconnector.yaml`](https://github.com/vllm-project/vllm-omni/tree/main/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml) and replace the `metadata_server` and `master` addresses with your Mooncake master node's actual IP. For single-node testing, `127.0.0.1` works.
+
+The client-side usage is identical to the default setup -- the Mooncake connector is transparent to the API. See the requests section below.
+
+For more details on the Mooncake connector configuration, see the [Mooncake Store Connector documentation](https://github.com/vllm-project/vllm-omni/tree/main/docs/design/feature/omni_connectors/mooncake_store_connector.md).
-### Multi-Node Deployment
+#### Multi-Node Deployment
-Deploy each stage on a **separate node** for better resource utilization. Replace `` with the actual IP address of your orchestrator node.
+You can deploy each stage on a **separate node** for better resource utilization. In this example, the orchestrator (Stage 0 / Thinker) and Stage 1 (DiT) run on different machines, connected via Mooncake.
-**1. Launch Stage 0 (Thinker / Orchestrator)** on the orchestrator node:
+Replace `` below with the actual IP address of your orchestrator node (e.g., `10.244.227.244`).
+
+> [!WARNING]
+> **Before launching**, edit [`bagel_multiconnector.yaml`](https://github.com/vllm-project/vllm-omni/tree/main/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml) and replace the `metadata_server` and `master` addresses with your Mooncake master node's actual IP. Mismatched addresses will cause silent connection failures.
+
+**1. Start Mooncake Master** (on the orchestrator node):
+
+```bash
+mooncake_master \
+ --rpc_port=50051 \
+ --enable_http_metadata_server=true \
+ --http_metadata_server_host= \
+ --http_metadata_server_port=8080 \
+ --metrics_port=9003
+```
+
+**2. Launch Stage 0 (Thinker / Orchestrator)** on the orchestrator node:
```bash
-# API server port for client requests: 8000
vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni \
- --port 8000 \
+ --port 8000 \ # API server port for client requests
+ --stage-configs-path vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml \
--stage-id 0 \
- --omni-master-address \
- --omni-master-port 8091
+ -oma \
+ -omp 8091
```
-**2. Launch Stage 1 (DiT)** on the remote node in headless mode:
+**3. Launch Stage 1 (DiT)** on the remote node in headless mode:
```bash
vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml \
--stage-id 1 \
--headless \
- --omni-master-address \
- --omni-master-port 8091
+ -oma \
+ -omp 8091
```
-Or use the convenience script:
-
-```bash
-# Terminal 1: Stage 0
-bash run_server_stage_cli.sh --stage 0
+**Mooncake Master arguments:**
-# Terminal 2: Stage 1
-bash run_server_stage_cli.sh --stage 1
-
-# With extra args
-bash run_server_stage_cli.sh --stage 0 -- --tensor-parallel-size 2
-bash run_server_stage_cli.sh --stage 1 -- --gpu-memory-utilization 0.9
-```
+| Argument | Description |
+| :------- | :---------- |
+| `--rpc_port` | Mooncake RPC port for control-plane coordination between stages |
+| `--enable_http_metadata_server` | Enable the HTTP metadata server for service discovery |
+| `--http_metadata_server_host` | IP address to bind the metadata server (use the orchestrator node's IP) |
+| `--http_metadata_server_port` | Port for the HTTP metadata server |
+| `--metrics_port` | Port for Prometheus-compatible metrics endpoint |
**vllm serve arguments:**
@@ -115,31 +150,85 @@ bash run_server_stage_cli.sh --stage 1 -- --gpu-memory-utilization 0.9
| :------- | :---------- |
| `--stage-id` | Which stage this process runs (0 = Thinker, 1 = DiT) |
| `--headless` | Run without the API server (worker-only mode) |
-| `-oma` / `--omni-master-address` | Orchestrator master address |
-| `-omp` / `--omni-master-port` | Orchestrator master port |
+| `-oma` | Orchestrator master address |
+| `-omp` | Orchestrator master port for Stage 1 to connect to Stage 0 for task coordination |
> [!IMPORTANT]
> **Startup Order**: Stage 0 (orchestrator) must be launched **before** Stage 1 (headless).
> Stage 0 will appear to hang on startup until Stage 1 (worker) connects — this is expected behavior.
-### Inter-Stage Connectors
+**Network Requirements**
+
+All nodes must have network connectivity to each other. Ensure the following ports are open **between all participating nodes**:
-When deploying stages across nodes, configure the connector type in the deploy YAML:
+| Port | Protocol | Service | Direction |
+| :--- | :------- | :------ | :-------- |
+| 50051 | TCP | Mooncake Master RPC | Worker → Orchestrator |
+| 8080 | TCP | Mooncake HTTP Metadata Server | Worker → Orchestrator |
+| 8091 | TCP | Orchestrator Master (`-omp`) | Worker → Orchestrator |
+| 8000 | TCP | API Server (`--port`) | Client → Orchestrator |
+| 9003 | TCP | Metrics (optional) | Monitoring → Orchestrator |
-- **SharedMemoryConnector** (default): Used for single-node deployments. No explicit configuration needed.
-- **MooncakeTransferEngineConnector**: For multi-node setups with RDMA hardware. Defined in [`bagel.yaml`](https://github.com/vllm-project/vllm-omni/tree/main/vllm_omni/deploy/bagel.yaml) under `connectors.rdma_connector`.
+> **Tip**: If nodes are behind a firewall or in different VPCs/security groups, make sure the above ports are allowed in ingress/egress rules. All nodes should be reachable via their IP addresses (no NAT). Using nodes on the same subnet or VPC is recommended to minimize latency for Mooncake KV cache transfers.
-To use Mooncake, create a custom deploy YAML that binds `output_connectors` / `input_connectors` on each stage to the `rdma_connector` defined in the `connectors` section.
+### Send Multi-modal Request
-## Send Requests
+Get into the bagel folder:
```bash
cd examples/online_serving/bagel
```
+Send request via Python
+
+```bash
+python openai_chat_client.py --prompt "A cute cat" --modality text2img
+```
+
+The Python client supports the following command-line arguments:
+
+- `--prompt` (or `-p`): Text prompt for generation (default: `A cute cat`)
+- `--output` (or `-o`): Output file path for image results (default: `bagel_output.png`)
+- `--server` (or `-s`): Server URL (default: `http://localhost:8091`)
+- `--image-url` (or `-i`): Input image URL or local file path (for img2img/img2text modes)
+- `--modality` (or `-m`): Task modality (default: `text2img`). Options: `text2img`, `img2img`, `img2text`, `text2text`
+- `--height`: Image height in pixels (default: 512)
+- `--width`: Image width in pixels (default: 512)
+- `--steps`: Number of inference steps (default: 25)
+- `--seed`: Random seed (default: 42)
+- `--negative`: Negative prompt for image generation
+
+Example with custom parameters:
+
+```bash
+python openai_chat_client.py \
+ --prompt "A futuristic city" \
+ --modality text2img \
+ --height 768 \
+ --width 768 \
+ --steps 50 \
+ --seed 42 \
+ --negative "blurry, low quality"
+```
+
+## Modality Control
+
+BAGEL-7B-MoT supports **multiple modality modes** for different use cases.
+
+The default yaml configuration deploys Thinker and DiT on the same GPU. You can use the default configuration file: [`bagel.yaml`](https://github.com/vllm-project/vllm-omni/tree/main/vllm_omni/model_executor/stage_configs/bagel.yaml)
+
+| Modality | Input | Output | Description |
+| ----------- | ------------ | ------ | -------------------------------------- |
+| `text2img` | Text | Image | Generate images from text prompts |
+| `img2img` | Image + Text | Image | Transform images using text guidance |
+| `img2text` | Image + Text | Text | Generate text descriptions from images |
+| `text2text` | Text | Text | Pure text generation |
+
### Text to Image (text2img)
-**Python client:**
+Generate images from text prompts:
+
+**Using Python client**
```bash
python openai_chat_client.py \
@@ -149,7 +238,7 @@ python openai_chat_client.py \
--steps 50
```
-**curl:**
+**Using curl**
```bash
curl http://localhost:8091/v1/chat/completions \
@@ -164,9 +253,12 @@ curl http://localhost:8091/v1/chat/completions \
}'
```
+
### Image to Image (img2img)
-**Python client:**
+Transform images based on text prompts:
+
+**Using Python client**
```bash
python openai_chat_client.py \
@@ -176,7 +268,7 @@ python openai_chat_client.py \
--output transformed.png
```
-**curl:**
+**Using curl**
```bash
IMAGE_BASE64=$(base64 -w 0 cat.jpg)
@@ -201,11 +293,14 @@ EOF
curl http://localhost:8091/v1/chat/completions \
-H "Content-Type: application/json" \
-d @payload.json
+
```
### Image to Text (img2text)
-**Python client:**
+Generate text descriptions from images:
+
+**Using Python client**
```bash
python openai_chat_client.py \
@@ -214,7 +309,7 @@ python openai_chat_client.py \
--image-url /path/to/image.jpg
```
-**curl:**
+**Using curl**
```bash
IMAGE_BASE64=$(base64 -w 0 cat.jpg)
@@ -239,7 +334,9 @@ curl http://localhost:8091/v1/chat/completions \
### Text to Text (text2text)
-**Python client:**
+Pure text generation:
+
+**Using Python client**
```bash
python openai_chat_client.py \
@@ -247,81 +344,33 @@ python openai_chat_client.py \
--modality text2text
```
-**curl:**
+**Using curl**
```bash
curl http://localhost:8091/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
- "messages": [{"role": "user", "content": [{"type": "text", "text": "<|im_start|>user\nWhat is the capital of France?<|im_end|>\n<|im_start|>assistant\n"}]}],
+ "messages": [{"role": "user", "content": [{"type": "text", "text": "<|im_start|>user\nWhat is the capital of France?<|im_end|>\n<|im_start|>assistant\n"}]}]
"modalities": ["text"]
}'
```
-### Python Client Arguments
-
-| Argument | Default | Description |
-| :------- | :------ | :---------- |
-| `--prompt` / `-p` | `A cute cat` | Text prompt |
-| `--output` / `-o` | `bagel_output.png` | Output file path |
-| `--server` / `-s` | `http://localhost:8091` | Server URL |
-| `--image-url` / `-i` | `None` | Input image URL or local path (img2img/img2text) |
-| `--modality` / `-m` | `text2img` | `text2img`, `img2img`, `img2text`, `text2text` |
-| `--height` | `512` | Image height in pixels |
-| `--width` | `512` | Image width in pixels |
-| `--steps` | `25` | Number of inference steps |
-| `--seed` | `42` | Random seed |
-| `--negative` | `None` | Negative prompt for CFG |
+## FAQ
-Example with custom parameters:
+- If you encounter an error about the backend of librosa, try to install ffmpeg with the command below.
```bash
-python openai_chat_client.py \
- --prompt "A futuristic city" \
- --modality text2img \
- --height 768 \
- --width 768 \
- --steps 50 \
- --seed 42 \
- --negative "blurry, low quality"
+sudo apt update
+sudo apt install ffmpeg
```
-## Configuration Reference
-
-### Deploy YAML Files
-
-| File | Description |
-| :--- | :---------- |
-| [`bagel.yaml`](https://github.com/vllm-project/vllm-omni/tree/main/vllm_omni/deploy/bagel.yaml) | Two-stage default (Thinker + DiT on GPU 0) |
-| [`bagel_single_stage.yaml`](https://github.com/vllm-project/vllm-omni/tree/main/vllm_omni/deploy/bagel_single_stage.yaml) | Single-stage (DiT only) |
-
-### Key Deploy YAML Fields
-
-| Field | Scope | Description |
-| :---- | :---- | :---------- |
-| `pipeline` | top-level | Override auto-detected pipeline (e.g. `bagel_single_stage`) |
-| `stages[].stage_id` | per-stage | Stage identifier (0, 1, ...) |
-| `stages[].devices` | per-stage | GPU device IDs (e.g. `"0"`, `"0,1"`) |
-| `stages[].max_num_seqs` | per-stage | Maximum concurrent sequences |
-| `stages[].gpu_memory_utilization` | per-stage | Fraction of GPU memory to use |
-| `stages[].enforce_eager` | per-stage | Disable CUDA graphs |
-| `stages[].tensor_parallel_size` | per-stage | TP degree for this stage |
-| `connectors` | top-level | Define available connector instances (SHM, Mooncake) |
-| `platforms` | top-level | Platform-specific overrides (e.g. `xpu`) |
-
-## FAQ
-
-- If you encounter OOM errors, try decreasing `max_model_len` or `gpu_memory_utilization` in the deploy YAML.
-
-**Two-stage VRAM usage:**
-
-| Stage | VRAM |
-| :---- | :--- |
-| Stage 0 (Thinker) | **15.04 GiB + KV Cache** |
-| Stage 1 (DiT) | **26.50 GiB** |
-| Total | **~42 GiB + KV Cache** |
+- If you don’t know how much VRAM is needed for the model or encounter the OOM error, you can try to decrease the max_model_len.
-**Single-stage VRAM usage:** The DiT loads the full model (~42 GiB) in one process.
+| Stage | VRAM |
+| :------------------ | :--------------------------- |
+| Stage-0 (Thinker) | **15.04 GiB** **+ KV Cache** |
+| Stage-1 (DiT) | **26.50 GiB** |
+| Total | **~42 GiB + KV Cache** |
## Example materials
diff --git a/docs/user_guide/examples/online_serving/diffusers_pipeline_adapter.md b/docs/user_guide/examples/online_serving/diffusers_pipeline_adapter.md
deleted file mode 100644
index ac88071d53f..00000000000
--- a/docs/user_guide/examples/online_serving/diffusers_pipeline_adapter.md
+++ /dev/null
@@ -1,93 +0,0 @@
-# Diffusers Backend Adapter Example
-
-Source .
-
-
-This example demonstrates how to serve any 🤗 Diffusers pipeline through vLLM-Omni
-using the `diffusers` load format.
-
-## Supported Models
-
-Any model loadable via `DiffusionPipeline.from_pretrained()` should be supported, including text-to-image, image-to-image, text-to-video, image-to-video, and text-to-audio.
-
-## Limitations
-
-The diffusers backend is a black-box adapter. The following features are NOT yet supported.
-It is not guaranteed whether they will be supported in the future.
-
-- CFG parallel execution
-- Sequence parallel execution
-- TeaCache / Cache-DiT acceleration
-- Step-wise execution (continuous batching)
-
-For these features, it is recommended to use natively supported pipelines instead.
-
-## Usage
-
-### Option 1: CLI arguments
-
-```bash
-vllm serve "stable-diffusion-v1-5/stable-diffusion-v1-5" \
- --omni \
- --diffusion-load-format diffusers \
- --diffusers-load-kwargs '{"use_safetensors": true}' \
- --diffusers-call-kwargs '{"num_inference_steps": 30, "guidance_scale": 7.5}'
-```
-
-`--diffusers-load-kwargs` and `--diffusers-call-kwargs` are only valid together with `--diffusion-load-format diffusers`.
-
-### Option 2: Stage config YAML
-
-```bash
-vllm serve stable-diffusion-v1-5/stable-diffusion-v1-5 --stage-configs-path examples/online_serving/diffusers_pipeline_adapter/stage_config.yaml --omni
-```
-
-The particular fields of interest are `model`, `diffusion_load_format`, `diffusers_load_kwargs`, and `diffusers_call_kwargs` under `engine_args`. They are the same as the CLI arguments.
-
-## Send a Request
-
-```bash
-curl http://localhost:8000/v1/images/generations \
- -H "Content-Type: application/json" \
- -d '{
- "model": "stable-diffusion-v1-5/stable-diffusion-v1-5",
- "prompt": "a photo of an astronaut riding a horse on mars",
- "n": 1,
- "size": "512x512"
- }'
-```
-
-Or refer to other documentation pages on how to request a particular input/output modality, such as `examples/online_serving/text_to_image/openai_chat_client.py`.
-
-## Configuration Reference
-
-For the diffusers adapter, set options under **`engine_args`**:
-
-### `diffusion_load_format: "diffusers"`
-
-This field selects the Hugging Face diffusers adapter path (see `DiffusersPipelineLoader`).
-
-### `diffusers_load_kwargs`
-
-Passed to `DiffusionPipeline.from_pretrained()`.
-
-This is suitable for model-specific configurations not available through the vLLM-Omni interface (such as `Omni.__init__()`, `vllm serve` CLI arguments, and stage config YAML fields outside `diffusers_load_kwargs`).
-
-When a parameter is available in the vLLM-Omni interface, it will be adapted here.
-But if that parameter is simultaneously set in both the vLLM-Omni interface and `diffusers_load_kwargs`, the **latter** will take precedence.
-
-### `diffusers_call_kwargs`
-
-Passed to `pipeline.__call__()`.
-
-This is suitable for sampling parameters not available through the vLLM-Omni interface (such as `Omni.generate()` and online serving payloads).
-
-When a parameter is available in the vLLM-Omni interface, it will be adapted here.
-But if that parameter is simultaneously set in both the vLLM-Omni interface and `diffusers_call_kwargs`, the **former** will take precedence (because it is set at request time).
-
-## Example materials
-
-??? abstract "stage_config.yaml"
- ``````yaml
- --8<-- "examples/online_serving/diffusers_pipeline_adapter/stage_config.yaml"
- ``````
diff --git a/docs/user_guide/examples/online_serving/fish_speech.md b/docs/user_guide/examples/online_serving/fish_speech.md
index 2a15ef44ac8..7322d06aaaf 100644
--- a/docs/user_guide/examples/online_serving/fish_speech.md
+++ b/docs/user_guide/examples/online_serving/fish_speech.md
@@ -41,11 +41,15 @@ Features:
## Launch the Server
```bash
-vllm serve fishaudio/s2-pro --omni --port 8091
+vllm-omni serve fishaudio/s2-pro \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml \
+ --omni \
+ --port 8091 \
+ --trust-remote-code \
+ --enforce-eager \
+ --gpu-memory-utilization 0.9
```
-The deploy config is auto-loaded from `vllm_omni/deploy/fish_qwen3_omni.yaml`.
-
Or use the convenience script:
```bash
diff --git a/docs/user_guide/examples/online_serving/glm_image.md b/docs/user_guide/examples/online_serving/glm_image.md
index 4cc49e84602..f7027b906db 100644
--- a/docs/user_guide/examples/online_serving/glm_image.md
+++ b/docs/user_guide/examples/online_serving/glm_image.md
@@ -1,96 +1,99 @@
# GLM-Image Online Serving
-GLM-Image is a 2-stage image generation model (AR + Diffusion) supported by vLLM-Omni's
-declarative config system. The pipeline topology and stage structure are declared in
-`vllm_omni/model_executor/models/glm_image/pipeline.py`; deployment knobs (GPU placement,
-memory, sampling params) live in `vllm_omni/deploy/glm_image.yaml`.
+Source .
-## Start Server
+
+This example demonstrates how to deploy GLM-Image for online image generation using vLLM-Omni.
+
+## 🛠️ Installation
+
+Please refer to [README.md](https://github.com/vllm-project/vllm-omni/tree/main/README.md)
+
+## Run examples (GLM-Image)
+
+**Note**: These examples work with the default configuration on **2× NVIDIA A100 (80GB)** or equivalent. Stage 0 (AR) and Stage 1 (Diffusion) each use one GPU by default. For single-GPU setups, modify the stage configuration to share the same device.
+
+### Launch the Server
```bash
+# Use default configuration
vllm serve zai-org/GLM-Image --omni --port 8091
```
-The config system auto-detects the pipeline from the model's `model_index.json` — no
-manual `--stage-configs-path` or `--deploy-config` needed.
+Or use the convenience script:
+
+```bash
+cd examples/online_serving/glm_image
+bash run_server.sh
+```
+
+If you have a custom stage configs file:
+
+```bash
+vllm serve zai-org/GLM-Image --omni --port 8091 --stage-configs-path /path/to/glm_image.yaml
+```
+
+### Send Requests
-By default, stage 0 (AR) runs on GPU 0 and stage 1 (Diffusion) on GPU 1. To colocate
-both stages on a single GPU, override per stage:
+Get into the glm_image folder:
```bash
-vllm serve zai-org/GLM-Image --omni --port 8091 \
- --stage-0-devices 0 --stage-1-devices 0
+cd examples/online_serving/glm_image
```
-## API Calls
+Send request via Python:
+
+```bash
+python openai_chat_client.py --prompt "A cute cat sitting on a window sill"
+```
+
+The Python client supports the following command-line arguments:
+
+- `--prompt` (or `-p`): Text prompt for generation (default: `A beautiful sunset over the ocean with sailing boats`)
+- `--output` (or `-o`): Output file path (default: `glm_image_output.png`)
+- `--server` (or `-s`): Server URL (default: `http://localhost:8091`)
+- `--image` (or `-i`): Input image path (for image-to-image editing)
+- `--height`: Image height in pixels (default: 1024)
+- `--width`: Image width in pixels (default: 1024)
+- `--steps`: Number of inference steps (default: 50)
+- `--guidance-scale`: Classifier-free guidance scale (default: 1.5)
+- `--seed`: Random seed (default: 42)
+- `--negative`: Negative prompt
+
+## Modality Control
+
+GLM-Image supports **text-to-image** and **image-to-image** modes.
+
+The default yaml configuration deploys AR on GPU 0 and DiT on GPU 1. You can use the default configuration file: [`glm_image.yaml`](https://github.com/vllm-project/vllm-omni/tree/main/vllm_omni/model_executor/stage_configs/glm_image.yaml)
+
+| Mode | Input | Output | Description |
+| -------------- | ------------ | ------ | ---------------------------------- |
+| Text-to-Image | Text | Image | Generate images from text prompts |
+| Image-to-Image | Image + Text | Image | Edit images with text instructions |
### Text-to-Image
```bash
-curl -s http://localhost:8091/v1/chat/completions \
- -H "Content-Type: application/json" \
- -d '{
- "messages": [
- {"role": "user", "content": "A photorealistic mountain landscape at sunset"}
- ],
- "extra_body": {
- "height": 1024,
- "width": 1024,
- "num_inference_steps": 50,
- "guidance_scale": 1.5,
- "seed": 42
- }
- }' | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2- | base64 -d > output.png
+python openai_chat_client.py \
+ --prompt "A photorealistic mountain landscape at sunset" \
+ --height 1024 \
+ --width 1024 \
+ --output landscape.png
+
+# Or use the curl script:
+bash run_curl_text_to_image.sh "A futuristic city skyline at night"
```
### Image-to-Image (Image Editing)
```bash
-curl -s http://localhost:8091/v1/chat/completions \
- -H "Content-Type: application/json" \
- -d '{
- "messages": [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": "Convert this image to watercolor style"},
- {"type": "image_url", "image_url": {"url": "data:image/png;base64,$(base64 -w0 input.png)}"}
- ]
- }
- ],
- "extra_body": {
- "height": 1024,
- "width": 1024,
- "num_inference_steps": 50,
- "guidance_scale": 1.5
- }
- }' | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2- | base64 -d > output.png
-```
+python openai_chat_client.py \
+ --prompt "Convert this image to watercolor style" \
+ --image input.png \
+ --output watercolor.png
-### Using the OpenAI Python SDK
-
-```python
-from openai import OpenAI
-import base64
-
-client = OpenAI(base_url="http://localhost:8091/v1", api_key="none")
-
-response = client.chat.completions.create(
- model="zai-org/GLM-Image",
- messages=[{"role": "user", "content": "A beautiful sunset over the ocean"}],
- extra_body={
- "height": 1024,
- "width": 1024,
- "num_inference_steps": 50,
- "guidance_scale": 1.5,
- "seed": 42,
- },
-)
-
-img_url = response.choices[0].message.content[0].image_url.url
-_, b64_data = img_url.split(",", 1)
-with open("output.png", "wb") as f:
- f.write(base64.b64decode(b64_data))
+# Or use the curl script:
+bash run_curl_image_edit.sh input.png "Convert to watercolor style"
```
For general-purpose request methods (curl, OpenAI SDK, Python `requests`), see
@@ -101,9 +104,9 @@ guides.
When using `/v1/chat/completions`, pass these inside `extra_body` in the curl
JSON, or via the `extra_body` keyword argument in the OpenAI Python SDK (see the
-[Diffusion Chat API guide](../../../serving/diffusion_chat_api.md)).
-When using the dedicated [`/v1/images/generations`](../../../serving/image_generation_api.md)
-or [`/v1/images/edits`](../../../serving/image_edit_api.md) endpoints, pass
+[Diffusion Chat API guide](../../../../serving/diffusion_chat_api.md)).
+When using the dedicated [`/v1/images/generations`](../../../../serving/image_generation_api.md)
+or [`/v1/images/edits`](../../../../serving/image_edit_api.md) endpoints, pass
the supported generation controls as top-level fields directly. For image
dimensions and count, use `size` and `n` rather than `height` or `width`.
@@ -113,7 +116,7 @@ dimensions and count, use `size` and `n` rather than `height` or `width`.
| `width` | int | 1024 | Image width in pixels |
| `num_inference_steps` | int | 50 | Number of diffusion denoising steps |
| `guidance_scale` | float | 1.5 | Classifier-free guidance scale |
-| `seed` | int | None | Optional random seed |
+| `seed` | int | None | Optional random seed; `/v1/images/*` generates one server-side if omitted |
| `negative_prompt` | str | None | Negative prompt |
## Response Format
@@ -147,12 +150,13 @@ dimensions and count, use `size` and `n` rather than `height` or `width`.
## Extract Image
```bash
+# From a saved JSON response
cat response.json | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2- | base64 -d > output.png
```
## Architecture
-GLM-Image uses a 2-stage pipeline:
+GLM-Image uses a 2-stage multistage pipeline:
```
Stage 0 (AR Model) Stage 1 (Diffusion)
@@ -177,13 +181,41 @@ Stage 0 (AR Model) Stage 1 (Diffusion)
| Stage-1 (DiT+VAE) | **~20 GiB** |
| Total | **~38 GiB + KV Cache** |
+## File Description
+
+| File | Description |
+| --------------------------- | ------------------------------------- |
+| `run_server.sh` | Server startup script |
+| `run_curl_text_to_image.sh` | Text-to-image curl example |
+| `run_curl_image_edit.sh` | Image-to-image (editing) curl example |
+| `openai_chat_client.py` | Python client (t2i + i2i) |
+
## FAQ
-- If you encounter OOM errors, adjust `gpu_memory_utilization` in the deploy config:
+- If you encounter OOM errors, adjust `gpu_memory_utilization` in the stage config:
```yaml
-# In vllm_omni/deploy/glm_image.yaml, reduce from default 0.6:
+# In glm_image.yaml, reduce from default 0.6:
gpu_memory_utilization: 0.5
```
- The first request may be slow due to model warmup. Subsequent requests will be faster.
+
+## Example materials
+
+??? abstract "openai_chat_client.py"
+ ``````py
+ --8<-- "examples/online_serving/glm_image/openai_chat_client.py"
+ ``````
+??? abstract "run_curl_image_edit.sh"
+ ``````sh
+ --8<-- "examples/online_serving/glm_image/run_curl_image_edit.sh"
+ ``````
+??? abstract "run_curl_text_to_image.sh"
+ ``````sh
+ --8<-- "examples/online_serving/glm_image/run_curl_text_to_image.sh"
+ ``````
+??? abstract "run_server.sh"
+ ``````sh
+ --8<-- "examples/online_serving/glm_image/run_server.sh"
+ ``````
diff --git a/docs/user_guide/examples/online_serving/image_to_video.md b/docs/user_guide/examples/online_serving/image_to_video.md
index 781f0c2a5ed..00b67d74e26 100644
--- a/docs/user_guide/examples/online_serving/image_to_video.md
+++ b/docs/user_guide/examples/online_serving/image_to_video.md
@@ -72,9 +72,6 @@ curl -X POST http://localhost:8091/v1/videos/sync \
-F "guidance_scale_2=1.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=12.0" \
- -F "enable_frame_interpolation=true" \
- -F "frame_interpolation_exp=1" \
- -F "frame_interpolation_scale=1.0" \
-F "seed=42" \
-o sync_i2v_output.mp4
```
@@ -117,9 +114,6 @@ create_response=$(curl -s http://localhost:8091/v1/videos \
-F "guidance_scale_2=1.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=12.0" \
- -F "enable_frame_interpolation=true" \
- -F "frame_interpolation_exp=1" \
- -F "frame_interpolation_scale=1.0" \
-F "seed=42")
video_id=$(echo "$create_response" | jq -r '.id')
@@ -178,35 +172,9 @@ curl -X POST http://localhost:8091/v1/videos \
-F "guidance_scale_2=1.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=12.0" \
- -F "enable_frame_interpolation=true" \
- -F "frame_interpolation_exp=1" \
- -F "frame_interpolation_scale=1.0" \
-F "seed=42"
```
-Frame interpolation is also available for supported Wan2.2 I2V requests. See
-[Frame Interpolation](../../diffusion/frame_interpolation.md) for worker-side
-execution details and feature constraints.
-
-### Frame Interpolation Example
-
-```bash
-curl -X POST http://localhost:8091/v1/videos/sync \
- -F "prompt=A bear playing with yarn, smooth motion" \
- -F "input_reference=@/path/to/qwen-bear.png" \
- -F "width=832" \
- -F "height=480" \
- -F "num_frames=33" \
- -F "fps=16" \
- -F "num_inference_steps=40" \
- -F "guidance_scale=1.0" \
- -F "guidance_scale_2=1.0" \
- -F "enable_frame_interpolation=true" \
- -F "frame_interpolation_exp=1" \
- -F "frame_interpolation_scale=1.0" \
- -o sync_i2v_interpolated.mp4
-```
-
## Create Response Format
`POST /v1/videos` returns a job record, not inline base64 video data.
diff --git a/docs/user_guide/examples/online_serving/mimo_audio.md b/docs/user_guide/examples/online_serving/mimo_audio.md
index c8752f5782e..4737eca3664 100644
--- a/docs/user_guide/examples/online_serving/mimo_audio.md
+++ b/docs/user_guide/examples/online_serving/mimo_audio.md
@@ -13,10 +13,10 @@ Please refer to [README.md](https://github.com/vllm-project/vllm-omni/tree/main/
```bash
export MIMO_AUDIO_TOKENIZER_PATH="XiaomiMiMo/MiMo-Audio-Tokenizer"
-vllm serve XiaomiMiMo/MiMo-Audio-7B-Instruct --omni \
- --served-model-name "MiMo-Audio-7B-Instruct" \
- --port 18091 \
- --chat-template ./examples/online_serving/mimo_audio/chat_template.jinja
+vllm-omni serve XiaomiMiMo/MiMo-Audio-7B-Instruct --omni \
+--served-model-name "MiMo-Audio-7B-Instruct" \
+--port 18091 --stage-configs-path ./vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
+--chat-template ./examples/online_serving/mimo_audio/chat_template.jinja
```
> ⚠️ **Important**
> **MiMo-Audio is not compatible with the default chat template.**
diff --git a/docs/user_guide/examples/online_serving/qwen2_5_omni.md b/docs/user_guide/examples/online_serving/qwen2_5_omni.md
index b3a2c9f2ac9..43576469242 100644
--- a/docs/user_guide/examples/online_serving/qwen2_5_omni.md
+++ b/docs/user_guide/examples/online_serving/qwen2_5_omni.md
@@ -218,6 +218,14 @@ The gradio script supports the following arguments:
- `--port`: Port for Gradio server (default: 7861)
- `--share`: Share the Gradio demo publicly (creates a public link)
+### FAQ
+
+If you encounter error about backend of librosa, try to install ffmpeg with command below.
+```
+sudo apt update
+sudo apt install ffmpeg
+```
+
## Example materials
??? abstract "gradio_demo.py"
diff --git a/docs/user_guide/examples/online_serving/qwen3_omni.md b/docs/user_guide/examples/online_serving/qwen3_omni.md
index 22d89ee8018..69de24852f6 100644
--- a/docs/user_guide/examples/online_serving/qwen3_omni.md
+++ b/docs/user_guide/examples/online_serving/qwen3_omni.md
@@ -15,72 +15,15 @@ Please refer to [README.md](https://github.com/vllm-project/vllm-omni/tree/main/
vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
```
-The default deployment configuration situated at `vllm_omni/deploy/qwen3_omni_moe.yaml` is resolved and loaded
-automatically via the model registry, obviating the necessity for the `--deploy-config` flag in standard deployment topologies.
-Asynchronous chunk streaming is **enabled by default** within the bundled configuration.
+If you want to open async chunking for qwen3-omni, launch the server with command below
-To explicitly utilize a custom deployment YAML, specify the configuration path:
```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
- --deploy-config /path/to/deploy_config_file
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml
```
-### Launch individual stages (stage-based CLI)
-
-Adopt the stage-based CLI architecture to independently instantiate execution processes per functional stage.
-
-**1. Stage 0 (Thinker + API server)**
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni \
- --port 8091 \
- --stage-id 0 \
- --omni-master-address 127.0.0.1 \
- --omni-master-port 26000
-```
-
-**2. Stage 1 (Talker)**
-
+If you have custom stage configs file, launch the server with command below
```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni \
- --stage-id 1 \
- --headless \
- --omni-master-address 127.0.0.1 \
- --omni-master-port 26000
-```
-
-**3. Stage 2 (Code2Wav)**
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni \
- --stage-id 2 \
- --headless \
- --omni-master-address 127.0.0.1 \
- --omni-master-port 26000
-```
-
-Add `--deploy-config /path/to/deploy_config_file` to every command if you want
-to override the bundled deploy YAML.
-
-For the regular one-process launch, stage-specific CLI tuning is usually done
-with `--stage-overrides`, for example:
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
- --stage-overrides '{"1": {"gpu_memory_utilization": 0.5}}'
-```
-
-For the stage-based CLI, you usually do **not** need `--stage-overrides` for
-that kind of change. Since each command launches one stage, just pass the knob
-directly on that stage command:
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni \
- --stage-id 1 \
- --headless \
- --gpu-memory-utilization 0.5 \
- --omni-master-address 127.0.0.1 \
- --omni-master-port 26000
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
```
### Send Multi-modal Request
@@ -121,6 +64,15 @@ python openai_chat_completion_client_for_multimodal_generation.py \
bash run_curl_multimodal_generation.sh use_image
```
+
+### FAQ
+
+If you encounter error about backend of librosa, try to install ffmpeg with command below.
+```
+sudo apt update
+sudo apt install ffmpeg
+```
+
## Modality control
You can control output modalities to specify which types of output the model should generate. This is useful when you only need text output and want to skip audio generation stages for better performance.
@@ -244,7 +196,7 @@ The script supports the following arguments:
- `--model`: Model name/path (default: Qwen/Qwen3-Omni-30B-A3B-Instruct)
- `--server-port`: Port for vLLM server (default: 8091)
- `--gradio-port`: Port for Gradio demo (default: 7861)
-- `--deploy-config`: Path to custom deploy config YAML file (optional)
+- `--stage-configs-path`: Path to custom stage configs YAML file (optional)
- `--server-host`: Host for vLLM server (default: 0.0.0.0)
- `--gradio-ip`: IP for Gradio demo (default: 127.0.0.1)
- `--share`: Share Gradio demo publicly (creates a public link)
@@ -259,7 +211,7 @@ vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
If you have custom stage configs file:
```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --deploy-config /path/to/deploy_config_file
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
```
**Step 2: Run the Gradio demo**
diff --git a/docs/user_guide/examples/online_serving/qwen3_tts.md b/docs/user_guide/examples/online_serving/qwen3_tts.md
index 95f234f02de..156c4942cd9 100644
--- a/docs/user_guide/examples/online_serving/qwen3_tts.md
+++ b/docs/user_guide/examples/online_serving/qwen3_tts.md
@@ -58,7 +58,7 @@ Then open http://localhost:7860 in your browser.
```bash
# CustomVoice model (predefined speakers)
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
@@ -66,7 +66,7 @@ vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
# VoiceDesign model
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign \
- --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
@@ -74,7 +74,7 @@ vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign \
# Base model (voice cloning)
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-Base \
- --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
@@ -211,6 +211,14 @@ with open("output.wav", "wb") as f:
f.write(response.content)
```
+### FAQ
+
+If you encounter error about backend of librosa, try to install ffmpeg with command below.
+```
+sudo apt update
+sudo apt install ffmpeg
+```
+
## API Reference
### Voices Endpoint
diff --git a/docs/user_guide/examples/online_serving/text_to_video.md b/docs/user_guide/examples/online_serving/text_to_video.md
index b918aac19d0..d58296fcc78 100644
--- a/docs/user_guide/examples/online_serving/text_to_video.md
+++ b/docs/user_guide/examples/online_serving/text_to_video.md
@@ -3,28 +3,17 @@
Source .
-This example demonstrates how to deploy text-to-video models for online video generation using vLLM-Omni.
+This example demonstrates how to deploy the Wan2.2 text-to-video model for online video generation using vLLM-Omni.
-## Supported Models
+## Start Server
-| Model | Model ID |
-|-------|----------|
-| Wan2.1 T2V (1.3B) | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` |
-| Wan2.1 T2V (14B) | `Wan-AI/Wan2.1-T2V-14B-Diffusers` |
-| Wan2.2 T2V | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` |
-| LTX-2 | `Lightricks/LTX-2` |
-
-## Wan2.2 T2V
-
-### Start Server
-
-#### Basic Start
+### Basic Start
```bash
vllm serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --omni --port 8091
```
-#### Start with Parameters
+### Start with Parameters
Or use the startup script:
@@ -165,9 +154,6 @@ curl -X POST http://localhost:8091/v1/videos \
-F "guidance_scale_2=4.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=5.0" \
- -F "enable_frame_interpolation=true" \
- -F "frame_interpolation_exp=1" \
- -F "frame_interpolation_scale=1.0" \
-F "seed=42"
```
@@ -190,35 +176,6 @@ curl -X POST http://localhost:8091/v1/videos \
| `flow_shift` | float | None | Scheduler flow shift (Wan2.2) |
| `seed` | int | None | Random seed (reproducible) |
| `lora` | object | None | LoRA configuration |
-| `enable_frame_interpolation` | bool | false | Enable RIFE frame interpolation before MP4 encoding |
-| `frame_interpolation_exp` | int | 1 | Interpolation exponent; 1=2x temporal resolution, 2=4x |
-| `frame_interpolation_scale` | float | 1.0 | RIFE inference scale; use 0.5 for high-resolution inputs |
-| `frame_interpolation_model_path` | str | None | Local directory or Hugging Face repo ID with `flownet.pkl`; defaults to `elfgum/RIFE-4.22.lite` |
-
-## Frame Interpolation
-
-Frame interpolation is an optional post-processing step for `/v1/videos` and
-`/v1/videos/sync`. It synthesizes intermediate frames between generated frames
-without rerunning the diffusion model. If the generated video has `N` frames,
-the interpolated output frame count is `(N - 1) * 2**exp + 1`. The encoder FPS
-is multiplied by `2**exp` so the output duration remains close to the original.
-
-Frame interpolation runs in the diffusion worker post-processing path instead of
-the API server encoding path, so it can reuse the worker's current accelerator
-device without blocking the FastAPI event loop.
-
-Example: generate 5 frames and interpolate to 9 frames:
-
-```bash
-curl -X POST http://localhost:8091/v1/videos/sync \
- -F "prompt=A dog running through a park" \
- -F "num_frames=5" \
- -F "fps=8" \
- -F "enable_frame_interpolation=true" \
- -F "frame_interpolation_exp=1" \
- -F "frame_interpolation_scale=1.0" \
- -o sync_t2v_interpolated.mp4
-```
## Create Response Format
@@ -277,102 +234,8 @@ while true; do
done
```
-## LTX-2
-
-### Start Server
-
-#### Basic Start
-
-```bash
-vllm serve Lightricks/LTX-2 --omni --port 8098 \
- --enforce-eager --flow-shift 1.0 --boundary-ratio 1.0
-```
-
-For multi-GPU memory reduction, you can enable HSDP:
-
-```bash
-vllm serve Lightricks/LTX-2 --omni --port 8098 \
- --enforce-eager --flow-shift 1.0 --boundary-ratio 1.0 \
- --use-hsdp --hsdp-shard-size 2
-```
-
-#### Start with Optimization Presets
-
-Use the LTX-2 startup script with built-in optimization presets:
-
-```bash
-# Baseline (1 GPU, eager)
-bash run_server_ltx2.sh baseline
-
-# 4-GPU Ulysses sequence parallelism (lossless)
-bash run_server_ltx2.sh ulysses4
-
-# Cache-DiT lossy acceleration (1 GPU, ~1.4× speedup)
-bash run_server_ltx2.sh cache-dit
-
-# Best combo: 4-GPU Ulysses SP + Cache-DiT (~2.2× speedup)
-bash run_server_ltx2.sh best-combo
-```
-
-#### Optimization Benchmarks
-
-Benchmarked on H800, online serving (480×768, 41 frames, 20 steps, `seed=42`).
-"Inference" is the server-reported inference time; excludes HTTP/poll overhead.
-
-| Preset | Server Command | Inference (s) | Speedup | Type |
-|--------|---------------|---------------|---------|------|
-| `baseline` | `--enforce-eager` | 10.3 | 1.00× | — |
-| `compile` | *(default, no --enforce-eager)* | ~10.3 (warm) | ~1.00× | Lossless |
-| `ulysses4` | `--enforce-eager --usp 4` | ~10.3 | ~1.00× | Lossless |
-| `cache-dit` | `--enforce-eager --cache-backend cache_dit` | 7.4 avg | ~1.4× | Lossy |
-| `best-combo` | `--enforce-eager --usp 4 --cache-backend cache_dit` | 4.7 avg | **~2.2×** | Lossless + Lossy |
-
-**Observations**:
-- **torch.compile**: On H800, warm-request inference time matches the eager baseline (~10.3s).
- The first request pays ~6s compilation overhead. Benefit depends on model architecture and GPU.
-- **Ulysses SP (4 GPU)**: No measurable speedup alone for 41-frame generation at this resolution.
- Communication overhead outweighs gains at this sequence length.
-- **Cache-DiT**: Inference varies per request (6–10s) due to dynamic caching decisions.
- Average is ~7.4s (~1.4× speedup) with slight quality tradeoff.
-- **Best combo**: 4-GPU Ulysses SP + Cache-DiT synergize well — Cache-DiT reduces per-step
- computation, making the communication overhead of Ulysses SP worthwhile. Average ~4.7s
- (~2.2× speedup).
-- **FP8 quantization**: Reduces VRAM but does not speed up LTX-2 on H800 (compute-bound).
-
-**Deployment Recommendations**:
-- For **production with quality priority**: use `baseline` with `--enforce-eager`
-- For **maximum throughput** (4 GPUs, quality tradeoff): use `best-combo` (~2.2× speedup)
-- For **single-GPU throughput**: use `cache-dit` (~1.4× speedup)
-- `--enforce-eager` is recommended to avoid torch.compile warmup latency on first request
-
-### Send Requests (curl)
-
-```bash
-# Using the provided script
-bash run_curl_ltx2.sh
-
-# Or directly
-curl -sS -X POST http://localhost:8098/v1/videos \
- -H "Accept: application/json" \
- -F "prompt=A serene lakeside sunrise with mist over the water." \
- -F "width=768" \
- -F "height=480" \
- -F "num_frames=41" \
- -F "fps=24" \
- -F "num_inference_steps=20" \
- -F "guidance_scale=3.0" \
- -F "seed=42"
-```
-
## Example materials
-??? abstract "response.json"
- ``````json
- --8<-- "examples/online_serving/text_to_video/response.json"
- ``````
-??? abstract "run_curl_ltx2.sh"
- ``````sh
- --8<-- "examples/online_serving/text_to_video/run_curl_ltx2.sh"
??? abstract "run_curl_hunyuan_video_15.sh"
``````sh
--8<-- "examples/online_serving/text_to_video/run_curl_hunyuan_video_15.sh"
@@ -385,9 +248,6 @@ curl -sS -X POST http://localhost:8098/v1/videos \
``````sh
--8<-- "examples/online_serving/text_to_video/run_server.sh"
``````
-??? abstract "run_server_ltx2.sh"
- ``````sh
- --8<-- "examples/online_serving/text_to_video/run_server_ltx2.sh"
??? abstract "run_server_hunyuan_video_15.sh"
``````sh
--8<-- "examples/online_serving/text_to_video/run_server_hunyuan_video_15.sh"
diff --git a/examples/offline_inference/bagel/README.md b/examples/offline_inference/bagel/README.md
index 9955fd90db9..226c009f792 100644
--- a/examples/offline_inference/bagel/README.md
+++ b/examples/offline_inference/bagel/README.md
@@ -1,60 +1,44 @@
# BAGEL-7B-MoT
-## Setup
+## Set up
Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup.
-## Architecture
+## Run examples
-BAGEL-7B-MoT is a Mixture-of-Transformers (MoT) model supporting both image generation and understanding. It offers two deployment topologies:
+**Note**: These examples work with the default configuration on an **NVIDIA A100 (80GB)**. We also tested on dual **NVIDIA RTX 5000 Ada (32GB each)**. For dual-GPU setups, please modify the stage configuration to distribute the model across devices.
-| Topology | Stages | Description |
-| :------- | :----- | :---------- |
-| **Two-stage** (default) | Stage 0 (Thinker, AR) + Stage 1 (DiT, Diffusion) | Thinker handles text/understanding via vLLM AR engine; DiT handles image generation. KV cache is transferred between stages. |
-| **Single-stage** | Stage 0 (DiT, Diffusion) only | The DiT stage contains a full LLM, ViT, VAE, and tokenizer internally. All modalities are handled within a single diffusion process. |
-
-Both topologies support all four modalities: `text2img`, `img2img`, `img2text`, `text2text`.
-
-## Quick Start
+Get into the bagel folder
```bash
cd examples/offline_inference/bagel
-
-# Default two-stage mode (auto-detected)
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality text2img \
- --prompts "A cute cat"
-
-# Single-stage mode
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality text2img \
- --prompts "A cute cat" \
- --deploy-config vllm_omni/deploy/bagel_single_stage.yaml
```
-> **Note**: These examples work with the default configuration on an **NVIDIA A100 (80GB)**. For dual-GPU setups, modify the deploy YAML to distribute stages across devices.
+### Modality Control
-## Modality Control
+BAGEL-7B-MoT supports multiple modality modes. You can control the mode using the `--modality` argument:
-Control the mode using the `--modality` argument:
+#### Text to Image (text2img)
-| Modality | Input | Output | Description |
-| :------- | :---- | :----- | :---------- |
-| `text2img` | Text | Image | Generate images from text prompts |
-| `img2img` | Image + Text | Image | Transform images using text guidance |
-| `img2text` | Image + Text | Text | Generate text descriptions from images |
-| `text2text` | Text | Text | Pure text generation (language model mode) |
+- **Pipeline**: Text → Thinker → DiT → VAE Decode → Image
+- **Stages Used**: Stage 0 (Thinker) + Stage 1 (DiT)
+- **KV Transfer**: Thinker sends KV cache to DiT for conditioned generation
-### Text to Image (text2img)
+Generate images from text prompts:
```bash
python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
--modality text2img \
- --prompts "A cute cat" \
- --steps 50
+ --prompts "A cute cat"
```
-### Image to Image (img2img)
+#### Image to Image (img2img)
+
+- **Pipeline**: Image → VAE Encode → DiT → VAE Decode → New Image
+- **Stages Used**: Stage 1 (DiT) only
+- **Special**: Bypasses the Thinker stage, direct image-to-image transformation
+
+Transform images based on text prompts:
```bash
python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
@@ -63,7 +47,13 @@ python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
--prompts "Let the woman wear a blue dress"
```
-### Image to Text (img2text)
+#### Image to Text (img2text)
+
+- **Pipeline**: Image → ViT + VAE Encode → Thinker → Text Output
+- **Stages Used**: Stage 0 (Thinker) only
+- **Special**: Uses both VAE latent encoding AND ViT semantic encoding for comprehensive image understanding
+
+Generate text descriptions from images:
```bash
python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
@@ -72,206 +62,202 @@ python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
--prompts "Describe this image in detail"
```
-### Text to Text (text2text)
+#### Text to Text (text2text)
+
+- **Pipeline**: Text → Thinker → Text Output
+- **Stages Used**: Stage 0 (Thinker) only
+- **Special**: No visual components involved, operates as pure language model
+
+Pure text generation:
```bash
python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
--modality text2text \
--prompts "What is the capital of France?"
-# Load prompts from a text file (one prompt per line):
+# You can load prompts from a text file (one prompt per line):
python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
--modality text2text \
--txt-prompts /path/to/prompts.txt
```
-## Think Mode
+### Inference Steps
-Think mode enables the model to generate `... ` planning/reasoning tokens before producing the final output. This improves generation quality for complex prompts.
-
-- **Two-stage**: The Thinker (AR) stage decodes think tokens, then transfers the augmented KV cache to the DiT stage for image generation.
-- **Single-stage**: The DiT's internal LLM generates think tokens in-place before proceeding to denoise.
+Control the number of inference steps for image generation:
```bash
-# Think + text2img: plan before generating
+# You can adjust steps to 100 to improve image quality
python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
--modality text2img \
- --prompts "A futuristic city with flying cars" \
- --think \
- --max-think-tokens 1000
+ --steps 50 \
+ --prompts "A cute cat"
+```
-# Think + img2img: reason about the edit
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality img2img \
- --image-path /path/to/image.jpg \
- --prompts "Make it look like a watercolor painting" \
- --think
+### Key arguments
-# Think + img2text: reason before describing
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality img2text \
- --image-path /path/to/image.jpg \
- --prompts "What is happening in this image?" \
- --think
+BAGEL-7B-MoT supports **multiple modality modes** for different use cases.
-# Think + text2text: chain-of-thought reasoning
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality text2text \
- --prompts "Solve: 23 * 47" \
- --think
-```
+The default yaml configuration deploys Thinker and DiT on the same GPU. You can use the default configuration file: [`bagel.yaml`](../../../vllm_omni/model_executor/stage_configs/bagel.yaml)
-Think mode parameters:
+#### 📌 Command Line Arguments (end2end.py)
-| Argument | Default | Description |
-| :------- | :------ | :---------- |
-| `--think` | `False` | Enable thinking mode |
-| `--max-think-tokens` | `1000` | Maximum tokens for think generation |
-| `--do-sample` | `False` | Enable sampling (vs. greedy) for text generation |
-| `--text-temperature` | `0.3` | Temperature for text generation sampling |
+| Argument | Type | Default | Description |
+| :--------------------- | :----- | :---------------------------- | :----------------------------------------------------------- |
+| `--model` | string | `ByteDance-Seed/BAGEL-7B-MoT` | Model path or name |
+| `--modality` | choice | `text2img` | Modality mode: `text2img`, `img2img`, `img2text`, `text2text` |
+| `--prompts` | list | `None` | Input text prompts directly |
+| `--txt-prompts` | string | `None` | Path to txt file with one prompt per line |
+| `--image-path` | string | `None` | Input image path (for `img2img`/`img2text`) |
+| `--steps` | int | `50` | Number of inference steps |
+| `--stage-configs-path` | string | `None` | Custom stage config file path |
+| `--worker-backend` | choice | `process` | Worker backend: `process` or `ray` |
+| `--ray-address` | string | `None` | Ray cluster address |
+| `--enable-stats` | flag | `False` | Enable statistics logging |
+| `--init-sleep-seconds` | int | `20` | Initialization sleep time |
+| `--batch-timeout` | int | `5` | Batch timeout |
+| `--init-timeout` | int | `300` | Initialization timeout |
-## Classifier-Free Guidance (CFG)
+------
-CFG controls the trade-off between prompt fidelity and diversity. These parameters apply to image generation modalities (`text2img`, `img2img`).
+#### ⚙️ Stage Configuration Parameters (bagel.yaml)
-```bash
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality text2img \
- --prompts "A photorealistic portrait" \
- --cfg-text-scale 6.0 \
- --cfg-img-scale 2.0 \
- --negative-prompt "blurry, low quality, distorted" \
- --cfg-interval 0.4 1.0 \
- --cfg-renorm-type global \
- --cfg-renorm-min 0.0
-```
+ **Stage 0 - Thinker (LLM Stage)**
-| Argument | Default | Description |
-| :------- | :------ | :---------- |
-| `--cfg-text-scale` | `4.0` | Text CFG scale (higher = more prompt-adherent) |
-| `--cfg-img-scale` | `1.5` | Image CFG scale (for img2img) |
-| `--negative-prompt` | `None` | Negative prompt for CFG conditioning |
-| `--cfg-interval` | pipeline default | CFG active interval `[start, end]` as fractions of total timesteps |
-| `--cfg-renorm-type` | `None` | Renormalization type: `global`, `text_channel`, `channel` |
-| `--cfg-renorm-min` | `None` | Minimum renormalization value |
-| `--cfg-parallel-size` | `1` | CFG parallel size: `1` = batched (single GPU), `2` = 2-branch parallel, `3` = full 3-GPU parallel |
+| Parameter | Value | Description |
+| :------------------------------- | :------------------------------ | :----------------------- |
+| `stage_type` | `llm` | Stage type |
+| `devices` | `"0"` | GPU device ID |
+| `max_num_seqs` | `1` | Maximum batch size |
+| `model_stage` | `thinker` | Model stage identifier |
+| `model_arch` | `BagelForConditionalGeneration` | Model architecture |
+| `gpu_memory_utilization` | `0.4` | GPU memory utilization |
+| `tensor_parallel_size` | `1` | Tensor parallel size |
+| `max_num_batched_tokens` | `32768` | Maximum batched tokens |
+| `omni_kv_config.need_send_cache` | `true` | Whether to send KV cache |
-## Deployment Topologies
+------
-### Two-Stage (Default)
+**Stage 1 - DiT (Diffusion Stage)**
-The default topology auto-detected from the model. No extra flags needed.
+| Parameter | Value | Description |
+| :------------------------------- | :---------- | :-------------------------- |
+| `stage_type` | `diffusion` | Stage type |
+| `devices` | `"0"` | GPU device ID |
+| `max_num_seqs` | `1` | Maximum batch size |
+| `model_stage` | `dit` | Model stage identifier |
+| `gpu_memory_utilization` | `0.4` | GPU memory utilization |
+| `omni_kv_config.need_recv_cache` | `true` | Whether to receive KV cache |
+| `engine_input_source` | `[0]` | Input source from Stage 0 |
-```bash
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality text2img \
- --prompts "A cute cat"
-```
+------
-The pipeline is defined in [`bagel.yaml`](../../../vllm_omni/deploy/bagel.yaml). Stage 0 (Thinker) and Stage 1 (DiT) share GPU 0 by default. For dual-GPU setups, customize the deploy YAML and set `devices: "1"` for stage 1.
+#### Tensor Parallelism (TP)
-### Single-Stage
+For larger models or multi-GPU environments, you can enable Tensor Parallelism (TP) by modifying the stage configuration (e.g., [`bagel.yaml`](../../../vllm_omni/model_executor/stage_configs/bagel.yaml)).
-Pass the single-stage deploy config via `--deploy-config`:
+1. **Set `tensor_parallel_size`**: Increase this value (e.g., to `2` or `4`).
+2. **Set `devices`**: Specify the comma-separated GPU IDs to be used for the stage (e.g., `"0,1"`).
-```bash
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality text2img \
- --prompts "A cute cat" \
- --deploy-config vllm_omni/deploy/bagel_single_stage.yaml
+Example configuration for TP=2 on GPUs 0 and 1:
+```yaml
+ engine_args:
+ tensor_parallel_size: 2
+ ...
+ runtime:
+ devices: "0,1"
```
-See [`bagel_single_stage.yaml`](../../../vllm_omni/deploy/bagel_single_stage.yaml) for configuration details. The `pipeline: bagel_single_stage` field selects the single-stage topology from the pipeline registry.
+------
-### Tensor Parallelism (TP)
+#### 🔗 Runtime Configuration
-For larger models or multi-GPU environments, customize the deploy YAML (see [`bagel.yaml`](../../../vllm_omni/deploy/bagel.yaml)) and set per-stage `tensor_parallel_size` and `devices`:
+| Parameter | Value | Description |
+| :-------------------- | :------ | :------------------------------- |
+| `window_size` | `-1` | Window size (-1 means unlimited) |
+| `max_inflight` | `1` | Maximum inflight requests |
+| `shm_threshold_bytes` | `65536` | Shared memory threshold (64KB) |
-```yaml
-# Example: TP=2 on GPUs 0,1 for the Thinker stage
-stages:
- - stage_id: 0
- tensor_parallel_size: 2
- devices: "0,1"
+## Using Mooncake Connector
+
+[Mooncake](https://github.com/kvcache-ai/Mooncake) is a high-performance distributed KV cache transfer engine that enables efficient cross-node data movement via TCP or RDMA, making it ideal for multi-node disaggregated inference.
+
+By default, BAGEL uses `SharedMemoryConnector` for inter-stage communication. You can switch to the Mooncake connector for better performance on multi-GPU setups and to enable multi-node deployment.
+
+### Prerequisites
+
+Install the Mooncake transfer engine:
+
+```bash
+# For CUDA-enabled systems (recommended)
+pip install mooncake-transfer-engine
+
+# For non-CUDA systems
+pip install mooncake-transfer-engine-non-cuda
```
-Then pass the custom deploy YAML:
+### Step 1: Start the Mooncake Master
+
+On the **primary node**, start the Mooncake master service (run in a separate terminal or background with `&`):
```bash
-python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
- --modality text2img \
- --prompts "A cute cat" \
- --deploy-config /path/to/custom_bagel.yaml
+# Optional: enable disk-backed storage by creating a directory and passing --root_fs_dir.
+# Without it, Mooncake runs in memory-only mode, which is sufficient for KV cache transfer.
+mkdir -p ./mc_storage
+
+mooncake_master \
+ --rpc_port=50051 \
+ --enable_http_metadata_server=true \
+ --http_metadata_server_host=0.0.0.0 \
+ --http_metadata_server_port=8080 \
+ --metrics_port=9003 \
+ --root_fs_dir=./mc_storage/ \
+ --cluster_id=mc-local-1 &
```
-### FP8 Quantization
+### Step 2: Run Offline Inference with Mooncake
+
+Use the provided Mooncake stage config [`bagel_multiconnector.yaml`](../../../vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml). Before launching, update the `metadata_server` and `master` addresses in the YAML to match your Mooncake master node's IP (use `127.0.0.1` for single-node testing).
```bash
+cd examples/offline_inference/bagel
+
+# Text to Image with Mooncake
python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
--modality text2img \
--prompts "A cute cat" \
- --quantization fp8
+ --stage-configs-path ../../../vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml
+
+# Image to Text with Mooncake
+python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
+ --modality img2text \
+ --image-path /path/to/image.jpg \
+ --prompts "Describe this image" \
+ --stage-configs-path ../../../vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml
+
+# Text to Text with Mooncake
+python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
+ --modality text2text \
+ --prompts "What is the capital of France?" \
+ --stage-configs-path ../../../vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml
```
-## Command Line Reference
-
-### Core Arguments
-
-| Argument | Type | Default | Description |
-| :------- | :--- | :------ | :---------- |
-| `--model` | string | `ByteDance-Seed/BAGEL-7B-MoT` | Model path or HuggingFace name |
-| `--modality` | choice | `text2img` | `text2img`, `img2img`, `img2text`, `text2text` |
-| `--prompts` | list | `None` | Input text prompts |
-| `--txt-prompts` | string | `None` | Path to text file with one prompt per line |
-| `--image-path` | string | `None` | Input image path (required for `img2img`/`img2text`) |
-| `--output` | string | `.` | Output directory for saved images |
-| `--steps` | int | `50` | Number of diffusion inference steps |
-| `--seed` | int | `None` | Random seed for reproducibility |
-
-### Think Mode Arguments
-
-| Argument | Type | Default | Description |
-| :------- | :--- | :------ | :---------- |
-| `--think` | flag | `False` | Enable `... ` planning/reasoning |
-| `--max-think-tokens` | int | `1000` | Maximum tokens for think generation |
-| `--do-sample` | flag | `False` | Use sampling instead of greedy decoding |
-| `--text-temperature` | float | `0.3` | Sampling temperature for text generation |
-
-### CFG Arguments
-
-| Argument | Type | Default | Description |
-| :------- | :--- | :------ | :---------- |
-| `--cfg-text-scale` | float | `4.0` | Text CFG guidance scale |
-| `--cfg-img-scale` | float | `1.5` | Image CFG guidance scale |
-| `--negative-prompt` | string | `None` | Negative prompt for CFG |
-| `--cfg-parallel-size` | int | `1` | CFG parallel GPU count (1, 2, or 3) |
-| `--cfg-interval` | float[2] | pipeline default | CFG active window `[start, end]` |
-| `--cfg-renorm-type` | string | `None` | `global`, `text_channel`, or `channel` |
-| `--cfg-renorm-min` | float | `None` | Minimum renormalization value |
-
-### Engine Arguments
-
-| Argument | Type | Default | Description |
-| :------- | :--- | :------ | :---------- |
-| `--deploy-config` | string | `None` | Path to deploy YAML (auto-detected if omitted) |
-| `--worker-backend` | choice | `process` | `process` or `ray` |
-| `--ray-address` | string | `None` | Ray cluster address |
-| `--quantization` | string | `None` | Quantization method (e.g. `fp8`) |
-| `--log-stats` | flag | `False` | Enable statistics logging |
-| `--init-timeout` | int | `300` | Initialization timeout (seconds) |
-| `--batch-timeout` | int | `5` | Batch timeout (seconds) |
-| `--enable-diffusion-pipeline-profiler` | flag | `False` | Profile diffusion stage durations |
+For more details on the Mooncake connector and multi-node setup, see the [Mooncake Store Connector documentation](../../../docs/design/feature/omni_connectors/mooncake_store_connector.md).
+
+------
## FAQ
-- If you encounter OOM errors, try decreasing `max_model_len` or `gpu_memory_utilization` in the deploy YAML.
+- If you encounter an error about the backend of librosa, try to install ffmpeg with the command below.
-**Two-stage VRAM usage:**
+```bash
+sudo apt update
+sudo apt install ffmpeg
+```
-| Stage | VRAM |
-| :---- | :--- |
-| Stage 0 (Thinker) | **15.04 GiB + KV Cache** |
-| Stage 1 (DiT) | **26.50 GiB** |
-| Total | **~42 GiB + KV Cache** |
+- If you don’t know how much VRAM is needed for the model or encounter the OOM error, you can try to decrease the max_model_len.
-**Single-stage VRAM usage:** The DiT loads the full model (~42 GiB) in one process.
+| Stage | VRAM |
+| :------------------ | :--------------------------- |
+| Stage-0 (Thinker) | **15.04 GiB** **+ KV Cache** |
+| Stage-1 (DiT) | **26.50 GiB** |
+| Total | **~42 GiB + KV Cache** |
diff --git a/examples/offline_inference/bagel/end2end.py b/examples/offline_inference/bagel/end2end.py
index a6ce1f1314f..2153a31ba70 100644
--- a/examples/offline_inference/bagel/end2end.py
+++ b/examples/offline_inference/bagel/end2end.py
@@ -2,10 +2,7 @@
import os
from vllm_omni.inputs.data import OmniPromptType
-from vllm_omni.model_executor.stage_input_processors.bagel import (
- GEN_THINK_SYSTEM_PROMPT,
- VLM_THINK_SYSTEM_PROMPT,
-)
+from vllm_omni.model_executor.stage_input_processors.bagel import GEN_THINK_SYSTEM_PROMPT
def parse_args():
@@ -53,12 +50,7 @@ def parse_args():
parser.add_argument("--shm-threshold-bytes", type=int, default=65536)
parser.add_argument("--worker-backend", type=str, default="process", choices=["process", "ray"])
parser.add_argument("--ray-address", type=str, default=None)
- parser.add_argument(
- "--deploy-config",
- type=str,
- default=None,
- help="Path to deploy YAML. If unset, auto-loads vllm_omni/deploy/bagel.yaml based on the HF model_type.",
- )
+ parser.add_argument("--stage-configs-path", type=str, default=None)
parser.add_argument("--steps", type=int, default=50, help="Number of inference steps.")
parser.add_argument("--cfg-text-scale", type=float, default=4.0, help="Text CFG scale (default: 4.0)")
@@ -102,28 +94,7 @@ def parse_args():
default=False,
help="Enable thinking mode: AR stage decodes ... planning tokens before image generation.",
)
- parser.add_argument(
- "--max-think-tokens",
- type=int,
- default=1000,
- help="Maximum number of tokens for thinking text generation (default: 1000).",
- )
- parser.add_argument(
- "--do-sample",
- action="store_true",
- default=False,
- help="Enable sampling for text generation (default: greedy).",
- )
- parser.add_argument(
- "--text-temperature",
- type=float,
- default=0.3,
- help="Temperature for text generation sampling (default: 0.3).",
- )
-
- from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
- nullify_stage_engine_defaults(parser)
args = parser.parse_args()
return args
@@ -134,6 +105,7 @@ def main():
model_name = args.model
prompts: list[OmniPromptType] = []
try:
+ # Preferred: load from txt file (one prompt per line)
if getattr(args, "txt_prompts", None) and args.prompt_type == "text":
with open(args.txt_prompts, encoding="utf-8") as f:
lines = [ln.strip() for ln in f.readlines()]
@@ -146,20 +118,22 @@ def main():
raise
if not prompts:
+ # Default prompt for text2img test if none provided
prompts = ["A cute cat"]
print(f"[Info] No prompts provided, using default: {prompts}")
+ omni_outputs = []
from PIL import Image
from vllm_omni.entrypoints.omni import Omni
omni_kwargs = {}
- deploy_config = args.deploy_config
- if args.think and deploy_config is None:
- deploy_config = "vllm_omni/deploy/bagel_think.yaml"
- print(f"[Info] Think mode enabled, using deploy config: {deploy_config}")
- if deploy_config:
- omni_kwargs["deploy_config"] = deploy_config
+ stage_configs_path = args.stage_configs_path
+ if args.think and stage_configs_path is None:
+ stage_configs_path = "vllm_omni/model_executor/stage_configs/bagel_think.yaml"
+ print(f"[Info] Think mode enabled, using stage config: {stage_configs_path}")
+ if stage_configs_path:
+ omni_kwargs["stage_configs_path"] = stage_configs_path
omni_kwargs.update(
{
@@ -176,7 +150,7 @@ def main():
if args.quantization:
omni_kwargs["quantization_config"] = args.quantization
- omni = Omni.from_cli_args(args, model=model_name, **omni_kwargs)
+ omni = Omni(model=model_name, **omni_kwargs)
formatted_prompts = []
for p in prompts:
@@ -197,10 +171,7 @@ def main():
elif args.modality == "img2text":
if args.image_path:
loaded_image = Image.open(args.image_path).convert("RGB")
- think_prefix = f"<|im_start|>system\n{VLM_THINK_SYSTEM_PROMPT}<|im_end|>\n" if args.think else ""
- final_prompt_text = (
- f"{think_prefix}<|im_start|>user\n<|image_pad|>\n{p}<|im_end|>\n<|im_start|>assistant\n"
- )
+ final_prompt_text = f"<|im_start|>user\n<|image_pad|>\n{p}<|im_end|>\n<|im_start|>assistant\n"
prompt_dict = {
"prompt": final_prompt_text,
"multi_modal_data": {"image": loaded_image},
@@ -208,8 +179,7 @@ def main():
}
formatted_prompts.append(prompt_dict)
elif args.modality == "text2text":
- think_prefix = f"<|im_start|>{VLM_THINK_SYSTEM_PROMPT}<|im_end|>" if args.think else ""
- final_prompt_text = f"{think_prefix}<|im_start|>{p}<|im_end|><|im_start|>"
+ final_prompt_text = f"<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n"
prompt_dict = {"prompt": final_prompt_text, "modalities": ["text"]}
formatted_prompts.append(prompt_dict)
else:
@@ -221,63 +191,44 @@ def main():
formatted_prompts.append(prompt_dict)
params_list = omni.default_sampling_params_list
- # Bagel exposes 1 sampling param set for single-stage (DiT-only) and
- # 2 for two-stage (Thinker + DiT). This heuristic may need updating
- # if future pipelines break that 1:1 mapping.
- is_single_stage = len(params_list) == 1
-
- diffusion_params_idx = 0 if is_single_stage else (1 if len(params_list) > 1 else 0)
- diffusion_params = params_list[diffusion_params_idx]
-
if args.modality in ("text2img", "img2img"):
- diffusion_params.num_inference_steps = args.steps # type: ignore
- diffusion_params.cfg_parallel_size = args.cfg_parallel_size # type: ignore
- if args.seed is not None:
- diffusion_params.seed = args.seed # type: ignore
-
- extra = getattr(diffusion_params, "extra_args", {}) or {}
- extra["cfg_text_scale"] = args.cfg_text_scale
- extra["cfg_img_scale"] = args.cfg_img_scale
- if args.cfg_interval is not None:
- extra["cfg_interval"] = tuple(args.cfg_interval)
- if args.cfg_renorm_type is not None:
- extra["cfg_renorm_type"] = args.cfg_renorm_type
- if args.cfg_renorm_min is not None:
- extra["cfg_renorm_min"] = args.cfg_renorm_min
- if args.negative_prompt is not None:
- extra["negative_prompt"] = args.negative_prompt
-
- needs_text_gen = is_single_stage and (args.think or args.modality in ("text2text", "img2text"))
- if needs_text_gen:
- if args.think:
- extra["think"] = True
- extra["max_think_tokens"] = args.max_think_tokens
- extra["do_sample"] = args.do_sample
- extra["text_temperature"] = args.text_temperature
- diffusion_params.extra_args = extra # type: ignore
+ if len(params_list) > 1:
+ diffusion_params = params_list[1]
+ diffusion_params.num_inference_steps = args.steps # type: ignore
+ diffusion_params.cfg_parallel_size = args.cfg_parallel_size # type: ignore
+ if args.seed is not None:
+ diffusion_params.seed = args.seed # type: ignore
+ extra = {
+ "cfg_text_scale": args.cfg_text_scale,
+ "cfg_img_scale": args.cfg_img_scale,
+ }
+ if args.cfg_interval is not None:
+ extra["cfg_interval"] = tuple(args.cfg_interval)
+ if args.cfg_renorm_type is not None:
+ extra["cfg_renorm_type"] = args.cfg_renorm_type
+ if args.cfg_renorm_min is not None:
+ extra["cfg_renorm_min"] = args.cfg_renorm_min
+ if args.negative_prompt is not None:
+ extra["negative_prompt"] = args.negative_prompt
+ diffusion_params.extra_args = extra # type: ignore
omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list))
img_idx = 0
for req_output in omni_outputs:
- # 2-stage think mode: text output from thinker stage
- ro = getattr(req_output, "request_output", None)
- if ro and getattr(ro, "outputs", None):
- txt = "".join(getattr(o, "text", "") or "" for o in ro.outputs)
- if txt:
- if args.think:
- print(f"[Think]\n{txt}")
- else:
- print(f"[Output] Text:\n{txt}")
-
- # Single-stage DiT: text from custom_output
- custom = getattr(req_output, "_custom_output", {}) or {}
- if custom.get("think_text"):
- print(f"[Think]\n{custom['think_text']}")
- if custom.get("text_output"):
- print(f"[Output] Text:\n{custom['text_output']}")
+ if args.think:
+ text_output = getattr(req_output, "text", None) or getattr(req_output, "outputs", None)
+ if text_output:
+ if isinstance(text_output, list) and text_output:
+ for out in text_output:
+ txt = getattr(out, "text", str(out))
+ if txt:
+ print(f"[Think] {txt}")
+ elif isinstance(text_output, str):
+ print(f"[Think] {text_output}")
images = getattr(req_output, "images", None)
+
if not images:
continue
@@ -287,6 +238,8 @@ def main():
print(f"[Output] Saved image to {save_path}")
img_idx += 1
+ print(omni_outputs)
+
if __name__ == "__main__":
main()
diff --git a/examples/offline_inference/cosyvoice3/README.md b/examples/offline_inference/cosyvoice3/README.md
index 704b49614fb..895d3f660f0 100644
--- a/examples/offline_inference/cosyvoice3/README.md
+++ b/examples/offline_inference/cosyvoice3/README.md
@@ -7,7 +7,7 @@ Install dependencies:
uv pip install -e .
```
-> **Note:** This includes required libraries such as `soundfile`,
+> **Note:** This includes required libraries such as `librosa`, `soundfile`,
> `onnxruntime`, `x-transformers`, and `einops` via
> `requirements/common.txt` and platform-specific requirements files.
@@ -58,14 +58,7 @@ Key components live in `vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py
- Stage 0 uses `CosyVoice3LM` and outputs speech tokens + conditioning features.
- Stage 1 runs the flow model (DiT-based CFM) and HiFiGAN to synthesize waveform.
-Pipeline topology lives in `vllm_omni/model_executor/models/cosyvoice3/pipeline.py`;
-runtime tunables (batch size, memory limits, sampling) live in
-`vllm_omni/deploy/cosyvoice3.yaml`. The deploy config auto-loads by
-HF `model_type` and defaults to `async_chunk: true` (shared-memory
-streaming). Pass `--no-async-chunk` on `vllm serve` to switch to the
-legacy sync path where stage 1 runs `text2flow` over the full
-speech-token sequence.
+Stage wiring is configured in `vllm_omni/model_executor/stage_configs/cosyvoice3.yaml`.
- Stage 0 emits latent speech tokens.
-- Stage 1 consumes them via `sync_process_input_func` (sync mode) or the
- shared-memory connector (async-chunk mode) and outputs audio.
+- Stage 1 consumes them via `custom_process_input_func` and outputs audio.
diff --git a/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py b/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py
index a5dc564ec3b..68ab72b3870 100644
--- a/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py
+++ b/examples/offline_inference/cosyvoice3/verify_e2e_cosyvoice.py
@@ -2,12 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import os
+from pathlib import Path
+import librosa
import numpy as np
import soundfile as sf
from vllm import SamplingParams
from vllm.assets.audio import AudioAsset
-from vllm.multimodal.media.audio import load_audio
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.model_executor.models.cosyvoice3.config import CosyVoice3Config
@@ -15,6 +16,22 @@
from vllm_omni.model_executor.models.cosyvoice3.utils import extract_text_token
+def _ensure_mel_filters_asset() -> None:
+ repo_root = Path(__file__).resolve().parents[3]
+ filters_path = repo_root / "vllm_omni" / "model_executor" / "models" / "cosyvoice3" / "assets" / "mel_filters.npz"
+ if filters_path.exists():
+ return
+
+ source_url = "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/mel_filters.npz"
+ raise FileNotFoundError(
+ "Missing CosyVoice3 mel filter asset:\n"
+ f" {filters_path}\n"
+ "Download it with:\n"
+ f" mkdir -p {filters_path.parent} && "
+ f"curl -L {source_url} -o {filters_path}"
+ )
+
+
def run_e2e():
parser = argparse.ArgumentParser()
# ""FunAudioLLM/Fun-CosyVoice3-0.5B-2512
@@ -24,13 +41,7 @@ def run_e2e():
required=True,
help="Path to CosyVoice3 model directory (e.g., pretrained_models/Fun-CosyVoice3-0.5B/).",
)
- parser.add_argument(
- "--deploy-config",
- type=str,
- default=None,
- help="Override the deploy config path. If unset, auto-loads "
- "vllm_omni/deploy/cosyvoice3.yaml based on the HF model_type.",
- )
+ parser.add_argument("--stage-config", type=str, default="vllm_omni/model_executor/stage_configs/cosyvoice3.yaml")
parser.add_argument("--prompt", type=str, default="Hello, this is a test of the CosyVoice system capability.")
parser.add_argument(
"--prompt-text",
@@ -45,18 +56,24 @@ def run_e2e():
help="Path to tokenizer directory (e.g., /CosyVoice-BlankEN).",
)
args = parser.parse_args()
+ _ensure_mel_filters_asset()
# Ensure tokenizer directory exists
if not os.path.exists(args.tokenizer):
raise FileNotFoundError(f"{args.tokenizer} does not exist!")
- if args.deploy_config is not None and not os.path.exists(args.deploy_config):
- raise FileNotFoundError(f"{args.deploy_config} does not exist!")
+ # Ensure stage config exists
+ if not os.path.exists(args.stage_config):
+ raise FileNotFoundError(f"{args.stage_config} does not exist!")
print(f"Initializing cosyvoice E2E with model={args.model}")
+ # Initialize Omni
+ # This spins up the engine(s) based on the stage config
+ # We pass trust_remote_code=True same as Qwen examples
omni = Omni(
model=args.model,
- deploy_config=args.deploy_config,
+ stage_configs_path=args.stage_config,
+ trust_remote_code=True,
tokenizer=args.tokenizer,
log_stats=True,
)
@@ -68,7 +85,7 @@ def run_e2e():
if not os.path.exists(args.audio_path):
raise FileNotFoundError(f"Audio file not found: {args.audio_path}")
# Load at native sample rate
- audio_signal, sr = load_audio(args.audio_path, sr=None)
+ audio_signal, sr = librosa.load(args.audio_path, sr=None)
# Validate sample rate before processing (similar to original CosyVoice)
min_sr = 16000
diff --git a/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py b/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py
index dc1085c28ef..8ab5e0d9a6c 100644
--- a/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py
+++ b/examples/offline_inference/custom_pipeline/image_to_image/image_edit.py
@@ -44,11 +44,9 @@
import argparse
import asyncio
-import json
import os
import time
from pathlib import Path
-from typing import Any
import torch
from PIL import Image
@@ -60,16 +58,6 @@
from vllm_omni.platforms import current_omni_platform
-def parse_profiler_config(value: str) -> dict[str, Any]:
- try:
- config = json.loads(value)
- except json.JSONDecodeError as e:
- raise argparse.ArgumentTypeError(f"--profiler-config must be valid JSON: {e}") from e
- if not isinstance(config, dict):
- raise argparse.ArgumentTypeError("--profiler-config must be a JSON object")
- return config
-
-
# ===========================
# Argument Parser
# ===========================
@@ -111,16 +99,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--vae-use-slicing", action="store_true")
parser.add_argument("--vae-use-tiling", action="store_true")
parser.add_argument("--enable-cpu-offload", action="store_true")
- parser.add_argument(
- "--profiler-config",
- type=parse_profiler_config,
- default=None,
- help='JSON profiler config for torch/cuda profiling, e.g. \'{"profiler":"torch","torch_profiler_dir":"./perf"}\'.',
- )
-
- from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
- nullify_stage_engine_defaults(parser)
return parser.parse_args()
@@ -179,13 +158,12 @@ async def main():
enable_cpu_offload=args.enable_cpu_offload,
diffusion_load_format="dummy",
custom_pipeline_args={"pipeline_class": "custom_pipeline.CustomPipeline"},
- profiler_config=args.profiler_config,
)
print(">>> Pipeline loaded successfully")
# ---- Profiling + Info ----
- profiler_enabled = args.profiler_config is not None
+ profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR"))
print(f"\n{'=' * 60}")
print("Generation Configuration")
print(f"Model: {args.model}")
diff --git a/examples/offline_inference/dynin_omni/README.md b/examples/offline_inference/dynin_omni/README.md
deleted file mode 100644
index d28b360714e..00000000000
--- a/examples/offline_inference/dynin_omni/README.md
+++ /dev/null
@@ -1,110 +0,0 @@
-# Dynin-Omni Offline End2End Example
-
-This folder contains a unified offline inference entrypoint:
-
-- `end2end.py`
-
-## 1. Environment Setup
-
-Run from repository root:
-
-```bash
-cd
-```
-
-If needed, install this repo in editable mode:
-
-```bash
-pip install -e .
-```
-
-## 2. Extra Dependencies (EMOVA)
-
-Install the following packages for EMOVA-related components:
-
-```bash
-pip install \
- "phonemizer==3.3.0" \
- "Unidecode==1.4.0" \
- "hydra-core==1.3.2" \
- "pytorch-lightning==1.1.0" \
- "wget==3.2" \
- "wrapt==2.1.1" \
- "onnx==1.20.1" \
- "frozendict==2.4.7" \
- "inflect==7.5.0" \
- "braceexpand==0.1.7" \
- "webdataset==1.0.2" \
- "torch-stft==0.1.4" \
- "editdistance==0.8.1"
-```
-
-## 3. Hardware and VRAM Requirements
-
-This example uses a 3-stage pipeline on one GPU by default
-([`dynin_omni.yaml`](../../../vllm_omni/model_executor/stage_configs/dynin_omni.yaml)):
-
-- Stage-0 (`token2text`): `gpu_memory_utilization: 0.5`
-- Stage-1 (`token2image`): `gpu_memory_utilization: 0.1`
-- Stage-2 (`token2audio`): `gpu_memory_utilization: 0.1`
-
-### Requested GPU Memory Budget from `gpu_memory_utilization`
-
-| Stage | Utilization | A100 80GB | H200 141GB |
-| :-- | :-- | :-- | :-- |
-| Stage-0 (token2text) | 0.5 | ~40.0 GB | ~70.5 GB |
-| Stage-1 (token2image) | 0.1 | ~8.0 GB | ~14.1 GB |
-| Stage-2 (token2audio) | 0.1 | ~8.0 GB | ~14.1 GB |
-| Total requested budget | 0.7 | ~56.0 GB | ~98.7 GB |
-
-### Observed Runtime Signal (from your log)
-
-- Stage-0 reported: `Model loading took 15.12 GiB memory` (weights footprint signal).
-- Stages 1/2 can still add runtime memory depending on task path and backend allocations.
-- Keep extra headroom for CUDA/PyTorch overhead and temporary allocations.
-
-### GPU Compatibility
-
-- Confirmed target GPUs for this setup: **NVIDIA H200**, **NVIDIA A100**.
-- CI/e2e coverage in this repo also includes CUDA **L4** markers for Dynin tests.
-
-## 4. End2End Run Examples
-
-```bash
-# t2t
-python /examples/offline_inference/dynin_omni/end2end.py \
- --task t2t --model snu-aidas/Dynin-Omni --text
-
-# i2t
-python /examples/offline_inference/dynin_omni/end2end.py \
- --task i2t --model snu-aidas/Dynin-Omni --image --text "Please describe this image in detail."
-
-# s2t
-python /examples/offline_inference/dynin_omni/end2end.py \
- --task s2t --model snu-aidas/Dynin-Omni --audio --text "Transcribe the given audio."
-
-# t2i
-python /examples/offline_inference/dynin_omni/end2end.py \
- --task t2i --model snu-aidas/Dynin-Omni --text
-
-# v2t
-python /examples/offline_inference/dynin_omni/end2end.py \
- --task v2t --model snu-aidas/Dynin-Omni --video --text "Describe this video in detail."
-
-# i2i
-python /examples/offline_inference/dynin_omni/end2end.py \
- --task i2i --model snu-aidas/Dynin-Omni --image --text
-
-# t2s
-python /examples/offline_inference/dynin_omni/end2end.py \
- --task t2s --model snu-aidas/Dynin-Omni --text
-```
-
-## 5. Notes
-
-- Outputs are saved under task-specific directories in `/tmp` by default.
-- You can override output path with `--output-dir`.
-- If you want to force local config resolution, pass `--dynin-config-path `.
-- If you see the warning
- `max_num_batched_tokens (32768) exceeds max_num_seqs * max_model_len (4096)`,
- reduce `max_num_batched_tokens` in stage config (for example, `4096` in CI config).
diff --git a/examples/offline_inference/dynin_omni/end2end.py b/examples/offline_inference/dynin_omni/end2end.py
deleted file mode 100644
index 82cff0c0015..00000000000
--- a/examples/offline_inference/dynin_omni/end2end.py
+++ /dev/null
@@ -1,1451 +0,0 @@
-#!/usr/bin/env python3
-# SPDX-License-Identifier: Apache-2.0
-
-from __future__ import annotations
-
-import argparse
-import json
-import os
-import re
-import sys
-import time
-import types
-from importlib.machinery import ModuleSpec
-from pathlib import Path
-from typing import Any
-
-import numpy as np
-import torch
-from PIL import Image
-
-TASK_CHOICES = ("t2t", "t2i", "t2s", "i2i", "i2t", "s2t", "v2t")
-
-TASK_DEFAULT_RUNTIME = {
- "t2t": ("mmu", "mmu", 0, "text"),
- "t2i": ("t2i", "t2i_gen", 2, "image"),
- "t2s": ("t2s_mmu_like", "t2s_gen", 1, "audio"),
- "i2i": ("i2i", "i2i", 2, "image"),
- "i2t": ("mmu", "mmu", 0, "text"),
- "s2t": ("s2t", "s2t", 0, "text"),
- "v2t": ("v2t", "v2t", 0, "text"),
-}
-
-TASK_RUNTIME_FALLBACKS: dict[str, dict[str, Any]] = {
- "t2t": {
- "output_dir": "/tmp/dynin_end2end_outputs",
- "prompt_max_text_len": 1024,
- "max_new_tokens": 1024,
- "steps": 1024,
- "block_length": 16,
- "temperature": 0.0,
- "cfg_scale": 0.0,
- },
- "t2i": {
- "output_dir": "/tmp/dynin_t2i_outputs",
- "prompt_max_text_len": 128,
- "image_token_count": 1024,
- "mask_token_id": 126336,
- "codebook_size": 8192,
- "timesteps": 20,
- "guidance_scale": 3.5,
- "temperature": 1.0,
- },
- "i2i": {
- "output_dir": "/tmp/dynin_i2i_outputs",
- "prompt_max_text_len": 128,
- "mask_token_id": 126336,
- "codebook_size": 8192,
- "timesteps": 64,
- "guidance_scale": 3.5,
- "temperature": 1.0,
- "image_resolution": 336,
- "use_train_i2i_prompt": True,
- },
- "i2t": {
- "output_dir": "/tmp/dynin_i2t_outputs",
- "prompt_max_text_len": 128,
- "max_new_tokens": 128,
- "steps": 128,
- "block_length": 2,
- "temperature": 0.0,
- "cfg_scale": 0.0,
- "mask_token_id": 126336,
- "codebook_size": 8192,
- "image_resolution": 480,
- "remasking": "low_confidence",
- },
- "s2t": {
- "output_dir": "/tmp/dynin_s2t_outputs",
- "prompt_max_text_len": 1024,
- "max_new_tokens": 128,
- "steps": 128,
- "block_length": 2,
- "temperature": 0.0,
- "cfg_scale": 0.0,
- "mask_token_id": 126336,
- "codebook_size": 8192,
- "remasking": "low_confidence",
- },
- "t2s": {
- "output_dir": "/tmp/dynin_t2s_outputs",
- "runtime_task": "t2s_mmu_like",
- "prompting_task": "t2s_gen",
- "prompt_max_text_len": 1024,
- "t2s_token_length": 512,
- "mask_token_id": 126336,
- "codebook_size": 8192,
- "audio_codebook_size": 4096,
- "steps": 512,
- "block_length": 128,
- "temperature": 1.0,
- "cfg_scale": 2.5,
- "t2s_condition": "gender-female_emotion-neutral_speed-normal_pitch-normal",
- },
- "v2t": {
- "output_dir": "/tmp/dynin_v2t_outputs",
- "prompt_max_text_len": 1024,
- "max_new_tokens": 128,
- "steps": 128,
- "block_length": 2,
- "temperature": 0.0,
- "cfg_scale": 0.0,
- "mask_token_id": 126336,
- "codebook_size": 8192,
- "image_resolution": 224,
- "num_frames": 5,
- "remasking": "low_confidence",
- },
-}
-
-DEFAULT_I2T_QUESTION = "Please describe this image in detail."
-DEFAULT_S2T_INSTRUCTION = "Transcribe the given audio."
-DEFAULT_V2T_QUESTION = "Please provide a detailed description of the video."
-DEFAULT_T2T_PROMPT = "Explain multimodal LLM inference in 3 sentences."
-DEFAULT_T2S_INSTRUCTION = "Convert the given text into spoken audio."
-DEFAULT_T2S_PROMPT = "Hello. This is a default text-to-speech sample."
-
-DYNIN_SPECIAL_TOKENS = (
- "<|soi|>",
- "<|eoi|>",
- "<|sov|>",
- "<|eov|>",
- "<|t2i|>",
- "<|mmu|>",
- "<|t2v|>",
- "<|v2v|>",
- "<|lvg|>",
- "<|i2i|>",
- "<|ti2ti|>",
- "<|v2t|>",
- "<|v2s|>",
- "<|s2t|>",
- "<|t2s|>",
- "<|s2s|>",
- "<|soa|>",
- "<|eoa|>",
-)
-
-
-def bootstrap_repo_path() -> Path:
- repo_root = Path(__file__).resolve().parents[3]
- repo_root_str = str(repo_root)
- if repo_root_str not in sys.path:
- sys.path.insert(0, repo_root_str)
- return repo_root
-
-
-def ensure_safe_import_for_vllm() -> None:
- os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1")
- try:
- import torchvision # noqa: F401
-
- return
- except Exception:
- pass
-
- import enum
-
- class _InterpolationMode(enum.Enum):
- NEAREST = 0
- BILINEAR = 2
- BICUBIC = 3
- LANCZOS = 1
- HAMMING = 4
- BOX = 5
-
- tv_mod = types.ModuleType("torchvision")
- tv_mod.__dict__["__version__"] = "0.0-stub"
- tv_mod.__spec__ = ModuleSpec(name="torchvision", loader=None)
- transforms_mod = types.ModuleType("torchvision.transforms")
- transforms_mod.__spec__ = ModuleSpec(name="torchvision.transforms", loader=None)
- transforms_mod.InterpolationMode = _InterpolationMode
- tv_mod.transforms = transforms_mod
- sys.modules["torchvision"] = tv_mod
- sys.modules["torchvision.transforms"] = transforms_mod
-
-
-def sanitize_repo_id(repo_id: str) -> str:
- return re.sub(r"[^a-zA-Z0-9._-]+", "_", repo_id)
-
-
-def is_hf_repo_id(value: str) -> bool:
- return isinstance(value, str) and value.count("/") == 1 and all(value.split("/", 1))
-
-
-def ensure_local_model_dir(model: str, cache_dir: Path, localize: bool) -> Path:
- model_path = Path(model).expanduser()
- if model_path.is_dir():
- return model_path.resolve()
- if not localize:
- return Path(model)
-
- from huggingface_hub import snapshot_download
-
- cache_dir.mkdir(parents=True, exist_ok=True)
- os.environ.setdefault("HF_HOME", str(cache_dir / ".hf_home"))
- local_dir = cache_dir / sanitize_repo_id(model)
- if not local_dir.exists():
- print(f"[end2end] Downloading model into local cache: {local_dir}")
- snapshot_download(
- repo_id=model,
- local_dir=str(local_dir),
- local_dir_use_symlinks=True,
- resume_download=True,
- )
- return local_dir.resolve()
-
-
-def resolve_local_only(
- override: bool | None,
- source: str,
- default: bool,
-) -> bool:
- if override is not None:
- return bool(override)
- return default or Path(source).expanduser().is_dir()
-
-
-def load_text_tokenizer(tokenizer_source: str, local_files_only: bool):
- from transformers import AutoTokenizer
-
- kwargs = {
- "trust_remote_code": True,
- "padding_side": "left",
- "local_files_only": bool(local_files_only),
- }
- try:
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_source, **kwargs)
- except TypeError:
- kwargs.pop("local_files_only", None)
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_source, **kwargs)
- return tokenizer
-
-
-def preprocess_image(image: Image.Image, resolution: int) -> torch.Tensor:
- w, h = image.size
- short_side = min(w, h)
- scale = resolution / short_side
- new_w, new_h = round(w * scale), round(h * scale)
- image = image.resize((new_w, new_h), Image.BICUBIC)
- left = (new_w - resolution) // 2
- top = (new_h - resolution) // 2
- image = image.crop((left, top, left + resolution, top + resolution))
- arr = np.array(image, dtype=np.float32) / 255.0
- tensor = torch.from_numpy(arr).permute(2, 0, 1)
- return (tensor - 0.5) / 0.5
-
-
-def load_vq_image_encoder(source: str, local_files_only: bool, device: torch.device) -> Any:
- from vllm_omni.model_executor.models.dynin_omni.dynin_omni_common import get_dynin_magvit_attr
-
- MAGVITv2 = get_dynin_magvit_attr("MAGVITv2", source=source, local_files_only=local_files_only)
- vq_model = MAGVITv2.from_pretrained(source, local_files_only=local_files_only).to(device)
- vq_model.requires_grad_(False)
- vq_model.eval()
- return vq_model
-
-
-def encode_image_tokens(
- image_path: Path,
- vq_model: Any,
- device: torch.device,
- resolution: int,
-) -> torch.Tensor:
- image = Image.open(image_path).convert("RGB")
- image_tensor = preprocess_image(image, resolution=resolution).unsqueeze(0).to(device)
- with torch.no_grad():
- token_ids = vq_model.get_code(image_tensor)
- token_ids = torch.as_tensor(token_ids, dtype=torch.long).detach().cpu()
- if token_ids.ndim == 2 and token_ids.shape[0] == 1:
- token_ids = token_ids[0]
- return token_ids.contiguous()
-
-
-def encode_video_tokens(
- video_path: Path,
- vq_model: Any,
- device: torch.device,
- resolution: int,
- num_frames: int,
-) -> torch.Tensor:
- import cv2
-
- cap = cv2.VideoCapture(str(video_path))
- frames: list[np.ndarray] = []
- while True:
- ok, frame = cap.read()
- if not ok:
- break
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
- frames.append(frame)
- cap.release()
- if not frames:
- raise ValueError(f"Video has no readable frames: {video_path}")
- if len(frames) < num_frames:
- raise ValueError(f"Video has {len(frames)} frames, requires >= {num_frames}: {video_path}")
-
- indices = np.linspace(0, len(frames) - 1, num_frames).astype(int)
- token_list: list[torch.Tensor] = []
- for idx in indices:
- pil = Image.fromarray(frames[int(idx)])
- frame_tensor = preprocess_image(pil, resolution=resolution).unsqueeze(0).to(device)
- with torch.no_grad():
- token_list.append(torch.as_tensor(vq_model.get_code(frame_tensor), dtype=torch.long))
- merged = torch.cat(token_list, dim=1).detach().cpu()
- if merged.ndim == 2 and merged.shape[0] == 1:
- merged = merged[0]
- return merged.contiguous()
-
-
-def load_vq_audio_encoder(source: str, local_files_only: bool, device: torch.device) -> Any:
- from transformers import AutoModel
-
- kwargs = {
- "trust_remote_code": True,
- "local_files_only": bool(local_files_only),
- "low_cpu_mem_usage": False,
- }
- try:
- model = AutoModel.from_pretrained(source, **kwargs)
- except TypeError:
- kwargs.pop("low_cpu_mem_usage", None)
- try:
- model = AutoModel.from_pretrained(source, **kwargs)
- except TypeError:
- kwargs.pop("local_files_only", None)
- model = AutoModel.from_pretrained(source, **kwargs)
- model.requires_grad_(False)
- model.eval()
- if hasattr(model, "to"):
- model = model.to(device)
- return model
-
-
-def encode_audio_tokens(audio_path: Path, vq_audio_model: Any) -> torch.Tensor:
- encoded = vq_audio_model.encode(str(audio_path))
- if isinstance(encoded, dict):
- for key in ("input_ids", "token_ids", "codes", "tokens"):
- if key in encoded:
- encoded = encoded[key]
- break
- encoded = torch.as_tensor(encoded, dtype=torch.long).detach().cpu()
- if encoded.ndim == 1:
- encoded = encoded.unsqueeze(0)
- elif encoded.ndim > 2:
- encoded = encoded.view(encoded.shape[0], -1)
- return encoded.contiguous()
-
-
-def build_chat_prompt(content: str) -> str:
- return (
- f"<|start_header_id|>user<|end_header_id|>\n{content}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
- )
-
-
-def resolve_task_text(
- *,
- task_name: str,
- text: str,
- instruction: str = "",
- raw_prompt: bool = False,
-) -> str:
- text = str(text or "").strip()
-
- if task_name == "t2t" and not text:
- return DEFAULT_T2T_PROMPT
- if task_name == "i2t" and not text:
- return DEFAULT_I2T_QUESTION
- if task_name == "s2t" and not text:
- return DEFAULT_S2T_INSTRUCTION
- if task_name == "v2t" and not text:
- return DEFAULT_V2T_QUESTION
- if task_name in {"t2i", "i2i"} and not text:
- return "A high quality detailed image."
-
- if task_name != "t2s":
- return text
-
- if not text:
- text = DEFAULT_T2S_PROMPT
-
- if raw_prompt:
- return text
-
- instruction = str(instruction or "").strip() or DEFAULT_T2S_INSTRUCTION
- return build_chat_prompt(f"{instruction}\n{text}")
-
-
-def load_universal_prompting(
- *,
- tokenizer: Any,
- tokenizer_source: str,
- max_text_len: int,
- cond_dropout_prob: float,
- local_files_only: bool,
- max_audio_len: int = 512,
- max_audio_len_short: int = 256,
-) -> Any:
- from vllm_omni.model_executor.models.dynin_omni.dynin_omni_common import (
- DYNIN_REMOTE_SETTINGS,
- resolve_remote_attr,
- )
-
- UniversalPrompting = resolve_remote_attr(
- "UniversalPrompting",
- module_name="prompting_utils",
- settings=DYNIN_REMOTE_SETTINGS,
- source=tokenizer_source,
- local_files_only=bool(local_files_only),
- fallback_module_names=("modeling_dynin_omni",),
- )
- init_kwargs: dict[str, Any] = {
- "max_text_len": int(max_text_len),
- "special_tokens": DYNIN_SPECIAL_TOKENS,
- "ignore_id": -100,
- "cond_dropout_prob": float(cond_dropout_prob),
- "use_reserved_token": True,
- "max_audio_len": int(max_audio_len),
- "max_audio_len_short": int(max_audio_len_short),
- }
- try:
- return UniversalPrompting(tokenizer, **init_kwargs)
- except TypeError:
- init_kwargs.pop("max_audio_len", None)
- init_kwargs.pop("max_audio_len_short", None)
- return UniversalPrompting(tokenizer, **init_kwargs)
-
-
-def _runtime_fallback(task: str, key: str, value: Any) -> Any:
- if isinstance(value, str):
- if value.strip() != "":
- return value
- elif value is not None:
- return value
- return TASK_RUNTIME_FALLBACKS.get(task, {}).get(key)
-
-
-def _validate_generation_args(*, task: str, max_new_tokens: int, steps: int, block_length: int) -> None:
- # Keep i2t/v2t generation constraints aligned with i2t.py/v2t.py.
- if task not in {"i2t", "v2t"}:
- return
- if max_new_tokens <= 0:
- raise ValueError(f"{task} requires max_new_tokens > 0.")
- if block_length <= 0:
- raise ValueError(f"{task} requires block_length > 0.")
- if steps <= 0:
- raise ValueError(f"{task} requires steps > 0.")
- if max_new_tokens % block_length != 0:
- raise ValueError(f"{task} requires max_new_tokens % block_length == 0, got {max_new_tokens} % {block_length}")
- num_blocks = max_new_tokens // block_length
- if num_blocks <= 0:
- raise ValueError(f"{task} has invalid num_blocks.")
- if steps % num_blocks != 0:
- raise ValueError(
- f"{task} requires steps % (max_new_tokens // block_length) == 0, "
- f"got steps={steps}, max_new_tokens={max_new_tokens}, block_length={block_length}"
- )
-
-
-def make_prompt_payload(
- *,
- task: str,
- text: str,
- image_tokens: torch.Tensor | None,
- audio_tokens: torch.Tensor | None,
- video_tokens: torch.Tensor | None,
- image_placeholder_tokens: int,
- audio_placeholder_tokens: int,
- image_token_offset: int,
- speech_token_offset: int,
- mask_token_id: int,
- use_train_i2i_prompt: bool,
-) -> tuple[Any, str]:
- runtime_task, prompting_task, _, _ = TASK_DEFAULT_RUNTIME[task]
- del runtime_task
-
- if task == "t2t":
- payload = ([[]], [build_chat_prompt(text)])
- return payload, prompting_task
-
- if task == "i2t":
- if image_tokens is None:
- raise ValueError("i2t requires image tokens")
- img = image_tokens.view(-1).long() + int(image_token_offset)
- payload = ([[img]], [build_chat_prompt(text)])
- return payload, prompting_task
-
- if task == "s2t":
- if audio_tokens is None:
- raise ValueError("s2t requires audio tokens")
- aud = audio_tokens.long() + int(speech_token_offset)
- if aud.ndim == 1:
- aud = aud.unsqueeze(0)
- payload = ([aud], [build_chat_prompt(text)])
- return payload, prompting_task
-
- if task == "v2t":
- if video_tokens is None:
- raise ValueError("v2t requires video tokens")
- vid = video_tokens.view(-1).long() + int(image_token_offset)
- payload = (vid.unsqueeze(0), [build_chat_prompt(text)])
- return payload, prompting_task
-
- if task == "t2i":
- image_placeholder = torch.full(
- (1, int(image_placeholder_tokens)),
- fill_value=int(mask_token_id),
- dtype=torch.long,
- )
- payload = ([text], image_placeholder)
- return payload, prompting_task
-
- if task == "i2i":
- if image_tokens is None:
- raise ValueError("i2i requires image tokens")
- src = image_tokens.view(1, -1).long() + int(image_token_offset)
- target_len = int(image_placeholder_tokens) if image_placeholder_tokens > 0 else int(src.shape[1])
- image_placeholder = torch.full(
- (1, target_len),
- fill_value=int(mask_token_id),
- dtype=torch.long,
- )
- if use_train_i2i_prompt:
- labels_placeholder = torch.full(
- (1, target_len),
- fill_value=-100,
- dtype=torch.long,
- )
- payload = ([text], src, image_placeholder, labels_placeholder)
- return payload, "i2i"
- payload = ([text], src, image_placeholder)
- return payload, "i2i_gen"
-
- if task == "t2s":
- audio_placeholder = torch.full(
- (1, int(audio_placeholder_tokens)),
- fill_value=int(mask_token_id),
- dtype=torch.long,
- )
- payload = ([text], audio_placeholder)
- return payload, prompting_task
-
- raise ValueError(f"Unsupported task: {task}")
-
-
-def _to_1d_int_list(value: Any) -> list[int]:
- if value is None:
- return []
- if isinstance(value, torch.Tensor):
- tensor = value.detach().to(device="cpu", dtype=torch.long)
- else:
- tensor = torch.as_tensor(value, dtype=torch.long)
- if tensor.ndim == 0:
- tensor = tensor.view(1)
- elif tensor.ndim >= 2:
- tensor = tensor.view(tensor.shape[0], -1)[0]
- return [int(v) for v in tensor.tolist()]
-
-
-def _run_uni_prompting(uni_prompting: Any, payload: Any, prompting_task: str) -> tuple[list[int], list[int]]:
- prepared = uni_prompting(payload, prompting_task)
- if isinstance(prepared, tuple):
- prepared_input_ids = prepared[0] if len(prepared) > 0 else None
- prepared_attention_mask = prepared[1] if len(prepared) > 1 else None
- else:
- prepared_input_ids = prepared
- prepared_attention_mask = None
-
- input_ids = _to_1d_int_list(prepared_input_ids)
- attention_mask = _to_1d_int_list(prepared_attention_mask)
- if not input_ids:
- raise RuntimeError(f"UniversalPrompting returned empty input_ids for task={prompting_task}")
- return input_ids, attention_mask
-
-
-def _get_special_token_id(uni_prompting: Any, token: str) -> int:
- sptids = getattr(uni_prompting, "sptids_dict", None) or {}
- if token not in sptids:
- raise KeyError(f"Special token not found in UniversalPrompting.sptids_dict: {token}")
- token_ids = _to_1d_int_list(sptids[token])
- if not token_ids:
- raise ValueError(f"Special token id is empty for token: {token}")
- return int(token_ids[0])
-
-
-def _tokenize_chat_query(tokenizer: Any, text: str) -> list[int]:
- encoded = tokenizer(build_chat_prompt(text), return_tensors="pt").input_ids[0]
- token_ids = _to_1d_int_list(encoded)
- if not token_ids:
- raise RuntimeError("Failed to tokenize chat query text.")
- return token_ids
-
-
-def _flatten_media_token_ids_with_offset(token_ids: Any, token_offset: int) -> list[int]:
- media_ids = token_ids
- if isinstance(media_ids, torch.Tensor):
- media_ids = media_ids.detach().cpu().reshape(-1).tolist()
- else:
- media_ids = np.asarray(media_ids).reshape(-1).tolist()
- return [int(x) + int(token_offset) for x in media_ids]
-
-
-def _scalar_token_id(value: Any) -> int:
- if isinstance(value, torch.Tensor):
- if value.numel() == 0:
- raise ValueError("Empty special-token tensor.")
- return int(value.view(-1)[0].item())
- if isinstance(value, (list, tuple)):
- if not value:
- raise ValueError("Empty special-token list.")
- return int(value[0])
- return int(value)
-
-
-def build_v2t_input_ids(
- *,
- video_token_ids: Any,
- tokenizer: Any,
- uni_prompting: Any,
- question: str,
- image_token_offset: int,
-) -> tuple[list[int], str]:
- media_ids = video_token_ids
- if isinstance(media_ids, torch.Tensor):
- media_ids = media_ids.detach().cpu().reshape(-1).tolist()
- else:
- media_ids = np.asarray(media_ids).reshape(-1).tolist()
- media_ids = [int(x) + int(image_token_offset) for x in media_ids]
-
- sptids = uni_prompting.sptids_dict
- task_id = _scalar_token_id(sptids["<|v2t|>"])
- soi_id = _scalar_token_id(sptids["<|soi|>"])
- eoi_id = _scalar_token_id(sptids["<|eoi|>"])
- sot_id = _scalar_token_id(sptids["<|sot|>"])
-
- prompt_text = build_v2t_chat_prompt(question)
- query_ids = tokenizer(prompt_text, return_tensors="pt").input_ids[0].detach().cpu().tolist()
- input_ids = [task_id, soi_id] + media_ids + [eoi_id, sot_id] + [int(v) for v in query_ids]
- return input_ids, prompt_text
-
-
-def build_i2t_input_ids(
- *,
- image_token_ids: Any,
- tokenizer: Any,
- uni_prompting: Any,
- question: str,
- image_token_offset: int,
-) -> tuple[list[int], str]:
- image_ids = image_token_ids
- if isinstance(image_ids, torch.Tensor):
- image_ids = image_ids.detach().cpu().reshape(-1).tolist()
- else:
- image_ids = np.asarray(image_ids).reshape(-1).tolist()
- image_ids = [int(x) + int(image_token_offset) for x in image_ids]
-
- sptids = uni_prompting.sptids_dict
- task_id = _scalar_token_id(sptids["<|mmu|>"])
- soi_id = _scalar_token_id(sptids["<|soi|>"])
- eoi_id = _scalar_token_id(sptids["<|eoi|>"])
- sot_id = _scalar_token_id(sptids["<|sot|>"])
-
- prompt_text = build_i2t_chat_prompt(question)
- query_ids = tokenizer(prompt_text, return_tensors="pt").input_ids[0].detach().cpu().tolist()
- input_ids = [task_id, soi_id] + image_ids + [eoi_id, sot_id] + [int(v) for v in query_ids]
- return input_ids, prompt_text
-
-
-def build_v2t_chat_prompt(question: str) -> str:
- return (
- f"<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
- )
-
-
-def build_i2t_chat_prompt(question: str) -> str:
- return (
- f"<|start_header_id|>user<|end_header_id|>\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
- )
-
-
-def make_mmu_prompt(
- *,
- task: str,
- text: str,
- tokenizer: Any,
- uni_prompting: Any,
- image_tokens: torch.Tensor | None,
- audio_tokens: torch.Tensor | None,
- video_tokens: torch.Tensor | None,
- image_token_offset: int,
- speech_token_offset: int,
-) -> tuple[list[int], list[int]]:
- query_ids = _tokenize_chat_query(tokenizer, text)
-
- if task == "i2t":
- token_ids, _ = build_i2t_input_ids(
- image_token_ids=image_tokens,
- tokenizer=tokenizer,
- uni_prompting=uni_prompting,
- question=text,
- image_token_offset=int(image_token_offset),
- )
- token_ids = [int(v) for v in token_ids]
- return token_ids, [1] * len(token_ids)
-
- if task == "v2t":
- token_ids, _ = build_v2t_input_ids(
- video_token_ids=video_tokens,
- tokenizer=tokenizer,
- uni_prompting=uni_prompting,
- question=text,
- image_token_offset=int(image_token_offset),
- )
- token_ids = [int(v) for v in token_ids]
- return token_ids, [1] * len(token_ids)
-
- if task == "s2t":
- if audio_tokens is None:
- raise ValueError("s2t requires audio tokens")
- audio_ids = _to_1d_int_list(audio_tokens.long() + int(speech_token_offset))
- token_ids = [
- _get_special_token_id(uni_prompting, "<|s2t|>"),
- _get_special_token_id(uni_prompting, "<|soa|>"),
- *audio_ids,
- _get_special_token_id(uni_prompting, "<|eoa|>"),
- *query_ids,
- ]
- return token_ids, [1] * len(token_ids)
-
- raise ValueError(f"Unsupported task for validation-style MMU prompt: {task}")
-
-
-def iter_mm_outputs(outputs: list[Any]):
- for omni_out in outputs:
- req_out = getattr(omni_out, "request_output", None)
- req_list = req_out if isinstance(req_out, list) else [req_out]
- for item in req_list:
- if item is None:
- continue
- mm_out = getattr(item, "multimodal_output", None) or {}
- if mm_out:
- yield mm_out
- completions = getattr(item, "outputs", None) or []
- for completion in completions:
- c_mm_out = getattr(completion, "multimodal_output", None) or {}
- if c_mm_out:
- yield c_mm_out
- omni_mm = getattr(omni_out, "multimodal_output", None) or {}
- if omni_mm:
- yield omni_mm
-
-
-def _to_token_list(value: Any) -> list[int]:
- if value is None:
- return []
- if hasattr(value, "detach"):
- value = value.detach()
- if hasattr(value, "cpu"):
- value = value.cpu()
- if hasattr(value, "flatten"):
- value = value.flatten().tolist()
- if isinstance(value, tuple):
- value = list(value)
- if not isinstance(value, list):
- return []
- out: list[int] = []
- for token in value:
- if isinstance(token, bool):
- continue
- try:
- out.append(int(token))
- except Exception:
- continue
- return out
-
-
-def extract_text_output(outputs: list[Any], tokenizer: Any) -> str:
- for mm_out in iter_mm_outputs(outputs):
- text = mm_out.get("text")
- if isinstance(text, list) and text:
- text = text[-1]
- if isinstance(text, str) and text.strip():
- return text.strip()
- for key in ("text_tokens", "token_ids"):
- token_ids = _to_token_list(mm_out.get(key))
- if not token_ids:
- continue
- decoded = tokenizer.decode(token_ids, skip_special_tokens=True)
- if isinstance(decoded, str) and decoded.strip():
- return decoded.strip()
- return ""
-
-
-def extract_image_output(outputs: list[Any]) -> torch.Tensor | None:
- for mm_out in iter_mm_outputs(outputs):
- image = mm_out.get("image")
- if isinstance(image, list) and image:
- image = image[-1]
- if isinstance(image, torch.Tensor):
- return image
- return None
-
-
-def tensor_to_pil_image(image: torch.Tensor) -> Image.Image:
- arr = image.detach().cpu().numpy()
- if arr.ndim == 4:
- arr = arr[0]
- if arr.ndim == 3 and arr.shape[0] in (1, 3, 4):
- arr = np.transpose(arr, (1, 2, 0))
- if arr.dtype != np.uint8:
- arr = arr.astype(np.float32)
- if arr.max() <= 1.0:
- arr = arr * 255.0
- arr = np.clip(arr, 0.0, 255.0).astype(np.uint8)
- if arr.ndim == 3 and arr.shape[-1] == 1:
- arr = arr[..., 0]
- return Image.fromarray(arr)
-
-
-def extract_audio_output(outputs: list[Any]) -> tuple[np.ndarray, int] | None:
- for mm_out in iter_mm_outputs(outputs):
- audio = mm_out.get("audio")
- if audio is None:
- audio = mm_out.get("speech")
- if audio is None:
- continue
-
- def _to_wav_array(value: Any) -> np.ndarray:
- if isinstance(value, torch.Tensor):
- return value.detach().cpu().numpy().reshape(-1).astype(np.float32)
- return np.asarray(value).reshape(-1).astype(np.float32)
-
- if isinstance(audio, list):
- chunks = [_to_wav_array(chunk) for chunk in audio]
- wav = np.concatenate(chunks, axis=0) if chunks else np.zeros((0,), dtype=np.float32)
- else:
- wav = _to_wav_array(audio)
- sr = mm_out.get("sr", 24000)
- if hasattr(sr, "item"):
- try:
- sr = int(sr.item())
- except Exception:
- sr = 24000
- elif isinstance(sr, list):
- sr = int(sr[0]) if sr else 24000
- else:
- sr = int(sr)
- return wav, sr
- return None
-
-
-def save_audio_wav(path: Path, wav: np.ndarray, sr: int) -> None:
- try:
- import soundfile as sf
-
- sf.write(str(path), wav, int(sr), format="WAV")
- except Exception:
- from scipy.io import wavfile
-
- wav_i16 = np.clip(wav, -1.0, 1.0)
- wav_i16 = (wav_i16 * 32767.0).astype(np.int16)
- wavfile.write(str(path), int(sr), wav_i16)
-
-
-def parse_args(repo_root: Path) -> argparse.Namespace:
- parser = argparse.ArgumentParser(description="Dynin-Omni unified offline end2end example.")
- parser.add_argument("--task", type=str, required=True, choices=TASK_CHOICES)
- parser.add_argument("--model", type=str, required=True, help="HF repo id or local model directory.")
- parser.add_argument(
- "--stage-config-path",
- type=str,
- default=str(repo_root / "vllm_omni/model_executor/stage_configs/dynin_omni.yaml"),
- help="Path to stage config yaml.",
- )
- parser.add_argument(
- "--dynin-config-path",
- type=str,
- default="",
- help="Path to DYNIN config yaml (passed through additional_information).",
- )
- parser.add_argument(
- "--model-cache-dir",
- type=str,
- default="/tmp/dynin_localized_models",
- help="Cache directory used when --model is HF repo id.",
- )
- parser.add_argument(
- "--localize-model",
- action=argparse.BooleanOptionalAction,
- default=True,
- help="If true and --model is HF repo id, snapshot it under --model-cache-dir.",
- )
- parser.add_argument("--text", type=str, default="", help="Prompt/edit/question text.")
- parser.add_argument("--instruction", type=str, default="", help="Optional extra instruction.")
- parser.add_argument("--raw-prompt", action=argparse.BooleanOptionalAction, default=False)
- parser.add_argument("--image", type=str, default="", help="Input image path for i2i/i2t.")
- parser.add_argument("--audio", type=str, default="", help="Input audio path for s2t.")
- parser.add_argument("--video", type=str, default="", help="Input video path for v2t.")
- parser.add_argument("--image-resolution", type=int, default=None)
- parser.add_argument("--num-frames", type=int, default=None)
- parser.add_argument(
- "--output-dir",
- type=str,
- default="",
- help="Directory for generated outputs.",
- )
- parser.add_argument("--output-prefix", type=str, default="")
- parser.add_argument("--seed", type=int, default=0)
- parser.add_argument("--dtype", type=str, default="auto")
- parser.add_argument("--max-tokens-per-stage", type=int, default=1)
-
- parser.add_argument("--runtime-task", type=str, default="", help="Override runtime task key.")
- parser.add_argument("--prompting-task", type=str, default="", help="Override prompting task key.")
- parser.add_argument("--detok-id", type=int, default=None, help="Override detok id.")
-
- parser.add_argument("--prompt-max-text-len", type=int, default=None)
- parser.add_argument("--cond-dropout-prob", type=float, default=0.0)
- parser.add_argument("--max-new-tokens", type=int, default=None)
- parser.add_argument("--steps", type=int, default=None)
- parser.add_argument("--block-length", type=int, default=None)
- parser.add_argument("--temperature", type=float, default=None)
- parser.add_argument("--cfg-scale", type=float, default=None)
- parser.add_argument("--remasking", type=str, default="low_confidence")
-
- parser.add_argument("--timesteps", type=int, default=None)
- parser.add_argument("--guidance-scale", type=float, default=None)
- parser.add_argument("--noise-type", type=str, default="mask")
- parser.add_argument("--noise-schedule-name", type=str, default="cosine")
- parser.add_argument("--noise-schedule-params", type=str, default="{}")
-
- parser.add_argument("--mask-token-id", type=int, default=None)
- parser.add_argument("--codebook-size", type=int, default=None)
- parser.add_argument("--audio-codebook-size", type=int, default=None)
- parser.add_argument("--image-token-count", type=int, default=None)
- parser.add_argument("--t2s-token-length", type=int, default=None)
- parser.add_argument(
- "--t2s-condition",
- type=str,
- default="",
- )
- parser.add_argument(
- "--use-train-i2i-prompt",
- action="store_true",
- help="Use i2i training prompt template (default behavior of i2i.py).",
- )
- parser.add_argument(
- "--no-use-train-i2i-prompt",
- dest="use_train_i2i_prompt",
- action="store_false",
- help="Use i2i_gen prompt template.",
- )
- parser.set_defaults(use_train_i2i_prompt=None)
-
- parser.add_argument("--tokenizer-path", type=str, default="")
- parser.add_argument("--model-local-files-only", action=argparse.BooleanOptionalAction, default=None)
- parser.add_argument("--tokenizer-local-files-only", action=argparse.BooleanOptionalAction, default=None)
-
- parser.add_argument("--vq-model-image-path", type=str, default="")
- parser.add_argument("--vq-model-image-local-files-only", action=argparse.BooleanOptionalAction, default=None)
- parser.add_argument("--vq-model-audio-path", type=str, default="")
- parser.add_argument("--vq-model-audio-local-files-only", action=argparse.BooleanOptionalAction, default=None)
-
- parser.add_argument("--disable-hf-xet", action=argparse.BooleanOptionalAction, default=True)
- from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
-
- nullify_stage_engine_defaults(parser)
- return parser.parse_args()
-
-
-def main() -> None:
- repo_root = bootstrap_repo_path()
- ensure_safe_import_for_vllm()
- from vllm_omni.model_executor.models.dynin_omni.dynin_omni_common import (
- DYNIN_PROMPT_SOURCE_KEY,
- DYNIN_PROMPT_SOURCE_OFFLINE_PREBUILT,
- )
-
- args = parse_args(repo_root)
-
- if args.disable_hf_xet:
- os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
-
- np.random.seed(args.seed)
- torch.manual_seed(args.seed)
-
- model_dir = ensure_local_model_dir(
- model=args.model,
- cache_dir=Path(args.model_cache_dir).expanduser(),
- localize=bool(args.localize_model),
- )
- model_source = str(model_dir)
-
- task_name = str(args.task)
- dynin_config_path = str(Path(args.dynin_config_path).expanduser())
- os.environ["DYNIN_CONFIG_PATH"] = dynin_config_path
- default_runtime_task, default_prompting_task, default_detok_id, final_modality = TASK_DEFAULT_RUNTIME[task_name]
- runtime_task = args.runtime_task.strip() or str(
- _runtime_fallback(task_name, "runtime_task", None) or default_runtime_task
- )
- prompting_task = args.prompting_task.strip() or str(
- _runtime_fallback(task_name, "prompting_task", None) or default_prompting_task
- )
- detok_id_default = _runtime_fallback(task_name, "detok_id", None)
- if detok_id_default is None:
- detok_id_default = default_detok_id
- detok_id = int(detok_id_default if args.detok_id is None else args.detok_id)
-
- output_dir_default = _runtime_fallback(task_name, "output_dir", args.output_dir)
- resolved_output_dir = str(output_dir_default or "/tmp/dynin_end2end_outputs")
-
- image_resolution_value = _runtime_fallback(
- task_name,
- "image_resolution",
- args.image_resolution,
- )
- if image_resolution_value is None:
- image_resolution_value = 336
- image_resolution = int(image_resolution_value)
-
- num_frames_value = _runtime_fallback(
- task_name,
- "num_frames",
- args.num_frames,
- )
- if num_frames_value is None:
- num_frames_value = 8
- num_frames = int(num_frames_value)
-
- prompt_max_text_len_value = _runtime_fallback(
- task_name,
- "prompt_max_text_len",
- args.prompt_max_text_len,
- )
- if prompt_max_text_len_value is None:
- prompt_max_text_len_value = 1024
- prompt_max_text_len = int(prompt_max_text_len_value)
-
- max_new_tokens_value = _runtime_fallback(
- task_name,
- "max_new_tokens",
- args.max_new_tokens,
- )
- if max_new_tokens_value is None:
- max_new_tokens_value = 256
- max_new_tokens = int(max_new_tokens_value)
-
- steps_value = _runtime_fallback(
- task_name,
- "steps",
- args.steps,
- )
- if steps_value is None:
- steps_value = 256
- steps = int(steps_value)
-
- block_length_value = _runtime_fallback(
- task_name,
- "block_length",
- args.block_length,
- )
- if block_length_value is None:
- block_length_value = 2
- block_length = int(block_length_value)
-
- temperature_value = _runtime_fallback(
- task_name,
- "temperature",
- args.temperature,
- )
- if temperature_value is None:
- temperature_value = 0.0
- temperature = float(temperature_value)
-
- cfg_scale_value = _runtime_fallback(
- task_name,
- "cfg_scale",
- args.cfg_scale,
- )
- if cfg_scale_value is None:
- cfg_scale_value = 0.0
- cfg_scale = float(cfg_scale_value)
-
- remasking = str(_runtime_fallback(task_name, "remasking", args.remasking) or "low_confidence")
-
- timesteps_value = _runtime_fallback(
- task_name,
- "timesteps",
- args.timesteps,
- )
- if timesteps_value is None:
- timesteps_value = 20
- timesteps = int(timesteps_value)
-
- guidance_scale_value = _runtime_fallback(
- task_name,
- "guidance_scale",
- args.guidance_scale,
- )
- if guidance_scale_value is None:
- guidance_scale_value = 0.0
- guidance_scale = float(guidance_scale_value)
-
- mask_token_id_value = _runtime_fallback(
- task_name,
- "mask_token_id",
- args.mask_token_id,
- )
- if mask_token_id_value is None:
- mask_token_id_value = 126336
- mask_token_id = int(mask_token_id_value)
-
- codebook_size_value = _runtime_fallback(
- task_name,
- "codebook_size",
- args.codebook_size,
- )
- if codebook_size_value is None:
- codebook_size_value = 8192
- codebook_size = int(codebook_size_value)
-
- audio_codebook_size_value = _runtime_fallback(
- task_name,
- "audio_codebook_size",
- args.audio_codebook_size,
- )
- if audio_codebook_size_value is None:
- audio_codebook_size_value = 4096
- audio_codebook_size = int(audio_codebook_size_value)
-
- image_token_count_value = _runtime_fallback(
- task_name,
- "image_token_count",
- args.image_token_count,
- )
- image_token_count = int(image_token_count_value) if image_token_count_value is not None else 0
-
- t2s_token_length_value = _runtime_fallback(
- task_name,
- "t2s_token_length",
- args.t2s_token_length,
- )
- if t2s_token_length_value is None:
- t2s_token_length_value = 383
- t2s_token_length = int(t2s_token_length_value)
-
- t2s_condition = str(
- _runtime_fallback(task_name, "t2s_condition", args.t2s_condition)
- or "gender-female_emotion-neutral_speed-normal_pitch-normal"
- )
-
- _validate_generation_args(
- task=task_name,
- max_new_tokens=max_new_tokens,
- steps=steps,
- block_length=block_length,
- )
-
- use_train_i2i_prompt = _runtime_fallback(task_name, "use_train_i2i_prompt", args.use_train_i2i_prompt)
- if use_train_i2i_prompt is None:
- use_train_i2i_prompt = bool(task_name == "i2i")
- use_train_i2i_prompt = bool(use_train_i2i_prompt)
-
- if task_name in {"i2i", "i2t"} and not args.image:
- raise ValueError(f"--task {task_name} requires --image")
- if task_name == "s2t" and not args.audio:
- raise ValueError("--task s2t requires --audio")
- if task_name == "v2t" and not args.video:
- raise ValueError("--task v2t requires --video")
-
- text = resolve_task_text(
- task_name=task_name,
- text=args.text,
- instruction=args.instruction,
- raw_prompt=bool(args.raw_prompt),
- )
-
- tokenizer_source = args.tokenizer_path.strip() or model_source
- model_local_only = resolve_local_only(
- args.model_local_files_only, model_source, default=Path(model_source).is_dir()
- )
- tokenizer_local_only = resolve_local_only(
- args.tokenizer_local_files_only,
- tokenizer_source,
- default=model_local_only,
- )
- tokenizer = load_text_tokenizer(tokenizer_source, local_files_only=tokenizer_local_only)
- text_vocab_size = int(len(tokenizer))
-
- image_tokens: torch.Tensor | None = None
- audio_tokens: torch.Tensor | None = None
- video_tokens: torch.Tensor | None = None
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- vq_image_source = args.vq_model_image_path.strip() or "snu-aidas/magvitv2"
- vq_audio_source = args.vq_model_audio_path.strip() or "snu-aidas/emova_speech_tokenizer_vllm"
- vq_image_local_only = resolve_local_only(args.vq_model_image_local_files_only, vq_image_source, default=False)
- vq_audio_local_only = resolve_local_only(args.vq_model_audio_local_files_only, vq_audio_source, default=False)
-
- if task_name in {"i2i", "i2t", "v2t"}:
- vq_image = load_vq_image_encoder(vq_image_source, vq_image_local_only, device)
- if task_name in {"i2i", "i2t"}:
- image_tokens = encode_image_tokens(
- Path(args.image).expanduser().resolve(),
- vq_model=vq_image,
- device=device,
- resolution=int(image_resolution),
- )
- if task_name == "v2t":
- video_tokens = encode_video_tokens(
- Path(args.video).expanduser().resolve(),
- vq_model=vq_image,
- device=device,
- resolution=int(image_resolution),
- num_frames=int(num_frames),
- )
- if hasattr(vq_image, "cpu"):
- vq_image = vq_image.cpu()
-
- if task_name == "s2t":
- vq_audio = load_vq_audio_encoder(vq_audio_source, vq_audio_local_only, device)
- audio_tokens = encode_audio_tokens(Path(args.audio).expanduser().resolve(), vq_audio)
- if hasattr(vq_audio, "cpu"):
- vq_audio = vq_audio.cpu()
-
- noise_schedule_params: dict[str, Any] = {}
- try:
- parsed = json.loads(args.noise_schedule_params)
- if isinstance(parsed, dict):
- noise_schedule_params = {str(k): v for k, v in parsed.items()}
- except Exception:
- noise_schedule_params = {}
-
- image_token_count = int(image_token_count)
- if image_token_count <= 0:
- if image_tokens is not None:
- image_token_count = int(image_tokens.numel())
- else:
- base_res = int(image_resolution)
- image_token_count = max(1, (base_res // 16) ** 2)
-
- uncond_input_ids: list[int] | None = None
- uncond_attention_mask: list[int] | None = None
- if task_name == "t2t":
- messages = [{"role": "user", "content": text}]
- if getattr(tokenizer, "chat_template", None):
- prompt_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
- encoded = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False)
- else:
- encoded = tokenizer(text, return_tensors="pt", add_special_tokens=True)
- prompt_token_ids = _to_1d_int_list(encoded["input_ids"])
- prompt_attention_mask = _to_1d_int_list(encoded.get("attention_mask"))
- if not prompt_attention_mask:
- prompt_attention_mask = [1] * len(prompt_token_ids)
- else:
- max_audio_len_for_prompt = int(max(t2s_token_length, 512))
- if audio_tokens is not None:
- max_audio_len_for_prompt = max(max_audio_len_for_prompt, int(audio_tokens.numel()))
- max_audio_len_short_for_prompt = max(256, max_audio_len_for_prompt // 2)
-
- uni_prompting = load_universal_prompting(
- tokenizer=tokenizer,
- tokenizer_source=tokenizer_source,
- max_text_len=int(prompt_max_text_len),
- cond_dropout_prob=float(args.cond_dropout_prob),
- local_files_only=bool(tokenizer_local_only),
- max_audio_len=int(max_audio_len_for_prompt),
- max_audio_len_short=int(max_audio_len_short_for_prompt),
- )
- prompting_text_vocab_size = int(len(uni_prompting.text_tokenizer))
-
- is_mmu_task = task_name in {"i2t", "s2t", "v2t"} and not args.prompting_task.strip()
- if is_mmu_task:
- prompt_token_ids, prompt_attention_mask = make_mmu_prompt(
- task=task_name,
- text=text,
- tokenizer=uni_prompting.text_tokenizer,
- uni_prompting=uni_prompting,
- image_tokens=image_tokens,
- audio_tokens=audio_tokens,
- video_tokens=video_tokens,
- image_token_offset=prompting_text_vocab_size,
- speech_token_offset=prompting_text_vocab_size + int(codebook_size),
- )
- else:
- prompt_payload, prompting_task = make_prompt_payload(
- task=task_name,
- text=text,
- image_tokens=image_tokens,
- audio_tokens=audio_tokens,
- video_tokens=video_tokens,
- image_placeholder_tokens=image_token_count,
- audio_placeholder_tokens=int(t2s_token_length),
- image_token_offset=text_vocab_size,
- speech_token_offset=text_vocab_size + int(codebook_size),
- mask_token_id=int(mask_token_id),
- use_train_i2i_prompt=use_train_i2i_prompt,
- )
- if args.prompting_task.strip():
- prompting_task = args.prompting_task.strip()
-
- prompt_token_ids, prompt_attention_mask = _run_uni_prompting(
- uni_prompting,
- prompt_payload,
- prompting_task,
- )
-
- if task_name in {"i2t", "s2t", "v2t"}:
- prompt_attention_mask = [1] * len(prompt_token_ids)
- if not prompt_attention_mask:
- prompt_attention_mask = [1] * len(prompt_token_ids)
-
- if task_name in {"t2i", "i2i"} and guidance_scale > 0:
- uncond_payload, uncond_prompting_task = make_prompt_payload(
- task=task_name,
- text="",
- image_tokens=image_tokens,
- audio_tokens=audio_tokens,
- video_tokens=video_tokens,
- image_placeholder_tokens=image_token_count,
- audio_placeholder_tokens=int(t2s_token_length),
- image_token_offset=text_vocab_size,
- speech_token_offset=text_vocab_size + int(codebook_size),
- mask_token_id=int(mask_token_id),
- use_train_i2i_prompt=use_train_i2i_prompt,
- )
- uncond_input_ids, uncond_attention_mask = _run_uni_prompting(
- uni_prompting,
- uncond_payload,
- args.prompting_task.strip() or uncond_prompting_task,
- )
- if not uncond_attention_mask:
- uncond_attention_mask = [1] * len(uncond_input_ids)
-
- runtime_info: dict[str, Any] = {
- "task": [runtime_task],
- "detok_id": [int(detok_id)],
- DYNIN_PROMPT_SOURCE_KEY: [DYNIN_PROMPT_SOURCE_OFFLINE_PREBUILT],
- "dynin_config_path": [str(dynin_config_path)],
- "attention_mask": [prompt_attention_mask],
- "prompt_max_text_len": [int(prompt_max_text_len)],
- "prompting_max_text_len": [int(prompt_max_text_len)],
- "cond_dropout_prob": [float(args.cond_dropout_prob)],
- "prompting_cond_dropout_prob": [float(args.cond_dropout_prob)],
- "tokenizer_path": [str(tokenizer_source)],
- "text_vocab_size": [int(text_vocab_size)],
- "model_local_files_only": [bool(model_local_only)],
- "max_new_tokens": [int(max_new_tokens)],
- "steps": [int(steps)],
- "block_length": [int(block_length)],
- "temperature": [float(temperature)],
- "cfg_scale": [float(cfg_scale)],
- "remasking": [str(remasking)],
- "mask_id": [int(mask_token_id)],
- "mask_token_id": [int(mask_token_id)],
- "codebook_size": [int(codebook_size)],
- "audio_codebook_size": [int(audio_codebook_size)],
- "timesteps": [int(timesteps)],
- "guidance_scale": [float(guidance_scale)],
- "noise_type": [str(args.noise_type)],
- "noise_schedule_name": [str(args.noise_schedule_name)],
- "noise_schedule_params": [noise_schedule_params],
- "seq_len": [int(image_token_count)],
- "condition": [str(t2s_condition)],
- "vq_model_image_path": [str(vq_image_source)],
- "vq_model_image_local_files_only": [bool(vq_image_local_only)],
- "vq_model_audio_path": [str(vq_audio_source)],
- "vq_model_audio_local_files_only": [bool(vq_audio_local_only)],
- }
-
- if task_name in {"t2t", "i2t", "s2t", "v2t"}:
- runtime_info["prompt_length"] = [int(len(prompt_token_ids))]
- if uncond_input_ids is not None:
- runtime_info["uncond_input_ids"] = [uncond_input_ids]
- if uncond_attention_mask is not None:
- runtime_info["uncond_attention_mask"] = [uncond_attention_mask]
-
- if task_name == "t2s":
- runtime_info["max_new_tokens"] = [int(t2s_token_length)]
-
- prompt = {
- "prompt_token_ids": [int(v) for v in prompt_token_ids],
- "additional_information": runtime_info,
- "modalities": [final_modality],
- }
-
- from vllm import SamplingParams
-
- from vllm_omni.entrypoints.omni import Omni
-
- stage_config_path = str(Path(args.stage_config_path).expanduser())
- omni = Omni(model=model_source, stage_configs_path=stage_config_path, dtype=args.dtype)
- sampling_params_list = [
- SamplingParams(max_tokens=int(args.max_tokens_per_stage), temperature=0.0, top_p=1.0, detokenize=False)
- for _ in range(omni.num_stages)
- ]
-
- try:
- outputs = list(omni.generate(prompt, sampling_params_list))
- finally:
- omni.close()
-
- out_dir = Path(resolved_output_dir).expanduser()
- out_dir.mkdir(parents=True, exist_ok=True)
- stamp = time.strftime("%Y%m%d_%H%M%S")
- prefix = args.output_prefix.strip() or f"{task_name}_{stamp}"
-
- if final_modality == "text":
- text_out = extract_text_output(outputs, tokenizer=tokenizer)
- if not text_out:
- raise RuntimeError("No text output found.")
- out_path = out_dir / f"{prefix}.txt"
- out_path.write_text(text_out + "\n", encoding="utf-8")
- print(f"[end2end] text saved: {out_path}")
- print(text_out)
- return
-
- if final_modality == "image":
- image_out = extract_image_output(outputs)
- if image_out is None:
- raise RuntimeError("No image output found.")
- pil = tensor_to_pil_image(image_out)
- out_path = out_dir / f"{prefix}.png"
- pil.save(out_path)
- print(f"[end2end] image saved: {out_path}")
- return
-
- if final_modality == "audio":
- audio_out = extract_audio_output(outputs)
- if audio_out is None:
- raise RuntimeError("No audio output found.")
- wav, sr = audio_out
- out_path = out_dir / f"{prefix}.wav"
- save_audio_wav(out_path, wav, sr)
- print(f"[end2end] audio saved: {out_path} (sr={sr}, samples={wav.shape[0]})")
- return
-
- raise RuntimeError(f"Unsupported final modality: {final_modality}")
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/offline_inference/fish_speech/end2end.py b/examples/offline_inference/fish_speech/end2end.py
index 60830d06b7f..31c24d3d5d6 100644
--- a/examples/offline_inference/fish_speech/end2end.py
+++ b/examples/offline_inference/fish_speech/end2end.py
@@ -18,6 +18,7 @@
import logging
import math
import os
+import tempfile
import time
import numpy as np
@@ -87,10 +88,17 @@ def build_prompt(
semantic_len,
)
+ # The model-side structured clone prefill consumes a temporary .npy file and
+ # removes it after loading. Abnormal termination can still leave the file
+ # behind, which is acceptable for this offline example.
+ with tempfile.NamedTemporaryFile(prefix="fish_ref_", suffix=".npy", delete=False) as f:
+ np.save(f, np.asarray(ref_audio_wav, dtype=np.float32))
+ ref_audio_npy_path = f.name
+
additional_information = {
"text": normalized_text,
"ref_text": normalized_ref_text,
- "ref_audio_wav": torch.from_numpy(np.asarray(ref_audio_wav, dtype=np.float32)),
+ "ref_audio_path": ref_audio_npy_path,
"ref_audio_sr": int(ref_audio_sr),
"fish_structured_voice_clone": True,
}
diff --git a/examples/offline_inference/glm_image/README.md b/examples/offline_inference/glm_image/README.md
new file mode 100644
index 00000000000..c3c7c291696
--- /dev/null
+++ b/examples/offline_inference/glm_image/README.md
@@ -0,0 +1,145 @@
+# GLM-Image Multistage End-to-End Inference
+
+This example demonstrates how to run GLM-Image with the vLLM-Omni multistage architecture.
+
+## Architecture
+
+GLM-Image uses a 2-stage pipeline:
+
+```
+┌─────────────────────────────────────────────────────────────┐
+│ GLM-Image Pipeline │
+├─────────────────────────────────────────────────────────────┤
+│ │
+│ Stage 0 (AR Model) Stage 1 (Diffusion) │
+│ ┌─────────────────┐ ┌─────────────────────┐ │
+│ │ vLLM-optimized │ │ GlmImagePipeline │ │
+│ │ GlmImageFor │ prior │ ┌───────────────┐ │ │
+│ │ Conditional │──tokens───►│ │ DiT Denoiser │ │ │
+│ │ Generation │ │ └───────────────┘ │ │
+│ │ (9B AR model) │ │ │ │ │
+│ └─────────────────┘ │ ▼ │ │
+│ ▲ │ ┌───────────────┐ │ │
+│ │ │ │ VAE Decode │──┼──► Image
+│ Text/Image │ └───────────────┘ │ │
+│ Input └─────────────────────┘ │
+│ │
+└─────────────────────────────────────────────────────────────┘
+```
+
+## Features
+
+- **vLLM-optimized AR**: Uses PagedAttention and tensor parallelism for faster prior token generation
+- **Flexible deployment**: AR and Diffusion stages can run on different GPUs
+- **Text-to-Image**: Generate images from text descriptions
+- **Image-to-Image**: Edit existing images with text prompts
+
+## Usage
+
+### Text-to-Image
+
+```bash
+python end2end.py \
+ --config-path ../../../vllm_omni/model_executor/stage_configs/glm_image.yaml \
+ --prompt "A beautiful sunset over the ocean with sailing boats" \
+ --height 1024 \
+ --width 1024 \
+ --output output_t2i.png
+```
+
+### Image-to-Image (Image Editing)
+
+```bash
+python end2end.py \
+ --config-path ../../../vllm_omni/model_executor/stage_configs/glm_image.yaml \
+ --prompt "Transform this scene into a winter wonderland" \
+ --image input.png \
+ --output output_i2i.png
+```
+
+### With Custom Parameters
+
+```bash
+python end2end.py \
+ --model-path /path/to/glm-image \
+ --config-path ../../../vllm_omni/model_executor/stage_configs/glm_image.yaml \
+ --prompt "A photorealistic cat sitting on a window sill" \
+ --height 1024 \
+ --width 1024 \
+ --num-inference-steps 50 \
+ --guidance-scale 1.5 \
+ --seed 42 \
+ --output output.png
+```
+
+## Shell Scripts
+
+### Run Text-to-Image
+
+```bash
+./run_t2i.sh
+```
+
+### Run Image-to-Image
+
+```bash
+./run_i2i.sh --image /path/to/input.png
+```
+
+## Stage Configuration
+
+The stage config (`glm_image.yaml`) defines:
+
+- **Stage 0 (AR)**: Uses `GPUARWorker` with vLLM engine
+
+ - Model: `GlmImageForConditionalGeneration`
+ - Output: `token_ids` (prior tokens)
+
+- **Stage 1 (Diffusion)**: Uses diffusion engine
+ - Model: `GlmImagePipeline`
+ - Output: Generated image
+
+See `vllm_omni/model_executor/stage_configs/glm_image.yaml` for full configuration.
+
+## Comparison with Single-Stage
+
+| Aspect | Single-Stage (transformers) | Multistage (vLLM) |
+| ----------- | --------------------------- | ------------------- |
+| AR Model | transformers native | vLLM PagedAttention |
+| Memory | Higher (no KV cache opt) | Lower (optimized) |
+| Throughput | Lower | Higher |
+| Flexibility | Single GPU | Multi-GPU support |
+
+## Troubleshooting
+
+### OOM Error
+
+Try reducing memory usage:
+
+```bash
+# In glm_image.yaml, adjust:
+gpu_memory_utilization: 0.5 # Reduce from 0.6
+```
+
+### Slow Initialization
+
+The first run loads model weights. Subsequent runs are faster:
+
+```bash
+--stage-init-timeout 900 # Increase timeout for slow storage
+```
+
+### `Transformers does not recognize this architecture` Error
+
+Your have to upgrade `transformers` package to `5.3.0` or above:
+
+```
+pip install --upgrade transformers
+```
+
+## Requirements
+
+- vLLM-Omni with GLM-Image support
+- CUDA-capable GPU (recommended: H100/A100 with 80GB)
+- GLM-Image model weights
+- `transformers` v5.3.0 or above
diff --git a/examples/offline_inference/glm_image/end2end.py b/examples/offline_inference/glm_image/end2end.py
new file mode 100644
index 00000000000..13bcd23f55a
--- /dev/null
+++ b/examples/offline_inference/glm_image/end2end.py
@@ -0,0 +1,511 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+End-to-end offline inference example for GLM-Image with multistage architecture.
+
+This script tests the multistage pipeline where:
+- Stage 0 (AR): vLLM-optimized GlmImageForConditionalGeneration generates prior_token_ids
+- Stage 1 (Diffusion): GlmImagePipeline performs DiT denoising + VAE decode
+
+Usage (text-to-image):
+ python end2end.py \
+ --model-path /path/to/glm-image \
+ --config-path /path/to/glm_image.yaml \
+ --prompt "A beautiful sunset over the ocean" \
+ --output output_t2i.png
+
+Usage (image-to-image / image edit):
+ python end2end.py \
+ --model-path /path/to/glm-image \
+ --config-path /path/to/glm_image.yaml \
+ --prompt "Make it look like winter" \
+ --image input.png \
+ --output output_i2i.png
+
+Usage (with custom parameters):
+ python end2end.py \
+ --model-path /path/to/glm-image \
+ --config-path /path/to/glm_image.yaml \
+ --prompt "A cat sitting on a window sill" \
+ --height 1024 \
+ --width 1024 \
+ --num-inference-steps 50 \
+ --guidance-scale 1.5 \
+ --seed 42
+
+For more options, run:
+ python end2end.py --help
+"""
+
+import argparse
+import os
+import time
+from pathlib import Path
+
+from PIL import Image
+
+from vllm_omni.entrypoints.omni import Omni
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+
+# Default stage config path (relative to vllm_omni package)
+DEFAULT_CONFIG_PATH = "vllm_omni/model_executor/stage_configs/glm_image.yaml"
+
+SEED = 42
+
+# GLM-Image special tokens
+GLM_IMAGE_EOS_TOKEN_ID = 16385 # eos_token_id from generation_config.json
+GLM_IMAGE_VISION_VOCAB_SIZE = 16512 # top_k should be vision_vocab_size
+
+
+def compute_max_tokens(height: int, width: int, factor: int = 32) -> int:
+ """
+ Compute max_new_tokens for GLM-Image AR generation.
+
+ GLM-Image generates tokens in this order for text-to-image:
+ 1. Small preview image (half resolution in each dimension)
+ 2. Large target image (full resolution)
+ 3. EOS token
+
+ Args:
+ height: Target image height in pixels
+ width: Target image width in pixels
+ factor: Downsampling factor (32 for GLM-Image AR output)
+
+ Returns:
+ Total number of tokens to generate (small + large + EOS)
+ """
+ # Large image tokens (target resolution)
+ token_h = height // factor
+ token_w = width // factor
+ large_tokens = token_h * token_w
+
+ # Small preview tokens (half resolution in each dimension)
+ small_h = token_h // 2
+ small_w = token_w // 2
+ small_tokens = small_h * small_w
+
+ # Total: small + large + EOS
+ return small_tokens + large_tokens + 1
+
+
+def load_image(image_path: str) -> Image.Image:
+ """Load an image from file path."""
+ if not os.path.exists(image_path):
+ raise FileNotFoundError(f"Image file not found: {image_path}")
+ return Image.open(image_path).convert("RGB")
+
+
+def save_image(image: Image.Image, output_path: str) -> None:
+ """Save an image to file path."""
+ output_dir = os.path.dirname(output_path)
+ if output_dir:
+ os.makedirs(output_dir, exist_ok=True)
+ image.save(output_path)
+ print(f"Image saved to: {output_path}")
+
+
+def build_prompt_for_t2i(
+ prompt: str,
+ height: int = 1024,
+ width: int = 1024,
+) -> dict:
+ """
+ Build prompt dict for text-to-image generation.
+
+ Args:
+ prompt: Text description for image generation
+ height: Target image height
+ width: Target image width
+
+ Returns:
+ Dict containing prompt and generation parameters
+ """
+ return {
+ "prompt": prompt,
+ "height": height,
+ "width": width,
+ # Pass target dimensions to AR processor for proper grid token generation
+ "mm_processor_kwargs": {
+ "target_h": height,
+ "target_w": width,
+ },
+ }
+
+
+def build_prompt_for_i2i(
+ prompt: str,
+ image: Image.Image,
+ height: int | None = None,
+ width: int | None = None,
+) -> dict:
+ """
+ Build prompt dict for image-to-image generation.
+
+ Args:
+ prompt: Text description for image editing
+ image: Source image for editing
+ height: Target image height (default: use source image size)
+ width: Target image width (default: use source image size)
+
+ Returns:
+ Dict containing prompt, image, and generation parameters
+ """
+ # Use source image dimensions if not specified
+ if height is None:
+ height = image.height
+ if width is None:
+ width = image.width
+
+ return {
+ "prompt": prompt,
+ "multi_modal_data": {
+ "image": image,
+ },
+ "height": height,
+ "width": width,
+ # Pass target dimensions to AR processor for proper grid token generation
+ "mm_processor_kwargs": {
+ "target_h": height,
+ "target_w": width,
+ },
+ }
+
+
+def main(args: argparse.Namespace) -> None:
+ """Main entry point for GLM-Image end-to-end inference."""
+ print("=" * 60)
+ print("GLM-Image Multistage End-to-End Inference")
+ print("=" * 60)
+
+ # Validate arguments
+ if not args.prompt:
+ raise ValueError("--prompt is required")
+
+ # Determine config path
+ config_path = args.config_path
+ if config_path is None:
+ # Try to find default config
+ if os.path.exists(DEFAULT_CONFIG_PATH):
+ config_path = DEFAULT_CONFIG_PATH
+ else:
+ # Try relative to script location
+ script_dir = Path(__file__).parent.parent.parent.parent
+ config_path = script_dir / "vllm_omni/model_executor/stage_configs/glm_image.yaml"
+ if not config_path.exists():
+ raise FileNotFoundError(
+ f"Stage config not found. Please specify --config-path. Tried: {DEFAULT_CONFIG_PATH}"
+ )
+ config_path = str(config_path)
+
+ print(f"Model path: {args.model_path}")
+ print(f"Config path: {config_path}")
+ print(f"Prompt: {args.prompt}")
+
+ # Load source image for image-to-image mode
+ source_image = None
+ if args.image:
+ print(f"Source image: {args.image}")
+ source_image = load_image(args.image)
+ print(f" Image size: {source_image.width}x{source_image.height}")
+
+ # Build prompt based on mode
+ if source_image is not None:
+ # Image-to-image mode
+ prompt_dict = build_prompt_for_i2i(
+ prompt=args.prompt,
+ image=source_image,
+ height=args.height,
+ width=args.width,
+ )
+ mode = "image-to-image"
+ else:
+ # Text-to-image mode
+ prompt_dict = build_prompt_for_t2i(
+ prompt=args.prompt,
+ height=args.height or 1024,
+ width=args.width or 1024,
+ )
+ mode = "text-to-image"
+
+ print(f"Mode: {mode}")
+ print(f"Target size: {prompt_dict.get('height', 1024)}x{prompt_dict.get('width', 1024)}")
+
+ # Add generation parameters to prompt
+ prompt_dict["seed"] = args.seed
+ prompt_dict["num_inference_steps"] = args.num_inference_steps
+ prompt_dict["guidance_scale"] = args.guidance_scale
+
+ if args.negative_prompt:
+ prompt_dict["negative_prompt"] = args.negative_prompt
+
+ # Build cache-dit config if requested
+ cache_config = None
+ if args.cache_backend == "cache_dit":
+ cache_config = {
+ "Fn_compute_blocks": 1,
+ "Bn_compute_blocks": 0,
+ "max_warmup_steps": 4,
+ "residual_diff_threshold": 0.24,
+ "max_continuous_cached_steps": 3,
+ "enable_taylorseer": False,
+ "taylorseer_order": 1,
+ "scm_steps_mask_policy": None,
+ "scm_steps_policy": "dynamic",
+ }
+
+ # Initialize Omni with multistage config
+ print("\nInitializing Omni with multistage pipeline...")
+ print(f" Cache backend: {args.cache_backend or 'None (no acceleration)'}")
+ start_time = time.time()
+
+ omni = Omni(
+ model=args.model_path,
+ stage_configs_path=config_path,
+ log_stats=args.enable_stats,
+ stage_init_timeout=args.stage_init_timeout,
+ cache_backend=args.cache_backend,
+ cache_config=cache_config,
+ enable_cache_dit_summary=getattr(args, "enable_cache_dit_summary", False),
+ enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
+ )
+
+ init_time = time.time() - start_time
+ print(f"Initialization completed in {init_time:.2f}s")
+
+ # Prepare prompts (support batch generation)
+ prompts = [prompt_dict for _ in range(args.num_prompts)]
+
+ # No explicit sampling_params for diffusion - parameters are in prompt_dict
+ # For multistage, the AR stage may need sampling params
+ from vllm import SamplingParams
+
+ # Compute max_tokens dynamically based on target image size
+ target_height = prompt_dict.get("height", 1024)
+ target_width = prompt_dict.get("width", 1024)
+ calculated_max_tokens = compute_max_tokens(target_height, target_width)
+
+ # Use calculated value unless user explicitly specified a different value
+ # Default args.max_tokens is 16384 (very large), so prefer calculated value
+ effective_max_tokens = calculated_max_tokens if args.max_tokens == 16384 else args.max_tokens
+
+ if args.verbose:
+ print(f"AR max_tokens: {effective_max_tokens} (calculated: {calculated_max_tokens}, arg: {args.max_tokens})")
+
+ # IMPORTANT: GLM-Image AR model requires these exact sampling parameters
+ # from generation_config.json for proper image token generation.
+ # - temperature=0.9, top_p=0.75, top_k=16512 (vision_vocab_size)
+ # - stop_token_ids=[16385] (eos_token_id) is CRITICAL to stop generation
+ ar_sampling_params = SamplingParams(
+ temperature=0.9, # From generation_config.json
+ top_p=0.75, # From generation_config.json
+ top_k=GLM_IMAGE_VISION_VOCAB_SIZE, # 16512, vision vocabulary size
+ max_tokens=effective_max_tokens,
+ stop_token_ids=[GLM_IMAGE_EOS_TOKEN_ID], # 16385, CRITICAL for stopping
+ seed=args.seed,
+ detokenize=False,
+ )
+
+ # For diffusion stage, sampling_params contains diffusion-specific parameters
+ # These are passed as kwargs to the diffusion engine
+ diffusion_sampling_params = OmniDiffusionSamplingParams(
+ num_inference_steps=args.num_inference_steps,
+ guidance_scale=args.guidance_scale,
+ height=prompt_dict.get("height", 1024),
+ width=prompt_dict.get("width", 1024),
+ seed=args.seed,
+ )
+
+ # For multistage, we need sampling_params for each stage
+ # Stage 0 (AR): SamplingParams for vLLM
+ # Stage 1 (Diffusion): dict with diffusion kwargs
+ sampling_params_list = [ar_sampling_params, diffusion_sampling_params]
+
+ # Run generation
+ print(f"\nGenerating {args.num_prompts} image(s)...")
+ gen_start_time = time.time()
+
+ output_dir = os.path.dirname(args.output) if args.output else "outputs"
+ if output_dir:
+ os.makedirs(output_dir, exist_ok=True)
+
+ output_count = 0
+ for stage_outputs in omni.generate(prompts, sampling_params_list, py_generator=True):
+ output = stage_outputs.request_output
+ if stage_outputs.final_output_type == "image":
+ request_id = output.request_id
+
+ # Get generated images
+ images = output.images if hasattr(output, "images") else []
+ if not images and hasattr(output, "multimodal_output"):
+ images = output.multimodal_output.get("images", [])
+
+ # Save each generated image
+ for idx, img in enumerate(images):
+ if args.num_prompts == 1 and len(images) == 1:
+ output_path = args.output
+ else:
+ base, ext = os.path.splitext(args.output)
+ output_path = f"{base}_{request_id}_{idx}{ext}"
+
+ if isinstance(img, Image.Image):
+ save_image(img, output_path)
+ else:
+ print(f"Warning: Unexpected image type for request {request_id}: {type(img)}")
+
+ output_count += 1
+
+ elif stage_outputs.final_output_type == "text":
+ # AR stage output (intermediate, for debugging)
+ if args.verbose:
+ print(f"AR output for request {output.request_id}:")
+ print(f" Token count: {len(output.outputs[0].token_ids)}")
+
+ gen_time = time.time() - gen_start_time
+ print(f"\nGeneration completed in {gen_time:.2f}s")
+ print(f"Generated {output_count} image(s)")
+
+ # Cleanup
+ omni.close()
+ print("\nDone!")
+
+
+def parse_args() -> argparse.Namespace:
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser(
+ description="GLM-Image Multistage End-to-End Inference",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog=__doc__,
+ )
+
+ # Required arguments
+ parser.add_argument(
+ "--model-path",
+ type=str,
+ default="zai-org/GLM-Image",
+ help="Path to GLM-Image model directory or HuggingFace model ID",
+ )
+ parser.add_argument(
+ "--prompt",
+ type=str,
+ required=True,
+ help="Text prompt for image generation",
+ )
+
+ # Optional arguments
+ parser.add_argument(
+ "--config-path",
+ type=str,
+ default=None,
+ help="Path to stage config YAML file (default: auto-detect)",
+ )
+ parser.add_argument(
+ "--image",
+ type=str,
+ default=None,
+ help="Path to source image for image-to-image mode",
+ )
+ parser.add_argument(
+ "--output",
+ type=str,
+ default="output_glm_image.png",
+ help="Output image path (default: output_glm_image.png)",
+ )
+ parser.add_argument(
+ "--negative-prompt",
+ type=str,
+ default=None,
+ help="Negative prompt for classifier-free guidance",
+ )
+
+ # Generation parameters
+ parser.add_argument(
+ "--height",
+ type=int,
+ default=None,
+ help="Output image height (default: 1024 for t2i, source size for i2i)",
+ )
+ parser.add_argument(
+ "--width",
+ type=int,
+ default=None,
+ help="Output image width (default: 1024 for t2i, source size for i2i)",
+ )
+ parser.add_argument(
+ "--num-inference-steps",
+ type=int,
+ default=50,
+ help="Number of diffusion denoising steps (default: 50)",
+ )
+ parser.add_argument(
+ "--guidance-scale",
+ type=float,
+ default=1.5,
+ help="Classifier-free guidance scale (default: 1.5)",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=SEED,
+ help=f"Random seed for reproducibility (default: {SEED})",
+ )
+ parser.add_argument(
+ "--max-tokens",
+ type=int,
+ default=16384,
+ help="Maximum tokens for AR generation (default: 16384)",
+ )
+
+ # Batch processing
+ parser.add_argument(
+ "--num-prompts",
+ type=int,
+ default=1,
+ help="Number of images to generate (default: 1)",
+ )
+
+ # Cache acceleration
+ parser.add_argument(
+ "--cache-backend",
+ type=str,
+ default=None,
+ choices=["cache_dit"],
+ help="Cache backend for DiT acceleration. Default: None (no cache).",
+ )
+ parser.add_argument(
+ "--enable-cache-dit-summary",
+ action="store_true",
+ help="Enable cache-dit summary logging after diffusion forward passes.",
+ )
+
+ # Runtime options
+ parser.add_argument(
+ "--enable-stats",
+ action="store_true",
+ default=False,
+ help="Enable statistics logging",
+ )
+ parser.add_argument(
+ "--stage-init-timeout",
+ type=int,
+ default=600,
+ help="Timeout for stage initialization in seconds (default: 600)",
+ )
+ parser.add_argument(
+ "--verbose",
+ "-v",
+ action="store_true",
+ default=False,
+ help="Enable verbose output",
+ )
+ parser.add_argument(
+ "--enable-diffusion-pipeline-profiler",
+ action="store_true",
+ help="Enable diffusion pipeline profiler to display stage durations.",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/offline_inference/glm_image/run_i2i.sh b/examples/offline_inference/glm_image/run_i2i.sh
new file mode 100755
index 00000000000..f81b157b0c8
--- /dev/null
+++ b/examples/offline_inference/glm_image/run_i2i.sh
@@ -0,0 +1,93 @@
+#!/bin/bash
+# SPDX-License-Identifier: Apache-2.0
+# Run GLM-Image image-to-image (editing) with multistage pipeline
+
+set -e
+
+# Default values
+MODEL_PATH="${MODEL_PATH:-/path/to/glm-image}"
+CONFIG_PATH="${CONFIG_PATH:-vllm_omni/model_executor/stage_configs/glm_image.yaml}"
+PROMPT="${PROMPT:-Transform this image into an oil painting style}"
+INPUT_IMAGE=""
+OUTPUT="${OUTPUT:-output_i2i.png}"
+NUM_STEPS="${NUM_STEPS:-50}"
+GUIDANCE="${GUIDANCE:-1.5}"
+SEED="${SEED:-42}"
+
+# Parse command line arguments
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ --model-path)
+ MODEL_PATH="$2"
+ shift 2
+ ;;
+ --config-path)
+ CONFIG_PATH="$2"
+ shift 2
+ ;;
+ --prompt)
+ PROMPT="$2"
+ shift 2
+ ;;
+ --image)
+ INPUT_IMAGE="$2"
+ shift 2
+ ;;
+ --output)
+ OUTPUT="$2"
+ shift 2
+ ;;
+ --num-steps)
+ NUM_STEPS="$2"
+ shift 2
+ ;;
+ --guidance)
+ GUIDANCE="$2"
+ shift 2
+ ;;
+ --seed)
+ SEED="$2"
+ shift 2
+ ;;
+ *)
+ echo "Unknown option: $1"
+ exit 1
+ ;;
+ esac
+done
+
+# Check if input image is provided
+if [ -z "${INPUT_IMAGE}" ]; then
+ echo "Error: --image is required for image-to-image mode"
+ echo "Usage: ./run_i2i.sh --image /path/to/input.png [--prompt \"edit instruction\"]"
+ exit 1
+fi
+
+if [ ! -f "${INPUT_IMAGE}" ]; then
+ echo "Error: Input image not found: ${INPUT_IMAGE}"
+ exit 1
+fi
+
+echo "=============================================="
+echo "GLM-Image Image-to-Image Generation"
+echo "=============================================="
+echo "Model: ${MODEL_PATH}"
+echo "Config: ${CONFIG_PATH}"
+echo "Input: ${INPUT_IMAGE}"
+echo "Prompt: ${PROMPT}"
+echo "Output: ${OUTPUT}"
+echo "Steps: ${NUM_STEPS}"
+echo "Guidance: ${GUIDANCE}"
+echo "Seed: ${SEED}"
+echo "=============================================="
+
+python end2end.py \
+ --model-path "${MODEL_PATH}" \
+ --config-path "${CONFIG_PATH}" \
+ --prompt "${PROMPT}" \
+ --image "${INPUT_IMAGE}" \
+ --output "${OUTPUT}" \
+ --num-inference-steps "${NUM_STEPS}" \
+ --guidance-scale "${GUIDANCE}" \
+ --seed "${SEED}" \
+ --verbose
diff --git a/examples/offline_inference/glm_image/run_t2i.sh b/examples/offline_inference/glm_image/run_t2i.sh
new file mode 100755
index 00000000000..5d249960b8f
--- /dev/null
+++ b/examples/offline_inference/glm_image/run_t2i.sh
@@ -0,0 +1,87 @@
+#!/bin/bash
+# SPDX-License-Identifier: Apache-2.0
+# Run GLM-Image text-to-image generation with multistage pipeline
+
+set -e
+
+# Default values
+MODEL_PATH="${MODEL_PATH:-/path/to/glm-image}"
+CONFIG_PATH="${CONFIG_PATH:-vllm_omni/model_executor/stage_configs/glm_image.yaml}"
+PROMPT="${PROMPT:-A beautiful mountain landscape with snow-capped peaks and a clear blue lake}"
+OUTPUT="${OUTPUT:-output_t2i.png}"
+HEIGHT="${HEIGHT:-1024}"
+WIDTH="${WIDTH:-1024}"
+NUM_STEPS="${NUM_STEPS:-50}"
+GUIDANCE="${GUIDANCE:-1.5}"
+SEED="${SEED:-42}"
+
+# Parse command line arguments
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ --model-path)
+ MODEL_PATH="$2"
+ shift 2
+ ;;
+ --config-path)
+ CONFIG_PATH="$2"
+ shift 2
+ ;;
+ --prompt)
+ PROMPT="$2"
+ shift 2
+ ;;
+ --output)
+ OUTPUT="$2"
+ shift 2
+ ;;
+ --height)
+ HEIGHT="$2"
+ shift 2
+ ;;
+ --width)
+ WIDTH="$2"
+ shift 2
+ ;;
+ --num-steps)
+ NUM_STEPS="$2"
+ shift 2
+ ;;
+ --guidance)
+ GUIDANCE="$2"
+ shift 2
+ ;;
+ --seed)
+ SEED="$2"
+ shift 2
+ ;;
+ *)
+ echo "Unknown option: $1"
+ exit 1
+ ;;
+ esac
+done
+
+echo "=============================================="
+echo "GLM-Image Text-to-Image Generation"
+echo "=============================================="
+echo "Model: ${MODEL_PATH}"
+echo "Config: ${CONFIG_PATH}"
+echo "Prompt: ${PROMPT}"
+echo "Output: ${OUTPUT}"
+echo "Size: ${WIDTH}x${HEIGHT}"
+echo "Steps: ${NUM_STEPS}"
+echo "Guidance: ${GUIDANCE}"
+echo "Seed: ${SEED}"
+echo "=============================================="
+
+python end2end.py \
+ --model-path "${MODEL_PATH}" \
+ --config-path "${CONFIG_PATH}" \
+ --prompt "${PROMPT}" \
+ --output "${OUTPUT}" \
+ --height "${HEIGHT}" \
+ --width "${WIDTH}" \
+ --num-inference-steps "${NUM_STEPS}" \
+ --guidance-scale "${GUIDANCE}" \
+ --seed "${SEED}" \
+ --verbose
diff --git a/examples/offline_inference/helios/end2end.py b/examples/offline_inference/helios/end2end.py
index 6cf7dfdcb36..88c3b865d42 100644
--- a/examples/offline_inference/helios/end2end.py
+++ b/examples/offline_inference/helios/end2end.py
@@ -196,9 +196,6 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--cfg-parallel-size", type=int, default=1, choices=[1, 2], help="CFG parallel size.")
parser.add_argument("--tensor-parallel-size", type=int, default=1, help="Tensor parallelism size.")
- from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
-
- nullify_stage_engine_defaults(parser)
return parser.parse_args()
diff --git a/examples/offline_inference/hunyuan_image3/README.md b/examples/offline_inference/hunyuan_image3/README.md
index 3cd8fa01b2e..da28a44d9e6 100644
--- a/examples/offline_inference/hunyuan_image3/README.md
+++ b/examples/offline_inference/hunyuan_image3/README.md
@@ -1,161 +1,25 @@
-# HunyuanImage-3.0-Instruct
+# HunyuanImage-3.0 Image-to-Text Inference
-## Set up
+This example demonstrates how to run HunyuanImage-3.0 Image-to-Text with the vLLM-Omni.
-Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup.
+## Local CLI Usage
-## Run examples
-
-**Note**: These examples work with the default configuration on **8x NVIDIA L40S (48GB)**. For different GPU setups, modify the stage configuration to adjust device allocation and memory utilization.
-
-Get into the hunyuan_image3 folder:
-
-```bash
-cd examples/offline_inference/hunyuan_image3
-```
-
-### Modality Control
-
-HunyuanImage-3.0-Instruct supports multiple modality modes. You can control the mode using the `--modality` argument:
-
-#### Text to Image (text2img)
-
-- **Pipeline**: Text → AR (CoT + latent tokens) → DiT (denoise) → VAE Decode → Image
-- **Stages Used**: Stage 0 (AR) + Stage 1 (DiT)
-- **KV Transfer**: AR sends KV cache to DiT for conditioned generation
-- **Default Config**: `hunyuan_image3_t2i.yaml`
-
-```bash
-python end2end.py --model tencent/HunyuanImage-3.0-Instruct \
- --modality text2img \
- --prompts "A cute cat sitting on a windowsill watching the sunset"
-```
-
-#### Image to Image (img2img)
-
-- **Pipeline**: Image + Text → AR (CoT + recaption + latent) → DiT → Edited Image
-- **Stages Used**: Stage 0 (AR) + Stage 1 (DiT)
-- **KV Transfer**: AR sends KV cache to DiT
-- **Default Config**: `hunyuan_image3_it2i.yaml`
-
-```bash
-python end2end.py --model tencent/HunyuanImage-3.0-Instruct \
- --modality img2img \
- --image-path /path/to/image.png \
- --prompts "Make the petals neon pink"
-```
-
-#### Image to Text (img2text)
-
-- **Pipeline**: Image + Question → AR → Text description
-- **Stages Used**: Stage 0 (AR) only
-- **Default Config**: `hunyuan_image3_i2t.yaml`
+Download the example image:
```bash
-python end2end.py --model tencent/HunyuanImage-3.0-Instruct \
- --modality img2text \
- --image-path /path/to/image.jpg \
- --prompts "Describe the content of the picture."
+wget https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg
```
-#### Text to Text (text2text)
-
-- **Pipeline**: Text → AR → Text
-- **Stages Used**: Stage 0 (AR) only
-- **Default Config**: `hunyuan_image3_t2t.yaml`
+Run example:
```bash
-python end2end.py --model tencent/HunyuanImage-3.0-Instruct \
- --modality text2text \
- --prompts "What is the capital of France?"
+python image_to_text.py \
+ --image cherry_blossom.jpg \
+ --prompt "<|startoftext|>You are an assistant that understands images and outputs text. Describe the content of the picture."
```
-### Inference Steps & Guidance
-
-Control generation quality for image modalities:
-
-```bash
-python end2end.py --modality text2img \
- --steps 50 \
- --guidance-scale 5.0 \
- --height 1024 --width 1024 \
- --prompts "A photo-realistic sunset over the ocean"
-```
-
-### Key Arguments
-
-#### 📌 Command Line Arguments (end2end.py)
-
-| Argument | Type | Default | Description |
-| :--------------------- | :----- | :----------------------------------- | :----------------------------------------------------------- |
-| `--model` | string | `tencent/HunyuanImage-3.0-Instruct` | Model path or name |
-| `--modality` | choice | `text2img` | Modality: `text2img`, `img2img`, `img2text`, `text2text` |
-| `--prompts` | list | `None` | Input text prompts |
-| `--image-path` | string | `None` | Input image path (for `img2img`/`img2text`) |
-| `--output` | string | `.` | Output directory for saved images |
-| `--steps` | int | `50` | Number of inference steps |
-| `--guidance-scale` | float | `5.0` | Classifier-free guidance scale |
-| `--seed` | int | `42` | Random seed |
-| `--height` | int | `1024` | Output image height |
-| `--width` | int | `1024` | Output image width |
-| `--bot-task` | string | auto | Override prompt task (e.g. `it2i_think`, `t2i_recaption`) |
-| `--sys-type` | string | auto | Override system prompt type (e.g. `en_unified`, `en_vanilla`) |
-| `--stage-configs-path` | string | auto | Custom stage config YAML path |
-| `--enforce-eager` | flag | `False` | Disable torch.compile |
-| `--init-timeout` | int | `300` | Initialization timeout (seconds) |
-
-------
-
-#### ⚙️ Stage Configurations
-
-| Config YAML | Modality | Stages | GPUs | Description |
-| :---------------------------------- | :-------- | :----- | :----- | :------------------------------------ |
-| `hunyuan_image3_t2i.yaml` | text2img | 2 | 8 | T2I with AR→DiT, 4 GPU each |
-| `hunyuan_image3_it2i.yaml` | img2img | 2 | 8 | IT2I with AR→DiT, 4 GPU each |
-| `hunyuan_image3_i2t.yaml` | img2text | 1 | 4 | I2T (AR only) |
-| `hunyuan_image3_t2t.yaml` | text2text | 1 | 4 | T2T (AR only) |
-| `hunyuan_image3_t2i_2gpu.yaml` | text2img | 2 | 2 | T2I for 2-GPU setups |
-| `hunyuan_image3_moe.yaml` | text2img | 2 | 8 | T2I with MoE AR→DiT KV reuse |
-| `hunyuan_image3_moe_dit_2gpu_fp8.yaml` | text2img | 2 | 2 | T2I with FP8 quantization |
-
-------
-
-## Using MoE Config
-
-The `hunyuan_image3_moe.yaml` config enables AR→DiT KV cache reuse with 8 GPUs (4 for AR + 4 for DiT).
-
-```bash
-python end2end.py --model tencent/HunyuanImage-3.0-Instruct \
- --modality text2img \
- --stage-configs-path hunyuan_image3_moe.yaml \
- --prompts "A cute cat"
-```
-
-------
-
-## Prompt Format
-
-HunyuanImage-3.0 uses a pretrain template format:
-
-```
-<|startoftext|>{system_prompt}{ }{trigger_tag}{user_prompt}
-```
-
-- ` `: Placeholder for each input image (auto-inserted by `prompt_utils.py`)
-- Trigger tags: `` (CoT), `` (recaptioning)
-- System prompt: Auto-selected based on task
-
-The `prompt_utils.build_prompt()` handles this formatting automatically.
-
-------
-
-## FAQ
-
-- **OOM errors**: Decrease `gpu_memory_utilization` in the YAML stage config, or use a smaller `max_num_batched_tokens`.
-- **Custom image sizes**: Use `--height` and `--width` flags (multiples of 16 recommended).
+Key arguments:
-| Stage | VRAM (approx) |
-| :---------------- | :------------------- |
-| Stage 0 (AR) | ~15 GiB + KV Cache |
-| Stage 1 (DiT) | ~30 GiB |
-| Total (8-GPU) | ~45 GiB + KV Cache |
+- `--model`: Model used. Default is: tencent/HunyuanImage-3.0-Instruct (Optional).
+- `--image`: Path to input image (required).
+- `--prompt`: Text description used to guide image understanding (required).
diff --git a/examples/offline_inference/hunyuan_image3/end2end.py b/examples/offline_inference/hunyuan_image3/end2end.py
deleted file mode 100644
index 2cea303888e..00000000000
--- a/examples/offline_inference/hunyuan_image3/end2end.py
+++ /dev/null
@@ -1,265 +0,0 @@
-"""
-HunyuanImage-3.0-Instruct unified end-to-end inference script.
-
-Supports all modalities through a single entry point:
- - text2img: Text → AR → DiT → Image
- - img2img: Text+Image → AR → DiT → Edited Image (IT2I)
- - img2text: Image+Text → AR → Text description (I2T)
- - text2text: Text → AR → Text (comprehension, no image)
-
-Usage:
- python end2end.py --modality text2img --prompts "A cute cat"
- python end2end.py --modality img2img --image-path input.png --prompts "Make it snowy"
- python end2end.py --modality img2text --image-path input.png --prompts "Describe this image"
-"""
-
-import argparse
-import os
-
-from vllm_omni.diffusion.models.hunyuan_image3.system_prompt import (
- get_system_prompt,
-)
-from vllm_omni.entrypoints.omni import Omni
-from vllm_omni.inputs.data import OmniPromptType
-
-# task → (sys_type, bot_task, trigger_tag)
-_TASK_PRESETS: dict[str, tuple[str, str | None, str | None]] = {
- "t2t": ("en_unified", None, None),
- "i2t": ("en_unified", None, None),
- "it2i_think": ("en_unified", "think", ""),
- "it2i_recaption": ("en_unified", "recaption", ""),
- "t2i_think": ("en_unified", "think", ""),
- "t2i_recaption": ("en_unified", "recaption", ""),
- "t2i_vanilla": ("en_vanilla", "image", None),
-}
-
-# Modality → prompt_utils task mapping
-_MODALITY_TASK_MAP = {
- "text2img": "t2i_think",
- "img2img": "it2i_think",
- "img2text": "i2t",
- "text2text": "t2t",
-}
-
-
-def build_prompt(
- user_prompt: str,
- task: str = "it2i_think",
- sys_type: str | None = None,
- custom_system_prompt: str | None = None,
-) -> str:
- """Build a HunyuanImage-3.0 prompt using pretrain template format."""
- if task not in _TASK_PRESETS:
- raise ValueError(f"Unknown task {task!r}. Choose from: {sorted(_TASK_PRESETS)}")
-
- preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task]
- effective_sys_type = sys_type or preset_sys_type
-
- system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt)
- sys_text = system_prompt.strip() if system_prompt else ""
-
- has_image_input = task.startswith("i2t") or task.startswith("it2i")
-
- parts = ["<|startoftext|>"]
- if sys_text:
- parts.append(sys_text)
- if has_image_input:
- parts.append(" ")
- if trigger_tag:
- parts.append(trigger_tag)
- parts.append(user_prompt)
-
- return "".join(parts)
-
-
-# Modality → default stage config
-_MODALITY_DEFAULT_CONFIG = {
- "text2img": "hunyuan_image3_t2i.yaml",
- "img2img": "hunyuan_image3_it2i.yaml",
- "img2text": "hunyuan_image3_i2t.yaml",
- "text2text": "hunyuan_image3_t2t.yaml",
-}
-
-
-def parse_args():
- parser = argparse.ArgumentParser(description="HunyuanImage-3.0-Instruct end-to-end inference.")
- parser.add_argument(
- "--model",
- default="tencent/HunyuanImage-3.0-Instruct",
- help="Model name or local path.",
- )
- parser.add_argument(
- "--modality",
- default="text2img",
- choices=["text2img", "img2img", "img2text", "text2text"],
- help="Modality mode to control stage execution.",
- )
- parser.add_argument("--prompts", nargs="+", default=None, help="Input text prompts.")
- parser.add_argument(
- "--image-path",
- type=str,
- default=None,
- help="Path to input image (for img2img/img2text).",
- )
- parser.add_argument(
- "--output",
- type=str,
- default=".",
- help="Output directory to save results.",
- )
-
- # Generation parameters
- parser.add_argument("--steps", type=int, default=50, help="Number of inference steps.")
- parser.add_argument("--guidance-scale", type=float, default=5.0, help="Classifier-free guidance scale.")
- parser.add_argument("--seed", type=int, default=42, help="Random seed.")
- parser.add_argument("--height", type=int, default=1024, help="Output image height.")
- parser.add_argument("--width", type=int, default=1024, help="Output image width.")
-
- # Prompt configuration
- parser.add_argument(
- "--bot-task",
- type=str,
- default=None,
- help="Override prompt task (e.g. it2i_think, t2i_recaption). Default: auto from modality.",
- )
- parser.add_argument(
- "--sys-type",
- type=str,
- default=None,
- help="Override system prompt type (e.g. en_unified, en_vanilla).",
- )
-
- # Omni init args
- parser.add_argument("--stage-configs-path", type=str, default=None, help="Custom stage config YAML path.")
- parser.add_argument("--log-stats", action="store_true", default=False)
- parser.add_argument("--init-timeout", type=int, default=300, help="Initialization timeout in seconds.")
- parser.add_argument("--enforce-eager", action="store_true", help="Disable torch.compile.")
-
- from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
-
- nullify_stage_engine_defaults(parser)
- return parser.parse_args()
-
-
-def main():
- args = parse_args()
- os.makedirs(args.output, exist_ok=True)
-
- # Determine task for prompt formatting
- task = args.bot_task or _MODALITY_TASK_MAP[args.modality]
-
- # Determine stage config
- stage_configs_path = args.stage_configs_path or _MODALITY_DEFAULT_CONFIG[args.modality]
-
- # Build Omni
- omni_kwargs = {
- "model": args.model,
- "stage_configs_path": stage_configs_path,
- "log_stats": args.log_stats,
- "init_timeout": args.init_timeout,
- "enforce_eager": args.enforce_eager,
- }
- if args.modality in ("text2img", "img2img"):
- omni_kwargs["mode"] = "text-to-image"
-
- omni = Omni(**omni_kwargs)
-
- # Prepare prompts
- prompts = args.prompts or ["A cute cat"]
- if not prompts:
- print("[Info] No prompts provided, using default.")
- prompts = ["A cute cat"]
-
- # Load image if needed
- input_image = None
- if args.modality in ("img2img", "img2text"):
- if not args.image_path or not os.path.exists(args.image_path):
- raise ValueError(f"--image-path required for {args.modality}, got: {args.image_path}")
- from PIL import Image
-
- input_image = Image.open(args.image_path).convert("RGB")
-
- # Format prompts
- formatted_prompts: list[OmniPromptType] = []
- for p in prompts:
- formatted_text = build_prompt(p, task=task, sys_type=args.sys_type)
-
- prompt_dict: dict = {"prompt": formatted_text}
-
- if args.modality == "text2img":
- prompt_dict["modalities"] = ["image"]
- elif args.modality == "img2img":
- prompt_dict["modalities"] = ["image"]
- prompt_dict["multi_modal_data"] = {"image": input_image}
- prompt_dict["height"] = input_image.height
- prompt_dict["width"] = input_image.width
- elif args.modality == "img2text":
- prompt_dict["modalities"] = ["text"]
- prompt_dict["multi_modal_data"] = {"image": input_image}
- elif args.modality == "text2text":
- prompt_dict["modalities"] = ["text"]
-
- formatted_prompts.append(prompt_dict)
-
- # Build sampling params from defaults
- params_list = list(omni.default_sampling_params_list)
-
- # Override diffusion params if applicable
- from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
- for i, sp in enumerate(params_list):
- if isinstance(sp, OmniDiffusionSamplingParams):
- sp.num_inference_steps = args.steps
- sp.guidance_scale = args.guidance_scale
- if args.seed is not None:
- sp.seed = args.seed
- if args.modality in ("text2img",):
- sp.height = args.height
- sp.width = args.width
-
- # Print configuration
- print(f"\n{'=' * 60}")
- print("HunyuanImage-3.0 Generation Configuration:")
- print(f" Model: {args.model}")
- print(f" Modality: {args.modality}")
- print(f" Stage config: {stage_configs_path}")
- print(f" Num stages: {omni.num_stages}")
- if args.modality in ("text2img", "img2img"):
- print(f" Inference steps: {args.steps}")
- print(f" Guidance scale: {args.guidance_scale}")
- print(f" Seed: {args.seed}")
- if args.modality == "text2img":
- print(f" Output size: {args.width}x{args.height}")
- if args.image_path:
- print(f" Input image: {args.image_path}")
- print(f" Prompts: {prompts}")
- print(f"{'=' * 60}\n")
-
- # Generate
- omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list))
-
- # Process outputs
- img_idx = 0
- for req_output in omni_outputs:
- # Text output (AR stage or text-only)
- ro = getattr(req_output, "request_output", None)
- if ro and getattr(ro, "outputs", None):
- txt = "".join(getattr(o, "text", "") or "" for o in ro.outputs)
- if txt:
- print(f"[Output] Text:\n{txt}")
-
- # Image output (DiT stage)
- images = getattr(req_output, "images", None)
- if not images and ro and hasattr(ro, "images"):
- images = ro.images
-
- if images:
- for j, img in enumerate(images):
- save_path = os.path.join(args.output, f"output_{img_idx}_{j}.png")
- img.save(save_path)
- print(f"[Output] Saved image to {save_path}")
- img_idx += 1
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/offline_inference/hunyuan_image3/image_to_text.py b/examples/offline_inference/hunyuan_image3/image_to_text.py
new file mode 100644
index 00000000000..d40134ac0a0
--- /dev/null
+++ b/examples/offline_inference/hunyuan_image3/image_to_text.py
@@ -0,0 +1,92 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import argparse
+import os
+
+from PIL import Image
+
+from vllm_omni.entrypoints.omni import Omni
+
+"""
+The tencent/HunyuanImage-3.0-Instruct base model uses the tencent/Hunyuan-A13B-Instruct backbone. It utilizes two tokenizer delimiter templates:
+
+1) Pretrained template (default for gen_text mode), which concatenates system, image
+ tokens, and user question WITHOUT role delimiters:
+"<|startoftext|>{system_prompt}{image_tokens}{user_question}"
+
+ Example (before image token expansion):
+"<|startoftext|>You are an assistant that understands images and outputs text. Describe the content of the picture."
+
+2) Instruct template, which uses explicit role prefixes and separators.
+"""
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="Generate text from image using HunyuanImage-3.0-Instruct.")
+ parser.add_argument(
+ "--model",
+ default="tencent/HunyuanImage-3.0-Instruct",
+ help="Model name or local path.",
+ )
+ parser.add_argument(
+ "--image",
+ type=str,
+ required=True,
+ help="Path to input image file (PNG, JPG, etc.).",
+ )
+ parser.add_argument(
+ "--prompt",
+ type=str,
+ required=True,
+ help="Pretrain template prompt: <|startoftext|>{system} {question}",
+ )
+ parser.add_argument(
+ "--enable-diffusion-pipeline-profiler",
+ action="store_true",
+ help="Enable diffusion pipeline profiler to display stage durations.",
+ )
+ return parser.parse_args()
+
+
+def load_image(image_path: str) -> Image.Image:
+ """Load an image from file path."""
+ if not os.path.exists(image_path):
+ raise FileNotFoundError(f"Image file not found: {image_path}")
+ return Image.open(image_path).convert("RGB")
+
+
+def main(args: argparse.Namespace) -> None:
+ omni = Omni(
+ model=args.model,
+ enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
+ mode="image-to-text",
+ )
+
+ prompt = "<|startoftext|>You are an assistant that understands images and outputs text. " + args.prompt
+
+ prompt_dict = {
+ "prompt": prompt,
+ "modalities": ["text"],
+ }
+
+ # Add image input if provided
+ if args.image:
+ if not os.path.exists(args.image):
+ raise FileNotFoundError(f"Input image not found: {args.image}")
+
+ input_image = load_image(args.image)
+ prompt_dict["multi_modal_data"] = {"image": input_image}
+
+ prompts = [prompt_dict]
+ omni_outputs = omni.generate(prompts=prompts)
+
+ prompt_text = omni_outputs[0].request_output.prompt
+ generated_text = omni_outputs[0].request_output.outputs[0].text
+ print(f"Prompt: {prompt_text}")
+ print(f"Text: {generated_text}")
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py
index b857bb22d1b..a8035a3fdcb 100644
--- a/examples/offline_inference/image_to_image/image_edit.py
+++ b/examples/offline_inference/image_to_image/image_edit.py
@@ -87,11 +87,9 @@
"""
import argparse
-import json
import os
import time
from pathlib import Path
-from typing import Any
import torch
from PIL import Image
@@ -103,16 +101,6 @@
from vllm_omni.platforms import current_omni_platform
-def parse_profiler_config(value: str) -> dict[str, Any]:
- try:
- config = json.loads(value)
- except json.JSONDecodeError as e:
- raise argparse.ArgumentTypeError(f"--profiler-config must be valid JSON: {e}") from e
- if not isinstance(config, dict):
- raise argparse.ArgumentTypeError("--profiler-config must be a JSON object")
- return config
-
-
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Edit an image with Qwen-Image-Edit.")
parser.add_argument(
@@ -208,13 +196,6 @@ def parse_args() -> argparse.Namespace:
default=1,
help="Number of GPUs used for ulysses sequence parallelism.",
)
- parser.add_argument(
- "--ulysses-mode",
- type=str,
- default="strict",
- choices=["strict", "advanced_uaa"],
- help="Ulysses sequence-parallel mode: 'strict' (divisibility required) or 'advanced_uaa' (UAA).",
- )
parser.add_argument(
"--ring-degree",
type=int,
@@ -316,8 +297,8 @@ def parse_args() -> argparse.Namespace:
"--cfg-parallel-size",
type=int,
default=1,
- choices=[1, 2, 3],
- help="Number of GPUs used for classifier free guidance parallel size (max 3 branches).",
+ choices=[1, 2],
+ help="Number of GPUs used for classifier free guidance parallel size.",
)
parser.add_argument(
"--enforce-eager",
@@ -344,43 +325,11 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable layerwise (blockwise) offloading on DiT modules.",
)
- parser.add_argument(
- "--vae-patch-parallel-size",
- type=int,
- default=1,
- help="Number of GPUs used for VAE patch/tile parallelism (decode).",
- )
- parser.add_argument(
- "--use-hsdp",
- action="store_true",
- help="Enable HSDP (Hybrid Sharded Data Parallel) for diffusion models.",
- )
- parser.add_argument(
- "--hsdp-shard-size",
- type=int,
- default=1,
- help="Number of GPUs to shard weights across for HSDP.",
- )
- parser.add_argument(
- "--hsdp-replicate-size",
- type=int,
- default=1,
- help="Number of HSDP replica groups.",
- )
parser.add_argument(
"--enable-diffusion-pipeline-profiler",
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)
- parser.add_argument(
- "--profiler-config",
- type=parse_profiler_config,
- default=None,
- help='JSON profiler config for torch/cuda profiling, e.g. \'{"profiler":"torch","torch_profiler_dir":"./perf"}\'.',
- )
- from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
-
- nullify_stage_engine_defaults(parser)
return parser.parse_args()
@@ -446,11 +395,11 @@ def main():
enforce_eager=args.enforce_eager,
enable_cpu_offload=args.enable_cpu_offload,
enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
- profiler_config=args.profiler_config,
)
print("Pipeline loaded")
- profiler_enabled = args.profiler_config is not None
+ # Check if profiling is requested via environment variable
+ profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR"))
# Time profiling for generation
print(f"\n{'=' * 60}")
diff --git a/examples/offline_inference/image_to_image/image_to_image.md b/examples/offline_inference/image_to_image/image_to_image.md
index 1c1a5ff3a79..2df248e034f 100644
--- a/examples/offline_inference/image_to_image/image_to_image.md
+++ b/examples/offline_inference/image_to_image/image_to_image.md
@@ -51,6 +51,5 @@ Key arguments:
- `--vae-use-tiling`: enable VAE tiling for memory optimization.
- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel).
- `--enable-cpu-offload`: enable CPU offloading for diffusion models.
-- `--strength`: **Z-Image only** - controls the denoising start timestep for I2I (default: 0.6). Range: [0.0, 1.0]. Lower values preserve more of the original image; higher values allow more creative changes.
> ℹ️ If you encounter OOM errors, try using `--vae-use-slicing` and `--vae-use-tiling` to reduce memory usage.
diff --git a/examples/offline_inference/image_to_video/README.md b/examples/offline_inference/image_to_video/README.md
index a458850a02b..2692c76df26 100644
--- a/examples/offline_inference/image_to_video/README.md
+++ b/examples/offline_inference/image_to_video/README.md
@@ -59,13 +59,12 @@ Key arguments:
- `--negative-prompt`: Optional list of artifacts to suppress.
- `--boundary-ratio`: Boundary split ratio for two-stage MoE models.
- `--flow-shift`: Scheduler flow shift (5.0 for 720p, 12.0 for 480p).
-- `--sample-solver`: Wan2.2 sampling solver. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints.
- `--num-inference-steps`: Number of denoising steps (default 50).
- `--fps`: Frames per second for the saved MP4 (requires `diffusers` export_to_video).
- `--output`: Path to save the generated video.
- `--vae-use-slicing`: Enable VAE slicing for memory optimization.
- `--vae-use-tiling`: Enable VAE tiling for memory optimization.
-- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](https://github.com/vllm-project/vllm-omni/tree/main/docs/user_guide/diffusion/parallelism/cfg_parallel.md).
+- `--cfg-parallel-size`: set it to 2 to enable CFG Parallel. See more examples in [`user_guide`](../../../docs/user_guide/diffusion/parallelism_acceleration.md#cfg-parallel).
- `--tensor-parallel-size`: tensor parallel size (effective for models that support TP, e.g. LTX2).
- `--enable-cpu-offload`: enable CPU offloading for diffusion models.
- `--use-hsdp`: Enable Hybrid Sharded Data Parallel to shard model weights across GPUs.
@@ -75,6 +74,3 @@ Key arguments:
> ℹ️ If you encounter OOM errors, try using `--vae-use-slicing` and `--vae-use-tiling` to reduce memory usage.
-
-For Wan2.2 LightX2V-converted local Diffusers directories and related LoRA
-assets, see the [LoRA guide](../../../docs/user_guide/diffusion/lora.md#wan22-lightx2v-offline-assembly).
diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py
index fea5178d89b..c8c55c485ad 100644
--- a/examples/offline_inference/image_to_video/image_to_video.py
+++ b/examples/offline_inference/image_to_video/image_to_video.py
@@ -33,10 +33,9 @@
"""
import argparse
-import json
+import os
import time
from pathlib import Path
-from typing import Any
import numpy as np
import PIL.Image
@@ -49,16 +48,6 @@
from vllm_omni.platforms import current_omni_platform
-def parse_profiler_config(value: str) -> dict[str, Any]:
- try:
- config = json.loads(value)
- except json.JSONDecodeError as e:
- raise argparse.ArgumentTypeError(f"--profiler-config must be valid JSON: {e}") from e
- if not isinstance(config, dict):
- raise argparse.ArgumentTypeError("--profiler-config must be a JSON object")
- return config
-
-
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Generate a video from an image (Wan2.2, LTX2, HunyuanVideo-1.5).")
parser.add_argument(
@@ -95,13 +84,6 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--flow-shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)."
)
- parser.add_argument(
- "--sample-solver",
- type=str,
- default="unipc",
- choices=["unipc", "euler"],
- help="Sampling solver for Wan2.2 pipelines. Use 'euler' for Lightning/Distill setups.",
- )
parser.add_argument("--output", type=str, default="i2v_output.mp4", help="Path to save the video (mp4).")
parser.add_argument("--fps", type=int, default=None, help="Frames per second for the output video.")
parser.add_argument(
@@ -164,7 +146,7 @@ def parse_args() -> argparse.Namespace:
"--audio-sample-rate",
type=int,
default=24000,
- help="Sample rate for audio output when saved (default: 24000).",
+ help="Sample rate for audio output when saved (default: 24000 for LTX2).",
)
parser.add_argument(
"--cache-backend",
@@ -205,15 +187,6 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)
- parser.add_argument(
- "--profiler-config",
- type=parse_profiler_config,
- default=None,
- help='JSON profiler config for torch/cuda profiling, e.g. \'{"profiler":"torch","torch_profiler_dir":"./perf"}\'.',
- )
- from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
-
- nullify_stage_engine_defaults(parser)
return parser.parse_args()
@@ -294,7 +267,8 @@ def main():
"rel_l1_thresh": 0.2,
}
- profiler_enabled = args.profiler_config is not None
+ # Check if profiling is requested via environment variable
+ profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR"))
parallel_config = DiffusionParallelConfig(
ulysses_degree=args.ulysses_degree,
ring_degree=args.ring_degree,
@@ -319,7 +293,6 @@ def main():
cache_backend=args.cache_backend,
cache_config=cache_config,
enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
- profiler_config=args.profiler_config,
)
if profiler_enabled:
@@ -332,7 +305,6 @@ def main():
print(f" Model: {args.model}")
print(f" Inference steps: {args.num_inference_steps}")
print(f" Frames: {args.num_frames}")
- print(f" Solver: {args.sample_solver}")
print(
f" Parallel configuration: cfg_parallel_size={args.cfg_parallel_size},"
f" tensor_parallel_size={args.tensor_parallel_size}, vae_patch_parallel_size={args.vae_patch_parallel_size}"
@@ -354,14 +326,9 @@ def main():
generator=generator,
guidance_scale=guidance_scale,
guidance_scale_2=args.guidance_scale_high,
- boundary_ratio=args.boundary_ratio,
num_inference_steps=num_inference_steps,
num_frames=num_frames,
frame_rate=frame_rate,
- extra_args={
- "sample_solver": args.sample_solver,
- "flow_shift": args.flow_shift,
- },
),
)
generation_end = time.perf_counter()
@@ -504,9 +471,15 @@ def _ensure_frame_list(video_array):
video_array = _ensure_frame_list(video_array)
- if audio is not None:
- from vllm_omni.diffusion.utils.media_utils import mux_video_audio_bytes
+ use_ltx2_export = is_ltx2
+ encode_video = None
+ if use_ltx2_export:
+ try:
+ from diffusers.pipelines.ltx2.export_utils import encode_video
+ except ImportError:
+ encode_video = None
+ if use_ltx2_export and encode_video is not None:
if isinstance(video_array, list):
frames_np = np.stack(video_array, axis=0)
elif isinstance(video_array, np.ndarray):
@@ -517,24 +490,25 @@ def _ensure_frame_list(video_array):
if frames_np.ndim == 4 and frames_np.shape[-1] == 4:
frames_np = frames_np[..., :3]
- frames_u8 = (np.clip(frames_np, 0.0, 1.0) * 255).round().clip(0, 255).astype("uint8")
-
- audio_np = audio
- if isinstance(audio_np, list):
- audio_np = audio_np[0] if audio_np else None
- if isinstance(audio_np, torch.Tensor):
- audio_np = audio_np.detach().cpu().float().numpy()
- if isinstance(audio_np, np.ndarray):
- audio_np = np.squeeze(audio_np).astype(np.float32)
-
- video_bytes = mux_video_audio_bytes(
- frames_u8,
- audio_np,
- fps=float(fps),
- audio_sample_rate=args.audio_sample_rate,
+ audio_out = None
+ if audio is not None:
+ if isinstance(audio, list):
+ audio = audio[0] if audio else None
+ if isinstance(audio, np.ndarray):
+ audio = torch.from_numpy(audio)
+ if isinstance(audio, torch.Tensor):
+ audio_out = audio
+ if audio_out.dim() > 1:
+ audio_out = audio_out[0]
+ audio_out = audio_out.float().cpu()
+
+ encode_video(
+ frames_np,
+ fps=fps,
+ audio=audio_out,
+ audio_sample_rate=args.audio_sample_rate if audio_out is not None else None,
+ output_path=str(output_path),
)
- with open(str(output_path), "wb") as f:
- f.write(video_bytes)
else:
export_to_video(video_array, str(output_path), fps=fps)
print(f"Saved generated video to {output_path}")
diff --git a/examples/offline_inference/magi_human/README.md b/examples/offline_inference/magi_human/README.md
deleted file mode 100644
index 2b89093d941..00000000000
--- a/examples/offline_inference/magi_human/README.md
+++ /dev/null
@@ -1,72 +0,0 @@
-# MagiHuman Generation
-
-MagiHuman is an advanced, omni-modality model that generates both high-quality video and lip-synced audio from a text prompt.
-
-Because MagiHuman is a very large model featuring a powerful DiT MoE backbone and a ~9B parameter T5Gemma text encoder, it natively supports **Tensor Parallelism (TP)** in vLLM-Omni to run efficiently across multi-GPU setups, reducing device memory bottlenecks.
-
-## Setup
-
-### Install MagiCompiler (recommended)
-
-MagiHuman relies on [MagiCompiler](https://github.com/SandAI-org/MagiCompiler) for custom-op registration used by the DiT attention kernels. While the pipeline can fall back to stub implementations, installing MagiCompiler is **strongly recommended** for correct behaviour.
-
-```bash
-# Clone the repo
-git clone https://github.com/SandAI-org/MagiCompiler.git
-cd MagiCompiler
-
-# System dependencies (optional, for FX graph visualization; Debian/Ubuntu)
-sudo apt update && sudo apt install -y graphviz
-
-# Python dependencies
-pip install -r requirements.txt
-
-# Install MagiCompiler
-pip install . # end users (recommended)
-# pip install -e . # developers (editable install)
-```
-
-### Hardware requirements
-
-Ensure your hardware has enough VRAM. For a standard node with 80GB GPUs, running with `--tensor-parallel-size 4` is recommended to shard both the MoE weights and the T5Gemma text encoder across 4 GPUs, reducing the per-GPU peak VRAM overhead significantly (by roughly ~13.5GB per GPU compared to single-device inference).
-
-Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) for further details on allocating memory.
-
-## Run Examples
-
-Get into the example folder:
-```bash
-cd examples/offline_inference/magi_human
-```
-
-### End-to-End Generation (Text to Video+Audio)
-
-Generate a video with synchronized speech natively generated by the model.
-
-```bash
-python end2end.py \
- --model /proj-tango-pvc/users/zhipeng.wang/workspace/models/daVinci-MagiHuman \
- --prompt "A young woman with long, wavy golden blonde hair..." \
- --tensor-parallel-size 4 \
- --output output_magihuman.mp4
-```
-
-## Common Parameters
-
-| Parameter | Default | Description |
-|-----------|---------|-------------|
-| `--model` | *(Required)* | Local model path or HuggingFace ID |
-| `--prompt` | *(built-in demo prompt)* | Highly detailed text prompt dictating visual look and dialogue text |
-| `--tensor-parallel-size` | `4` | Tensor parallelism size (Number of GPUs) |
-| `--height` | `256` | Initial resolution height |
-| `--width` | `448` | Initial resolution width |
-| `--num-inference-steps` | `8` | Denoising steps |
-| `--seed` | `52` | Random seed |
-| `--output` | `output_magihuman.mp4` | Output video with audio path |
-
-## Example materials
-
-??? abstract "end2end.py"
- ``````py
- --8<-- "examples/offline_inference/magi_human/end2end.py"
- ``````
diff --git a/examples/offline_inference/magi_human/end2end.py b/examples/offline_inference/magi_human/end2end.py
deleted file mode 100644
index 7ea8161385f..00000000000
--- a/examples/offline_inference/magi_human/end2end.py
+++ /dev/null
@@ -1,125 +0,0 @@
-import argparse
-
-from vllm_omni.diffusion.utils.media_utils import mux_video_audio_bytes
-from vllm_omni.entrypoints.omni import Omni
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-
-def parse_args():
- parser = argparse.ArgumentParser(description="End-to-end inference script for MagiHuman.")
- parser.add_argument("--model", type=str, required=True, help="Path or ID of the MagiHuman model.")
- parser.add_argument(
- "--prompt",
- type=str,
- default="",
- help="Text prompt containing visual description, dialogue, and background sound.",
- )
- parser.add_argument(
- "--tensor-parallel-size", "-tp", type=int, default=4, help="Tensor parallel size (number of GPUs)."
- )
- parser.add_argument(
- "--output", type=str, default="output_magihuman.mp4", help="Path to save the generated mp4 file."
- )
- parser.add_argument("--height", type=int, default=256, help="Video height.")
- parser.add_argument("--width", type=int, default=448, help="Video width.")
- parser.add_argument("--num-inference-steps", type=int, default=8, help="Number of denoising steps.")
- parser.add_argument("--seed", type=int, default=52, help="Random seed for generation.")
- from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
-
- nullify_stage_engine_defaults(parser)
- return parser.parse_args()
-
-
-def main():
- args = parse_args()
-
- print(f"Initializing MagiHuman pipeline with TP={args.tensor_parallel_size}...")
- omni = Omni(
- model=args.model,
- init_timeout=1200,
- tensor_parallel_size=args.tensor_parallel_size,
- devices=list(range(args.tensor_parallel_size)),
- )
-
- prompt = args.prompt
- if not prompt:
- prompt = (
- "A young woman with long, wavy golden blonde hair and bright blue eyes, "
- "wearing a fitted ivory silk blouse with a delicate lace collar, sits "
- "stationary in front of a softly lit, blurred warm-toned interior. Her "
- "overall disposition is warm, composed, and gently confident. The camera "
- "holds a static medium close-up, framing her from the shoulders up, "
- "with shallow depth of field keeping her face in sharp focus. Soft "
- "directional key light falls from the upper left, casting a gentle "
- "highlight along her cheekbone and nose bridge. She draws a quiet breath, "
- "the levator labii superiors relaxing as her lips part. She speaks in "
- "clear, warm, unhurried American English: "
- "\"The most beautiful things in life aren't things at all — "
- "they're moments, feelings, and the people who make you feel truly alive.\" "
- "Her jaw descends smoothly on each stressed syllable; the orbicularis oris "
- "shapes each vowel with precision. A faint, genuine smile engages the "
- "zygomaticus major, lifting her lip corners fractionally. Her brows rest "
- "in a soft, neutral arch throughout. She maintains steady, forward-facing "
- "eye contact. Head position remains level; no torso displacement occurs.\n\n"
- "Dialogue:\n"
- ": "
- "\"The most beautiful things in life aren't things at all — "
- "they're moments, feelings, and the people who make you feel truly alive.\"\n\n"
- "Background Sound:\n"
- ""
- )
-
- sampling_params = OmniDiffusionSamplingParams(
- height=args.height,
- width=args.width,
- num_inference_steps=args.num_inference_steps,
- seed=args.seed,
- extra_args={
- "seconds": 5,
- "sr_height": 1080,
- "sr_width": 1920,
- "sr_num_inference_steps": 5,
- },
- )
-
- print(f"Generating with prompt: {prompt[:80]}...")
- outputs = omni.generate(
- prompts=[prompt],
- sampling_params_list=[sampling_params],
- )
-
- print(f"Generation complete. Output type: {type(outputs)}")
- if outputs:
- first = outputs[0]
-
- if hasattr(first, "images") and first.images:
- video_frames = first.images[0]
- print(f"Video frames: shape={video_frames.shape}, dtype={video_frames.dtype}")
-
- audio_waveform = None
- mm = first.multimodal_output or {}
- if mm:
- audio_waveform = mm.get("audio")
- if audio_waveform is not None:
- print(f"Audio waveform: shape={audio_waveform.shape}, dtype={audio_waveform.dtype}")
-
- output_fps = float(mm.get("fps", 25))
- output_sr = int(mm.get("audio_sample_rate", 24000))
- print(f"Using fps={output_fps}, audio_sample_rate={output_sr} from model output")
-
- video_bytes = mux_video_audio_bytes(
- video_frames,
- audio_waveform,
- fps=output_fps,
- audio_sample_rate=output_sr,
- )
- with open(args.output, "wb") as f:
- f.write(video_bytes)
- print(f"Saved MP4 ({len(video_bytes)} bytes) to {args.output}")
- print("SUCCESS: MagiHuman pipeline generation completed.")
- else:
- print("WARNING: No outputs returned.")
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py
index a742b535f69..ca87b9e9a94 100644
--- a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py
+++ b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_image_summarize.py
@@ -19,7 +19,6 @@
from vllm.multimodal.image import convert_image_mode
from vllm_omni import Omni
-from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
DEFAULT_SYSTEM = "You are a helpful assistant."
DEFAULT_QUESTION = "Please summarize the content of this image."
@@ -49,7 +48,6 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)
- nullify_stage_engine_defaults(parser)
return parser.parse_args()
diff --git a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py
index bd1282a2117..a4c41fee1f8 100644
--- a/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py
+++ b/examples/offline_inference/mammothmodal2_preview/run_mammothmoda2_t2i.py
@@ -29,7 +29,6 @@
from vllm.sampling_params import SamplingParams
from vllm_omni import Omni
-from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
@@ -118,7 +117,6 @@ def parse_args() -> argparse.Namespace:
)
p.add_argument("--out", type=str, default="output.png", help="Path to save the generated image.")
p.add_argument("--trust-remote-code", action="store_true", help="Trust remote code when loading the model.")
- nullify_stage_engine_defaults(p)
args = p.parse_args()
if not args.prompt:
args.prompt = ["A stylish woman with sunglasses riding a motorcycle in NYC."]
diff --git a/examples/offline_inference/mimo_audio/README.md b/examples/offline_inference/mimo_audio/README.md
index 5615dea5176..747e734cc24 100644
--- a/examples/offline_inference/mimo_audio/README.md
+++ b/examples/offline_inference/mimo_audio/README.md
@@ -39,6 +39,7 @@ Run a single sample for basic TTS:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type tts_sft
```
@@ -47,6 +48,7 @@ Run batch samples for basic TTS:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type tts_sft \
--num-prompts {batch_size}
@@ -64,6 +66,7 @@ Generate speech from text input:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type tts_sft \
--text "The weather is so nice today."
@@ -75,6 +78,7 @@ Generate speech with explicit voice style instructions:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type tts_sft_with_instruct \
--text "The weather is so nice today." \
@@ -87,6 +91,7 @@ Generate speech using an audio reference for voice cloning:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type tts_sft_with_audio \
--text "The weather is so nice today." \
@@ -99,6 +104,7 @@ Generate speech from text containing natural voice descriptions:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type tts_sft_with_natural_instruction \
--text "In a panting young male voice, he said: I can't run anymore, wait for me!"
@@ -110,6 +116,7 @@ Transcribe audio to text:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type audio_trancribing_sft \
--audio-path "./spoken_dialogue_assistant_turn_1.wav"
@@ -121,6 +128,7 @@ Understand and analyze audio content with text queries:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type audio_understanding_sft \
--text "Summarize the audio." \
@@ -133,6 +141,7 @@ Audio understanding with reasoning chain:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type audio_understanding_sft_with_thinking \
--text "Summarize the audio." \
@@ -145,6 +154,7 @@ Multi-turn dialogue with audio input and output:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type spoken_dialogue_sft_multiturn \
--audio-path "./prompt_speech_zh_m.wav"
@@ -158,6 +168,7 @@ Multi-turn dialogue converting speech to text:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type speech2text_dialogue_sft_multiturn
```
@@ -170,6 +181,7 @@ Multi-turn text-only dialogue:
```bash
python3 -u end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
--model-name XiaomiMiMo/MiMo-Audio-7B-Instruct \
--query-type text_dialogue_sft_multiturn
```
@@ -178,6 +190,29 @@ Note: This task uses hardcoded message lists in the script.
## Troubleshooting
+### Audio dependencies (soundfile, librosa)
+
+This example depends on **soundfile** (read/write WAV) and **librosa** (load audio including MP3). Install the project requirements first:
+
+```bash
+pip install -r requirements/common.txt
+# or at least: pip install soundfile>=0.13.1 librosa>=0.11.0
+```
+
+- **`soundfile` / libsndfile not found**
+ `soundfile` uses the C library **libsndfile**. On Linux, install the system package before pip:
+ - Debian/Ubuntu: `sudo apt-get install libsndfile1`
+ - For development builds: `sudo apt-get install libsndfile1-dev`
+ - Then: `pip install soundfile`
+
+- **`librosa` fails to load MP3 or reports "No backend available"**
+ Loading MP3 (e.g. in `spoken_dialogue_sft_multiturn` with `.mp3` files) uses **ffmpeg** as the backend. Install ffmpeg:
+ - Debian/Ubuntu: `sudo apt-get install ffmpeg`
+ - macOS: `brew install ffmpeg`
+
+- **`ImportError: No module named 'soundfile'` or `ModuleNotFoundError: ... librosa`**
+ Ensure you are in the same Python environment where vLLM Omni and the example dependencies are installed, and that `requirements/common.txt` (or the packages above) are installed.
+
### Tokenizer path
- **`MIMO_AUDIO_TOKENIZER_PATH` not set or model fails to find tokenizer**
diff --git a/examples/offline_inference/mimo_audio/end2end.py b/examples/offline_inference/mimo_audio/end2end.py
index 9c652fe2b05..ae044d2e8a1 100644
--- a/examples/offline_inference/mimo_audio/end2end.py
+++ b/examples/offline_inference/mimo_audio/end2end.py
@@ -182,7 +182,7 @@ def main(args):
omni = Omni(
model=model_name,
- deploy_config=args.deploy_config,
+ stage_configs_path=args.stage_configs_path,
log_stats=args.enable_stats,
log_file=("omni_pipeline.log" if args.enable_stats else None),
init_sleep_seconds=args.init_sleep_seconds,
@@ -309,10 +309,7 @@ def main(args):
lines.append("Prompt:\n")
lines.append(str(prompt_text) + "\n")
lines.append("vllm_text_output:\n")
- output_text = str(text_output)
- if "" in output_text or "" in output_text:
- output_text = output_text.replace("", "").replace("", "").strip()
- lines.append(output_text + "\n")
+ lines.append(str(text_output).strip() + "\n")
try:
with open(out_txt, "w", encoding="utf-8") as f:
print("lines", lines)
@@ -354,7 +351,7 @@ def parse_args():
"--text",
"-t",
type=str,
- default="",
+ default="The weather is so nice today.",
help="input text",
)
parser.add_argument(
@@ -431,11 +428,10 @@ def parse_args():
help="Sampling rate for audio.",
)
parser.add_argument(
- "--deploy-config",
+ "--stage-configs-path",
type=str,
- default=None,
- help="Override the deploy config path. If unset, auto-loads "
- "vllm_omni/deploy/mimo_audio.yaml based on the HF model_type.",
+ default="../../../model_executor/stage_configs/mimo_audio.yaml",
+ help="Path to a stage configs file.",
)
return parser.parse_args()
diff --git a/examples/offline_inference/mimo_audio/message_convert.py b/examples/offline_inference/mimo_audio/message_convert.py
index 416f21ccfaf..ebcc59c6b43 100644
--- a/examples/offline_inference/mimo_audio/message_convert.py
+++ b/examples/offline_inference/mimo_audio/message_convert.py
@@ -5,12 +5,12 @@
import re
from collections.abc import Callable
+import librosa
import numpy as np
import torch
import torchaudio
from process_speechdata import InputSegment, StreamingInputSegment
from torchaudio.transforms import MelSpectrogram
-from vllm.multimodal.media.audio import load_audio
speech_zeroemb_idx = 151667
empty_token = "<|empty|>"
@@ -685,7 +685,7 @@ def get_audio_data(audio_url):
# File path
audio_file = audio_url
- audio_signal, sr = load_audio(audio_file, sr=24000)
+ audio_signal, sr = librosa.load(audio_file, sr=24000)
audio_data = (audio_signal.astype(np.float32), sr)
return audio_data
diff --git a/examples/offline_inference/ming_flash_omni/README.md b/examples/offline_inference/ming_flash_omni/README.md
deleted file mode 100644
index be90b408d14..00000000000
--- a/examples/offline_inference/ming_flash_omni/README.md
+++ /dev/null
@@ -1,92 +0,0 @@
-# Ming-flash-omni 2.0
-
-[Ming-flash-omni-2.0](https://github.com/inclusionAI/Ming) is an omni-modal model supporting text, image, video, and audio understanding, with text and speech outputs.
-
-vLLM-Omni supports two deployment modes:
-
-| Mode | Stage config | Output |
-|------|-------------|--------|
-| Thinker only (multimodal understanding) | `ming_flash_omni_thinker.yaml` (default `--omni`) | Text |
-| Thinker + Talker (omni-speech) | `ming_flash_omni.yaml` | Text + Audio |
-
-For standalone TTS (talker only), see [`examples/offline_inference/ming_flash_omni_tts/`](../ming_flash_omni_tts/).
-
-## Setup
-
-Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup.
-
-The default `--omni` flag runs thinker only. For omni-speech, pass the two-stage config explicitly:
-
-```bash
---stage-configs-path vllm_omni/model_executor/stage_configs/ming_flash_omni.yaml
-```
-
-## Run examples
-
-The end-to-end script defaults to built-in assets; pass `--image-path`,
-`--audio-path`, or `--video-path` to override.
-
-```bash
-# Text-only
-python examples/offline_inference/ming_flash_omni/end2end.py --query-type text
-
-# Image / audio / video / mixed understanding
-python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image
-python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio
-python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_video --num-frames 16
-python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_mixed_modalities \
- --image-path /path/to/image.jpg --audio-path /path/to/audio.wav
-```
-
-#### Reasoning (Thinking Mode)
-
-Reasoning ("detailed thinking on") is applied by the script when
-`--query-type reasoning` is set. The default prompt matches Ming's cookbook
-and expects the reference figure from the upstream repo — see
-`get_reasoning_query` in `end2end.py`.
-
-```bash
-python examples/offline_inference/ming_flash_omni/end2end.py -q reasoning --image-path ./3_0.png
-```
-
-### Omni-speech (thinker + talker)
-
-To enable spoken output, use the two-stage config and request `audio` (or `text,audio`) modalities.
-The thinker processes your multimodal input, generates text, then the talker synthesises the response as speech.
-
-**Audio-only output** (speech response, no text):
-```bash
-python examples/offline_inference/ming_flash_omni/end2end.py \
- --query-type text \
- --stage-configs-path vllm_omni/model_executor/stage_configs/ming_flash_omni.yaml \
- --modalities audio \
- --output-dir output_ming_omni_speech
-```
-
-**Both text and audio output**:
-```bash
-python examples/offline_inference/ming_flash_omni/end2end.py \
- --query-type use_audio \
- --stage-configs-path vllm_omni/model_executor/stage_configs/ming_flash_omni.yaml \
- --modalities text,audio \
- --output-dir output_ming_omni_speech
-```
-
-Generated `.wav` files are saved to `--output-dir` (default `output_ming`), one per request.
-
-The stage config allocates thinker on GPUs 0–3 and talker on GPU 3 by default. Adjust `devices` in the YAML to match your hardware.
-
-### Modality control
-
-| `--modalities` | Thinker output | Talker | Saved files |
-|---------------|----------------|--------|-------------|
-| `text` (default) | Text | Not run | `.txt` |
-| `audio` | Text (internal) | Runs | `.wav` |
-| `text,audio` | Text | Runs | `.txt` + `.wav` |
-
-Pass `--stage-configs-path /path/to/your_config.yaml` to any of the commands
-above to override the stage config.
-
-## Online serving
-
-For online serving via the OpenAI-compatible API, see [examples/online_serving/ming_flash_omni/README.md](../../online_serving/ming_flash_omni/README.md).
diff --git a/examples/offline_inference/ming_flash_omni/end2end.py b/examples/offline_inference/ming_flash_omni/end2end.py
deleted file mode 100644
index e00dcea7bb3..00000000000
--- a/examples/offline_inference/ming_flash_omni/end2end.py
+++ /dev/null
@@ -1,507 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-# Partial example cases are referred from
-# https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/cookbook.ipynb
-import os
-import time
-from typing import NamedTuple
-
-import numpy as np
-import soundfile as sf
-import vllm
-from PIL import Image
-from transformers import AutoProcessor
-from vllm import SamplingParams
-from vllm.assets.audio import AudioAsset
-from vllm.assets.image import ImageAsset
-from vllm.assets.video import VideoAsset, video_to_ndarrays
-from vllm.multimodal.image import convert_image_mode
-from vllm.multimodal.media.audio import load_audio
-from vllm.utils.argparse_utils import FlexibleArgumentParser
-
-import vllm_omni
-from vllm_omni.entrypoints.omni import Omni
-
-# Imports the processor also registers itself
-from vllm_omni.transformers_utils.processors.ming import MingFlashOmniProcessor # noqa: F401
-
-SEED = 42
-MODEL_NAME = "Jonathan1909/Ming-flash-omni-2.0"
-
-
-class QueryResult(NamedTuple):
- inputs: dict
- limit_mm_per_prompt: dict[str, int]
-
-
-def get_text_query(processor: MingFlashOmniProcessor, question: str | None = None) -> QueryResult:
- if question is None:
- question = "请详细介绍鹦鹉的生活习性。"
- conversation = [{"role": "HUMAN", "content": question}]
- prompt = processor.apply_chat_template(conversation, tokenize=False)
- return QueryResult(
- inputs={"prompt": prompt},
- limit_mm_per_prompt={},
- )
-
-
-def get_image_query(
- processor: MingFlashOmniProcessor,
- question: str | None = None,
- image_path: str | None = None,
-) -> QueryResult:
- if question is None:
- question = "Describe this image in detail."
-
- if image_path:
- if not os.path.exists(image_path):
- raise FileNotFoundError(f"Image file not found: {image_path}")
- image_data = convert_image_mode(Image.open(image_path), "RGB")
- else:
- image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
-
- conversation = [
- {
- "role": "HUMAN",
- "content": [
- {"type": "image", "image": image_data},
- {"type": "text", "text": question},
- ],
- }
- ]
- prompt = processor.apply_chat_template(conversation, tokenize=False)
-
- return QueryResult(
- inputs={
- "prompt": prompt,
- "multi_modal_data": {"image": image_data},
- },
- limit_mm_per_prompt={"image": 1},
- )
-
-
-def get_audio_query(
- processor: MingFlashOmniProcessor,
- question: str | None = None,
- audio_path: str | None = None,
- sampling_rate: int = 16000,
-) -> QueryResult:
- if question is None:
- question = "Please recognize the language of this speech and transcribe it. Format: oral."
-
- if audio_path:
- if not os.path.exists(audio_path):
- raise FileNotFoundError(f"Audio file not found: {audio_path}")
- audio_signal, sr = load_audio(audio_path, sr=sampling_rate)
- audio_data = (audio_signal.astype(np.float32), sr)
- else:
- audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
-
- # Use a string for "audio" so the processor counts it as 1 audio input
- conversation = [
- {
- "role": "HUMAN",
- "content": [
- {"type": "audio", "audio": "input"},
- {"type": "text", "text": question},
- ],
- }
- ]
- prompt = processor.apply_chat_template(conversation, tokenize=False)
-
- return QueryResult(
- inputs={
- "prompt": prompt,
- "multi_modal_data": {"audio": audio_data},
- },
- limit_mm_per_prompt={"audio": 1},
- )
-
-
-def get_video_query(
- processor: MingFlashOmniProcessor,
- question: str | None = None,
- video_path: str | None = None,
- num_frames: int = 16,
-) -> QueryResult:
- if question is None:
- question = "Describe what is happening in this video."
-
- if video_path:
- if not os.path.exists(video_path):
- raise FileNotFoundError(f"Video file not found: {video_path}")
- video_frames = video_to_ndarrays(video_path, num_frames=num_frames)
- else:
- video_frames = VideoAsset(name="baby_reading", num_frames=num_frames).np_ndarrays
-
- conversation = [
- {
- "role": "HUMAN",
- "content": [
- {"type": "video"},
- {"type": "text", "text": question},
- ],
- }
- ]
- prompt = processor.apply_chat_template(conversation, tokenize=False)
-
- return QueryResult(
- inputs={
- "prompt": prompt,
- "multi_modal_data": {"video": video_frames},
- },
- limit_mm_per_prompt={"video": 1},
- )
-
-
-def get_mixed_modalities_query(
- processor: MingFlashOmniProcessor,
- image_path: str | None = None,
- audio_path: str | None = None,
- sampling_rate: int = 16000,
-) -> QueryResult:
- """Mixed image + audio understanding."""
- question = "Describe the image, and recognize the language of this speech and transcribe it. Format: oral"
-
- if image_path:
- if not os.path.exists(image_path):
- raise FileNotFoundError(f"Image file not found: {image_path}")
- image_data = convert_image_mode(Image.open(image_path), "RGB")
- else:
- image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
-
- if audio_path:
- if not os.path.exists(audio_path):
- raise FileNotFoundError(f"Audio file not found: {audio_path}")
- sig, sr = load_audio(audio_path, sr=sampling_rate)
- audio_data = (sig.astype(np.float32), sr)
- else:
- audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
-
- conversation = [
- {
- "role": "HUMAN",
- "content": [
- {"type": "image", "image": image_data},
- {"type": "audio", "audio": "input"},
- {"type": "text", "text": question},
- ],
- }
- ]
- prompt = processor.apply_chat_template(conversation, tokenize=False)
-
- return QueryResult(
- inputs={
- "prompt": prompt,
- "multi_modal_data": {"image": image_data, "audio": audio_data},
- },
- limit_mm_per_prompt={"image": 1, "audio": 1},
- )
-
-
-def get_reasoning_query(
- processor: MingFlashOmniProcessor,
- question: str | None = None,
- image_path: str | None = None,
-) -> QueryResult:
- if question is None:
- # NOTE: To use the following default question, input with example figure provided by Ming
- # https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/figures/cases/3_0.png
- # E.g.,
- # python examples/offline_inference/ming_flash_omni/end2end.py -q reasoning --image-path ./3_0.png
- # Otherwise, the problem solving might be false.
- question = (
- "Based on the following rules:\n•\tYou control the smiley face character\n"
- "•\tYou can move up, down, left, and right, and only a single square at a time\n"
- "•\tWalls are dark grey and cannot be moved into\n•\tThe brown square is a box\n•"
- "\tThe box can be pushed by moving into it (i.e., if you are in the square "
- "adjacent to the box to the left, and move onto the square with the box, "
- "the box will move one square to the right).\n"
- "•\tThe box cannot be pushed into walls\n"
- "•\tThe blue door at the bottom is locked and cannot be passed through, "
- "unless the box is placed on the blue square\n"
- "•\tThe square beneath the blue door is the exit\n"
- "•\tMoving from one square to another\n\n"
- "Let's assume a coordinate system where the smiley face is "
- "on the top left at (1,1) and the square below it is (1,2). "
- "The smiley face performs the following moves: {down, right, right, right}, "
- "such that the smiley face is at square (4,2) and the box is in square (5,2). "
- "What are the next sequence of moves that must be done to move the box down to (5,3)? "
- "Give your answer as a comma separated list."
- )
-
- if image_path:
- if not os.path.exists(image_path):
- raise FileNotFoundError(f"Image file not found: {image_path}")
- image_data = convert_image_mode(Image.open(image_path), "RGB")
- conversation = [
- {
- "role": "HUMAN",
- "content": [
- {"type": "image", "image": image_data},
- {"type": "text", "text": question},
- ],
- }
- ]
- prompt = processor.apply_chat_template(conversation, tokenize=False, use_cot_system_prompt=True)
- return QueryResult(
- inputs={
- "prompt": prompt,
- "multi_modal_data": {"image": image_data},
- },
- limit_mm_per_prompt={"image": 1},
- )
-
- conversation = [{"role": "HUMAN", "content": question}]
- prompt = processor.apply_chat_template(conversation, tokenize=False, use_cot_system_prompt=True)
- return QueryResult(
- inputs={"prompt": prompt},
- limit_mm_per_prompt={},
- )
-
-
-query_map = {
- "text": get_text_query,
- "use_audio": get_audio_query,
- "use_image": get_image_query,
- "use_video": get_video_query,
- "use_mixed_modalities": get_mixed_modalities_query,
- "reasoning": get_reasoning_query,
-}
-
-
-def main(args):
- print(
- "=" * 20,
- "\n",
- f"vllm version: {vllm.__version__}\n",
- f"vllm-omni version: {vllm_omni.__version__}\n",
- "=" * 20,
- sep="",
- )
-
- processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
- assert isinstance(processor, MingFlashOmniProcessor), f"Wrong processor type being used: {type(processor)}"
-
- query_func = query_map[args.query_type]
- if args.query_type == "use_image":
- query_result = query_func(processor, image_path=args.image_path)
- elif args.query_type == "use_audio":
- query_result = query_func(processor, audio_path=args.audio_path, sampling_rate=args.sampling_rate)
- elif args.query_type == "use_video":
- query_result = query_func(processor, video_path=args.video_path, num_frames=args.num_frames)
- elif args.query_type == "use_mixed_modalities":
- query_result = query_func(
- processor,
- image_path=args.image_path,
- audio_path=args.audio_path,
- sampling_rate=args.sampling_rate,
- )
- elif args.query_type == "reasoning":
- query_result = query_func(processor, image_path=args.image_path)
- else:
- query_result = query_func(processor)
-
- # Initialize Omni (with thinker-only stage config)
- omni = Omni(
- model=MODEL_NAME,
- stage_configs_path=args.stage_configs_path,
- log_stats=args.log_stats,
- init_timeout=args.init_timeout,
- stage_init_timeout=args.stage_init_timeout,
- )
-
- # Thinker sampling params
- thinker_sampling_params = SamplingParams(
- temperature=0.4,
- top_p=0.9,
- max_tokens=args.max_tokens,
- repetition_penalty=1.05,
- seed=SEED,
- detokenize=True,
- )
- # Talker (ming_tts) uses a custom generation loop (CFM + AudioVAE);
- # vLLM sampling is a no-op here — max_tokens=1 just satisfies the scheduler.
- talker_sampling_params = SamplingParams(
- temperature=0.0,
- max_tokens=1,
- )
- all_sampling_params = [thinker_sampling_params, talker_sampling_params]
- # Match sampling params to the number of configured stages
- # (thinker-only yaml → 1, thinker+talker yaml → 2).
- sampling_params_list = all_sampling_params[: omni.num_stages]
-
- prompts = [query_result.inputs for _ in range(args.num_prompts)]
-
- if args.modalities is not None:
- output_modalities = args.modalities.split(",")
- for prompt in prompts:
- prompt["modalities"] = output_modalities
-
- total_requests = len(prompts)
- processed_count = 0
- print(f"Query type: {args.query_type}")
- print(f"Number of prompts: {total_requests}")
-
- output_dir = args.output_dir
- os.makedirs(output_dir, exist_ok=True)
-
- profiler_enabled = args.enable_profiler
- if profiler_enabled:
- omni.start_profile(stages=args.profiler_stages)
-
- for stage_outputs in omni.generate(prompts, sampling_params_list):
- output = stage_outputs.request_output
- if stage_outputs.final_output_type == "text":
- request_id = output.request_id
- text_output = output.outputs[0].text
- lines = []
- lines.append("Prompt:\n")
- lines.append(str(output.prompt) + "\n")
- lines.append("Text Output:\n")
- lines.append(str(text_output).strip() + "\n")
- print(*lines, sep="")
-
- # Save to file
- out_txt = os.path.join(output_dir, f"{request_id}.txt")
- try:
- with open(out_txt, "w", encoding="utf-8") as f:
- f.writelines(lines)
- print(f"Request ID: {request_id}, text saved to {out_txt}")
- except Exception as e:
- print(f"Failed to write output file {out_txt}: {e}")
-
- elif stage_outputs.final_output_type == "audio":
- request_id = output.request_id
- mm = output.outputs[0].multimodal_output
- if mm and "audio" in mm:
- audio = mm["audio"]
- sr_raw = mm.get("sr", 44100)
- sample_rate = int(sr_raw.item() if hasattr(sr_raw, "item") else sr_raw)
- audio_numpy = audio.float().squeeze().cpu().numpy()
- output_wav = os.path.join(output_dir, f"{request_id}.wav")
- sf.write(output_wav, audio_numpy, samplerate=sample_rate, format="WAV")
- print(
- f"Request ID: {request_id}, audio saved to {output_wav} "
- f"({len(audio_numpy) / sample_rate:.2f}s, {sample_rate}Hz)"
- )
-
- processed_count += 1
- if profiler_enabled and processed_count >= total_requests:
- print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...")
- # Stop the profiler while workers are still alive
- omni.stop_profile(stages=args.profiler_stages)
-
- print("[Info] Waiting 30s for workers to write trace files to disk...")
- time.sleep(30)
- print("[Info] Trace export wait time finished.")
-
- omni.close()
-
-
-def parse_args():
- parser = FlexibleArgumentParser(description="Ming-flash-omni 2.0 offline inference example")
- parser.add_argument(
- "--query-type",
- "-q",
- type=str,
- default="text",
- choices=query_map.keys(),
- help="Query type.",
- )
- parser.add_argument(
- "--stage-configs-path",
- type=str,
- default=None,
- help="Path to a stage configs YAML file.",
- )
- parser.add_argument(
- "--log-stats",
- action="store_true",
- default=False,
- help="Enable detailed statistics logging.",
- )
- parser.add_argument("--init-timeout", type=int, default=2000, help="Timeout for initializing in seconds.")
- parser.add_argument(
- "--stage-init-timeout",
- type=int,
- default=2000,
- help="Timeout for initializing a single stage in seconds.",
- )
- parser.add_argument(
- "--enable-profiler",
- action="store_true",
- default=False,
- help="Enables profiling when set.",
- )
- parser.add_argument(
- "--profiler-stages",
- type=int,
- nargs="*",
- default=[0],
- help="List of stage IDs to profile. If not set, profiles all stages.",
- )
- parser.add_argument(
- "--image-path",
- "-i",
- type=str,
- default=None,
- help="Path to local image file. Uses default asset if not provided.",
- )
- parser.add_argument(
- "--audio-path",
- "-a",
- type=str,
- default=None,
- help="Path to local audio file. Uses default asset if not provided.",
- )
- parser.add_argument(
- "--video-path",
- "-v",
- type=str,
- default=None,
- help="Path to local video file. Uses default asset if not provided.",
- )
- parser.add_argument(
- "--num-frames",
- type=int,
- default=16,
- help="Number of frames to extract from video.",
- )
- parser.add_argument(
- "--sampling-rate",
- type=int,
- default=16000,
- help="Sampling rate for audio loading.",
- )
- parser.add_argument(
- "--max-tokens",
- type=int,
- default=16384,
- help="Maximum tokens to generate.",
- )
- parser.add_argument(
- "--num-prompts",
- type=int,
- default=1,
- help="Number of prompts to generate.",
- )
- parser.add_argument(
- "--modalities",
- type=str,
- default=None,
- help="Output modalities (comma-separated).",
- )
- parser.add_argument(
- "--output-dir",
- type=str,
- default="output_ming",
- help="Output directory for results.",
- )
-
- return parser.parse_args()
-
-
-if __name__ == "__main__":
- args = parse_args()
- main(args)
diff --git a/examples/offline_inference/ming_flash_omni_tts/README.md b/examples/offline_inference/ming_flash_omni_tts/README.md
deleted file mode 100644
index 15b84041df2..00000000000
--- a/examples/offline_inference/ming_flash_omni_tts/README.md
+++ /dev/null
@@ -1,47 +0,0 @@
-# Ming-flash-omni Standalone TTS (Offline)
-
-This example runs **Ming-flash-omni-2.0 talker-only** offline inference with:
-
-- `model`: `Jonathan1909/Ming-flash-omni-2.0`
-- `stage config`: `vllm_omni/model_executor/stage_configs/ming_flash_omni_tts.yaml`
-
-It follows the Ming cookbook parameter style:
-
-- `prompt`: `"Please generate speech based on the following description.\n"`
-- `max_decode_steps`: `200`
-- `cfg`: `2.0`
-- `sigma`: `0.25`
-- `temperature`: `0.0`
-
-## Quick Start
-
-```bash
-python examples/offline_inference/ming_flash_omni_tts/end2end.py --case style
-```
-
-## Cases
-
-```bash
-# Style
-python examples/offline_inference/ming_flash_omni_tts/end2end.py --case style
-
-# IP
-python examples/offline_inference/ming_flash_omni_tts/end2end.py --case ip
-
-# Basic (speed/pitch/volume control)
-python examples/offline_inference/ming_flash_omni_tts/end2end.py --case basic
-```
-
-## Useful Arguments
-
-- `--text`: override default text in the selected case
-- `--output`: custom output wav path
-- `--model`: local model path or HF repo id
-- `--stage-configs-path`: custom talker stage config path
-- `--log-stats`: enable runtime stats logs
-
-## Notes
-
-- This directory is for **standalone talker deployment (TTS)**.
-- For Ming thinker multimodal understanding examples, see:
- `examples/offline_inference/ming_flash_omni/`.
diff --git a/examples/offline_inference/ming_flash_omni_tts/end2end.py b/examples/offline_inference/ming_flash_omni_tts/end2end.py
deleted file mode 100644
index 928994510a6..00000000000
--- a/examples/offline_inference/ming_flash_omni_tts/end2end.py
+++ /dev/null
@@ -1,128 +0,0 @@
-"""Offline e2e example for Ming-flash-omni-2.0 standalone talker (TTS)."""
-
-import os
-from typing import Any
-
-import soundfile as sf
-import torch
-from vllm.utils.argparse_utils import FlexibleArgumentParser
-
-os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-
-from vllm_omni.entrypoints.omni import Omni
-from vllm_omni.inputs.data import OmniTokensPrompt
-from vllm_omni.model_executor.models.ming_flash_omni.prompt_utils import (
- DEFAULT_PROMPT,
- create_instruction,
-)
-
-MODEL_NAME = "Jonathan1909/Ming-flash-omni-2.0"
-DEFAULT_STAGE_CONFIG = "vllm_omni/model_executor/stage_configs/ming_flash_omni_tts.yaml"
-
-
-def get_messages(case: str, text_override: str | None) -> dict[str, Any]:
- if case == "style":
- text = text_override or "我会一直在这里陪着你,直到你慢慢、慢慢地沉入那个最温柔的梦里……好吗?"
- instruction = create_instruction(
- {
- "风格": "这是一种ASMR耳语,属于一种旨在引发特殊感官体验的创意风格。这个女性使用轻柔的普通话进行耳语,声音气音成分重。音量极低,紧贴麦克风,语速极慢,旨在制造触发听者颅内快感的声学刺激。",
- }
- )
- return {
- "prompt": DEFAULT_PROMPT,
- "text": text,
- "instruction": instruction,
- "use_zero_spk_emb": True,
- }
- if case == "ip":
- text = text_override or "这款产品的名字,叫变态坑爹牛肉丸。"
- return {
- "prompt": DEFAULT_PROMPT,
- "text": text,
- "instruction": create_instruction({"IP": "灵小甄"}),
- "use_zero_spk_emb": True,
- }
- if case == "basic":
- text = text_override or "我们当迎着阳光辛勤耕作,去摘取,去制作,去品尝,去馈赠。"
- return {
- "prompt": DEFAULT_PROMPT,
- "text": text,
- "instruction": create_instruction({"语速": "快速", "基频": "中", "音量": "中"}),
- "use_zero_spk_emb": True,
- }
- raise ValueError(f"Unknown case: {case}")
-
-
-def save_audio(mm: dict[str, Any], output_path: str) -> None:
- if not mm or "audio" not in mm:
- raise RuntimeError("No audio found in model output")
- audio = mm["audio"]
- sr_raw = mm.get("sr", 44100)
- if isinstance(sr_raw, torch.Tensor):
- sample_rate = int(sr_raw.item())
- else:
- sample_rate = int(sr_raw)
- waveform = audio.squeeze().float().cpu().numpy()
- sf.write(output_path, waveform, sample_rate)
- print(f"Saved {output_path} ({len(waveform) / sample_rate:.2f}s, {sample_rate}Hz)")
-
-
-def parse_args():
- parser = FlexibleArgumentParser(description="Ming-flash-omni standalone talker offline e2e example")
- parser.add_argument("--model", type=str, default=MODEL_NAME, help="Model name or local path.")
- parser.add_argument(
- "--stage-configs-path",
- type=str,
- default=DEFAULT_STAGE_CONFIG,
- help="Path to stage configs yaml for standalone talker deployment.",
- )
- parser.add_argument(
- "--case",
- type=str,
- default="style",
- choices=["style", "ip", "basic"],
- help="Example case.",
- )
- parser.add_argument("--text", type=str, default=None, help="Override default text for the selected case.")
- parser.add_argument("--output", type=str, default=None, help="Output wav path.")
- parser.add_argument("--log-stats", action="store_true", default=False, help="Enable stats logging.")
- parser.add_argument("--init-timeout", type=int, default=600, help="Engine init timeout in seconds.")
- parser.add_argument("--stage-init-timeout", type=int, default=300, help="Single stage init timeout in seconds.")
- return parser.parse_args()
-
-
-def main():
- args = parse_args()
-
- omni = Omni(
- model=args.model,
- stage_configs_path=args.stage_configs_path,
- log_stats=args.log_stats,
- init_timeout=args.init_timeout,
- stage_init_timeout=args.stage_init_timeout,
- )
-
- messages = get_messages(args.case, args.text)
- decode_args = {
- # Standalone TTS deployment
- "ming_task": "instruct",
- "max_decode_steps": 200,
- "cfg": 2.0,
- "sigma": 0.25,
- "temperature": 0.0,
- }
- req = OmniTokensPrompt(
- prompt_token_ids=[0],
- additional_information={**messages, **decode_args},
- )
-
- outputs = omni.generate(req)
- mm = outputs[0].outputs[0].multimodal_output
-
- output_path = args.output or f"output_{args.case}.wav"
- save_audio(mm, output_path)
- omni.close()
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/offline_inference/omnivoice/end2end.py b/examples/offline_inference/omnivoice/end2end.py
index cc6f585c50e..b41379b011a 100644
--- a/examples/offline_inference/omnivoice/end2end.py
+++ b/examples/offline_inference/omnivoice/end2end.py
@@ -89,6 +89,7 @@ def run_e2e():
omni = Omni(
model=args.model,
stage_configs_path=args.stage_config,
+ trust_remote_code=True,
log_stats=True,
)
@@ -102,9 +103,9 @@ def run_e2e():
if not os.path.exists(args.ref_audio):
raise FileNotFoundError(f"Reference audio not found: {args.ref_audio}")
- from vllm.multimodal.media.audio import load_audio
+ import librosa
- audio_signal, sr = load_audio(args.ref_audio, sr=None)
+ audio_signal, sr = librosa.load(args.ref_audio, sr=None)
multi_modal_data["audio"] = (audio_signal.astype(np.float32), sr)
mm_processor_kwargs["ref_text"] = args.ref_text or ""
mm_processor_kwargs["sample_rate"] = sr
diff --git a/examples/offline_inference/qwen2_5_omni/README.md b/examples/offline_inference/qwen2_5_omni/README.md
index e2eae8a96b5..20740a0da02 100644
--- a/examples/offline_inference/qwen2_5_omni/README.md
+++ b/examples/offline_inference/qwen2_5_omni/README.md
@@ -60,3 +60,11 @@ If media file paths are not provided, the script will use default assets. Suppor
- `mixed_modalities`: Audio + image + video
- `use_audio_in_video`: Extract audio from video
- `text`: Text-only query
+
+### FAQ
+
+If you encounter error about backend of librosa, try to install ffmpeg with command below.
+```
+sudo apt update
+sudo apt install ffmpeg
+```
diff --git a/examples/offline_inference/qwen2_5_omni/end2end.py b/examples/offline_inference/qwen2_5_omni/end2end.py
index a65c554a9b0..7bba5998308 100644
--- a/examples/offline_inference/qwen2_5_omni/end2end.py
+++ b/examples/offline_inference/qwen2_5_omni/end2end.py
@@ -5,11 +5,11 @@
with the correct prompt format on Qwen2.5-Omni
"""
-import json
import os
import time
from typing import NamedTuple
+import librosa
import numpy as np
import soundfile as sf
from PIL import Image
@@ -17,7 +17,6 @@
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset, video_to_ndarrays
from vllm.multimodal.image import convert_image_mode
-from vllm.multimodal.media.audio import load_audio
from vllm.sampling_params import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
@@ -97,7 +96,7 @@ def get_mixed_modalities_query(
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
- audio_signal, sr = load_audio(audio_path, sr=sampling_rate)
+ audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
else:
audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
@@ -131,7 +130,7 @@ def get_use_audio_in_video_query(
raise FileNotFoundError(f"Video file not found: {video_path}")
video_frames = video_to_ndarrays(video_path, num_frames=num_frames)
# Extract audio from video file
- audio_signal, sr = load_audio(video_path, sr=sampling_rate)
+ audio_signal, sr = librosa.load(video_path, sr=sampling_rate)
audio = (audio_signal.astype(np.float32), sr)
else:
asset = VideoAsset(name="baby_reading", num_frames=num_frames)
@@ -166,7 +165,7 @@ def get_multi_audios_query(audio_path: str | None = None, sampling_rate: int = 1
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
- audio_signal, sr = load_audio(audio_path, sr=sampling_rate)
+ audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
# Use the provided audio as the first audio, default as second
audio_list = [
@@ -262,7 +261,7 @@ def get_audio_query(question: str = None, audio_path: str | None = None, samplin
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
- audio_signal, sr = load_audio(audio_path, sr=sampling_rate)
+ audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
else:
audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
@@ -290,10 +289,7 @@ def get_audio_query(question: str = None, audio_path: str | None = None, samplin
def main(args):
- model_name = args.model
- quantization_config = None
- if args.quantization_config is not None:
- quantization_config = json.loads(args.quantization_config)
+ model_name = "Qwen/Qwen2.5-Omni-7B"
# Get paths from args
video_path = getattr(args, "video_path", None)
@@ -324,8 +320,14 @@ def main(args):
query_result = query_func(audio_path=audio_path, sampling_rate=sampling_rate)
else:
query_result = query_func()
- args.quantization_config = quantization_config
- omni = Omni.from_cli_args(args, model=model_name)
+ omni = Omni(
+ model=model_name,
+ log_stats=args.log_stats,
+ stage_init_timeout=args.stage_init_timeout,
+ batch_timeout=args.batch_timeout,
+ init_timeout=args.init_timeout,
+ shm_threshold_bytes=args.shm_threshold_bytes,
+ )
thinker_sampling_params = SamplingParams(
temperature=0.0, # Deterministic - no randomness
top_p=1.0, # Disable nucleus sampling
@@ -429,18 +431,6 @@ def main(args):
def parse_args():
parser = FlexibleArgumentParser(description="Demo on using vLLM for offline inference with audio language models")
- parser.add_argument(
- "--model",
- type=str,
- default="Qwen/Qwen2.5-Omni-7B",
- help="Model name or local path.",
- )
- parser.add_argument(
- "--quantization-config",
- type=str,
- default=None,
- help="Optional JSON string forwarded to Omni(quantization_config=...).",
- )
parser.add_argument(
"--query-type",
"-q",
diff --git a/examples/offline_inference/qwen3_omni/README.md b/examples/offline_inference/qwen3_omni/README.md
index 0710faa133c..b3e8592532e 100644
--- a/examples/offline_inference/qwen3_omni/README.md
+++ b/examples/offline_inference/qwen3_omni/README.md
@@ -70,8 +70,8 @@ For true stage-level concurrency -- where downstream stages (Talker, Code2Wav)
start **before** the upstream stage (Thinker) finishes -- use the async_chunk
example. This requires:
-1. A deploy config YAML with ``async_chunk: true`` (e.g.
- ``qwen3_omni_moe.yaml``).
+1. A stage config YAML with ``async_chunk: true`` (e.g.
+ ``qwen3_omni_moe_async_chunk.yaml``).
2. Hardware that matches the config (e.g. 2x H100 for the default 3-stage
config).
@@ -101,10 +101,18 @@ python end2end_async_chunk.py --query-type text --modalities text
```bash
python end2end_async_chunk.py \
--query-type use_audio \
- --deploy-config /path/to/your_deploy_config.yaml
+ --stage-configs-path /path/to/your_async_chunk.yaml
```
> **Note**: The synchronous ``end2end.py`` (using ``Omni``) is still the
> recommended entry point for non-async-chunk workflows. Only use the
> async_chunk example when you need the stage-level concurrency semantics
> described in PR #962 / #1151.
+
+### FAQ
+
+If you encounter error about backend of librosa, try to install ffmpeg with command below.
+```
+sudo apt update
+sudo apt install ffmpeg
+```
diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py
index 04aa7914db1..155eca4ed9f 100644
--- a/examples/offline_inference/qwen3_omni/end2end.py
+++ b/examples/offline_inference/qwen3_omni/end2end.py
@@ -9,6 +9,7 @@
import time
from typing import NamedTuple
+import librosa
import numpy as np
import soundfile as sf
import vllm
@@ -18,10 +19,8 @@
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset, video_to_ndarrays
from vllm.multimodal.image import convert_image_mode
-from vllm.multimodal.media.audio import load_audio
from vllm.utils.argparse_utils import FlexibleArgumentParser
-from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
from vllm_omni.entrypoints.omni import Omni
SEED = 42
@@ -130,7 +129,7 @@ def get_audio_query(question: str = None, audio_path: str | None = None, samplin
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
- audio_signal, sr = load_audio(audio_path, sr=sampling_rate)
+ audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
else:
audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
@@ -184,7 +183,7 @@ def get_mixed_modalities_query(
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
- audio_signal, sr = load_audio(audio_path, sr=sampling_rate)
+ audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
else:
audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
@@ -295,7 +294,14 @@ def main(args):
else:
query_result = query_func()
- omni = Omni.from_cli_args(args, model=model_name)
+ omni = Omni(
+ model=model_name,
+ dtype=args.dtype,
+ stage_configs_path=args.stage_configs_path,
+ log_stats=args.log_stats,
+ stage_init_timeout=args.stage_init_timeout,
+ init_timeout=args.init_timeout,
+ )
thinker_sampling_params = SamplingParams(
temperature=0.9,
@@ -551,7 +557,6 @@ def parse_args():
help="Model dtype (auto, half, float16, bfloat16, float, float32).",
)
- nullify_stage_engine_defaults(parser)
return parser.parse_args()
diff --git a/examples/offline_inference/qwen3_omni/end2end_async_chunk.py b/examples/offline_inference/qwen3_omni/end2end_async_chunk.py
index 85c2da20b04..8adbae9eb66 100644
--- a/examples/offline_inference/qwen3_omni/end2end_async_chunk.py
+++ b/examples/offline_inference/qwen3_omni/end2end_async_chunk.py
@@ -14,7 +14,7 @@
Usage
-----
python end2end_async_chunk.py --query-type use_audio \
- --deploy-config
+ --stage-configs-path
See ``--help`` for all options.
"""
@@ -32,13 +32,13 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+import librosa
from PIL import Image
from vllm import SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset, video_to_ndarrays
from vllm.multimodal.image import convert_image_mode
-from vllm.multimodal.media.audio import load_audio
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm_omni.entrypoints.async_omni import AsyncOmni
@@ -89,7 +89,7 @@ def get_audio_query(
if audio_path:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
- audio_signal, sr = load_audio(audio_path, sr=sampling_rate)
+ audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
audio_data = (audio_signal.astype(np.float32), sr)
else:
audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
@@ -179,26 +179,20 @@ def clone_prompt_for_request(template: dict) -> dict:
return cloned
-def _default_deploy_config_path() -> str | None:
- """Best-effort default deploy config for running Qwen3-Omni with async_chunk.
+def _default_async_chunk_stage_configs_path() -> str | None:
+ """Best-effort default stage config for running Qwen3-Omni with async_chunk.
- The default ``vllm_omni/deploy/qwen3_omni_moe.yaml`` ships with
- ``async_chunk: true`` at the top level, so loading it is enough to
- enable async-chunk semantics. To disable it, copy the YAML and set
- ``async_chunk: false`` (or pass ``--deploy-config`` to a YAML that
- overrides the flag).
-
- When this example is executed from within the repository, we resolve
- the default YAML path relative to this file. When installed elsewhere,
- the file may not exist and callers should pass ``--deploy-config``
- explicitly.
+ When this example is executed from within the repository, we resolve the
+ default YAML path relative to this file. When installed elsewhere, the
+ file may not exist and callers should pass --stage-configs-path explicitly.
"""
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
candidate = os.path.join(
repo_root,
"vllm_omni",
- "deploy",
- "qwen3_omni_moe.yaml",
+ "model_executor",
+ "stage_configs",
+ "qwen3_omni_moe_async_chunk.yaml",
)
return candidate if os.path.exists(candidate) else None
@@ -242,7 +236,8 @@ async def run_single_request(
if stage_0_first_output_ts is None:
stage_0_first_output_ts = time.perf_counter()
text_output = output.outputs[0].text
- text_parts.append(text_output)
+ if output.finished:
+ text_parts.append(text_output)
elif omni_output.final_output_type == "audio":
mm_out = output.outputs[0].multimodal_output
if mm_out and "audio" in mm_out:
@@ -292,7 +287,7 @@ async def run_single_request(
if text_parts:
text_file = os.path.join(output_dir, f"{request_id}.txt")
with open(text_file, "w", encoding="utf-8") as f:
- f.write("".join(text_parts))
+ f.write("\n".join(text_parts))
result["saved_files"].append(text_file)
print(
f"[Request {request_id}] Text saved to {text_file} "
@@ -379,23 +374,18 @@ async def run_all(args):
prompt["modalities"] = output_modalities
# Create AsyncOmni
- print(f"[Info] Creating AsyncOmni with deploy_config={args.deploy_config}")
+ print(f"[Info] Creating AsyncOmni with stage_configs_path={args.stage_configs_path}")
async_omni = None
try:
- # ``from_cli_args`` expands vars(args) into kwargs and auto-captures
- # ``_cli_explicit_keys`` from ``sys.argv[1:]`` so argparse defaults
- # do not silently override deploy YAML values. Mirrors the
- # ``EngineArgs.from_cli_args`` pattern used throughout vllm /
- # vllm-omni. ``deploy_config=None`` (the default) falls through to
- # the bundled ``vllm_omni/deploy/qwen3_omni_moe.yaml``.
- async_omni = AsyncOmni.from_cli_args(args)
+ async_omni = AsyncOmni(
+ model=args.model,
+ stage_configs_path=args.stage_configs_path,
+ log_stats=args.log_stats,
+ stage_init_timeout=args.stage_init_timeout,
+ )
# Use default sampling params from stage config (they are pre-configured
# in the YAML for each stage).
- #
- # NOTE: Since we do not set the sampling params directly, .generate in
- # will automatically set the output kind to delta, since this is what
- # makes sense for most multimodal use-cases.
sampling_params_list = None
output_dir = args.output_dir
@@ -480,11 +470,11 @@ def parse_args():
help="Query type.",
)
parser.add_argument(
- "--deploy-config",
+ "--stage-configs-path",
type=str,
- default=_default_deploy_config_path(),
+ default=_default_async_chunk_stage_configs_path(),
help=(
- "Path to a deploy config YAML. "
+ "Path to an async_chunk stage config YAML. "
"If not set, uses the model's default config "
"(make sure it has async_chunk: true)."
),
diff --git a/examples/offline_inference/qwen3_omni/run_multiple_prompts_async_chunk.sh b/examples/offline_inference/qwen3_omni/run_multiple_prompts_async_chunk.sh
index 2f2be20915a..809054867c3 100755
--- a/examples/offline_inference/qwen3_omni/run_multiple_prompts_async_chunk.sh
+++ b/examples/offline_inference/qwen3_omni/run_multiple_prompts_async_chunk.sh
@@ -17,7 +17,7 @@ REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
python "${SCRIPT_DIR}/end2end_async_chunk.py" \
--query-type text \
--txt-prompts "${SCRIPT_DIR}/text_prompts_10.txt" \
- --deploy-config "${REPO_ROOT}/vllm_omni/deploy/qwen3_omni_moe.yaml" \
+ --stage-configs-path "${REPO_ROOT}/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml" \
--output-dir output_audio_async_chunk \
--max-in-flight 2 \
"$@"
diff --git a/examples/offline_inference/qwen3_omni/run_single_prompt_async_chunk.sh b/examples/offline_inference/qwen3_omni/run_single_prompt_async_chunk.sh
index 9ef69293cb5..918c7ee4fd9 100755
--- a/examples/offline_inference/qwen3_omni/run_single_prompt_async_chunk.sh
+++ b/examples/offline_inference/qwen3_omni/run_single_prompt_async_chunk.sh
@@ -6,13 +6,13 @@
# achieving true stage-level concurrency via chunk-level streaming.
#
# Prerequisites:
-# - A deploy config YAML (e.g. qwen3_omni_moe.yaml)
+# - An async_chunk stage config YAML (e.g. qwen3_omni_moe_async_chunk.yaml)
# - Hardware matching the config (e.g. 2x H100 for the default 3-stage config)
#
# Usage:
# bash run_single_prompt_async_chunk.sh
# bash run_single_prompt_async_chunk.sh --query-type text --modalities text
-# bash run_single_prompt_async_chunk.sh --deploy-config /path/to/custom.yaml
+# bash run_single_prompt_async_chunk.sh --stage-configs-path /path/to/custom.yaml
set -euo pipefail
@@ -21,6 +21,6 @@ REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
python "${SCRIPT_DIR}/end2end_async_chunk.py" \
--query-type use_audio \
- --deploy-config "${REPO_ROOT}/vllm_omni/deploy/qwen3_omni_moe.yaml" \
+ --stage-configs-path "${REPO_ROOT}/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml" \
--output-dir output_audio_async_chunk \
"$@"
diff --git a/examples/offline_inference/qwen3_tts/README.md b/examples/offline_inference/qwen3_tts/README.md
index 2971ad716a2..bf59dc9ba49 100644
--- a/examples/offline_inference/qwen3_tts/README.md
+++ b/examples/offline_inference/qwen3_tts/README.md
@@ -15,11 +15,11 @@ Please refer to the [stage configuration documentation](https://docs.vllm.ai/pro
### ROCm Dependencies
-You will need to install the dependency `onnxruntime-rocm`.
+You will need to install these two dependencies `onnxruntime-rocm` and `sox`.
```
pip uninstall onnxruntime # should be removed before we can install onnxruntime-rocm
-pip install onnxruntime-rocm
+pip install onnxruntime-rocm sox
```
## Quick Start
@@ -104,13 +104,13 @@ completes. This demonstrates that audio data is available progressively rather t
## Batched Decoding
-The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, set `max_num_seqs > 1` on both stages via `--stage-overrides` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`.
+The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, provide a stage config with `max_num_seqs > 1` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`.
```
python end2end.py --query-type CustomVoice \
--txt-prompts benchmark_prompts.txt \
--batch-size 4 \
- --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2},"1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}'
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml
```
**Important:** `--batch-size` must match a CUDA graph capture size (1, 2, 4, 8, 16...) because the Talker's code predictor KV cache is sized to `max_num_seqs`, and CUDA graphs pad the batch to the next capture size. Both stages need `max_num_seqs >= batch_size` in the stage config for batching to take effect. If only stage 1 has a higher `max_num_seqs`, it won't help — stage 1 can only batch chunks from requests that are in-flight simultaneously, which requires stage 0 to also process multiple requests concurrently.
diff --git a/examples/offline_inference/qwen3_tts/end2end.py b/examples/offline_inference/qwen3_tts/end2end.py
index 77da356b4f8..901418c39b8 100644
--- a/examples/offline_inference/qwen3_tts/end2end.py
+++ b/examples/offline_inference/qwen3_tts/end2end.py
@@ -366,7 +366,12 @@ def main(args):
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
- omni = Omni.from_cli_args(args, model=model_name)
+ omni = Omni(
+ model=model_name,
+ stage_configs_path=args.stage_configs_path,
+ log_stats=args.log_stats,
+ stage_init_timeout=args.stage_init_timeout,
+ )
batch_size = args.batch_size
for batch_start in range(0, len(inputs), batch_size):
@@ -382,7 +387,12 @@ async def main_streaming(args):
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
- omni = AsyncOmni.from_cli_args(args, model=model_name)
+ omni = AsyncOmni(
+ model=model_name,
+ stage_configs_path=args.stage_configs_path,
+ log_stats=args.log_stats,
+ stage_init_timeout=args.stage_init_timeout,
+ )
for i, prompt in enumerate(inputs):
request_id = str(i)
diff --git a/examples/offline_inference/text_to_audio/README.md b/examples/offline_inference/text_to_audio/README.md
index 10a8ae37ed1..7edc38092ad 100644
--- a/examples/offline_inference/text_to_audio/README.md
+++ b/examples/offline_inference/text_to_audio/README.md
@@ -23,23 +23,6 @@ python text_to_audio.py \
--guidance-scale 7.0 \
--audio-length 10.0 \
--num-inference-steps 100 \
- --cache-backend tea_cache \
- --output stable_audio_output.wav
-```
-
-To reduce per-GPU memory for multi-GPU inference, launch with HSDP:
-
-```bash
-python text_to_audio.py \
- --model stabilityai/stable-audio-open-1.0 \
- --prompt "The sound of a hammer hitting a wooden surface" \
- --negative-prompt "Low quality" \
- --seed 42 \
- --guidance-scale 7.0 \
- --audio-length 10.0 \
- --num-inference-steps 100 \
- --use-hsdp \
- --hsdp-shard-size 2 \
--output stable_audio_output.wav
```
@@ -51,8 +34,4 @@ Key arguments:
- `--guidance-scale`: classifier-free guidance scale.
- `--audio-length`: audio duration in seconds.
- `--num-inference-steps`: diffusion sampling steps.(more steps = higher quality, slower).
-- `--use-hsdp`: enable HSDP weight sharding for the Stable Audio DiT.
-- `--hsdp-shard-size`: number of GPUs used for HSDP sharding.
-- `--hsdp-replicate-size`: number of HSDP replica groups.
-- `--cache-backend`: cache acceleration backend. Stable Audio currently supports `tea_cache`.
- `--output`: path to save the generated WAV file.
diff --git a/examples/offline_inference/text_to_audio/text_to_audio.py b/examples/offline_inference/text_to_audio/text_to_audio.py
index 2a1613e5e91..a6968c419f6 100644
--- a/examples/offline_inference/text_to_audio/text_to_audio.py
+++ b/examples/offline_inference/text_to_audio/text_to_audio.py
@@ -11,7 +11,6 @@
python text_to_audio.py --prompt "The sound of a dog barking"
python text_to_audio.py --prompt "A piano playing a gentle melody" --audio-length 10.0
python text_to_audio.py --prompt "Thunder and rain sounds" --negative-prompt "Low quality"
- python text_to_audio.py --prompt "A soft synth pad" --cache-backend tea_cache
"""
import argparse
@@ -21,7 +20,6 @@
import numpy as np
import torch
-from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
@@ -92,83 +90,11 @@ def parse_args() -> argparse.Namespace:
default=44100,
help="Sample rate for output audio (Stable Audio uses 44100 Hz).",
)
- parser.add_argument(
- "--cache-backend",
- type=str,
- default=None,
- choices=["tea_cache"],
- help=(
- "Cache backend to use for acceleration. "
- "Stable Audio currently supports 'tea_cache'. "
- "Default: None (no cache acceleration)."
- ),
- )
- parser.add_argument(
- "--tea-cache-rel-l1-thresh",
- type=float,
- default=0.2,
- help="[tea_cache] Threshold for accumulated relative L1 distance.",
- )
parser.add_argument(
"--enable-diffusion-pipeline-profiler",
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)
- parser.add_argument(
- "--use-hsdp",
- action="store_true",
- help="Enable HSDP for Stable Audio DiT weight sharding.",
- )
- parser.add_argument(
- "--hsdp-shard-size",
- type=int,
- default=1,
- help="Number of GPUs to shard Stable Audio DiT weights across when HSDP is enabled.",
- )
- parser.add_argument(
- "--hsdp-replicate-size",
- type=int,
- default=1,
- help="Number of HSDP replica groups. Default 1 means pure sharding.",
- )
- parser.add_argument(
- "--tensor-parallel-size",
- type=int,
- default=1,
- help="Number of GPUs used for tensor parallelism (TP) inside the DiT.",
- )
- parser.add_argument(
- "--ulysses-degree",
- type=int,
- default=1,
- help="Number of GPUs used for ulysses sequence parallelism.",
- )
- parser.add_argument(
- "--ulysses-mode",
- type=str,
- default="strict",
- choices=["strict", "advanced_uaa"],
- help="Ulysses sequence-parallel mode: 'strict' (divisibility required) or 'advanced_uaa' (UAA).",
- )
- parser.add_argument(
- "--ring-degree",
- type=int,
- default=1,
- help="Number of GPUs used for ring sequence parallelism.",
- )
- parser.add_argument(
- "--cfg-parallel-size",
- type=int,
- default=1,
- choices=[1, 2],
- help="Number of GPUs used for classifier free guidance parallel size.",
- )
- parser.add_argument(
- "--vae-patch-parallel-size",
- type=int,
- default=1,
- help="Number of GPUs used for VAE patch/tile parallelism (decode).",
- )
return parser.parse_args()
@@ -198,11 +124,6 @@ def save_audio(audio_data: np.ndarray, output_path: str, sample_rate: int = 4410
def main():
args = parse_args()
generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed)
- cache_config = None
- if args.cache_backend == "tea_cache":
- cache_config = {
- "rel_l1_thresh": args.tea_cache_rel_l1_thresh,
- }
print(f"\n{'=' * 60}")
print("Stable Audio Open - Text-to-Audio Generation")
@@ -213,26 +134,12 @@ def main():
print(f" Audio length: {args.audio_length}s")
print(f" Inference steps: {args.num_inference_steps}")
print(f" Guidance scale: {args.guidance_scale}")
- print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}")
- if args.use_hsdp:
- print(f" HSDP: enabled (shard_size={args.hsdp_shard_size}, replicate_size={args.hsdp_replicate_size})")
- else:
- print(" HSDP: disabled")
print(f" Seed: {args.seed}")
print(f"{'=' * 60}\n")
- parallel_config = DiffusionParallelConfig(
- use_hsdp=args.use_hsdp,
- hsdp_shard_size=args.hsdp_shard_size,
- hsdp_replicate_size=args.hsdp_replicate_size,
- )
-
# Initialize Omni with Stable Audio model
omni = Omni(
model=args.model,
- parallel_config=parallel_config,
- cache_backend=args.cache_backend,
- cache_config=cache_config,
enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
)
diff --git a/examples/offline_inference/text_to_image/README.md b/examples/offline_inference/text_to_image/README.md
index c71773972b3..235b710a68e 100644
--- a/examples/offline_inference/text_to_image/README.md
+++ b/examples/offline_inference/text_to_image/README.md
@@ -29,12 +29,10 @@ This folder provides several entrypoints for experimenting with text-to-image di
| `AIDC-AI/Ovis-Image-7B` | 1024 x 1024 | 71.8 | 17.1 |
| `OmniGen2/OmniGen2` | 1024 x 1024 | 20.1 | 14.7 |
| `stabilityai/stable-diffusion-3.5-medium` | 1024 x 1024 | 20.1 | 15.6 |
-| `black-forest-labs/FLUX.1-dev` | 1024 x 1024 | 33.9 | 31.4 |
-| `black-forest-labs/FLUX.1-schnell` | 1024 x 1024 | 33.9 | 31.4 |
+| `black-forest-labs/FLUX.1-dev` | 1024 x 1024 | 77.6 | 31.4 |
| `black-forest-labs/FLUX.2-klein-4B` | 1024 x 1024 | 72.7 | 14.9 |
| `black-forest-labs/FLUX.2-klein-9B` | 1024 x 1024 | 37.1 | 32.3 |
| `black-forest-labs/FLUX.2-dev` | 1024 x 1024 | 65.7 | >80 (CPU offload required) |
-| `HunyuanImage-3.0` | 1024 x 1024 | 80.0 (TP≥3) | 160 |
!!! info
*Peak VRAM: based on basic single-card usage, batch size =1, without any acceleration/optimization features. FLUX.2-dev requires `--enable-cpu-offload` on a single 80 GiB GPU.
@@ -92,8 +90,6 @@ python text_to_image.py \
| `--enable-cpu-offload` | flag | off | Enable CPU offloading for diffusion models |
| `--lora-path` | str | — | Path to PEFT LoRA adapter folder |
| `--lora-scale` | float | `1.0` | Scale factor for LoRA weights |
-| `--use-system-prompt` | str | `None` | System prompt preset: `en_unified`, `en_vanilla`, `en_recaption`, `en_think_recaption`, `dynamic`, `None`, or custom text. Recommended: `en_unified`. Only for HunyuanImage-3.0.|
-| `--system-prompt` | str | `None` | Custom system prompt text. Only used when `--use-system-prompt` is set to `custom`. Only for HunyuanImage-3.0.|
**NextStep-1.1 specific arguments:**
@@ -248,7 +244,7 @@ python examples/offline_inference/text_to_image/text_to_image.py \
#### CFG Parallel
Set `--cfg-parallel-size 2` to enable CFG Parallel for faster inference on multi-GPU setups.
-See more examples in the [cfg_parallel user guide](../../../docs/user_guide/parallelism/cfg_parallel.md#using-cfg-parallel).
+See more examples in the [diffusion acceleration user guide](../../../docs/user_guide/diffusion_acceleration.md#using-cfg-parallel).
#### LoRA
diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py
index bc18c685912..927b0f0b087 100644
--- a/examples/offline_inference/text_to_image/text_to_image.py
+++ b/examples/offline_inference/text_to_image/text_to_image.py
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
-import json
+import os
import time
from pathlib import Path
from typing import Any
@@ -30,16 +30,6 @@ def is_nextstep_model(model_name: str) -> bool:
return False
-def parse_profiler_config(value: str) -> dict[str, Any]:
- try:
- config = json.loads(value)
- except json.JSONDecodeError as e:
- raise argparse.ArgumentTypeError(f"--profiler-config must be valid JSON: {e}") from e
- if not isinstance(config, dict):
- raise argparse.ArgumentTypeError("--profiler-config must be a JSON object")
- return config
-
-
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Generate an image with supported diffusion models.")
parser.add_argument(
@@ -154,23 +144,6 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable layerwise (blockwise) offloading on DiT modules.",
)
- parser.add_argument(
- "--use-hsdp",
- action="store_true",
- help="Enable HSDP (Hybrid Sharded Data Parallel) for diffusion models.",
- )
- parser.add_argument(
- "--hsdp-shard-size",
- type=int,
- default=1,
- help="Number of GPUs to shard weights across for HSDP.",
- )
- parser.add_argument(
- "--hsdp-replicate-size",
- type=int,
- default=1,
- help="Number of HSDP replica groups.",
- )
parser.add_argument(
"--quantization",
type=str,
@@ -264,46 +237,11 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)
- parser.add_argument(
- "--profiler-config",
- type=parse_profiler_config,
- default=None,
- help='JSON profiler config for torch/cuda profiling, e.g. \'{"profiler":"torch","torch_profiler_dir":"./perf"}\'.',
- )
parser.add_argument(
"--log-stats",
action="store_true",
help="Enable logging of diffusion pipeline stats.",
)
- parser.add_argument(
- "--init-timeout",
- type=int,
- default=600,
- help="Timeout for initializing a single stage in seconds (default: 600s)",
- )
- parser.add_argument(
- "--stage-init-timeout",
- type=int,
- default=600,
- help="Timeout for initializing a single stage in seconds (default: 600s)",
- )
- parser.add_argument(
- "--use-system-prompt",
- type=str,
- default=None,
- choices=["None", "dynamic", "en_vanilla", "en_recaption", "en_think_recaption", "en_unified", "custom"],
- help="System prompt preset for generation. Recommended: en_unified.",
- )
- parser.add_argument(
- "--system-prompt",
- type=str,
- default=None,
- help=("Custom system prompt. Used when --use-system-prompt is custom. "),
- )
- current_omni_platform.pre_register_and_update(parser)
- from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
-
- nullify_stage_engine_defaults(parser)
return parser.parse_args()
@@ -353,7 +291,8 @@ def main():
enable_expert_parallel=args.enable_expert_parallel,
)
- profiler_enabled = args.profiler_config is not None
+ # Check if profiling is requested via environment variable
+ profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR"))
# Prepare LoRA kwargs for Omni initialization
lora_args: dict[str, Any] = {}
@@ -394,9 +333,6 @@ def main():
"mode": "text-to-image",
"log_stats": args.log_stats,
"enable_diffusion_pipeline_profiler": args.enable_diffusion_pipeline_profiler,
- "profiler_config": args.profiler_config,
- "init_timeout": args.init_timeout,
- "stage_init_timeout": args.stage_init_timeout,
**lora_args,
**quant_kwargs,
}
@@ -427,7 +363,7 @@ def main():
f"vae_patch_parallel_size={args.vae_patch_parallel_size}, "
f"enable_expert_parallel={args.enable_expert_parallel}."
)
- print(f" CPU offload: {args.enable_cpu_offload}; CPU Layerwise Offload: {args.enable_layerwise_offload}")
+ print(f" CPU offload: {args.enable_cpu_offload}")
print(f" Image size: {args.width}x{args.height}")
if args.lora_path:
print(f" LoRA: scale={args.lora_scale}")
@@ -446,13 +382,13 @@ def main():
)
generation_start = time.perf_counter()
+
extra_args = {
"timesteps_shift": args.timesteps_shift,
"cfg_schedule": args.cfg_schedule,
"use_norm": args.use_norm,
- "use_system_prompt": args.use_system_prompt,
- "system_prompt": args.system_prompt,
}
+
if lora_request:
extra_args["lora_request"] = lora_request
extra_args["lora_scale"] = args.lora_scale
diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py
index d1bbf27cb45..cf779210977 100644
--- a/examples/offline_inference/text_to_video/text_to_video.py
+++ b/examples/offline_inference/text_to_video/text_to_video.py
@@ -2,10 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
-import json
+import os
import time
from pathlib import Path
-from typing import Any
import numpy as np
import torch
@@ -45,16 +44,6 @@ def _detect_preset(model: str) -> dict:
return _MODEL_PRESETS["wan"]
-def parse_profiler_config(value: str) -> dict[str, Any]:
- try:
- config = json.loads(value)
- except json.JSONDecodeError as e:
- raise argparse.ArgumentTypeError(f"--profiler-config must be valid JSON: {e}") from e
- if not isinstance(config, dict):
- raise argparse.ArgumentTypeError("--profiler-config must be a JSON object")
- return config
-
-
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt. "
@@ -142,13 +131,6 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable layerwise (blockwise) offloading on DiT modules.",
)
- parser.add_argument(
- "--ulysses-mode",
- type=str,
- default="strict",
- choices=["strict", "advanced_uaa"],
- help="Ulysses sequence-parallel mode: 'strict' (divisibility required) or 'advanced_uaa' (UAA).",
- )
parser.add_argument(
"--ulysses-degree",
type=int,
@@ -178,7 +160,7 @@ def parse_args() -> argparse.Namespace:
"--audio-sample-rate",
type=int,
default=24000,
- help="Sample rate for audio output when saved (default: 24000).",
+ help="Sample rate for audio output when saved (default: 24000 for LTX2).",
)
parser.add_argument(
"--vae-patch-parallel-size",
@@ -196,12 +178,6 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)
- parser.add_argument(
- "--profiler-config",
- type=parse_profiler_config,
- default=None,
- help='JSON profiler config for torch/cuda profiling, e.g. \'{"profiler":"torch","torch_profiler_dir":"./perf"}\'.',
- )
parser.add_argument(
"--quantization",
type=str,
@@ -209,26 +185,6 @@ def parse_args() -> argparse.Namespace:
choices=["fp8", "gguf"],
help="Quantization method for the transformer (fp8 for online FP8 quantization).",
)
- parser.add_argument(
- "--use-hsdp",
- action="store_true",
- help="Enable HSDP (Hybrid Sharded Data Parallel) for diffusion models.",
- )
- parser.add_argument(
- "--hsdp-shard-size",
- type=int,
- default=1,
- help="Number of GPUs to shard weights across for HSDP.",
- )
- parser.add_argument(
- "--hsdp-replicate-size",
- type=int,
- default=1,
- help="Number of HSDP replica groups.",
- )
- from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
-
- nullify_stage_engine_defaults(parser)
return parser.parse_args()
@@ -268,7 +224,8 @@ def main():
enable_expert_parallel=args.enable_expert_parallel,
)
- profiler_enabled = args.profiler_config is not None
+ # Check if profiling is requested via environment variable
+ profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR"))
omni_kwargs = dict(
model=args.model,
@@ -282,7 +239,6 @@ def main():
cache_backend=args.cache_backend,
cache_config=cache_config,
enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
- profiler_config=args.profiler_config,
)
if args.boundary_ratio is not None:
omni_kwargs["boundary_ratio"] = args.boundary_ratio
@@ -482,8 +438,17 @@ def _ensure_frame_list(video_array):
video_array = _ensure_frame_list(video_array)
+ use_ltx2_export = False
+ if args.model and "ltx" in str(args.model).lower():
+ use_ltx2_export = True
if audio is not None:
- from vllm_omni.diffusion.utils.media_utils import mux_video_audio_bytes
+ use_ltx2_export = True
+
+ if use_ltx2_export:
+ try:
+ from diffusers.pipelines.ltx2.export_utils import encode_video
+ except ImportError:
+ raise ImportError("diffusers is required for LTX2 encode_video.")
if isinstance(video_array, list):
frames_np = np.stack(video_array, axis=0)
@@ -492,24 +457,28 @@ def _ensure_frame_list(video_array):
else:
frames_np = np.asarray(video_array)
- frames_u8 = (np.clip(frames_np, 0.0, 1.0) * 255).round().clip(0, 255).astype("uint8")
-
- audio_np = audio
- if isinstance(audio_np, list):
- audio_np = audio_np[0] if audio_np else None
- if isinstance(audio_np, torch.Tensor):
- audio_np = audio_np.detach().cpu().float().numpy()
- if isinstance(audio_np, np.ndarray):
- audio_np = np.squeeze(audio_np).astype(np.float32)
-
- video_bytes = mux_video_audio_bytes(
- frames_u8,
- audio_np,
- fps=float(args.fps),
- audio_sample_rate=args.audio_sample_rate,
+ frames_u8 = (frames_np * 255).round().clip(0, 255).astype("uint8")
+ video_tensor = torch.from_numpy(frames_u8)
+
+ audio_out = None
+ if audio is not None:
+ if isinstance(audio, list):
+ audio = audio[0] if audio else None
+ if isinstance(audio, np.ndarray):
+ audio = torch.from_numpy(audio)
+ if isinstance(audio, torch.Tensor):
+ audio_out = audio
+ if audio_out.dim() > 1:
+ audio_out = audio_out[0]
+ audio_out = audio_out.float().cpu()
+
+ encode_video(
+ video_tensor,
+ fps=args.fps,
+ audio=audio_out,
+ audio_sample_rate=args.audio_sample_rate if audio_out is not None else None,
+ output_path=str(output_path),
)
- with open(str(output_path), "wb") as f:
- f.write(video_bytes)
else:
export_to_video(video_array, str(output_path), fps=args.fps)
print(f"Saved generated video to {output_path}")
diff --git a/examples/offline_inference/vace/vace_video_generation.py b/examples/offline_inference/vace/vace_video_generation.py
index 5fad3736635..6ca0d74c52e 100644
--- a/examples/offline_inference/vace/vace_video_generation.py
+++ b/examples/offline_inference/vace/vace_video_generation.py
@@ -71,9 +71,6 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--ulysses-degree", type=int, default=1, help="Ulysses SP degree.")
parser.add_argument("--ring-degree", type=int, default=1, help="Ring attention degree.")
parser.add_argument("--cfg-parallel-size", type=int, default=1, choices=[1, 2], help="CFG parallel size.")
- from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
-
- nullify_stage_engine_defaults(parser)
return parser.parse_args()
diff --git a/examples/offline_inference/voxcpm/README.md b/examples/offline_inference/voxcpm/README.md
deleted file mode 100644
index 1eaea9b0dba..00000000000
--- a/examples/offline_inference/voxcpm/README.md
+++ /dev/null
@@ -1,123 +0,0 @@
-# VoxCPM Offline Example
-
-This directory contains the minimal offline VoxCPM example for vLLM Omni.
-
-`end2end.py` is intentionally small and only covers:
-
-- single text-to-speech
-- single voice cloning with `ref_audio` + `ref_text`
-- non-streaming with `vllm_omni/model_executor/stage_configs/voxcpm.yaml`
-- streaming with `vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml`
-
-Advanced workflows were moved out of the getting-started example:
-
-- `benchmarks/voxcpm/vllm_omni/bench_tts_offline.py`: warmup, batch prompts, profiler, offline TTFP / RTF
-- `benchmarks/voxcpm/vllm_omni/run_offline_matrix.py`: fixed offline smoke matrix
-- `benchmarks/voxcpm/`: benchmark scripts and benchmark docs
-
-## Prerequisites
-
-Install VoxCPM in one of these ways:
-
-```bash
-pip install voxcpm
-```
-
-or point vLLM Omni to the local VoxCPM source tree:
-
-```bash
-export VLLM_OMNI_VOXCPM_CODE_PATH=/path/to/VoxCPM/src
-```
-
-The example writes WAV files with `soundfile`:
-
-```bash
-pip install soundfile
-```
-
-## Model Path
-
-Pass the native VoxCPM model directory directly:
-
-```bash
-export VOXCPM_MODEL=/path/to/voxcpm-model
-```
-
-If the native VoxCPM `config.json` does not contain HuggingFace metadata such as
-`model_type`, prepare a persistent HF-compatible config directory and point the
-stage configs to it with `VLLM_OMNI_VOXCPM_HF_CONFIG_PATH`:
-
-```bash
-export VLLM_OMNI_VOXCPM_HF_CONFIG_PATH=/tmp/voxcpm_hf_config
-mkdir -p "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH"
-cp "$VOXCPM_MODEL/config.json" "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH/config.json"
-cp "$VOXCPM_MODEL/generation_config.json" "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH/generation_config.json" 2>/dev/null || true
-python3 -c 'import json, os; p=os.path.join(os.environ["VLLM_OMNI_VOXCPM_HF_CONFIG_PATH"], "config.json"); cfg=json.load(open(p, "r", encoding="utf-8")); cfg["model_type"]="voxcpm"; cfg.setdefault("architectures", ["VoxCPMForConditionalGeneration"]); json.dump(cfg, open(p, "w", encoding="utf-8"), indent=2, ensure_ascii=False)'
-```
-
-If the model directory itself already has `model_type`, this extra directory is
-not required.
-
-## Quick Start
-
-Single text-to-speech, non-streaming:
-
-```bash
-python examples/offline_inference/voxcpm/end2end.py \
- --model "$VOXCPM_MODEL" \
- --text "This is a split-stage VoxCPM synthesis example running on vLLM Omni."
-```
-
-Single voice cloning, non-streaming:
-
-```bash
-python examples/offline_inference/voxcpm/end2end.py \
- --model "$VOXCPM_MODEL" \
- --text "This sentence is synthesized with a cloned voice." \
- --ref-audio /path/to/reference.wav \
- --ref-text "The exact transcript spoken in reference.wav."
-```
-
-Streaming:
-
-```bash
-python examples/offline_inference/voxcpm/end2end.py \
- --model "$VOXCPM_MODEL" \
- --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml \
- --text "This is a split-stage VoxCPM streaming example running on vLLM Omni."
-```
-
-By default, `end2end.py` writes to `output_audio/` for non-streaming and
-`output_audio_streaming/` for streaming.
-
-## Advanced Workflows
-
-Use `benchmarks/voxcpm/vllm_omni/bench_tts_offline.py` when you need:
-
-- warmup runs
-- prompt files
-- batch JSONL inputs
-- profiler injection
-- offline TTFP / RTF emission
-
-Use `benchmarks/voxcpm/vllm_omni/run_offline_matrix.py` when you need the fixed offline smoke matrix that previously lived in `test.py`.
-
-Full matrix benchmark example:
-
-```bash
-python benchmarks/voxcpm/vllm_omni/run_offline_matrix.py \
- --model "$VOXCPM_MODEL" \
- --ref-audio /path/to/reference.wav \
- --ref-text "The exact transcript spoken in reference.wav."
-```
-
-For online serving examples, see [examples/online_serving/voxcpm](../../online_serving/voxcpm/README.md).
-
-For benchmark reporting, see [benchmarks/voxcpm](../../../benchmarks/voxcpm/README.md).
-
-## Notes
-
-- `voxcpm.yaml` is the default non-streaming stage config.
-- `voxcpm_async_chunk.yaml` is the streaming stage config.
-- Streaming is currently single-request oriented; the fixed smoke matrix now lives in `benchmarks/voxcpm/vllm_omni/run_offline_matrix.py`.
-- `ref_text` must be the real transcript of the reference audio. Mismatched text usually causes obvious quality degradation.
diff --git a/examples/offline_inference/voxcpm/end2end.py b/examples/offline_inference/voxcpm/end2end.py
deleted file mode 100644
index 980410feaeb..00000000000
--- a/examples/offline_inference/voxcpm/end2end.py
+++ /dev/null
@@ -1,206 +0,0 @@
-"""Minimal offline VoxCPM example for vLLM Omni."""
-
-from __future__ import annotations
-
-import asyncio
-import time
-from pathlib import Path
-from typing import Any
-
-import soundfile as sf
-import torch
-from vllm.utils.argparse_utils import FlexibleArgumentParser
-
-from vllm_omni import AsyncOmni, Omni
-
-REPO_ROOT = Path(__file__).resolve().parents[3]
-DEFAULT_SYNC_STAGE_CONFIG = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm.yaml"
-
-
-def _build_prompt(args) -> dict[str, Any]:
- additional_information: dict[str, list[Any]] = {
- "text": [args.text],
- "cfg_value": [args.cfg_value],
- "inference_timesteps": [args.inference_timesteps],
- "min_len": [args.min_len],
- "max_new_tokens": [args.max_new_tokens],
- }
- if args.streaming_prefix_len is not None:
- additional_information["streaming_prefix_len"] = [args.streaming_prefix_len]
- if args.ref_audio is not None:
- additional_information["ref_audio"] = [args.ref_audio]
- if args.ref_text is not None:
- additional_information["ref_text"] = [args.ref_text]
- return {
- "prompt_token_ids": [1],
- "additional_information": additional_information,
- }
-
-
-def _extract_audio_tensor(mm: dict[str, Any]) -> torch.Tensor:
- audio = mm.get("audio", mm.get("model_outputs"))
- if audio is None:
- raise ValueError("No audio output found in multimodal output.")
- if isinstance(audio, list):
- parts = [torch.as_tensor(item).float().cpu().reshape(-1) for item in audio]
- audio = torch.cat(parts, dim=-1) if parts else torch.zeros(0)
- if not isinstance(audio, torch.Tensor):
- audio = torch.as_tensor(audio)
- return audio.float().cpu().reshape(-1)
-
-
-def _extract_sample_rate(mm: dict[str, Any]) -> int:
- sr_raw = mm.get("sr", 24000)
- if isinstance(sr_raw, list) and sr_raw:
- sr_raw = sr_raw[-1]
- if hasattr(sr_raw, "item"):
- return int(sr_raw.item())
- return int(sr_raw)
-
-
-def _is_streaming_stage_config(stage_config_path: str) -> bool:
- return "async_chunk" in Path(stage_config_path).stem
-
-
-def _save_audio(audio: torch.Tensor, sample_rate: int, output_dir: Path, request_id: str) -> Path:
- output_dir.mkdir(parents=True, exist_ok=True)
- output_path = output_dir / f"output_{request_id}.wav"
- sf.write(
- output_path,
- audio.float().cpu().clamp(-1.0, 1.0).numpy(),
- sample_rate,
- format="WAV",
- subtype="PCM_16",
- )
- return output_path
-
-
-async def _run_streaming(args) -> Path:
- prompt = _build_prompt(args)
- output_dir = Path(args.output_dir) if args.output_dir is not None else Path("output_audio_streaming")
- request_id = "streaming_example"
- sample_rate = 24000
- buffered_samples = 0
- chunks: list[torch.Tensor] = []
- started = time.perf_counter()
- omni = AsyncOmni(
- model=args.model,
- stage_configs_path=args.stage_configs_path,
- log_stats=args.log_stats,
- stage_init_timeout=args.stage_init_timeout,
- )
- try:
- async for stage_output in omni.generate(prompt, request_id=request_id):
- mm = getattr(stage_output, "multimodal_output", None)
- if not isinstance(mm, dict):
- request_output = getattr(stage_output, "request_output", None)
- if request_output is None:
- continue
- mm = getattr(request_output, "multimodal_output", None)
- if not isinstance(mm, dict) and getattr(request_output, "outputs", None):
- mm = getattr(request_output.outputs[0], "multimodal_output", None)
- if not isinstance(mm, dict):
- continue
- audio = _extract_audio_tensor(mm)
- if audio.numel() == 0:
- continue
- sample_rate = _extract_sample_rate(mm)
- if audio.numel() > buffered_samples:
- delta = audio[buffered_samples:]
- buffered_samples = int(audio.numel())
- else:
- delta = audio
- buffered_samples += int(delta.numel())
- if delta.numel() > 0:
- chunks.append(delta)
- if not chunks:
- raise RuntimeError("No streaming audio chunks received from VoxCPM.")
- output_audio = torch.cat(chunks, dim=0)
- output_path = _save_audio(output_audio, sample_rate, output_dir, request_id)
- print(f"Saved streaming audio to: {output_path} ({time.perf_counter() - started:.2f}s)")
- return output_path
- finally:
- omni.shutdown()
-
-
-def _run_sync(args) -> Path:
- prompt = _build_prompt(args)
- output_dir = Path(args.output_dir) if args.output_dir is not None else Path("output_audio")
- request_id = "sync_example"
- started = time.perf_counter()
- last_mm: dict[str, Any] | None = None
- omni = Omni(
- model=args.model,
- stage_configs_path=args.stage_configs_path,
- log_stats=args.log_stats,
- stage_init_timeout=args.stage_init_timeout,
- )
- for stage_outputs in omni.generate(prompt):
- request_output = getattr(stage_outputs, "request_output", None)
- if request_output is None:
- continue
- outputs = getattr(request_output, "outputs", None)
- if outputs:
- for output in outputs:
- mm = getattr(output, "multimodal_output", None)
- if isinstance(mm, dict):
- last_mm = mm
- mm = getattr(request_output, "multimodal_output", None)
- if isinstance(mm, dict):
- last_mm = mm
- if last_mm is None:
- raise RuntimeError("No audio output received from VoxCPM.")
- output_path = _save_audio(
- _extract_audio_tensor(last_mm),
- _extract_sample_rate(last_mm),
- output_dir,
- request_id,
- )
- print(f"Saved audio to: {output_path} ({time.perf_counter() - started:.2f}s)")
- return output_path
-
-
-def parse_args():
- parser = FlexibleArgumentParser(description="Minimal offline VoxCPM example for vLLM Omni.")
- parser.add_argument("--model", type=str, required=True, help="Local VoxCPM model directory.")
- parser.add_argument(
- "--stage-configs-path",
- type=str,
- default=str(DEFAULT_SYNC_STAGE_CONFIG),
- help=("Stage config path. Use voxcpm.yaml for non-streaming or voxcpm_async_chunk.yaml for streaming."),
- )
- parser.add_argument("--text", type=str, required=True, help="Input text for synthesis.")
- parser.add_argument("--ref-audio", type=str, default=None, help="Reference audio path for voice cloning.")
- parser.add_argument("--ref-text", type=str, default=None, help="Transcript of the reference audio.")
- parser.add_argument("--output-dir", type=str, default=None, help="Output directory for generated wav files.")
- parser.add_argument("--cfg-value", type=float, default=2.0, help="Guidance value passed to VoxCPM.")
- parser.add_argument("--inference-timesteps", type=int, default=10, help="Number of diffusion timesteps.")
- parser.add_argument("--min-len", type=int, default=2, help="Minimum latent length.")
- parser.add_argument("--max-new-tokens", type=int, default=4096, help="Maximum latent length.")
- parser.add_argument(
- "--streaming-prefix-len",
- type=int,
- default=3,
- help="Streaming prefix length used by voxcpm_async_chunk.yaml.",
- )
- parser.add_argument("--stage-init-timeout", type=int, default=600, help="Stage initialization timeout in seconds.")
- parser.add_argument("--log-stats", action="store_true", help="Enable vLLM Omni stats logging.")
- args = parser.parse_args()
- if (args.ref_audio is None) != (args.ref_text is None):
- raise ValueError("Voice cloning requires --ref-audio and --ref-text together.")
- return args
-
-
-def main(args) -> None:
- route = "streaming" if _is_streaming_stage_config(args.stage_configs_path) else "sync"
- print(f"Model: {args.model}")
- print(f"Stage config: {args.stage_configs_path}")
- print(f"Route: {route}")
- if route == "streaming":
- asyncio.run(_run_streaming(args))
- else:
- _run_sync(args)
-
-
-if __name__ == "__main__":
- main(parse_args())
diff --git a/examples/offline_inference/voxcpm2/README.md b/examples/offline_inference/voxcpm2/README.md
deleted file mode 100644
index e9827307997..00000000000
--- a/examples/offline_inference/voxcpm2/README.md
+++ /dev/null
@@ -1,83 +0,0 @@
-# VoxCPM2 Offline Inference (Native AR)
-
-VoxCPM2 is a 2B-parameter tokenizer-free diffusion AR TTS model. It produces 48kHz audio and supports 30+ languages with a single-stage native AR pipeline backed by MiniCPM4.
-
-## Prerequisites
-
-Install the `voxcpm` package, or set the environment variable pointing to the source tree:
-
-```bash
-# Option A: install package
-pip install voxcpm
-
-# Option B: use source checkout
-export VLLM_OMNI_VOXCPM_CODE_PATH=/path/to/voxcpm
-```
-
-## Quick Start
-
-Zero-shot synthesis:
-
-```bash
-python examples/offline_inference/voxcpm2/end2end.py \
- --model openbmb/VoxCPM2 \
- --text "Hello, this is a VoxCPM2 demo." \
- --output-dir output_audio
-```
-
-Voice cloning with a reference audio:
-
-```bash
-python examples/offline_inference/voxcpm2/end2end.py \
- --text "Hello, this is a voice clone demo." \
- --reference-audio /path/to/reference.wav \
- --output-dir output_clone
-```
-
-Prompt continuation (matched audio + text prefix):
-
-```bash
-python examples/offline_inference/voxcpm2/end2end.py \
- --text "Continuation target sentence." \
- --prompt-audio /path/to/prompt.wav \
- --prompt-text "Transcript of the prompt audio." \
- --output-dir output_cont
-```
-
-The script accepts the following arguments:
-
-| Argument | Default | Description |
-|---|---|---|
-| `--model` | `openbmb/VoxCPM2` | HuggingFace repo ID or local path |
-| `--text` | (example sentence) | Text to synthesize |
-| `--output-dir` | `output_audio` | Directory for output WAV files |
-| `--stage-configs-path` | `voxcpm2.yaml` | Stage config YAML path |
-| `--reference-audio` | `None` | Reference audio for voice cloning (isolated) |
-| `--prompt-audio` | `None` | Prompt audio for continuation mode |
-| `--prompt-text` | `None` | Transcript matching `--prompt-audio` |
-
-## Performance
-
-Measured on a single H20 GPU (80 GB):
-
-| Input length | RTF | Sample rate |
-|---|---|---|
-| Short (~10 tokens) | ~0.28 | 48 kHz |
-| Long (~100 tokens) | ~0.34 | 48 kHz |
-
-RTF < 1.0 means faster than real time.
-
-## Architecture
-
-VoxCPM2 uses a single-stage native AR pipeline:
-
-```
-feat_encoder
-└─► MiniCPM4 (base LM)
- └─► FSQ (finite scalar quantization)
- └─► residual_lm (residual AR)
- └─► LocDiT (local diffusion transformer)
- └─► AudioVAE → 48 kHz waveform
-```
-
-All stages are fused into one vllm-native execution graph via `voxcpm2.yaml`, eliminating inter-stage coordination overhead and enabling true end-to-end batching.
diff --git a/examples/offline_inference/voxcpm2/end2end.py b/examples/offline_inference/voxcpm2/end2end.py
deleted file mode 100644
index 6b6bf78ddf1..00000000000
--- a/examples/offline_inference/voxcpm2/end2end.py
+++ /dev/null
@@ -1,171 +0,0 @@
-"""Offline VoxCPM2 inference example (native AR pipeline).
-
-Uses the single-stage native AR config (voxcpm2.yaml).
-Requires the `voxcpm` package or VLLM_OMNI_VOXCPM_CODE_PATH env var.
-"""
-
-from __future__ import annotations
-
-import os
-import time
-from pathlib import Path
-
-import soundfile as sf
-import torch
-from vllm.utils.argparse_utils import FlexibleArgumentParser
-
-from vllm_omni import Omni
-
-REPO_ROOT = Path(__file__).resolve().parents[3]
-DEFAULT_STAGE_CONFIGS_PATH = str(REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm2.yaml")
-SAMPLE_RATE = 48_000
-
-
-def parse_args():
- parser = FlexibleArgumentParser(description="Offline VoxCPM2 native AR inference")
- parser.add_argument(
- "--model",
- type=str,
- default="openbmb/VoxCPM2",
- help="VoxCPM2 model path or HuggingFace repo ID.",
- )
- parser.add_argument(
- "--text",
- type=str,
- default="This is a VoxCPM2 native AR synthesis example running on vLLM Omni.",
- help="Text to synthesize.",
- )
- parser.add_argument(
- "--output-dir",
- type=str,
- default="output_audio",
- help="Directory for output WAV files.",
- )
- parser.add_argument(
- "--stage-configs-path",
- type=str,
- default=DEFAULT_STAGE_CONFIGS_PATH,
- help="Path to the stage config YAML file.",
- )
- parser.add_argument(
- "--reference-audio",
- type=str,
- default=None,
- help="Path to reference audio for voice cloning (isolated ref mode).",
- )
- parser.add_argument(
- "--prompt-audio",
- type=str,
- default=None,
- help="Path to prompt audio for continuation mode (requires --prompt-text).",
- )
- parser.add_argument(
- "--prompt-text",
- type=str,
- default=None,
- help="Text matching --prompt-audio for continuation mode.",
- )
- parser.add_argument(
- "--ref-text",
- type=str,
- default=None,
- help="Optional transcript of --reference-audio (enables ref_continuation mode).",
- )
- return parser.parse_args()
-
-
-def extract_audio(multimodal_output: dict) -> torch.Tensor:
- """Extract the final complete audio tensor from multimodal output.
-
- The output processor concatenates per-step delta tensors under
- ``model_outputs``. Falls back to ``audio`` for backwards compat.
- """
- audio = multimodal_output.get("model_outputs")
- if audio is None:
- audio = multimodal_output.get("audio")
- if audio is None:
- raise ValueError(f"No audio key in multimodal_output: {list(multimodal_output.keys())}")
-
- if isinstance(audio, list):
- # Defensive: usually the output processor consolidates into a single
- # tensor at request completion, but concatenate here too in case the
- # caller consumes intermediate (pre-consolidation) outputs.
- valid = [torch.as_tensor(a).float().cpu().reshape(-1) for a in audio if a is not None]
- if not valid:
- raise ValueError("Audio list is empty or all elements are None.")
- return torch.cat(valid, dim=0) if len(valid) > 1 else valid[0]
-
- return torch.as_tensor(audio).float().cpu().reshape(-1)
-
-
-def main():
- args = parse_args()
-
- output_dir = Path(args.output_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
-
- engine = Omni(
- model=args.model,
- stage_configs_path=args.stage_configs_path,
- )
-
- from transformers import AutoTokenizer
-
- from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import (
- build_cjk_split_map,
- build_voxcpm2_prompt,
- )
-
- tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
- split_map = build_cjk_split_map(tokenizer)
- hf_config = engine.engine.stage_vllm_configs[0].model_config.hf_config
-
- ref_audio_arg = args.reference_audio or args.prompt_audio
- ref_text_arg = args.ref_text or args.prompt_text
- ref_wav, ref_sr = (None, None)
- if ref_audio_arg:
- ref_wav_arr, ref_sr = sf.read(ref_audio_arg)
- ref_wav = ref_wav_arr.mean(axis=-1).tolist() if ref_wav_arr.ndim > 1 else ref_wav_arr.tolist()
-
- prompt = build_voxcpm2_prompt(
- hf_config=hf_config,
- tokenizer=tokenizer,
- split_map=split_map,
- text=args.text,
- ref_audio=ref_wav,
- ref_sr=ref_sr,
- ref_text=ref_text_arg,
- )
-
- print(f"Model : {args.model}")
- print(f"Text : {args.text}")
- if ref_audio_arg:
- print(f"Ref audio : {ref_audio_arg}")
- if ref_text_arg:
- print(f"Ref text : {ref_text_arg}")
- print(f"Output dir : {output_dir}")
-
- t_start = time.perf_counter()
- outputs = engine.generate([prompt])
- elapsed = time.perf_counter() - t_start
-
- # outputs[0].outputs[0].multimodal_output["audio"] is a list of tensors
- request_output = outputs[0]
- mm = request_output.outputs[0].multimodal_output
- audio = extract_audio(mm)
-
- duration = audio.numel() / SAMPLE_RATE
- rtf = elapsed / duration if duration > 0 else float("inf")
-
- output_path = output_dir / "output.wav"
- sf.write(str(output_path), audio.numpy(), SAMPLE_RATE, format="WAV")
-
- print(f"Saved : {output_path}")
- print(f"Duration : {duration:.2f}s")
- print(f"Inference : {elapsed:.2f}s")
- print(f"RTF : {rtf:.3f}")
-
-
-if __name__ == "__main__":
- os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
- main()
diff --git a/examples/offline_inference/voxtral_tts/README.md b/examples/offline_inference/voxtral_tts/README.md
index a55ce8830ee..5f3d5413be7 100644
--- a/examples/offline_inference/voxtral_tts/README.md
+++ b/examples/offline_inference/voxtral_tts/README.md
@@ -10,24 +10,28 @@ When `mistral_common` has `SpeechRequest` support, prompt token IDs are built vi
```bash
# Basic single-prompt with cheerful_female voice preset
python3 examples/offline_inference/voxtral_tts/end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxtral_tts.yaml \
--write-audio --voice cheerful_female \
--model mistralai/Voxtral-4B-TTS-2603 \
--text "That eerie silence after the first storm was just the calm before another round of chaos, wasn't it?"
# 32 replicate prompts with cheerful_female voice preset
python3 examples/offline_inference/voxtral_tts/end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxtral_tts.yaml \
--num-prompts 32 --write-audio --voice cheerful_female \
--model mistralai/Voxtral-4B-TTS-2603 \
--text "That eerie silence after the first storm was just the calm before another round of chaos, wasn't it?"
# Streaming with neutral_female voice preset
python3 examples/offline_inference/voxtral_tts/end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxtral_tts.yaml \
--streaming --write-audio --voice neutral_female \
--model mistralai/Voxtral-4B-TTS-2603 \
--text "That eerie silence after the first storm was just the calm before another round of chaos, wasn't it?"
# 32 prompts, 8 concurrent requests per wave, streaming with neutral_female voice
python3 examples/offline_inference/voxtral_tts/end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxtral_tts.yaml \
--num-prompts 32 --concurrency 8 --streaming --write-audio --voice neutral_female \
--model mistralai/Voxtral-4B-TTS-2603 \
--text "That eerie silence after the first storm was just the calm before another round of chaos, wasn't it?"
@@ -35,6 +39,7 @@ python3 examples/offline_inference/voxtral_tts/end2end.py \
# Short debug prompt with reference audio
# Note: Reference audio capability is not yet released.
python3 examples/offline_inference/voxtral_tts/end2end.py \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/voxtral_tts.yaml \
--write-audio \
--model mistralai/Voxtral-4B-TTS-2603 \
--text "This is a test message." \
@@ -49,7 +54,7 @@ python3 examples/offline_inference/voxtral_tts/end2end.py \
| `--text TEXT` | Text to synthesize (default: `"This is a test message."`) |
| `--audio-path PATH` | Path to reference audio file for voice cloning |
| `--output-dir DIR` | Directory to write output WAV files (default: `output_audio`) |
-| `--deploy-config PATH` | Override the deploy config path. If unset, auto-loads `vllm_omni/deploy/voxtral_tts.yaml` from the HF `model_type`. |
+| `--stage-configs-path PATH` | Path to stage configs YAML (currently it must be set for VoxtralTTS) |
| `--num-prompts N` | Number of replicate prompts to run for measuring performance (default: 1) |
| `--streaming` | Use streaming generation via `AsyncOmni` (default: blocking `Omni`) |
| `--concurrency N` | Max concurrent requests per wave (must be used with `--streaming`, must evenly divide `--num-prompts`) |
diff --git a/examples/offline_inference/voxtral_tts/end2end.py b/examples/offline_inference/voxtral_tts/end2end.py
index 0a6f88715a9..0750246450a 100644
--- a/examples/offline_inference/voxtral_tts/end2end.py
+++ b/examples/offline_inference/voxtral_tts/end2end.py
@@ -39,7 +39,7 @@
async def run_streaming(inputs, sampling_params_list, model_name, args, output_dir):
async_omni = AsyncOmni(
model=model_name,
- deploy_config=args.deploy_config,
+ stage_configs_path=args.stage_configs_path,
log_stats=args.log_stats,
)
@@ -192,7 +192,7 @@ def run_non_streaming(inputs, sampling_params_list, model_name, args, output_dir
llm = Omni(
model=model_name,
log_stats=args.log_stats,
- deploy_config=args.deploy_config,
+ stage_configs_path=args.stage_configs_path,
)
if args.profiling_mode:
@@ -253,11 +253,10 @@ def parse_args() -> Namespace:
help="Directory to write output wav files.",
)
parser.add_argument(
- "--deploy-config",
+ "--stage-configs-path",
type=str,
default=None,
- help="Override the deploy config path. If unset, auto-loads "
- "vllm_omni/deploy/voxtral_tts.yaml based on the HF model_type.",
+ help="Path to stage configs YAML. Auto-resolved from model if not set.",
)
parser.add_argument(
"--num-prompts", type=int, default=1, help="Number of replicate prompts to run for measuring performance"
@@ -298,12 +297,6 @@ def parse_args() -> Namespace:
default=None,
help="Voice to use instead of audio file.",
)
- parser.add_argument(
- "--cfg-alpha",
- type=float,
- default=None,
- help="CFG alpha for flow-matching guidance (default: use value from stage config, typically 1.2).",
- )
return parser.parse_args()
@@ -355,13 +348,8 @@ def main(args: Any) -> None:
inputs = compose_request(model_name, text_chunk, audio_prompt_file, args)
- extra_args = {}
- if args.cfg_alpha is not None:
- extra_args["cfg_alpha"] = args.cfg_alpha
-
sampling_params = SamplingParams(
max_tokens=max_num_tokens,
- extra_args=extra_args if extra_args else None,
)
sampling_params_list = [
sampling_params,
diff --git a/examples/offline_inference/x_to_video_audio/download_dreamid_omni.py b/examples/offline_inference/x_to_video_audio/download_dreamid_omni.py
index 2f66d5f7789..0dbf402e9e3 100644
--- a/examples/offline_inference/x_to_video_audio/download_dreamid_omni.py
+++ b/examples/offline_inference/x_to_video_audio/download_dreamid_omni.py
@@ -82,6 +82,7 @@ def main(output_dir: str):
data = {
"_class_name": "DreamIDOmniPipeline",
+ "fusion": "DreamID-Omni/dreamid_omni.safetensors",
}
with open(os.path.join(output_dir, "model_index.json"), "w", encoding="utf-8") as f:
@@ -89,12 +90,6 @@ def main(output_dir: str):
print(f"model_index.json created at {os.path.join(output_dir, 'model_index.json')}")
- transformer_dir = os.path.join(output_dir, "transformer")
- os.makedirs(transformer_dir, exist_ok=True)
- with open(os.path.join(transformer_dir, "config.json"), "w", encoding="utf-8") as f:
- json.dump({"fusion": "DreamID-Omni/dreamid_omni.safetensors"}, f)
- print(f"transformer/config.json created at {os.path.join(transformer_dir, 'config.json')}")
-
# now we download the dependency code
download_dependency()
diff --git a/examples/offline_inference/x_to_video_audio/x_to_video_audio.md b/examples/offline_inference/x_to_video_audio/x_to_video_audio.md
index 13f2cfe7c0a..59b993a728d 100644
--- a/examples/offline_inference/x_to_video_audio/x_to_video_audio.md
+++ b/examples/offline_inference/x_to_video_audio/x_to_video_audio.md
@@ -24,15 +24,13 @@ dreamid_omni/
│ ├── models_t5_umt5-xxl-enc-bf16.pth
│ ├── Wan2.2_VAE.pth
│
-├── model_index.json
-└── transformer/
- └── config.json # create by download_dreamid_omni.py
+├── model_index.json # create by download_dreamid_omni.py
```
### Run the Inference
-```python
+```
python x_to_video_audio.py \
- --model /path/to/dreamid_omni \
+ --model /xx/dreamid_omni \
--prompt "Two people walking together and singing happily" \
--image-path ./example0.png ./example1.png \
--audio-path ./example0.wav ./example1.wav \
@@ -42,33 +40,11 @@ python x_to_video_audio.py \
--num-inference-steps 45 \
--height 704 \
--width 1280 \
- --output out_dreamid_omni_twoip.mp4
+ --output dreamid_omni.mp4
```
In the current test scenario (2 images + 2 audio inputs), the VRAM requirement is 72GB, regardless of whether cfg-parallel is enabled or disabled.
The VRAM usage can be reduced by enabling CPU offload via --enable-cpu-offload.
-
-You could take reference images/audios from the test cases in the official repo: https://github.com/Guoxu1233/DreamID-Omni
-
-For example, single IP ref resources can be found under https://github.com/Guoxu1233/DreamID-Omni/tree/main/test_case/oneip, you could download them correspondingly to your local and use them for testing.
-
-```python
-# Example usage for oneip, ref media from the official repo DreamID-Omni
-python x_to_video_audio.py \
- --model /path/to/dreamid_omni \
- --prompt ": In the frame, a woman with black long hair is identified as .\n**Overall Environment/Scene**: A lively open-kitchen café at night; stove flames flare, steam rises, and warm pendant lights swing slightly as staff move behind her. The shot is an upper-body close-up.\n**Main Characters/Subjects Appearance**: is a young woman with thick dark wavy hair and a side part. She wears a fitted black top under a light apron, a thin gold chain necklace, and small stud earrings.\n**Main Characters/Subjects Actions**: tastes the sauce with a spoon, then turns her face toward the camera while still holding the spoon, her expression shifting from focused to conflicted.\n maintains eye contact, swallows as if choosing her words, and says, I keep telling myself I’m fine,but some nights it feels like I’m just performing calm." \
- --image-path 9.png \
- --audio-path 9.wav \
- --video-negative-prompt "jitter, bad hands, blur, distortion" \
- --audio-negative-prompt "robotic, muffled, echo, distorted" \
- --cfg-parallel-size 2 \
- --num-inference-steps 45 \
- --height 704 \
- --width 1280 \
- --output out_dreamid_omni_oneip.mp4
-```
-
-
Key arguments:
- `--prompt`: text description (string).
- `--model`: path to the model local directory.
diff --git a/examples/offline_inference/x_to_video_audio/x_to_video_audio.py b/examples/offline_inference/x_to_video_audio/x_to_video_audio.py
index 497284ceb96..17d0f06c3c5 100644
--- a/examples/offline_inference/x_to_video_audio/x_to_video_audio.py
+++ b/examples/offline_inference/x_to_video_audio/x_to_video_audio.py
@@ -5,12 +5,10 @@
import re
import time
-import numpy as np
+import librosa
from PIL import Image
-from vllm.multimodal.media.audio import load_audio
from vllm_omni.diffusion.data import DiffusionParallelConfig
-from vllm_omni.diffusion.utils.media_utils import mux_video_audio_bytes
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -38,8 +36,8 @@ def parse_args() -> argparse.Namespace:
"--cfg-parallel-size",
type=int,
default=1,
- choices=[1, 2, 3, 4],
- help="Number of GPUs used for classifier free guidance parallel size (max 4 branches).",
+ choices=[1, 2],
+ help="Number of GPUs used for classifier free guidance parallel size.",
)
parser.add_argument(
"--video-negative-prompt",
@@ -58,11 +56,6 @@ def parse_args() -> argparse.Namespace:
default=False,
help="Enable CPU offloading for diffusion models.",
)
- parser.add_argument(
- "--enable-layerwise-offload",
- action="store_true",
- help="Enable layerwise (blockwise) offloading on DiT modules.",
- )
return parser.parse_args()
@@ -76,7 +69,7 @@ def load_image_and_audio(image_paths, audio_paths):
image.append(img)
for path in audio_paths:
- audio_array, sr = load_audio(path, sr=16000)
+ audio_array, sr = librosa.load(path, sr=16000)
audio_array = audio_array[int(sr * 1) : int(sr * 3)]
audio.append(audio_array)
return image, audio
@@ -131,7 +124,6 @@ def main() -> None:
parallel_config=parallel_config,
model_type=args.model_type,
enable_cpu_offload=args.enable_cpu_offload,
- enable_layerwise_offload=args.enable_layerwise_offload,
)
start = time.perf_counter()
outputs = omni.generate(prompt, sampling_params)
@@ -139,35 +131,15 @@ def main() -> None:
if not outputs:
raise RuntimeError("No output returned from DreamID-Omni.")
- result = outputs[0]
- if not result.images:
- raise RuntimeError("No video frames found in DreamID-Omni output.")
- generated_video = result.images[0]
- mm = result.multimodal_output or {}
- generated_audio = mm.get("audio")
- fps = int(mm.get("fps", 24))
- sample_rate = int(mm.get("audio_sample_rate", 16000))
-
- # DreamID-Omni returns video as (C, F, H, W) float32 in [-1, 1].
- # mux_video_audio_bytes expects (F, H, W, C) uint8.
- if not isinstance(generated_video, np.ndarray) or generated_video.ndim != 4:
- raise RuntimeError(f"Unexpected video shape: {getattr(generated_video, 'shape', None)}")
- frames = generated_video.transpose(1, 2, 3, 0)
- frames = (np.clip((frames + 1.0) / 2.0, 0.0, 1.0) * 255.0).round().astype(np.uint8)
-
- audio_np = None
- if generated_audio is not None:
- audio_np = np.squeeze(np.asarray(generated_audio)).astype(np.float32)
-
+ output = outputs[0].request_output
+ generated_video = output[0].images[0][0]
+ generated_audio = output[0].images[0][1]
+ try:
+ from dreamid_omni.utils.io_utils import save_video
+ except Exception as e:
+ raise RuntimeError(f"Failed to extract video and audio from DreamID-Omni output. Error: {e}")
output_path = args.output
- video_bytes = mux_video_audio_bytes(
- frames,
- audio_np,
- fps=float(fps),
- audio_sample_rate=sample_rate,
- )
- with open(output_path, "wb") as f:
- f.write(video_bytes)
+ save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000)
print(f"Saved generated video to {output_path}")
print(f"Total time: {elapsed:.2f}s")
diff --git a/examples/online_serving/bagel/README.md b/examples/online_serving/bagel/README.md
index 763927222cf..9b74acae10e 100644
--- a/examples/online_serving/bagel/README.md
+++ b/examples/online_serving/bagel/README.md
@@ -1,111 +1,145 @@
# BAGEL-7B-MoT
-## Installation
+## 🛠️ Installation
Please refer to [README.md](../../../README.md)
-## Architecture
+## Run examples (BAGEL-7B-MoT)
-BAGEL-7B-MoT is a Mixture-of-Transformers (MoT) model supporting both image generation and understanding. It offers two deployment topologies:
+**Note**: These examples work with the default configuration on an **NVIDIA A100 (80GB)**. We also tested on dual **NVIDIA RTX 5000 Ada (32GB each)**. For dual-GPU setups, please modify the stage configuration to distribute the model across devices.
-| Topology | Stages | Description |
-| :------- | :----- | :---------- |
-| **Two-stage** (default) | Stage 0 (Thinker, AR) + Stage 1 (DiT, Diffusion) | Thinker handles text/understanding via vLLM AR engine; DiT handles image generation. KV cache is transferred between stages. |
-| **Single-stage** | Stage 0 (DiT, Diffusion) only | The DiT stage contains a full LLM, ViT, VAE, and tokenizer internally. All modalities are handled within a single diffusion process. |
-
-Both topologies support all four modalities: `text2img`, `img2img`, `img2text`, `text2text`.
-
-> **Note**: These examples work with the default configuration on an **NVIDIA A100 (80GB)**. We also tested on dual **NVIDIA RTX 5000 Ada (32GB each)**. For dual-GPU setups, modify the deploy YAML to distribute stages across devices.
-
-## Launch the Server
-
-### Two-Stage (Default)
-
-The default pipeline is auto-detected from the model. No extra flags needed:
+### Launch the Server
```bash
+# Use default configuration
vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091
```
Or use the convenience script:
```bash
-cd examples/online_serving/bagel
+cd /workspace/vllm-omni/examples/online_serving/bagel
bash run_server.sh
+```
-# Initialize each stage in a discrete isolated process terminal
-bash run_server_stage_cli.sh --stage 0
-bash run_server_stage_cli.sh --stage 1
+```bash
+vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
```
-To use a custom deploy YAML, pass it via `--deploy-config`:
+#### 🚀 Tensor Parallelism (TP)
+
+For larger models or multi-GPU environments, you can enable Tensor Parallelism (TP) for the server.
+1. **Modify Stage Config**: Create or modify a stage configuration yaml (e.g., [`bagel.yaml`](../../../vllm_omni/model_executor/stage_configs/bagel.yaml)). Set `tensor_parallel_size` to `2` (or more) and update `devices` to include multiple GPU IDs (e.g., `"0,1"`).
+
+```yaml
+ engine_args:
+ tensor_parallel_size: 2
+ ...
+ runtime:
+ devices: "0,1"
+```
+
+2. **Launch Server**:
```bash
-vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 \
- --deploy-config /path/to/deploy_config.yaml
+vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 --stage-configs-path /path/to/your/custom_bagel.yaml
```
-See [`bagel.yaml`](../../../vllm_omni/deploy/bagel.yaml) for the default two-stage deploy configuration.
+#### Using Mooncake Connector
-### Single-Stage
+By default, BAGEL uses `SharedMemoryConnector` for inter-stage communication. You can use the [Mooncake](https://github.com/kvcache-ai/Mooncake) connector to transfer KV cache between stages, which also enables multi-node deployment.
-The DiT stage contains a full LLM, ViT, VAE, and tokenizer, so it can handle all modalities (text2img, img2img, img2text, text2text, think) without a separate Thinker stage:
+**1. Install Mooncake**
```bash
-vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 \
- --deploy-config vllm_omni/deploy/bagel_single_stage.yaml
+# For CUDA-enabled systems (recommended)
+pip install mooncake-transfer-engine
+
+# For non-CUDA systems
+pip install mooncake-transfer-engine-non-cuda
```
-See [`bagel_single_stage.yaml`](../../../vllm_omni/deploy/bagel_single_stage.yaml) for configuration. The `pipeline: bagel_single_stage` field selects the single-stage topology from the pipeline registry.
+**2. Start Mooncake Master** on the primary node:
-### Tensor Parallelism (TP)
+```bash
+# Optional: enable disk-backed storage by creating a directory and passing --root_fs_dir.
+# Without it, Mooncake runs in memory-only mode, which is sufficient for KV cache transfer.
+mkdir -p ./mc_storage
+
+mooncake_master \
+ --rpc_port=50051 \
+ --enable_http_metadata_server=true \
+ --http_metadata_server_host=0.0.0.0 \
+ --http_metadata_server_port=8080 \
+ --metrics_port=9003 \
+ --root_fs_dir=./mc_storage/ \
+ --cluster_id=mc-local-1 &
+```
-For larger models or multi-GPU environments, enable TP via CLI:
+**3. Launch the server** with the Mooncake stage config:
```bash
-vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 --tensor-parallel-size 2
+vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni --port 8091 \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml
```
-Or set `tensor_parallel_size` per stage in a custom deploy YAML.
+> **Note**: Before launching, edit [`bagel_multiconnector.yaml`](../../../vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml) and replace the `metadata_server` and `master` addresses with your Mooncake master node's actual IP. For single-node testing, `127.0.0.1` works.
+
+The client-side usage is identical to the default setup -- the Mooncake connector is transparent to the API. See the requests section below.
+
+For more details on the Mooncake connector configuration, see the [Mooncake Store Connector documentation](../../../docs/design/feature/omni_connectors/mooncake_store_connector.md).
+
+#### Multi-Node Deployment
-### Multi-Node Deployment
+You can deploy each stage on a **separate node** for better resource utilization. In this example, the orchestrator (Stage 0 / Thinker) and Stage 1 (DiT) run on different machines, connected via Mooncake.
-Deploy each stage on a **separate node** for better resource utilization. Replace `` with the actual IP address of your orchestrator node.
+Replace `` below with the actual IP address of your orchestrator node (e.g., `10.244.227.244`).
-**1. Launch Stage 0 (Thinker / Orchestrator)** on the orchestrator node:
+> [!WARNING]
+> **Before launching**, edit [`bagel_multiconnector.yaml`](../../../vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml) and replace the `metadata_server` and `master` addresses with your Mooncake master node's actual IP. Mismatched addresses will cause silent connection failures.
+
+**1. Start Mooncake Master** (on the orchestrator node):
+
+```bash
+mooncake_master \
+ --rpc_port=50051 \
+ --enable_http_metadata_server=true \
+ --http_metadata_server_host= \
+ --http_metadata_server_port=8080 \
+ --metrics_port=9003
+```
+
+**2. Launch Stage 0 (Thinker / Orchestrator)** on the orchestrator node:
```bash
-# API server port for client requests: 8000
vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni \
- --port 8000 \
+ --port 8000 \ # API server port for client requests
+ --stage-configs-path vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml \
--stage-id 0 \
- --omni-master-address \
- --omni-master-port 8091
+ -oma \
+ -omp 8091
```
-**2. Launch Stage 1 (DiT)** on the remote node in headless mode:
+**3. Launch Stage 1 (DiT)** on the remote node in headless mode:
```bash
vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml \
--stage-id 1 \
--headless \
- --omni-master-address \
- --omni-master-port 8091
+ -oma \
+ -omp 8091
```
-Or use the convenience script:
-
-```bash
-# Terminal 1: Stage 0
-bash run_server_stage_cli.sh --stage 0
+**Mooncake Master arguments:**
-# Terminal 2: Stage 1
-bash run_server_stage_cli.sh --stage 1
-
-# With extra args
-bash run_server_stage_cli.sh --stage 0 -- --tensor-parallel-size 2
-bash run_server_stage_cli.sh --stage 1 -- --gpu-memory-utilization 0.9
-```
+| Argument | Description |
+| :------- | :---------- |
+| `--rpc_port` | Mooncake RPC port for control-plane coordination between stages |
+| `--enable_http_metadata_server` | Enable the HTTP metadata server for service discovery |
+| `--http_metadata_server_host` | IP address to bind the metadata server (use the orchestrator node's IP) |
+| `--http_metadata_server_port` | Port for the HTTP metadata server |
+| `--metrics_port` | Port for Prometheus-compatible metrics endpoint |
**vllm serve arguments:**
@@ -113,31 +147,85 @@ bash run_server_stage_cli.sh --stage 1 -- --gpu-memory-utilization 0.9
| :------- | :---------- |
| `--stage-id` | Which stage this process runs (0 = Thinker, 1 = DiT) |
| `--headless` | Run without the API server (worker-only mode) |
-| `-oma` / `--omni-master-address` | Orchestrator master address |
-| `-omp` / `--omni-master-port` | Orchestrator master port |
+| `-oma` | Orchestrator master address |
+| `-omp` | Orchestrator master port for Stage 1 to connect to Stage 0 for task coordination |
> [!IMPORTANT]
> **Startup Order**: Stage 0 (orchestrator) must be launched **before** Stage 1 (headless).
> Stage 0 will appear to hang on startup until Stage 1 (worker) connects — this is expected behavior.
-### Inter-Stage Connectors
+**Network Requirements**
+
+All nodes must have network connectivity to each other. Ensure the following ports are open **between all participating nodes**:
-When deploying stages across nodes, configure the connector type in the deploy YAML:
+| Port | Protocol | Service | Direction |
+| :--- | :------- | :------ | :-------- |
+| 50051 | TCP | Mooncake Master RPC | Worker → Orchestrator |
+| 8080 | TCP | Mooncake HTTP Metadata Server | Worker → Orchestrator |
+| 8091 | TCP | Orchestrator Master (`-omp`) | Worker → Orchestrator |
+| 8000 | TCP | API Server (`--port`) | Client → Orchestrator |
+| 9003 | TCP | Metrics (optional) | Monitoring → Orchestrator |
-- **SharedMemoryConnector** (default): Used for single-node deployments. No explicit configuration needed.
-- **MooncakeTransferEngineConnector**: For multi-node setups with RDMA hardware. Defined in [`bagel.yaml`](../../../vllm_omni/deploy/bagel.yaml) under `connectors.rdma_connector`.
+> **Tip**: If nodes are behind a firewall or in different VPCs/security groups, make sure the above ports are allowed in ingress/egress rules. All nodes should be reachable via their IP addresses (no NAT). Using nodes on the same subnet or VPC is recommended to minimize latency for Mooncake KV cache transfers.
-To use Mooncake, create a custom deploy YAML that binds `output_connectors` / `input_connectors` on each stage to the `rdma_connector` defined in the `connectors` section.
+### Send Multi-modal Request
-## Send Requests
+Get into the bagel folder:
```bash
cd examples/online_serving/bagel
```
+Send request via Python
+
+```bash
+python openai_chat_client.py --prompt "A cute cat" --modality text2img
+```
+
+The Python client supports the following command-line arguments:
+
+- `--prompt` (or `-p`): Text prompt for generation (default: `A cute cat`)
+- `--output` (or `-o`): Output file path for image results (default: `bagel_output.png`)
+- `--server` (or `-s`): Server URL (default: `http://localhost:8091`)
+- `--image-url` (or `-i`): Input image URL or local file path (for img2img/img2text modes)
+- `--modality` (or `-m`): Task modality (default: `text2img`). Options: `text2img`, `img2img`, `img2text`, `text2text`
+- `--height`: Image height in pixels (default: 512)
+- `--width`: Image width in pixels (default: 512)
+- `--steps`: Number of inference steps (default: 25)
+- `--seed`: Random seed (default: 42)
+- `--negative`: Negative prompt for image generation
+
+Example with custom parameters:
+
+```bash
+python openai_chat_client.py \
+ --prompt "A futuristic city" \
+ --modality text2img \
+ --height 768 \
+ --width 768 \
+ --steps 50 \
+ --seed 42 \
+ --negative "blurry, low quality"
+```
+
+## Modality Control
+
+BAGEL-7B-MoT supports **multiple modality modes** for different use cases.
+
+The default yaml configuration deploys Thinker and DiT on the same GPU. You can use the default configuration file: [`bagel.yaml`](../../../vllm_omni/model_executor/stage_configs/bagel.yaml)
+
+| Modality | Input | Output | Description |
+| ----------- | ------------ | ------ | -------------------------------------- |
+| `text2img` | Text | Image | Generate images from text prompts |
+| `img2img` | Image + Text | Image | Transform images using text guidance |
+| `img2text` | Image + Text | Text | Generate text descriptions from images |
+| `text2text` | Text | Text | Pure text generation |
+
### Text to Image (text2img)
-**Python client:**
+Generate images from text prompts:
+
+**Using Python client**
```bash
python openai_chat_client.py \
@@ -147,7 +235,7 @@ python openai_chat_client.py \
--steps 50
```
-**curl:**
+**Using curl**
```bash
curl http://localhost:8091/v1/chat/completions \
@@ -162,9 +250,12 @@ curl http://localhost:8091/v1/chat/completions \
}'
```
+
### Image to Image (img2img)
-**Python client:**
+Transform images based on text prompts:
+
+**Using Python client**
```bash
python openai_chat_client.py \
@@ -174,7 +265,7 @@ python openai_chat_client.py \
--output transformed.png
```
-**curl:**
+**Using curl**
```bash
IMAGE_BASE64=$(base64 -w 0 cat.jpg)
@@ -199,11 +290,14 @@ EOF
curl http://localhost:8091/v1/chat/completions \
-H "Content-Type: application/json" \
-d @payload.json
+
```
### Image to Text (img2text)
-**Python client:**
+Generate text descriptions from images:
+
+**Using Python client**
```bash
python openai_chat_client.py \
@@ -212,7 +306,7 @@ python openai_chat_client.py \
--image-url /path/to/image.jpg
```
-**curl:**
+**Using curl**
```bash
IMAGE_BASE64=$(base64 -w 0 cat.jpg)
@@ -237,7 +331,9 @@ curl http://localhost:8091/v1/chat/completions \
### Text to Text (text2text)
-**Python client:**
+Pure text generation:
+
+**Using Python client**
```bash
python openai_chat_client.py \
@@ -245,78 +341,30 @@ python openai_chat_client.py \
--modality text2text
```
-**curl:**
+**Using curl**
```bash
curl http://localhost:8091/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
- "messages": [{"role": "user", "content": [{"type": "text", "text": "<|im_start|>user\nWhat is the capital of France?<|im_end|>\n<|im_start|>assistant\n"}]}],
+ "messages": [{"role": "user", "content": [{"type": "text", "text": "<|im_start|>user\nWhat is the capital of France?<|im_end|>\n<|im_start|>assistant\n"}]}]
"modalities": ["text"]
}'
```
-### Python Client Arguments
-
-| Argument | Default | Description |
-| :------- | :------ | :---------- |
-| `--prompt` / `-p` | `A cute cat` | Text prompt |
-| `--output` / `-o` | `bagel_output.png` | Output file path |
-| `--server` / `-s` | `http://localhost:8091` | Server URL |
-| `--image-url` / `-i` | `None` | Input image URL or local path (img2img/img2text) |
-| `--modality` / `-m` | `text2img` | `text2img`, `img2img`, `img2text`, `text2text` |
-| `--height` | `512` | Image height in pixels |
-| `--width` | `512` | Image width in pixels |
-| `--steps` | `25` | Number of inference steps |
-| `--seed` | `42` | Random seed |
-| `--negative` | `None` | Negative prompt for CFG |
+## FAQ
-Example with custom parameters:
+- If you encounter an error about the backend of librosa, try to install ffmpeg with the command below.
```bash
-python openai_chat_client.py \
- --prompt "A futuristic city" \
- --modality text2img \
- --height 768 \
- --width 768 \
- --steps 50 \
- --seed 42 \
- --negative "blurry, low quality"
+sudo apt update
+sudo apt install ffmpeg
```
-## Configuration Reference
-
-### Deploy YAML Files
-
-| File | Description |
-| :--- | :---------- |
-| [`bagel.yaml`](../../../vllm_omni/deploy/bagel.yaml) | Two-stage default (Thinker + DiT on GPU 0) |
-| [`bagel_single_stage.yaml`](../../../vllm_omni/deploy/bagel_single_stage.yaml) | Single-stage (DiT only) |
-
-### Key Deploy YAML Fields
-
-| Field | Scope | Description |
-| :---- | :---- | :---------- |
-| `pipeline` | top-level | Override auto-detected pipeline (e.g. `bagel_single_stage`) |
-| `stages[].stage_id` | per-stage | Stage identifier (0, 1, ...) |
-| `stages[].devices` | per-stage | GPU device IDs (e.g. `"0"`, `"0,1"`) |
-| `stages[].max_num_seqs` | per-stage | Maximum concurrent sequences |
-| `stages[].gpu_memory_utilization` | per-stage | Fraction of GPU memory to use |
-| `stages[].enforce_eager` | per-stage | Disable CUDA graphs |
-| `stages[].tensor_parallel_size` | per-stage | TP degree for this stage |
-| `connectors` | top-level | Define available connector instances (SHM, Mooncake) |
-| `platforms` | top-level | Platform-specific overrides (e.g. `xpu`) |
-
-## FAQ
-
-- If you encounter OOM errors, try decreasing `max_model_len` or `gpu_memory_utilization` in the deploy YAML.
-
-**Two-stage VRAM usage:**
-
-| Stage | VRAM |
-| :---- | :--- |
-| Stage 0 (Thinker) | **15.04 GiB + KV Cache** |
-| Stage 1 (DiT) | **26.50 GiB** |
-| Total | **~42 GiB + KV Cache** |
+- If you don’t know how much VRAM is needed for the model or encounter the OOM error, you can try to decrease the max_model_len.
-**Single-stage VRAM usage:** The DiT loads the full model (~42 GiB) in one process.
+| Stage | VRAM |
+| :------------------ | :--------------------------- |
+| Stage-0 (Thinker) | **15.04 GiB** **+ KV Cache** |
+| Stage-1 (DiT) | **26.50 GiB** |
+| Total | **~42 GiB + KV Cache** |
diff --git a/examples/online_serving/bagel/run_server_stage_cli.sh b/examples/online_serving/bagel/run_server_stage_cli.sh
index 912e212f97e..51639153f73 100644
--- a/examples/online_serving/bagel/run_server_stage_cli.sh
+++ b/examples/online_serving/bagel/run_server_stage_cli.sh
@@ -1,164 +1,34 @@
#!/bin/bash
-# Bagel multi-stage online serving startup script.
-#
-# Usage:
-# ./run_server_stage_cli.sh --stage 0
-# ./run_server_stage_cli.sh --stage 1
-# ./run_server_stage_cli.sh --stage 0 -- --tensor-parallel-size 2
-# ./run_server_stage_cli.sh --stage 1 -- --gpu-memory-utilization 0.9
-#
-# By default, `--stage all` keeps the old behavior and launches both stages in
-# one session. Use `--stage 0` / `--stage 1` to launch each stage separately in
-# different terminal sessions, with stage-specific extra CLI arguments passed
-# after `--`.
-
-set -euo pipefail
+# Bagel multi-stage online serving startup script
+# Starts stage 0 as master with API server, and stage 1 in headless mode
MODEL="${MODEL:-ByteDance-Seed/BAGEL-7B-MoT}"
PORT="${PORT:-8091}"
MASTER_ADDRESS="${MASTER_ADDRESS:-127.0.0.1}"
MASTER_PORT="${MASTER_PORT:-8092}"
-STAGE="all"
-SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
-DEPLOY_CONFIG="${DEPLOY_CONFIG:-$SCRIPT_DIR/../../../vllm_omni/deploy/bagel.yaml}"
-EXTRA_ARGS=()
-
-usage() {
- cat <&2
- usage
- exit 1
- ;;
- esac
-done
-
-if [[ "$STAGE" != "0" && "$STAGE" != "1" && "$STAGE" != "all" ]]; then
- echo "Invalid --stage value: $STAGE" >&2
- usage
- exit 1
-fi
-
-print_config() {
- echo "Model: $MODEL"
- echo "API Port: $PORT"
- echo "Master Address: $MASTER_ADDRESS"
- echo "Master Port: $MASTER_PORT"
- echo "Deploy Config: $DEPLOY_CONFIG"
- echo "Selected Stage: $STAGE"
- if [[ ${#EXTRA_ARGS[@]} -gt 0 ]]; then
- echo "Extra Args: ${EXTRA_ARGS[*]}"
- fi
-}
-
-run_stage_0() {
- echo "Starting Stage 0 (Thinker) as master..."
- vllm serve "$MODEL" --omni \
- --port "$PORT" \
- --deploy-config "$DEPLOY_CONFIG" \
- --stage-id 0 \
- --omni-master-address "$MASTER_ADDRESS" \
- --omni-master-port "$MASTER_PORT" \
- "${EXTRA_ARGS[@]}"
-}
-
-run_stage_1() {
- echo "Starting Stage 1 (DiT) in headless mode..."
- vllm serve "$MODEL" --omni \
- --deploy-config "$DEPLOY_CONFIG" \
- --stage-id 1 \
- --headless \
- --omni-master-address "$MASTER_ADDRESS" \
- --omni-master-port "$MASTER_PORT" \
- "${EXTRA_ARGS[@]}"
-}
+STAGE_CONFIGS_PATH="$(dirname "$0")/../../../vllm_omni/model_executor/stage_configs/bagel.yaml"
echo "Starting Bagel multi-stage server..."
-print_config
-
-case "$STAGE" in
- 0)
- run_stage_0
- ;;
- 1)
- run_stage_1
- ;;
- all)
- echo "Launching both stages in one session (legacy mode)..."
- echo "Starting Stage 0 (Thinker) in background first..."
- run_stage_0 &
- STAGE_0_PID=$!
-
- cleanup() {
- if [[ -n "${STAGE_0_PID:-}" ]]; then
- kill "$STAGE_0_PID" 2>/dev/null || true
- wait "$STAGE_0_PID" 2>/dev/null || true
- fi
- }
-
- trap cleanup EXIT INT TERM
-
- echo "Waiting briefly for Stage 0 to initialize..."
- sleep 2
- run_stage_1
- ;;
-esac
+echo "Model: $MODEL"
+echo "API Port: $PORT"
+echo "Master Address: $MASTER_ADDRESS"
+echo "Master Port: $MASTER_PORT"
+echo "Stage Configs: $STAGE_CONFIGS_PATH"
+
+# Start stage 1 (DiT) in headless mode first
+echo "Starting Stage 1 (DiT) in headless mode..."
+vllm serve "$MODEL" --omni \
+ --stage-configs-path "$STAGE_CONFIGS_PATH" \
+ --stage-id 1 \
+ --headless \
+ -oma "$MASTER_ADDRESS" \
+ -omp "$MASTER_PORT" &
+
+# Start stage 0 (Thinker) as master with API server
+echo "Starting Stage 0 (Thinker) as master..."
+vllm serve "$MODEL" --omni \
+ --port "$PORT" \
+ --stage-configs-path "$STAGE_CONFIGS_PATH" \
+ --stage-id 0 \
+ -oma "$MASTER_ADDRESS" \
+ -omp "$MASTER_PORT"
diff --git a/examples/online_serving/diffusers_pipeline_adapter/README.md b/examples/online_serving/diffusers_pipeline_adapter/README.md
deleted file mode 100644
index 8dbf9369ae8..00000000000
--- a/examples/online_serving/diffusers_pipeline_adapter/README.md
+++ /dev/null
@@ -1,83 +0,0 @@
-# Diffusers Backend Adapter Example
-
-This example demonstrates how to serve any 🤗 Diffusers pipeline through vLLM-Omni
-using the `diffusers` load format.
-
-## Supported Models
-
-Any model loadable via `DiffusionPipeline.from_pretrained()` should be supported, including text-to-image, image-to-image, text-to-video, image-to-video, and text-to-audio.
-
-## Limitations
-
-The diffusers backend is a black-box adapter. The following features are NOT yet supported.
-It is not guaranteed whether they will be supported in the future.
-
-- CFG parallel execution
-- Sequence parallel execution
-- TeaCache / Cache-DiT acceleration
-- Step-wise execution (continuous batching)
-
-For these features, it is recommended to use natively supported pipelines instead.
-
-## Usage
-
-### Option 1: CLI arguments
-
-```bash
-vllm serve "stable-diffusion-v1-5/stable-diffusion-v1-5" \
- --omni \
- --diffusion-load-format diffusers \
- --diffusers-load-kwargs '{"use_safetensors": true}' \
- --diffusers-call-kwargs '{"num_inference_steps": 30, "guidance_scale": 7.5}'
-```
-
-`--diffusers-load-kwargs` and `--diffusers-call-kwargs` are only valid together with `--diffusion-load-format diffusers`.
-
-### Option 2: Stage config YAML
-
-```bash
-vllm serve stable-diffusion-v1-5/stable-diffusion-v1-5 --stage-configs-path examples/online_serving/diffusers_pipeline_adapter/stage_config.yaml --omni
-```
-
-The particular fields of interest are `model`, `diffusion_load_format`, `diffusers_load_kwargs`, and `diffusers_call_kwargs` under `engine_args`. They are the same as the CLI arguments.
-
-## Send a Request
-
-```bash
-curl http://localhost:8000/v1/images/generations \
- -H "Content-Type: application/json" \
- -d '{
- "model": "stable-diffusion-v1-5/stable-diffusion-v1-5",
- "prompt": "a photo of an astronaut riding a horse on mars",
- "n": 1,
- "size": "512x512"
- }'
-```
-
-Or refer to other documentation pages on how to request a particular input/output modality, such as `examples/online_serving/text_to_image/openai_chat_client.py`.
-
-## Configuration Reference
-
-For the diffusers adapter, set options under **`engine_args`**:
-
-### `diffusion_load_format: "diffusers"`
-
-This field selects the Hugging Face diffusers adapter path (see `DiffusersPipelineLoader`).
-
-### `diffusers_load_kwargs`
-
-Passed to `DiffusionPipeline.from_pretrained()`.
-
-This is suitable for model-specific configurations not available through the vLLM-Omni interface (such as `Omni.__init__()`, `vllm serve` CLI arguments, and stage config YAML fields outside `diffusers_load_kwargs`).
-
-When a parameter is available in the vLLM-Omni interface, it will be adapted here.
-But if that parameter is simultaneously set in both the vLLM-Omni interface and `diffusers_load_kwargs`, the **latter** will take precedence.
-
-### `diffusers_call_kwargs`
-
-Passed to `pipeline.__call__()`.
-
-This is suitable for sampling parameters not available through the vLLM-Omni interface (such as `Omni.generate()` and online serving payloads).
-
-When a parameter is available in the vLLM-Omni interface, it will be adapted here.
-But if that parameter is simultaneously set in both the vLLM-Omni interface and `diffusers_call_kwargs`, the **former** will take precedence (because it is set at request time).
diff --git a/examples/online_serving/diffusers_pipeline_adapter/stage_config.yaml b/examples/online_serving/diffusers_pipeline_adapter/stage_config.yaml
deleted file mode 100644
index 7c96eb6c167..00000000000
--- a/examples/online_serving/diffusers_pipeline_adapter/stage_config.yaml
+++ /dev/null
@@ -1,31 +0,0 @@
-# Example stage config for diffusers backend
-# This config demonstrates serving Stable Diffusion 1.5 via the diffusers adapter.
-# Users should copy and modify this for their own models.
-
-model_type: diffusion
-
-stage_args:
- - stage_id: 0
- stage_type: diffusion
- engine_args:
- model_stage: diffusion
- model: "stable-diffusion-v1-5/stable-diffusion-v1-5"
- distributed_executor_backend: "mp"
- # gpu_memory_utilization: 0.9
- engine_output_type: image
- # Select the HF diffusers adapter
- diffusion_load_format: "diffusers"
- # model_class_name: "DiffusersAdapterPipeline" # default when diffusion_load_format is diffusers
- diffusers_load_kwargs:
- # Passed to DiffusionPipeline.from_pretrained().
- # Good for model-specific loading parameters not covered by OmniDiffusionConfig.
- # During model load time, parameters here override their counterparts in the vLLM-Omni interface.
- use_safetensors: true
- diffusers_call_kwargs:
- # Passed to pipeline.__call__().
- # Good for model-specific sampling parameters not covered by OmniDiffusionSamplingParams.
- # During request time, parameters here are overridden by the counterparts in OmniDiffusionSamplingParams.
- num_inference_steps: 30
- guidance_scale: 7.5
- final_output: true
- final_output_type: image
diff --git a/examples/online_serving/dynin_omni/README.md b/examples/online_serving/dynin_omni/README.md
deleted file mode 100644
index d8526d42373..00000000000
--- a/examples/online_serving/dynin_omni/README.md
+++ /dev/null
@@ -1,97 +0,0 @@
-# Dynin-Omni Online Serving Example
-
-## Installation
-
-Please refer to [README.md](../../../README.md).
-
-## Launch the Server
-
-First, find the `transformers_modules` path:
-
-```bash
-python - <<'PY'
-from transformers.utils.hub import HF_MODULES_CACHE
-print(HF_MODULES_CACHE)
-PY
-```
-
-Then export it for both `PYTHONPATH` and `HF_MODULES_CACHE`:
-
-```bash
-export PYTHONPATH=:$PYTHONPATH
-export HF_MODULES_CACHE=
-```
-
-Run from repository root:
-
-```bash
-vllm-omni serve snu-aidas/Dynin-Omni \
- --omni \
- --port 8091 \
- --stage-configs-path "$(pwd)/vllm_omni/model_executor/stage_configs/dynin_omni.yaml"
-```
-
-If `vllm-omni` is not in PATH, run:
-
-```bash
-PYTHONPATH="$(pwd)" python -m vllm_omni.entrypoints.cli.main serve snu-aidas/Dynin-Omni \
- --omni \
- --port 8091 \
- --stage-configs-path "$(pwd)/vllm_omni/model_executor/stage_configs/dynin_omni.yaml"
-```
-
-Wait until the server logs show both `All stages initialized successfully` and
-`Application startup complete.` before sending requests.
-
-## Send Requests via Python Client
-
-Move to the example directory:
-
-```bash
-cd examples/online_serving/dynin_omni
-```
-
-### Text -> Image
-
-```bash
-python openai_chat_completion_client_for_multimodal_generation.py \
- --query-type t2i \
- --prompt "A realistic indoor living room with natural daylight."
-```
-
-### Image -> Image
-
-```bash
-python openai_chat_completion_client_for_multimodal_generation.py \
- --query-type i2i \
- --image-path ../../offline_inference/dynin_omni/data/image/sofa_under_water.jpg \
- --prompt "Transform this surreal underwater setting into a realistic indoor living room while preserving the sofa layout."
-```
-
-### Text -> Speech
-
-```bash
-python openai_chat_completion_client_for_multimodal_generation.py \
- --query-type t2s \
- --prompt "Hello. This is Dynin-omni."
-```
-
-## CLI Arguments
-
-- `--query-type` (`t2i|t2s|i2i`)
-- `--model` (default: `snu-aidas/Dynin-Omni`)
-- `--host` / `--port` (OpenAI-compatible vLLM endpoint)
-- `--prompt` (custom text)
-- `--image-path` (required for `i2i`)
-- `--modalities` (optional output modalities override)
-- `--output-dir` (default: `/tmp/dynin_online_outputs`)
-
-## Notes
-
-- This client currently supports only `t2i`, `t2s`, and `i2i`.
-- `t2t` is intentionally not exposed in this online example.
-- This example intentionally uses the OpenAI-compatible chat completion endpoint.
-- Task routing for non-text outputs relies on Dynin task trigger tokens (`<|t2i|>`, `<|i2i|>`, `<|t2s|>`) injected by the client.
-- Outputs are saved under `/tmp/dynin_online_outputs` by default.
-- Dynin stage-0 warmup can take a while on first startup; do not send requests before startup completes.
-- Dynin itself can execute text-returning tasks such as `t2t`, `s2t`, `i2t`, and `v2t`, but this online serving example currently runs stage-0 in `generation` mode. In that path, the generation worker does not surface the final text as `output.text`, so OpenAI chat responses for those text-output tasks may complete internally but still return empty text.
diff --git a/examples/online_serving/dynin_omni/openai_chat_completion_client_for_multimodal_generation.py b/examples/online_serving/dynin_omni/openai_chat_completion_client_for_multimodal_generation.py
deleted file mode 100644
index 9728555431f..00000000000
--- a/examples/online_serving/dynin_omni/openai_chat_completion_client_for_multimodal_generation.py
+++ /dev/null
@@ -1,342 +0,0 @@
-#!/usr/bin/env python3
-# SPDX-License-Identifier: Apache-2.0
-
-from __future__ import annotations
-
-import argparse
-import base64
-import json
-import mimetypes
-import os
-import time
-from pathlib import Path
-from typing import Any
-
-DEFAULT_MODEL = "snu-aidas/Dynin-Omni"
-DEFAULT_OUTPUT_DIR = "/tmp/dynin_online_outputs"
-
-QUERY_CHOICES = ("t2i", "t2s", "i2i")
-DEFAULT_PROMPT_BY_QUERY = {
- "t2i": "A high quality detailed living room interior photo.",
- "t2s": "Please read this sentence naturally: Hello from Dynin-Omni online serving.",
- "i2i": "Transform this image into a realistic indoor living room while preserving layout.",
-}
-DEFAULT_MODALITIES_BY_QUERY = {
- "t2i": ["image"],
- "t2s": ["audio"],
- "i2i": ["image"],
-}
-OFFLINE_PARITY_STAGE_COUNT = 3
-OFFLINE_PARITY_STAGE_SAMPLING = {
- "max_tokens": 1,
- "temperature": 0.0,
- "top_p": 1.0,
- "detokenize": False,
-}
-
-
-def _infer_mime_type(path: Path) -> str:
- mime_type, _ = mimetypes.guess_type(str(path))
- return mime_type or "application/octet-stream"
-
-
-def _encode_file_as_data_url(path: Path) -> str:
- mime_type = _infer_mime_type(path)
- raw = path.read_bytes()
- encoded = base64.b64encode(raw).decode("utf-8")
- return f"data:{mime_type};base64,{encoded}"
-
-
-def _to_image_url(path_or_url: str) -> str:
- value = str(path_or_url)
- if value.startswith(("http://", "https://", "data:image/")):
- return value
- path = Path(value).expanduser().resolve()
- if not path.exists():
- raise FileNotFoundError(f"Image file not found: {path}")
- return _encode_file_as_data_url(path)
-
-
-def _build_user_content(query_type: str, prompt: str, image_path: str | None) -> list[dict[str, Any]]:
- if query_type == "t2i":
- return [{"type": "text", "text": f"<|t2i|> {prompt}"}]
-
- if query_type == "t2s":
- return [{"type": "text", "text": f"<|t2s|> {prompt}"}]
-
- if query_type == "i2i":
- if not image_path:
- raise ValueError("--image-path is required for query type i2i")
- return [
- {"type": "text", "text": f"<|i2i|> {prompt}"},
- {"type": "image_url", "image_url": {"url": _to_image_url(image_path)}},
- ]
-
- raise ValueError(f"Unsupported query_type: {query_type}")
-
-
-def _collect_text_from_content(content: Any) -> list[str]:
- texts: list[str] = []
- if isinstance(content, str):
- stripped = content.strip()
- if stripped:
- texts.append(stripped)
- return texts
-
- if isinstance(content, dict):
- for key in ("text", "content", "value", "output_text"):
- text_value = content.get(key)
- if isinstance(text_value, str) and text_value.strip():
- texts.append(text_value.strip())
- return texts
-
- if isinstance(content, list):
- for item in content:
- texts.extend(_collect_text_from_content(item))
- return texts
-
- content_text = getattr(content, "text", None)
- if isinstance(content_text, str) and content_text.strip():
- texts.append(content_text.strip())
- content_value = getattr(content, "content", None)
- if isinstance(content_value, str) and content_value.strip():
- texts.append(content_value.strip())
- output_text = getattr(content, "output_text", None)
- if isinstance(output_text, str) and output_text.strip():
- texts.append(output_text.strip())
- return texts
-
-
-def _extract_text_outputs(chat_completion: Any) -> list[str]:
- texts: list[str] = []
- for choice in getattr(chat_completion, "choices", []) or []:
- message = getattr(choice, "message", None)
- if message is None:
- continue
- content = getattr(message, "content", None)
- texts.extend(_collect_text_from_content(content))
- reasoning_content = getattr(message, "reasoning_content", None)
- if isinstance(reasoning_content, str) and reasoning_content.strip():
- texts.append(reasoning_content.strip())
- choice_text = getattr(choice, "text", None)
- if isinstance(choice_text, str) and choice_text.strip():
- texts.append(choice_text.strip())
- top_level_output_text = getattr(chat_completion, "output_text", None)
- if isinstance(top_level_output_text, str) and top_level_output_text.strip():
- texts.append(top_level_output_text.strip())
- return texts
-
-
-def _extract_image_data_urls(chat_completion: Any) -> list[str]:
- urls: list[str] = []
- for choice in getattr(chat_completion, "choices", []) or []:
- message = getattr(choice, "message", None)
- if message is None:
- continue
- content = getattr(message, "content", None)
- if not isinstance(content, list):
- continue
- for item in content:
- if not isinstance(item, dict):
- continue
- if item.get("type") != "image_url":
- continue
- image_url = (item.get("image_url") or {}).get("url")
- if isinstance(image_url, str) and image_url.startswith("data:image"):
- urls.append(image_url)
- return urls
-
-
-def _extract_audio_payloads(chat_completion: Any) -> list[bytes]:
- payloads: list[bytes] = []
- for choice in getattr(chat_completion, "choices", []) or []:
- message = getattr(choice, "message", None)
- if message is None:
- continue
- message_audio = getattr(message, "audio", None)
- if message_audio is None:
- continue
- data_b64 = getattr(message_audio, "data", None)
- if isinstance(data_b64, str) and data_b64:
- try:
- payloads.append(base64.b64decode(data_b64))
- except Exception:
- continue
- return payloads
-
-
-def _decode_data_url(data_url: str) -> tuple[bytes, str]:
- header, data = data_url.split(",", 1)
- mime_type = "image/png"
- if ";" in header and ":" in header:
- mime_type = header.split(":", 1)[1].split(";", 1)[0]
- return base64.b64decode(data), mime_type
-
-
-def _image_extension_from_mime(mime_type: str) -> str:
- if mime_type == "image/jpeg":
- return ".jpg"
- if mime_type == "image/webp":
- return ".webp"
- if mime_type == "image/gif":
- return ".gif"
- return ".png"
-
-
-def _save_outputs(
- *,
- query_type: str,
- chat_completion: Any,
- output_dir: Path,
-) -> None:
- output_dir.mkdir(parents=True, exist_ok=True)
- stamp = time.strftime("%Y%m%d_%H%M%S")
-
- text_outputs = _extract_text_outputs(chat_completion)
- image_data_urls = _extract_image_data_urls(chat_completion)
- audio_payloads = _extract_audio_payloads(chat_completion)
-
- if text_outputs:
- text_path = output_dir / f"{query_type}_{stamp}.txt"
- text_path.write_text("\n\n".join(text_outputs) + "\n", encoding="utf-8")
- print(f"[dynin-online] text saved: {text_path}")
- print(text_outputs[0])
-
- for idx, image_url in enumerate(image_data_urls):
- image_bytes, mime_type = _decode_data_url(image_url)
- ext = _image_extension_from_mime(mime_type)
- image_path = output_dir / f"{query_type}_{stamp}_{idx}{ext}"
- image_path.write_bytes(image_bytes)
- print(f"[dynin-online] image saved: {image_path}")
-
- for idx, audio_bytes in enumerate(audio_payloads):
- audio_path = output_dir / f"{query_type}_{stamp}_{idx}.wav"
- audio_path.write_bytes(audio_bytes)
- print(f"[dynin-online] audio saved: {audio_path}")
-
- if not text_outputs and not image_data_urls and not audio_payloads:
- print("[dynin-online] no output extracted from response")
- raw_path = output_dir / f"{query_type}_{stamp}_raw_response.json"
- try:
- if hasattr(chat_completion, "model_dump_json"):
- serialized = chat_completion.model_dump_json(indent=2)
- else:
- if hasattr(chat_completion, "model_dump"):
- raw_payload: Any = chat_completion.model_dump(mode="json")
- else:
- raw_payload = chat_completion
- try:
- serialized = json.dumps(raw_payload, ensure_ascii=False, indent=2)
- except Exception:
- serialized = json.dumps({"repr": repr(raw_payload)}, ensure_ascii=False, indent=2)
- raw_path.write_text(serialized + "\n", encoding="utf-8")
- print(f"[dynin-online] raw response saved: {raw_path}")
- except Exception:
- pass
-
-
-def _build_offline_parity_sampling_params_list() -> list[dict[str, Any]]:
- return [dict(OFFLINE_PARITY_STAGE_SAMPLING) for _ in range(OFFLINE_PARITY_STAGE_COUNT)]
-
-
-def run_request(args: argparse.Namespace) -> None:
- from openai import OpenAI
-
- client = OpenAI(
- api_key="EMPTY",
- base_url=f"http://{args.host}:{args.port}/v1",
- )
- prompt = args.prompt.strip() if args.prompt else DEFAULT_PROMPT_BY_QUERY[args.query_type]
- user_content = _build_user_content(
- query_type=args.query_type,
- prompt=prompt,
- image_path=args.image_path,
- )
- if args.modalities:
- modalities = [item.strip() for item in args.modalities.split(",") if item.strip()]
- else:
- modalities = DEFAULT_MODALITIES_BY_QUERY[args.query_type]
-
- extra_body = {
- "sampling_params_list": _build_offline_parity_sampling_params_list(),
- }
- chat_completion = client.chat.completions.create(
- model=args.model,
- messages=[{"role": "user", "content": user_content}],
- modalities=modalities,
- extra_body=extra_body,
- )
- _save_outputs(
- query_type=args.query_type,
- chat_completion=chat_completion,
- output_dir=Path(args.output_dir).expanduser(),
- )
-
-
-def parse_args() -> argparse.Namespace:
- parser = argparse.ArgumentParser(description="Dynin-Omni online chat completion client")
- parser.add_argument(
- "--query-type",
- "-q",
- type=str,
- default="t2i",
- choices=QUERY_CHOICES,
- help="Dynin query type",
- )
- parser.add_argument(
- "--model",
- "-m",
- type=str,
- default=DEFAULT_MODEL,
- help="Model name/path",
- )
- parser.add_argument(
- "--host",
- type=str,
- default="localhost",
- help="Host/IP of the vLLM Omni API server",
- )
- parser.add_argument(
- "--port",
- type=int,
- default=8091,
- help="Port of the vLLM Omni API server",
- )
- parser.add_argument(
- "--prompt",
- "-p",
- type=str,
- default="",
- help="Custom prompt text",
- )
- parser.add_argument(
- "--image-path",
- "-i",
- type=str,
- default=None,
- help="Image path/URL for i2i",
- )
- parser.add_argument(
- "--modalities",
- type=str,
- default="",
- help="Comma-separated output modalities override (e.g., text,image,audio)",
- )
- parser.add_argument(
- "--output-dir",
- "-o",
- type=str,
- default=DEFAULT_OUTPUT_DIR,
- help="Directory to save outputs",
- )
- return parser.parse_args()
-
-
-def main() -> None:
- args = parse_args()
- os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
- run_request(args)
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/online_serving/fish_speech/README.md b/examples/online_serving/fish_speech/README.md
index 9b4e3cc403d..ae968d3bada 100644
--- a/examples/online_serving/fish_speech/README.md
+++ b/examples/online_serving/fish_speech/README.md
@@ -29,12 +29,15 @@ Features:
## Launch the Server
```bash
-vllm serve fishaudio/s2-pro --omni --port 8091
+vllm-omni serve fishaudio/s2-pro \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml \
+ --omni \
+ --port 8091 \
+ --trust-remote-code \
+ --enforce-eager \
+ --gpu-memory-utilization 0.9
```
-The deploy config is auto-loaded from `vllm_omni/deploy/fish_qwen3_omni.yaml`
-(the HF `model_type` on the fishaudio checkpoints is `fish_qwen3_omni`).
-
Or use the convenience script:
```bash
diff --git a/examples/online_serving/fish_speech/run_gradio_demo.sh b/examples/online_serving/fish_speech/run_gradio_demo.sh
index a0370b9cc88..98a69664437 100755
--- a/examples/online_serving/fish_speech/run_gradio_demo.sh
+++ b/examples/online_serving/fish_speech/run_gradio_demo.sh
@@ -11,13 +11,18 @@ MODEL="${MODEL:-fishaudio/s2-pro}"
PORT="${PORT:-8091}"
GRADIO_PORT="${GRADIO_PORT:-7860}"
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
+REPO_ROOT="$(cd "$SCRIPT_DIR/../../.." && pwd)"
echo "Starting Fish Speech S2 Pro server (port $PORT)..."
FLASHINFER_DISABLE_VERSION_CHECK=1 \
-vllm serve "$MODEL" \
- --omni \
+vllm-omni serve "$MODEL" \
+ --stage-configs-path "$REPO_ROOT/vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml" \
--host 0.0.0.0 \
- --port "$PORT" &
+ --port "$PORT" \
+ --gpu-memory-utilization 0.9 \
+ --trust-remote-code \
+ --enforce-eager \
+ --omni &
SERVER_PID=$!
cleanup() {
diff --git a/examples/online_serving/fish_speech/run_server.sh b/examples/online_serving/fish_speech/run_server.sh
index a865daf9378..59c09c7fe05 100755
--- a/examples/online_serving/fish_speech/run_server.sh
+++ b/examples/online_serving/fish_speech/run_server.sh
@@ -13,7 +13,11 @@ PORT="${PORT:-8091}"
echo "Starting Fish Speech S2 Pro server with model: $MODEL"
FLASHINFER_DISABLE_VERSION_CHECK=1 \
-vllm serve "$MODEL" \
- --omni \
+vllm-omni serve "$MODEL" \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml \
--host 0.0.0.0 \
- --port "$PORT"
+ --port "$PORT" \
+ --gpu-memory-utilization 0.9 \
+ --trust-remote-code \
+ --enforce-eager \
+ --omni
diff --git a/examples/online_serving/glm_image/README.md b/examples/online_serving/glm_image/README.md
new file mode 100644
index 00000000000..7ba4b501ca9
--- /dev/null
+++ b/examples/online_serving/glm_image/README.md
@@ -0,0 +1,204 @@
+# GLM-Image Online Serving
+
+This example demonstrates how to deploy GLM-Image for online image generation using vLLM-Omni.
+
+## 🛠️ Installation
+
+Please refer to [README.md](../../../README.md)
+
+## Run examples (GLM-Image)
+
+**Note**: These examples work with the default configuration on **2× NVIDIA A100 (80GB)** or equivalent. Stage 0 (AR) and Stage 1 (Diffusion) each use one GPU by default. For single-GPU setups, modify the stage configuration to share the same device.
+
+### Launch the Server
+
+```bash
+# Use default configuration
+vllm serve zai-org/GLM-Image --omni --port 8091
+```
+
+Or use the convenience script:
+
+```bash
+cd examples/online_serving/glm_image
+bash run_server.sh
+```
+
+If you have a custom stage configs file:
+
+```bash
+vllm serve zai-org/GLM-Image --omni --port 8091 --stage-configs-path /path/to/glm_image.yaml
+```
+
+### Send Requests
+
+Get into the glm_image folder:
+
+```bash
+cd examples/online_serving/glm_image
+```
+
+Send request via Python:
+
+```bash
+python openai_chat_client.py --prompt "A cute cat sitting on a window sill"
+```
+
+The Python client supports the following command-line arguments:
+
+- `--prompt` (or `-p`): Text prompt for generation (default: `A beautiful sunset over the ocean with sailing boats`)
+- `--output` (or `-o`): Output file path (default: `glm_image_output.png`)
+- `--server` (or `-s`): Server URL (default: `http://localhost:8091`)
+- `--image` (or `-i`): Input image path (for image-to-image editing)
+- `--height`: Image height in pixels (default: 1024)
+- `--width`: Image width in pixels (default: 1024)
+- `--steps`: Number of inference steps (default: 50)
+- `--guidance-scale`: Classifier-free guidance scale (default: 1.5)
+- `--seed`: Random seed (default: 42)
+- `--negative`: Negative prompt
+
+## Modality Control
+
+GLM-Image supports **text-to-image** and **image-to-image** modes.
+
+The default yaml configuration deploys AR on GPU 0 and DiT on GPU 1. You can use the default configuration file: [`glm_image.yaml`](../../../vllm_omni/model_executor/stage_configs/glm_image.yaml)
+
+| Mode | Input | Output | Description |
+| -------------- | ------------ | ------ | ---------------------------------- |
+| Text-to-Image | Text | Image | Generate images from text prompts |
+| Image-to-Image | Image + Text | Image | Edit images with text instructions |
+
+### Text-to-Image
+
+```bash
+python openai_chat_client.py \
+ --prompt "A photorealistic mountain landscape at sunset" \
+ --height 1024 \
+ --width 1024 \
+ --output landscape.png
+
+# Or use the curl script:
+bash run_curl_text_to_image.sh "A futuristic city skyline at night"
+```
+
+### Image-to-Image (Image Editing)
+
+```bash
+python openai_chat_client.py \
+ --prompt "Convert this image to watercolor style" \
+ --image input.png \
+ --output watercolor.png
+
+# Or use the curl script:
+bash run_curl_image_edit.sh input.png "Convert to watercolor style"
+```
+
+For general-purpose request methods (curl, OpenAI SDK, Python `requests`), see
+the [Text-to-Image](../text_to_image/README.md) and
+[Image-to-Image](../image_to_image/README.md) guides.
+
+## Generation Parameters
+
+When using `/v1/chat/completions`, pass these inside `extra_body` in the curl
+JSON, or via the `extra_body` keyword argument in the OpenAI Python SDK.
+When using the dedicated `/v1/images/generations` or `/v1/images/edits`
+endpoints, pass the supported generation controls as top-level fields directly.
+For image dimensions and count, use `size` and `n` rather than `height` or
+`width`.
+
+| Parameter | Type | Default | Description |
+| --------------------- | ----- | ------- | ----------------------------------- |
+| `height` | int | 1024 | Image height in pixels |
+| `width` | int | 1024 | Image width in pixels |
+| `num_inference_steps` | int | 50 | Number of diffusion denoising steps |
+| `guidance_scale` | float | 1.5 | Classifier-free guidance scale |
+| `seed` | int | None | Optional random seed; `/v1/images/*` generates one server-side if omitted |
+| `negative_prompt` | str | None | Negative prompt |
+
+## Response Format
+
+```json
+{
+ "id": "chatcmpl-xxx",
+ "created": 1234567890,
+ "model": "zai-org/GLM-Image",
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": "data:image/png;base64,..."
+ }
+ }
+ ]
+ },
+ "finish_reason": "stop"
+ }
+ ],
+ "usage": {}
+}
+```
+
+## Extract Image
+
+```bash
+# From a saved JSON response
+cat response.json | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2- | base64 -d > output.png
+```
+
+## Architecture
+
+GLM-Image uses a 2-stage multistage pipeline:
+
+```
+Stage 0 (AR Model) Stage 1 (Diffusion)
+┌───────────────────┐ ┌─────────────────────┐
+│ vLLM-optimized │ prior │ GlmImagePipeline │
+│ GlmImageFor │──tokens──►│ ┌───────────────┐ │
+│ Conditional │ │ │ DiT Denoiser │ │
+│ Generation │ │ └───────┬───────┘ │
+│ (9B AR model) │ │ ▼ │
+└───────────────────┘ │ ┌───────────────┐ │
+ ▲ │ │ VAE Decode │──┼──► Image
+ │ │ └───────────────┘ │
+ Text / Image └─────────────────────┘
+ Input
+```
+
+## VRAM Requirements
+
+| Stage | VRAM |
+| :---------------- | :--------------------- |
+| Stage-0 (AR) | **~18 GiB + KV Cache** |
+| Stage-1 (DiT+VAE) | **~20 GiB** |
+| Total | **~38 GiB + KV Cache** |
+
+## File Description
+
+| File | Description |
+| --------------------------- | ------------------------------------- |
+| `run_server.sh` | Server startup script |
+| `run_curl_text_to_image.sh` | Text-to-image curl example |
+| `run_curl_image_edit.sh` | Image-to-image (editing) curl example |
+| `openai_chat_client.py` | Python client (t2i + i2i) |
+
+## FAQ
+
+- If you encounter OOM errors, adjust `gpu_memory_utilization` in the stage config:
+
+```yaml
+# In glm_image.yaml, reduce from default 0.6:
+gpu_memory_utilization: 0.5
+```
+
+- The first request may be slow due to model warmup. Subsequent requests will be faster.
+
+- If you encounter `Transformers does not recognize this architecture` error, your have to upgrade `transformers` package to `5.3.0` or above:
+
+```
+pip install --upgrade transformers
+```
diff --git a/examples/online_serving/glm_image/openai_chat_client.py b/examples/online_serving/glm_image/openai_chat_client.py
new file mode 100644
index 00000000000..e142b071904
--- /dev/null
+++ b/examples/online_serving/glm_image/openai_chat_client.py
@@ -0,0 +1,172 @@
+#!/usr/bin/env python3
+"""
+GLM-Image OpenAI-compatible chat client for text-to-image and image-to-image.
+
+Usage:
+ # Text-to-image
+ python openai_chat_client.py --prompt "A cute cat" --output output.png
+
+ # Image-to-image (image editing)
+ python openai_chat_client.py --prompt "Convert to watercolor style" --image input.png --output output.png
+"""
+
+import argparse
+import base64
+from pathlib import Path
+
+import requests
+
+
+def generate_image(
+ prompt: str,
+ server_url: str = "http://localhost:8091",
+ image_path: str | None = None,
+ height: int = 1024,
+ width: int = 1024,
+ steps: int = 50,
+ guidance_scale: float = 1.5,
+ seed: int | None = None,
+ negative_prompt: str | None = None,
+) -> bytes | None:
+ """Generate or edit an image using the chat completions API.
+
+ Args:
+ prompt: Text description or editing instruction
+ server_url: Server URL
+ image_path: Path to input image (for image-to-image editing)
+ height: Image height in pixels
+ width: Image width in pixels
+ steps: Number of inference steps
+ guidance_scale: Classifier-free guidance scale
+ seed: Random seed for reproducibility
+ negative_prompt: Negative prompt
+
+ Returns:
+ Image bytes or None if failed
+ """
+ # Build message content
+ content: list[dict] = [{"type": "text", "text": prompt}]
+
+ if image_path:
+ img_path = Path(image_path)
+ if not img_path.exists():
+ print(f"Error: Image file not found: {image_path}")
+ return None
+ b64_data = base64.b64encode(img_path.read_bytes()).decode("utf-8")
+ suffix = img_path.suffix.lstrip(".").lower()
+ mime = {"jpg": "jpeg", "jpeg": "jpeg", "png": "png", "webp": "webp"}.get(suffix, "png")
+ content.append(
+ {
+ "type": "image_url",
+ "image_url": {"url": f"data:image/{mime};base64,{b64_data}"},
+ }
+ )
+
+ messages = [{"role": "user", "content": content}]
+
+ # Build request payload
+ extra_body: dict = {
+ "height": height,
+ "width": width,
+ "num_inference_steps": steps,
+ "guidance_scale": guidance_scale,
+ }
+ if seed is not None:
+ extra_body["seed"] = seed
+ if negative_prompt:
+ extra_body["negative_prompt"] = negative_prompt
+
+ payload = {"messages": messages, "extra_body": extra_body}
+
+ # Send request
+ try:
+ mode = "image-to-image" if image_path else "text-to-image"
+ print(f"Sending {mode} request to {server_url}...")
+ response = requests.post(
+ f"{server_url}/v1/chat/completions",
+ headers={"Content-Type": "application/json"},
+ json=payload,
+ timeout=600,
+ )
+ response.raise_for_status()
+ data = response.json()
+
+ # Extract image from response
+ choices = data.get("choices", [])
+ for choice in choices:
+ choice_content = choice.get("message", {}).get("content")
+ if isinstance(choice_content, list):
+ for item in choice_content:
+ if isinstance(item, dict) and "image_url" in item:
+ img_url = item["image_url"].get("url", "")
+ if img_url.startswith("data:image"):
+ _, b64 = img_url.split(",", 1)
+ return base64.b64decode(b64)
+
+ print(f"Unexpected response format: {data}")
+ return None
+
+ except Exception as e:
+ print(f"Error: {e}")
+ return None
+
+
+def main():
+ parser = argparse.ArgumentParser(description="GLM-Image chat client")
+ parser.add_argument(
+ "--prompt",
+ "-p",
+ default="A beautiful sunset over the ocean with sailing boats",
+ help="Text prompt",
+ )
+ parser.add_argument("--output", "-o", default="glm_image_output.png", help="Output file")
+ parser.add_argument("--server", "-s", default="http://localhost:8091", help="Server URL")
+
+ # Image-to-image
+ parser.add_argument(
+ "--image",
+ "-i",
+ type=str,
+ help="Input image path (for image-to-image editing)",
+ )
+
+ # Generation parameters
+ parser.add_argument("--height", type=int, default=1024, help="Image height")
+ parser.add_argument("--width", type=int, default=1024, help="Image width")
+ parser.add_argument("--steps", type=int, default=50, help="Inference steps")
+ parser.add_argument("--guidance-scale", type=float, default=1.5, help="CFG guidance scale")
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
+ parser.add_argument("--negative", help="Negative prompt")
+
+ args = parser.parse_args()
+
+ mode = "image-to-image" if args.image else "text-to-image"
+ print(f"Mode: {mode}")
+ print(f"Prompt: {args.prompt}")
+ if args.image:
+ print(f"Input image: {args.image}")
+
+ image_bytes = generate_image(
+ prompt=args.prompt,
+ server_url=args.server,
+ image_path=args.image,
+ height=args.height,
+ width=args.width,
+ steps=args.steps,
+ guidance_scale=args.guidance_scale,
+ seed=args.seed,
+ negative_prompt=args.negative,
+ )
+
+ if image_bytes:
+ output_path = Path(args.output)
+ output_path.write_bytes(image_bytes)
+ print(f"Image saved to: {output_path}")
+ print(f"Size: {len(image_bytes) / 1024:.1f} KB")
+ else:
+ print("Failed to generate image")
+ exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/online_serving/glm_image/run_curl_image_edit.sh b/examples/online_serving/glm_image/run_curl_image_edit.sh
new file mode 100755
index 00000000000..bb1e851ba32
--- /dev/null
+++ b/examples/online_serving/glm_image/run_curl_image_edit.sh
@@ -0,0 +1,61 @@
+#!/bin/bash
+# GLM-Image image-edit (image-to-image) curl example
+
+set -euo pipefail
+
+if [[ $# -lt 2 ]]; then
+ echo "Usage: $0 \"\" [output_file]" >&2
+ exit 1
+fi
+
+INPUT_IMG=$1
+PROMPT=$2
+SERVER="${SERVER:-http://localhost:8091}"
+CURRENT_TIME=$(date +%Y%m%d%H%M%S)
+OUTPUT="${3:-glm_image_i2i_${CURRENT_TIME}.png}"
+
+if [[ ! -f "$INPUT_IMG" ]]; then
+ echo "Input image not found: $INPUT_IMG" >&2
+ exit 1
+fi
+
+# base64 encode (macOS uses -i, Linux uses -w0)
+if [[ "$(uname)" == "Darwin" ]]; then
+ IMG_B64=$(base64 < "$INPUT_IMG" | tr -d '\n')
+else
+ IMG_B64=$(base64 -w0 "$INPUT_IMG")
+fi
+
+REQUEST_JSON=$(
+ jq -n --arg prompt "$PROMPT" --arg img "$IMG_B64" '{
+ messages: [{
+ role: "user",
+ content: [
+ {"type": "text", "text": $prompt},
+ {"type": "image_url", "image_url": {"url": ("data:image/png;base64," + $img)}}
+ ]
+ }],
+ extra_body: {
+ height: 1024,
+ width: 1024,
+ num_inference_steps: 50,
+ guidance_scale: 1.5,
+ seed: 42
+ }
+ }'
+)
+
+echo "Generating edited image..."
+echo "Server: $SERVER"
+echo "Prompt: $PROMPT"
+echo "Input : $INPUT_IMG"
+echo "Output: $OUTPUT"
+
+curl -s "$SERVER/v1/chat/completions" \
+ -H "Content-Type: application/json" \
+ -d "$REQUEST_JSON" \
+ | jq -r '.choices[0].message.content[0].image_url.url' \
+ | cut -d',' -f2- \
+ | base64 -d > "$OUTPUT"
+
+echo "Image saved to: $OUTPUT"
diff --git a/examples/online_serving/glm_image/run_curl_text_to_image.sh b/examples/online_serving/glm_image/run_curl_text_to_image.sh
new file mode 100755
index 00000000000..aecb6953c45
--- /dev/null
+++ b/examples/online_serving/glm_image/run_curl_text_to_image.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+# GLM-Image text-to-image curl example
+
+set -euo pipefail
+
+PROMPT="${1:-A beautiful sunset over the ocean with sailing boats}"
+SERVER="${SERVER:-http://localhost:8091}"
+OUTPUT="${OUTPUT:-glm_image_t2i_output.png}"
+
+echo "Generating image..."
+echo "Server: $SERVER"
+echo "Prompt: $PROMPT"
+echo "Output: $OUTPUT"
+
+curl -s "$SERVER/v1/chat/completions" \
+ -H "Content-Type: application/json" \
+ -d "{
+ \"messages\": [
+ {\"role\": \"user\", \"content\": \"$PROMPT\"}
+ ],
+ \"extra_body\": {
+ \"height\": 1024,
+ \"width\": 1024,
+ \"num_inference_steps\": 50,
+ \"guidance_scale\": 1.5,
+ \"seed\": 42
+ }
+ }" | jq -r '.choices[0].message.content[0].image_url.url' | cut -d',' -f2- | base64 -d > "$OUTPUT"
+
+echo "Image saved to: $OUTPUT"
diff --git a/examples/online_serving/glm_image/run_server.sh b/examples/online_serving/glm_image/run_server.sh
new file mode 100755
index 00000000000..b47d9f88504
--- /dev/null
+++ b/examples/online_serving/glm_image/run_server.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+# GLM-Image online serving startup script
+
+MODEL="${MODEL:-zai-org/GLM-Image}"
+PORT="${PORT:-8091}"
+
+echo "Starting GLM-Image server..."
+echo "Model: $MODEL"
+echo "Port: $PORT"
+
+vllm serve "$MODEL" --omni \
+ --port "$PORT"
diff --git a/examples/online_serving/image_to_image/README.md b/examples/online_serving/image_to_image/README.md
index 59b1f0e2c15..789258473fd 100644
--- a/examples/online_serving/image_to_image/README.md
+++ b/examples/online_serving/image_to_image/README.md
@@ -314,7 +314,6 @@ count, use `size` and `n` rather than `height`, `width`, or
| `seed` | int | None | Random seed (reproducible) |
| `negative_prompt` | str | None | Negative prompt |
| `num_outputs_per_prompt` | int | 1 | Number of images to generate |
-| `strength` | float | 0.6 | **Z-Image only** - Denoising start timestep for I2I. Range: [0.0, 1.0]. Lower preserves more of original image. |
| `layers` | int | 4 | Number of layers (Qwen-Image-Layered) |
| `resolution` | int | 640 | Resolution, 640 or 1024 (Qwen-Image-Layered) |
diff --git a/examples/online_serving/image_to_video/README.md b/examples/online_serving/image_to_video/README.md
index 285eeb27983..49283bd9a06 100644
--- a/examples/online_serving/image_to_video/README.md
+++ b/examples/online_serving/image_to_video/README.md
@@ -26,23 +26,6 @@ The script allows overriding:
- `CACHE_BACKEND` (default: `none`)
- `ENABLE_CACHE_DIT_SUMMARY` (default: `0`)
-### Ascend / Local LightX2V Example
-
-For a local Wan2.2-LightX2V Diffusers directory on Ascend/NPU, you can start the server like this:
-
-```bash
-vllm serve /path/to/Wan2.2-I2V-A14B-LightX2V-Diffusers-Lightning \
- --omni \
- --port 8091 \
- --flow-shift 12 \
- --cfg-parallel-size 1 \
- --ulysses-degree 4 \
- --use-hsdp \
- --trust-remote-code \
- --allowed-local-media-path / \
- --seed 42
-```
-
## Async Job Behavior
`POST /v1/videos` is asynchronous. It creates a video job and immediately
@@ -86,35 +69,10 @@ curl -X POST http://localhost:8091/v1/videos/sync \
-F "guidance_scale_2=1.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=12.0" \
- -F 'extra_params={"sample_solver":"euler"}' \
-F "seed=42" \
-o sync_i2v_output.mp4
```
-For Wan Lightning/Distill checkpoints, pass `{"sample_solver":"euler"}` via `extra_params`. The default solver is `unipc`.
-
-Example matching the local LightX2V deployment above:
-
-```bash
-curl -sS -X POST http://localhost:8091/v1/videos/sync \
- -H "Accept: video/mp4" \
- -F "prompt=A cat playing with yarn" \
- -F "input_reference=@/path/to/input.jpg" \
- -F "width=832" \
- -F "height=480" \
- -F "num_frames=81" \
- -F "fps=16" \
- -F "num_inference_steps=4" \
- -F "guidance_scale=1.0" \
- -F "guidance_scale_2=1.0" \
- -F "boundary_ratio=0.875" \
- -F "seed=42" \
- -F 'extra_params={"sample_solver":"euler"}' \
- -o ./output.mp4
-```
-
-Use `/v1/videos/sync` if you want to write the MP4 directly to a file. `POST /v1/videos` is async and returns job metadata, not inline `b64_json`.
-
## Storage
Generated video files are stored on local disk by the async video API.
@@ -138,9 +96,6 @@ export VLLM_OMNI_STORAGE_MAX_CONCURRENCY=8
# Basic image-to-video generation
bash run_curl_image_to_video.sh
-# Wan Lightning/Distill checkpoints
-SAMPLE_SOLVER=euler bash run_curl_image_to_video.sh
-
# Or execute directly (OpenAI-style multipart)
create_response=$(curl -s http://localhost:8091/v1/videos \
-H "Accept: application/json" \
@@ -156,7 +111,6 @@ create_response=$(curl -s http://localhost:8091/v1/videos \
-F "guidance_scale_2=1.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=12.0" \
- -F 'extra_params={"sample_solver":"euler"}' \
-F "seed=42")
video_id=$(echo "$create_response" | jq -r '.id')
@@ -215,12 +169,9 @@ curl -X POST http://localhost:8091/v1/videos \
-F "guidance_scale_2=1.0" \
-F "boundary_ratio=0.875" \
-F "flow_shift=12.0" \
- -F 'extra_params={"sample_solver":"euler"}' \
-F "seed=42"
```
-`sample_solver` is supported by Wan2.2 online serving through the existing `extra_params` field, which is merged into the pipeline `extra_args`. Use `unipc` for the default multistep solver, or `euler` for Lightning/Distill checkpoints.
-
## Create Response Format
`POST /v1/videos` returns a job record, not inline base64 video data.
diff --git a/examples/online_serving/image_to_video/run_curl_image_to_video.sh b/examples/online_serving/image_to_video/run_curl_image_to_video.sh
index 6f6a6f96d59..f4c1496a69a 100644
--- a/examples/online_serving/image_to_video/run_curl_image_to_video.sh
+++ b/examples/online_serving/image_to_video/run_curl_image_to_video.sh
@@ -7,7 +7,6 @@ INPUT_IMAGE="${INPUT_IMAGE:-../../offline_inference/image_to_video/qwen-bear.png
BASE_URL="${BASE_URL:-http://localhost:8099}"
OUTPUT_PATH="${OUTPUT_PATH:-wan22_i2v_output.mp4}"
NEGATIVE_PROMPT="${NEGATIVE_PROMPT:-}"
-SAMPLE_SOLVER="${SAMPLE_SOLVER:-}"
POLL_INTERVAL="${POLL_INTERVAL:-2}"
if [ ! -f "$INPUT_IMAGE" ]; then
@@ -35,10 +34,6 @@ if [ -n "${NEGATIVE_PROMPT}" ]; then
create_cmd+=(-F "negative_prompt=${NEGATIVE_PROMPT}")
fi
-if [ -n "${SAMPLE_SOLVER}" ]; then
- create_cmd+=(-F "extra_params={\"sample_solver\":\"${SAMPLE_SOLVER}\"}")
-fi
-
create_response="$("${create_cmd[@]}")"
video_id="$(echo "${create_response}" | jq -r '.id')"
if [ -z "${video_id}" ] || [ "${video_id}" = "null" ]; then
diff --git a/examples/online_serving/mimo_audio/README.md b/examples/online_serving/mimo_audio/README.md
index 9c1be7f21c8..9f70d59cbe8 100644
--- a/examples/online_serving/mimo_audio/README.md
+++ b/examples/online_serving/mimo_audio/README.md
@@ -13,10 +13,10 @@ Please refer to [README.md](../../../README.md)
```bash
export MIMO_AUDIO_TOKENIZER_PATH="XiaomiMiMo/MiMo-Audio-Tokenizer"
-vllm serve XiaomiMiMo/MiMo-Audio-7B-Instruct --omni \
- --served-model-name "MiMo-Audio-7B-Instruct" \
- --port 18091 \
- --chat-template ./examples/online_serving/mimo_audio/chat_template.jinja
+vllm-omni serve XiaomiMiMo/MiMo-Audio-7B-Instruct --omni \
+--served-model-name "MiMo-Audio-7B-Instruct" \
+--port 18091 --stage-configs-path ./vllm_omni/model_executor/stage_configs/mimo_audio.yaml \
+--chat-template ./examples/online_serving/mimo_audio/chat_template.jinja
```
> ⚠️ **Important**
> **MiMo-Audio is not compatible with the default chat template.**
diff --git a/examples/online_serving/ming_flash_omni/README.md b/examples/online_serving/ming_flash_omni/README.md
deleted file mode 100644
index 8b7d03e211a..00000000000
--- a/examples/online_serving/ming_flash_omni/README.md
+++ /dev/null
@@ -1,95 +0,0 @@
-# Ming-flash-omni 2.0
-
-## Installation
-
-Please refer to [README.md](../../../README.md)
-
-## Deployment modes
-
-| Mode | Launch command | Output |
-|------|---------------|--------|
-| Thinker only (multimodal understanding) | `vllm serve ... --omni` | Text |
-| Thinker + Talker (omni-speech) | `vllm serve ... --omni --stage-configs-path ming_flash_omni.yaml` | Text + Audio |
-
-For standalone TTS (talker only), see [`examples/online_serving/ming_flash_omni_tts/`](../ming_flash_omni_tts/).
-
-## Run examples (Ming-flash-omni 2.0)
-
-### Launch the Server
-
-**Thinker only (text output):**
-```bash
-vllm serve Jonathan1909/Ming-flash-omni-2.0 --omni --port 8091
-```
-
-**Thinker + Talker (omni-speech, text + audio output):**
-```bash
-vllm serve Jonathan1909/Ming-flash-omni-2.0 --omni --port 8091 \
- --stage-configs-path vllm_omni/model_executor/stage_configs/ming_flash_omni.yaml
-```
-
-Pass `--stage-configs-path /path/to/your_config.yaml` to use a custom stage
-config.
-
-### Send Multi-modal Request
-
-Shared Python client (supports `text | use_image | use_audio | use_video |
-use_mixed_modalities`; pass `--image-path` / `--audio-path` / `--video-path`
-for local files or URLs, `--modalities text` for output, `--help` for the
-full flag list):
-
-```bash
-python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py \
- --model Jonathan1909/Ming-flash-omni-2.0 \
- --query-type use_mixed_modalities \
- --port 8091 --host localhost \
- --modalities text
-```
-
-Parameterized curl wrapper in this directory:
-
-```bash
-bash run_curl_multimodal_generation.sh text
-bash run_curl_multimodal_generation.sh use_image
-bash run_curl_multimodal_generation.sh use_audio
-bash run_curl_multimodal_generation.sh use_video
-bash run_curl_multimodal_generation.sh use_mixed_modalities
-```
-
-## Modality control
-
-| `modalities` | Server config | Output |
-|-------------|--------------|--------|
-| `["text"]` or omitted | Thinker only | Text |
-| `["audio"]` | Thinker + Talker | Audio (speech) |
-| `["text", "audio"]` | Thinker + Talker | Text + Audio |
-
-For ready-to-copy curl examples (text / audio / multimodal input, SSE
-streaming, reasoning mode), see the recipe at
-[`recipes/inclusionAI/Ming-flash-omni-2.0.md`](../../../recipes/inclusionAI/Ming-flash-omni-2.0.md).
-
-## OpenAI Python SDK — streaming
-
-```python
-from openai import OpenAI
-
-client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY")
-
-response = client.chat.completions.create(
- model="Jonathan1909/Ming-flash-omni-2.0",
- messages=[
- {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]},
- {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"},
- ],
- modalities=["text"],
- stream=True,
-)
-for chunk in response:
- for choice in chunk.choices:
- if hasattr(choice, "delta") and choice.delta.content:
- print(choice.delta.content, end="", flush=True)
-print()
-```
-
-The `--stream` flag on the Python client script above shows the same pattern
-driven by the shared multimodal client.
diff --git a/examples/online_serving/ming_flash_omni/run_curl_multimodal_generation.sh b/examples/online_serving/ming_flash_omni/run_curl_multimodal_generation.sh
deleted file mode 100755
index 768a424e451..00000000000
--- a/examples/online_serving/ming_flash_omni/run_curl_multimodal_generation.sh
+++ /dev/null
@@ -1,145 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-
-# Server port
-PORT="${PORT:-8091}"
-# Default query type
-QUERY_TYPE="${1:-text}"
-
-# Validate query type
-if [[ ! "$QUERY_TYPE" =~ ^(text|use_audio|use_image|use_video|use_mixed_modalities)$ ]]; then
- echo "Error: Invalid query type '$QUERY_TYPE'"
- echo "Usage: $0 [text|use_audio|use_image|use_video|use_mixed_modalities]"
- echo " text: Text-only query"
- echo " use_audio: Audio + Text query"
- echo " use_image: Image + Text query"
- echo " use_video: Video + Text query"
- echo " use_mixed_modalities: Audio + Image + Video + Text query"
- exit 1
-fi
-
-thinker_sampling_params='{
- "temperature": 0.4,
- "top_p": 0.9,
- "top_k": -1,
- "max_tokens": 16384,
- "seed": 42,
- "detokenize": true,
- "repetition_penalty": 1.05
-}'
-# Above is optional, it has a default setting in stage_configs of the corresponding model.
-
-# Define URLs for assets
-MARY_HAD_LAMB_AUDIO_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/mary_had_lamb.ogg"
-CHERRY_BLOSSOM_IMAGE_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg"
-SAMPLE_VIDEO_URL="https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4"
-
-# Build user content based on query type
-case "$QUERY_TYPE" in
- text)
- user_content='[
- {
- "type": "text",
- "text": "请详细介绍鹦鹉的生活习性。"
- }
- ]'
- ;;
- use_image)
- user_content='[
- {
- "type": "image_url",
- "image_url": {
- "url": "'"$CHERRY_BLOSSOM_IMAGE_URL"'"
- }
- },
- {
- "type": "text",
- "text": "Describe this image in detail."
- }
- ]'
- ;;
- use_audio)
- user_content='[
- {
- "type": "audio_url",
- "audio_url": {
- "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'"
- }
- },
- {
- "type": "text",
- "text": "Please recognize the language of this speech and transcribe it. Format: oral."
- }
- ]'
- ;;
- use_video)
- user_content='[
- {
- "type": "video_url",
- "video_url": {
- "url": "'"$SAMPLE_VIDEO_URL"'"
- }
- },
- {
- "type": "text",
- "text": "Describe what is happening in this video."
- }
- ]'
- ;;
- use_mixed_modalities)
- user_content='[
- {
- "type": "image_url",
- "image_url": {
- "url": "'"$CHERRY_BLOSSOM_IMAGE_URL"'"
- }
- },
- {
- "type": "audio_url",
- "audio_url": {
- "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'"
- }
- },
- {
- "type": "text",
- "text": "Describe the image, and recognize the language of this speech and transcribe it. Format: oral"
- }
- ]'
- ;;
-esac
-
-echo "Running query type: $QUERY_TYPE"
-echo ""
-
-request_body=$(cat < None:
- payload = {
- "model": args.model,
- "input": args.text,
- "response_format": args.response_format,
- }
-
- instructions = args.instructions
- if args.instruction_json:
- if instructions:
- sys.exit("--instructions and --instruction-json are mutually exclusive")
-
- try:
- parsed = json.loads(args.instruction_json)
- except json.JSONDecodeError as exc:
- sys.exit(f"--instruction-json must be valid JSON: {exc}")
- if not isinstance(parsed, dict):
- sys.exit("--instruction-json must decode to a JSON object")
- # Re-encode with ensure_ascii=False so UTF-8 Chinese keys/values
- # arrive at the server intact rather than as \\uXXXX escapes.
- instructions = json.dumps(parsed, ensure_ascii=False)
- if instructions:
- payload["instructions"] = instructions
-
- print(f"Model: {args.model}")
- print(f"Text: {args.text}")
- print("Generating audio...")
-
- api_url = f"{args.api_base}/v1/audio/speech"
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {args.api_key}",
- }
-
- with httpx.Client(timeout=300.0) as client:
- response = client.post(api_url, json=payload, headers=headers)
-
- if response.status_code != 200:
- print(f"Error: {response.status_code}")
- print(response.text)
- return
-
- output_path = args.output or "ming_tts_output.wav"
- with open(output_path, "wb") as f:
- f.write(response.content)
- print(f"Audio saved to: {output_path}")
-
-
-def main():
- parser = argparse.ArgumentParser(description="Ming standalone TTS speech client")
- parser.add_argument("--api-base", default=DEFAULT_API_BASE, help="API base URL")
- parser.add_argument("--api-key", default=DEFAULT_API_KEY, help="API key")
- parser.add_argument("--model", "-m", default=DEFAULT_MODEL, help="Model name or local path")
- parser.add_argument("--text", required=True, help="Text to synthesize")
- parser.add_argument(
- "--response-format",
- default="wav",
- choices=["wav", "mp3", "flac", "pcm", "aac", "opus"],
- help="Audio format (default: wav)",
- )
- parser.add_argument("--output", "-o", default=None, help="Output file path")
- parser.add_argument(
- "--instructions",
- default=None,
- help="Free-form style description (mapped to caption 风格 on the server).",
- )
- parser.add_argument(
- "--instruction-json",
- default=None,
- help=(
- "Structured caption JSON forwarded as `instructions`. Accepts Ming "
- "caption keys: 方言, 风格, 语速, 基频, 音量, 情感, IP, 说话人, BGM. "
- ),
- )
- args = parser.parse_args()
- run_tts(args)
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/online_serving/qwen2_5_omni/README.md b/examples/online_serving/qwen2_5_omni/README.md
index c528732064a..91aab3b6518 100644
--- a/examples/online_serving/qwen2_5_omni/README.md
+++ b/examples/online_serving/qwen2_5_omni/README.md
@@ -208,3 +208,11 @@ The gradio script supports the following arguments:
- `--ip`: Host/IP for Gradio server (default: 127.0.0.1)
- `--port`: Port for Gradio server (default: 7861)
- `--share`: Share the Gradio demo publicly (creates a public link)
+
+### FAQ
+
+If you encounter error about backend of librosa, try to install ffmpeg with command below.
+```
+sudo apt update
+sudo apt install ffmpeg
+```
diff --git a/examples/online_serving/qwen3_omni/README.md b/examples/online_serving/qwen3_omni/README.md
index c85970555f9..c3171e43667 100644
--- a/examples/online_serving/qwen3_omni/README.md
+++ b/examples/online_serving/qwen3_omni/README.md
@@ -12,221 +12,17 @@ Please refer to [README.md](../../../README.md)
vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
```
-The default deployment configuration, situated at `vllm_omni/deploy/qwen3_omni_moe.yaml`, is resolved and loaded
-automatically via the model registry, obviating the `--deploy-config` flag in standard deployment topologies.
-Asynchronous chunk streaming operates as **enabled by default** within this bundled configuration.
-Additionally, NPU, ROCm, and XPU per-platform configuration deltas are deterministically merged from the
-`platforms`: section of the corresponding YAML.
-
-**Note:** The OpenAI-style **`/v1/realtime`** WebSocket interface (facilitating streaming PCM audio input alongside audio and transcription output)
-is currently **unsupported** while the `async_chunk` configuration attribute is enabled.
-It is requisite to instantiate the default omni architecture or utilize a deployment configuration specifying `async_chunk: false` to facilitate real-time streaming sessions.
-
-To explicitly utilize a custom deployment YAML, mandate the configuration path accordingly:
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
- --deploy-config /path/to/your_deploy_config.yaml
-```
-
-### Launch individual stages (stage-based CLI)
-
-Use the stage-based CLI when you want to run one stage per process.
-
-**1. Stage 0 (Thinker + API server)**
+If you want to open async chunking for qwen3-omni, launch the server with command below
```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni \
- --port 8091 \
- --stage-id 0 \
- --omni-master-address 127.0.0.1 \
- --omni-master-port 26000
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml
```
-**2. Stage 1 (Talker)**
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni \
- --stage-id 1 \
- --headless \
- --omni-master-address 127.0.0.1 \
- --omni-master-port 26000
-```
-
-**3. Stage 2 (Code2Wav)**
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni \
- --stage-id 2 \
- --headless \
- --omni-master-address 127.0.0.1 \
- --omni-master-port 26000
-```
-
-Append `--deploy-config /path/to/your_deploy_config.yaml` to each node invocation if it is necessary
-to explicitly override the bundled deployment YAML schema.
-
-For standard **unified-process** launcher, stage-specific CLI configuration tuning is conventionally implemented
-via the `--stage-overrides` directive, as demonstrated below:
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
- --stage-overrides '{"1": {"gpu_memory_utilization": 0.5}}'
-```
-
-Conversely, within the stage-based CLI paradigm, `--stage-overrides` modifiers are typically **unnecessary**
-for this category of optimization. Given that each instantiation strictly initiates a single functional stage,
-parameter flags can be systematically assigned directly onto that specific stage's command sequence:
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni \
- --stage-id 1 \
- --headless \
- --gpu-memory-utilization 0.5 \
- --omni-master-address 127.0.0.1 \
- --omni-master-port 26000
-```
-
-### Tuning deployment parameters
-
-Most engine knobs (`max_num_batched_tokens`, `max_model_len`, `enforce_eager`,
-`gpu_memory_utilization`, `tensor_parallel_size`, …) can be tuned without
-editing the YAML. There are three layers, in increasing specificity:
-
-#### 1. Global CLI flags (apply to every stage)
-
+If you have custom stage configs file, launch the server with command below
```bash
-# Tighter memory budget on a smaller GPU
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
- --gpu-memory-utilization 0.85
-
-# Disable cudagraphs (e.g. for debugging)
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
- --enforce-eager
-
-# Reduce context length
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
- --max-model-len 32768
-
-# Toggle prefix caching on every stage (yaml default: off)
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
- --enable-prefix-caching
-# ...or force it off if the yaml turned it on
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
- --no-enable-prefix-caching
-
-# Toggle pipeline-wide async chunked streaming between stages
-# (yaml default for qwen3_omni_moe: on)
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
- --no-async-chunk
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
```
-For the TTS counterpart (synchronous codec variant), see
-[qwen3_tts README](../qwen3_tts/README.md#sync-vs-async-chunk-mode).
-
-Explicit CLI flags **override** the deploy YAML (which itself overrides the
-parser defaults). If you don't pass a flag, the YAML value wins.
-
-> **Note on `--no-async-chunk`**: Flips the deploy yaml's `async_chunk:`
-> bool. Pipelines that implement alternate processor functions for
-> chunked vs end-to-end modes (e.g. qwen3_tts code2wav) dispatch
-> automatically based on that bool — no extra flag or variant yaml is
-> needed.
-
-> ⚠️ **For multi-stage models that share GPUs (qwen3_omni_moe by default
-> shares cuda:1 between stages 1 and 2), avoid using global memory flags.**
-> A global `--gpu-memory-utilization 0.85` would apply to every stage and
-> oversubscribe the shared device. Use per-stage overrides instead — see
-> below.
-
-#### 2. Per-stage overrides via `--stage-overrides` (recommended for memory)
-
-```bash
-# Lower stage 1's memory budget; leave others at the YAML default
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
- --stage-overrides '{
- "1": {"gpu_memory_utilization": 0.5},
- "2": {"max_num_batched_tokens": 65536}
- }'
-```
-
-Per-stage values are always treated as explicit and beat YAML defaults for
-the named stage. Other stages keep their YAML values.
-
-If you switch to the stage-based CLI, the same per-stage tuning can usually be
-passed directly on that stage's command instead of using `--stage-overrides`.
-
-#### 3. Custom deploy YAML
-
-When per-stage overrides get long, write a small overlay YAML that inherits
-from the bundled default:
-
-```yaml
-# my_qwen3_omni_overrides.yaml
-base_config: /path/to/vllm_omni/deploy/qwen3_omni_moe.yaml
-
-stages:
- - stage_id: 0
- max_num_batched_tokens: 65536
- enforce_eager: true
- - stage_id: 1
- gpu_memory_utilization: 0.5
- - stage_id: 2
- max_model_len: 8192
-```
-
-Then start the server with `--deploy-config my_qwen3_omni_overrides.yaml`.
-The `base_config:` line tells the loader to inherit everything else (stages,
-connectors, edges, platforms section) from the bundled production YAML, so
-you only need to spell out the deltas.
-
-#### 4. Multi-node deployment (cross-host transfer connector)
-
-The bundled `qwen3_omni_moe.yaml` uses `SharedMemoryConnector` between stages,
-which only works when all stages run on the same physical host. For
-**cross-node** deployments, write a small overlay YAML that swaps in a
-network-capable connector (e.g. `MooncakeStoreConnector`) and re-points each
-stage's connector wiring at it. The connector spec carries your own server
-addresses — there is no checked-in default because every cluster is
-different.
-
-```yaml
-# my_qwen3_omni_multinode.yaml
-base_config: /path/to/vllm_omni/deploy/qwen3_omni_moe.yaml
-
-connectors:
- mooncake_connector:
- name: MooncakeStoreConnector
- extra:
- host: "127.0.0.1"
- metadata_server: "http://YOUR_METADATA_HOST:8080/metadata"
- master: "YOUR_MASTER_HOST:50051"
- segment: 512000000 # 512 MB transfer segment
- localbuf: 64000000 # 64 MB local buffer
- proto: "tcp"
-
-stages:
- - stage_id: 0
- output_connectors:
- to_stage_1: mooncake_connector
- - stage_id: 1
- input_connectors:
- from_stage_0: mooncake_connector
- output_connectors:
- to_stage_2: mooncake_connector
- - stage_id: 2
- input_connectors:
- from_stage_1: mooncake_connector
-```
-
-Then launch with `--deploy-config my_qwen3_omni_multinode.yaml`. Same
-pattern works for Qwen2.5-Omni — replace `base_config:` with the path to
-`vllm_omni/deploy/qwen2_5_omni.yaml`.
-
-> ⚠️ Replace `YOUR_METADATA_HOST` / `YOUR_MASTER_HOST` with the actual
-> mooncake server addresses for your cluster. The `base_config:` overlay
-> inherits all stage budgets, devices, and edges from the bundled prod
-> YAML — you only need to spell out the connector swap.
-
### Send Multi-modal Request
Get into the example folder
@@ -242,43 +38,38 @@ python examples/online_serving/openai_chat_completion_client_for_multimodal_gene
#### Realtime WebSocket client (`openai_realtime_client.py`)
-[`openai_realtime_client.py`](./openai_realtime_client.py) connects to **`ws://:/v1/realtime`**, streams a local WAV as **PCM16 mono @ 16 kHz** in fixed-size chunks (OpenAI-style `input_audio_buffer.append` / `commit`), and receives **`response.audio.delta`** (incremental PCM for the reply) plus **`transcription.*`** events. By default it concatenates audio deltas and writes **`--output-wav`** (model output is typically **24 kHz**). Optional **`--delta-dump-dir`** saves each delta as `delta_000001.wav`, … for debugging.
-
-Streaming input works well for translation-style use cases; if the Thinker runs while input is still incomplete, consider limiting **`max_tokens`** in your session / server defaults to avoid over-generation.
+[`openai_realtime_client.py`](./openai_realtime_client.py) connects to **`ws://:/v1/realtime`**, uploads a local audio file as **PCM16 mono @ 16 kHz** chunks (OpenAI-style `input_audio_buffer.append` / `commit`), and prints **streaming transcription** (`transcription.delta` / `transcription.done`).
**Dependencies:**
```bash
-pip install websockets
+pip install websockets librosa numpy
```
+(ffmpeg may be required by `librosa` for some formats; see the FAQ below.)
+
**From this directory** (`examples/online_serving/qwen3_omni`):
```bash
python openai_realtime_client.py \
- --url ws://localhost:8091/v1/realtime \
+ --host localhost \
+ --port 8091 \
--model Qwen/Qwen3-Omni-30B-A3B-Instruct \
- --input-wav /path/to/input_16k_mono.wav \
- --output-wav realtime_output.wav \
- --delta-dump-dir ./rt_delta_wavs
+ --audio_path /path/to/your.wav
```
+If `--audio_path` is omitted, the script uses a bundled default clip (`mary_had_lamb` via vLLM assets).
+
**Arguments:**
| Flag | Default | Description |
|------|---------|-------------|
-| `--url` | `ws://localhost:8091/v1/realtime` | Full WebSocket URL including path |
-| `--model` | `Qwen/Qwen3-Omni-30B-A3B-Instruct` | Must match the served model (sent in `session.update`) |
-| `--input-wav` | *(required)* | Input WAV: mono, 16-bit PCM, **16 kHz** |
-| `--output-wav` | `realtime_output.wav` | Output path for concatenated reply audio |
-| `--output-text` | *(optional)* | If set, write final transcription text to this path |
-| `--chunk-ms` | `200` | Size of each uploaded audio chunk (milliseconds of audio) |
-| `--send-delay-ms` | `0` | Delay between chunk sends (simulate realtime upload) |
-| `--delta-dump-dir` | *(optional)* | Directory to write per-`response.audio.delta` WAV files |
-| `--num-requests` | `1` | Number of sequential sessions (see `--concurrency`) |
-| `--concurrency` | `1` | Max concurrent WebSocket sessions when `--num-requests` > 1 |
-
-Ensure the server is running **without** `async_chunk` if you use `/v1/realtime`, for example:
+| `--host` | `localhost` | API server host |
+| `--port` | `8000` | API server port (match your `vllm serve` port, e.g. `8091`) |
+| `--model` | `Qwen/Qwen3-Omni-30B-A3B-Instruct` | Must match the served model (also sent in `session.update`) |
+| `--audio_path` | *(optional)* | Path to input audio; resampled to 16 kHz mono inside the client |
+
+Ensure the vLLM-Omni server is running with realtime support for this endpoint, for example:
```bash
vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
@@ -314,6 +105,12 @@ bash run_curl_multimodal_generation.sh use_image
### FAQ
+If you encounter error about backend of librosa, try to install ffmpeg with command below.
+```
+sudo apt update
+sudo apt install ffmpeg
+```
+
## Modality control
You can control output modalities to specify which types of output the model should generate. This is useful when you only need text output and want to skip audio generation stages for better performance.
@@ -487,7 +284,7 @@ The script supports the following arguments:
- `--model`: Model name/path (default: Qwen/Qwen3-Omni-30B-A3B-Instruct)
- `--server-port`: Port for vLLM server (default: 8091)
- `--gradio-port`: Port for Gradio demo (default: 7861)
-- `--deploy-config`: Path to custom deploy config YAML file (optional)
+- `--stage-configs-path`: Path to custom stage configs YAML file (optional)
- `--server-host`: Host for vLLM server (default: 0.0.0.0)
- `--gradio-ip`: IP for Gradio demo (default: 127.0.0.1)
- `--share`: Share Gradio demo publicly (creates a public link)
@@ -502,7 +299,7 @@ vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
If you have custom stage configs file:
```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --deploy-config /path/to/deploy_config_file
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
```
**Step 2: Run the Gradio demo**
diff --git a/examples/online_serving/qwen3_omni/openai_realtime_client.py b/examples/online_serving/qwen3_omni/openai_realtime_client.py
index 79e30a3f50b..4fa043c481d 100644
--- a/examples/online_serving/qwen3_omni/openai_realtime_client.py
+++ b/examples/online_serving/qwen3_omni/openai_realtime_client.py
@@ -1,118 +1,81 @@
-"""Realtime client for vLLM-Omni /v1/realtime (audio + text events).
-
-This client:
-1) Reads a local WAV file (must be mono, 16-bit PCM, 16kHz),
-2) Streams PCM16 chunks to /v1/realtime with OpenAI-style events,
-3) Receives response.audio.* and transcription.* events,
-4) Saves synthesized audio to an output WAV file and optional text file.
+"""
+This script demonstrates how to use the vLLM-Omni Realtime WebSocket API to perform
+audio transcription by uploading an audio file.
-By default each ``response.audio.delta`` is treated as an **incremental PCM**
-chunk and all chunks are concatenated into the final ``--output-wav``.
+Before running this script, you must start the vLLM-Omni server with a realtime-capable
+model, for example:
-Optional debugging: pass ``--delta-dump-dir DIR`` to write every
-``response.audio.delta`` payload as ``delta_000001.wav``, ``delta_000002.wav``, …
+ vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni
-Usage:
- python openai_realtime_client.py \
- --url ws://localhost:8091/v1/realtime \
- --model Qwen/Qwen3-Omni-30B-A3B-Instruct \
- --input-wav input_16k_mono.wav \
- --output-wav realtime_output.wav \
- --delta-dump-dir ./rt_delta_wavs
+Requirements:
+- vllm with audio support
+- websockets
+- librosa
+- numpy
-Dependencies:
- pip install websockets
+The script:
+1. Connects to the Realtime WebSocket endpoint
+2. Converts an audio file to PCM16 @ 16kHz
+3. Sends audio chunks to the server
+4. Receives and prints transcription as it streams
"""
-from __future__ import annotations
-
import argparse
import asyncio
import base64
import json
-import wave
-from pathlib import Path
-
-try:
- import websockets
-except ImportError:
- print("Please install websockets: pip install websockets")
- raise SystemExit(1)
-
-
-def _read_wav_pcm16(path: Path) -> bytes:
- with wave.open(str(path), "rb") as wf:
- nchannels = wf.getnchannels()
- sampwidth = wf.getsampwidth()
- framerate = wf.getframerate()
- comptype = wf.getcomptype()
- nframes = wf.getnframes()
-
- if nchannels != 1:
- raise ValueError(f"Input WAV must be mono (got {nchannels} channels).")
- if sampwidth != 2:
- raise ValueError(f"Input WAV must be 16-bit PCM (got sample width={sampwidth}).")
- if framerate != 16000:
- raise ValueError(f"Input WAV must be 16kHz (got {framerate} Hz).")
- if comptype != "NONE":
- raise ValueError(f"Input WAV must be uncompressed PCM (got comptype={comptype}).")
- if nframes <= 0:
- raise ValueError("Input WAV has no audio frames.")
-
- return wf.readframes(nframes)
-
-
-def _write_wav_pcm16(path: Path, pcm16_bytes: bytes, sample_rate_hz: int) -> None:
- with wave.open(str(path), "wb") as wf:
- wf.setnchannels(1)
- wf.setsampwidth(2)
- wf.setframerate(sample_rate_hz)
- wf.writeframes(pcm16_bytes)
-
-
-async def run_client(
- url: str,
- model: str,
- input_wav: Path,
- output_wav: Path,
- output_text: Path | None,
- chunk_ms: int,
- send_delay_ms: int,
- delta_dump_dir: Path | None,
- request_idx: int = 1,
- total_requests: int = 1,
-) -> None:
- log_prefix = f"[req {request_idx:02d}/{total_requests:02d}] " if total_requests > 1 else ""
- pcm16 = _read_wav_pcm16(input_wav)
- bytes_per_ms = 16000 * 2 // 1000 # mono PCM16 at 16kHz
- chunk_bytes = max(bytes_per_ms * chunk_ms, 2)
- incremental_pcm_parts: list[bytes] = []
- output_sample_rate = 24000
- delta_index = 0
- text_chunks: list[str] = []
- final_text: str = ""
-
- if delta_dump_dir is not None:
- delta_dump_dir.mkdir(parents=True, exist_ok=True)
-
- async with websockets.connect(url, max_size=64 * 1024 * 1024) as ws:
- # 1) Validate model.
- await ws.send(
- json.dumps(
- {
- "type": "session.update",
- "model": model,
- }
- )
- )
-
- # 2) Start generation once (non-final commit).
- await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": False}))
-
- # 3) Stream audio chunks.
- for i in range(0, len(pcm16), chunk_bytes):
- chunk = pcm16[i : i + chunk_bytes]
+import librosa
+import numpy as np
+import websockets
+from vllm.assets.audio import AudioAsset
+
+
+def audio_to_pcm16_base64(audio_path: str) -> str:
+ """
+ Load an audio file and convert it to base64-encoded PCM16 @ 16kHz.
+ """
+ # Load audio and resample to 16kHz mono
+ audio, _ = librosa.load(audio_path, sr=16000, mono=True)
+ # Convert to PCM16
+ pcm16 = (audio * 32767).astype(np.int16)
+ # Encode as base64
+ return base64.b64encode(pcm16.tobytes()).decode("utf-8")
+
+
+async def realtime_transcribe(audio_path: str, host: str, port: int, model: str):
+ """
+ Connect to the Realtime API and transcribe an audio file.
+ """
+ uri = f"ws://{host}:{port}/v1/realtime"
+
+ async with websockets.connect(uri) as ws:
+ # Wait for session.created
+ response = json.loads(await ws.recv())
+ if response["type"] == "session.created":
+ print(f"Session created: {response['id']}")
+ else:
+ print(f"Unexpected response: {response}")
+ return
+
+ # Validate model
+ await ws.send(json.dumps({"type": "session.update", "model": model}))
+
+ # Signal ready to start
+ await ws.send(json.dumps({"type": "input_audio_buffer.commit"}))
+
+ # Convert audio file to base64 PCM16
+ print(f"Loading audio from: {audio_path}")
+ audio_base64 = audio_to_pcm16_base64(audio_path)
+
+ # Send audio in chunks (4KB of raw audio = ~8KB base64)
+ chunk_size = 4096
+ audio_bytes = base64.b64decode(audio_base64)
+ total_chunks = (len(audio_bytes) + chunk_size - 1) // chunk_size
+
+ print(f"Sending {total_chunks} audio chunks...")
+ for i in range(0, len(audio_bytes), chunk_size):
+ chunk = audio_bytes[i : i + chunk_size]
await ws.send(
json.dumps(
{
@@ -121,212 +84,63 @@ async def run_client(
}
)
)
- if send_delay_ms > 0:
- await asyncio.sleep(send_delay_ms / 1000.0)
- # 4) Final commit closes input stream.
+ # Signal all audio is sent
await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": True}))
+ print("Audio sent. Waiting for transcription...\n")
- # 5) Receive server events until audio done.
+ # Receive transcription
+ print("Transcription: ", end="", flush=True)
while True:
- message = await ws.recv()
- if isinstance(message, bytes):
- # We only expect JSON text frames.
- continue
-
- event = json.loads(message)
- event_type = event.get("type")
-
- if event_type == "session.created":
- continue
-
- if event_type == "response.audio.delta":
- sr = event.get("sample_rate_hz")
- if isinstance(sr, int) and sr > 0:
- output_sample_rate = sr
- audio_b64 = event.get("audio", "")
- if audio_b64:
- pcm_delta = base64.b64decode(audio_b64)
- incremental_pcm_parts.append(pcm_delta)
- if delta_dump_dir is not None and pcm_delta:
- delta_index += 1
- dump_path = delta_dump_dir / f"delta_{delta_index:06d}.wav"
- _write_wav_pcm16(dump_path, pcm_delta, output_sample_rate)
- print(
- f"{log_prefix}delta dump #{delta_index}: {dump_path} "
- f"(pcm bytes={len(pcm_delta)}, sr={output_sample_rate})"
- )
- continue
-
- if event_type == "transcription.delta":
- delta = event.get("delta", "")
- if delta:
- text_chunks.append(delta)
- print(delta, end="", flush=True)
- continue
-
- if event_type == "transcription.done":
- final_text = event.get("text", "") or "".join(text_chunks)
- usage = event.get("usage")
- final_text_with_tag = f"Final transcription: {final_text}"
- if text_chunks:
- print()
- print(f"{log_prefix}{final_text_with_tag}")
- if usage:
- print(f"{log_prefix}text usage: {usage}")
- continue
-
- if event_type == "response.audio.done":
+ response = json.loads(await ws.recv())
+ if response["type"] == "transcription.delta":
+ print(response["delta"], end="", flush=True)
+ elif response["type"] == "transcription.done":
+ print(f"\n\nFinal transcription: {response['text']}")
+ if response.get("usage"):
+ print(f"Usage: {response['usage']}")
+ break
+ elif response["type"] == "error":
+ print(f"\nError: {response['error']}")
break
- if event_type == "error":
- raise RuntimeError(f"Server error: {event}")
-
- all_pcm16 = b"".join(incremental_pcm_parts)
- if not all_pcm16:
- raise RuntimeError("No audio received from server.")
-
- output_wav.parent.mkdir(parents=True, exist_ok=True)
- _write_wav_pcm16(output_wav, all_pcm16, output_sample_rate)
- print(f"{log_prefix}Saved realtime audio to: {output_wav} (incremental chunks joined)")
-
- if output_text is not None:
- text_to_save = final_text if final_text else "".join(text_chunks)
- output_text.parent.mkdir(parents=True, exist_ok=True)
- output_text.write_text(text_to_save, encoding="utf-8")
- print(f"{log_prefix}Saved realtime text to: {output_text}")
-
-
-def _indexed_output_path(path: Path | None, index: int, total: int) -> Path | None:
- if path is None or total <= 1:
- return path
- return path.with_name(f"{path.stem}_{index:02d}{path.suffix}")
-
-
-async def run_clients_concurrent(
- *,
- url: str,
- model: str,
- input_wav: Path,
- output_wav: Path,
- output_text: Path | None,
- chunk_ms: int,
- send_delay_ms: int,
- delta_dump_dir: Path | None,
- num_requests: int,
- concurrency: int,
-) -> None:
- sem = asyncio.Semaphore(concurrency)
-
- async def _run_one(index: int) -> tuple[int, bool, str | None]:
- per_output_wav = _indexed_output_path(output_wav, index, num_requests)
- per_output_text = _indexed_output_path(output_text, index, num_requests)
- per_delta_dir = None
- if delta_dump_dir is not None:
- per_delta_dir = delta_dump_dir / f"req_{index:02d}"
- async with sem:
- try:
- await run_client(
- url=url,
- model=model,
- input_wav=input_wav,
- output_wav=per_output_wav,
- output_text=per_output_text,
- chunk_ms=chunk_ms,
- send_delay_ms=send_delay_ms,
- delta_dump_dir=per_delta_dir,
- request_idx=index,
- total_requests=num_requests,
- )
- return index, True, None
- except Exception as exc:
- return index, False, str(exc)
- tasks = [asyncio.create_task(_run_one(i), name=f"rt-client-{i}") for i in range(1, num_requests + 1)]
- results = await asyncio.gather(*tasks)
+def main(args):
+ if args.audio_path:
+ audio_path = args.audio_path
+ else:
+ # Use default audio asset
+ audio_path = str(AudioAsset("mary_had_lamb").get_local_path())
+ print(f"No audio path provided, using default: {audio_path}")
- failed = [(idx, err) for idx, ok, err in results if not ok]
- succeeded = num_requests - len(failed)
- print(f"[summary] succeeded={succeeded}, failed={len(failed)}, total={num_requests}")
- if failed:
- for idx, err in failed:
- print(f"[summary] req {idx:02d} failed: {err}")
- raise RuntimeError(f"{len(failed)} concurrent request(s) failed")
+ asyncio.run(realtime_transcribe(audio_path, args.host, args.port, args.model))
-def main() -> None:
- parser = argparse.ArgumentParser(description="Realtime audio/text client for vLLM-Omni")
- parser.add_argument("--url", default="ws://localhost:8091/v1/realtime", help="WebSocket URL")
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Realtime WebSocket Transcription Client")
parser.add_argument(
"--model",
+ type=str,
default="Qwen/Qwen3-Omni-30B-A3B-Instruct",
- help="Model name for session.update",
+ help="Model that is served and should be pinged.",
)
- parser.add_argument("--input-wav", required=True, type=Path, help="Input WAV (mono, PCM16, 16kHz)")
- parser.add_argument("--output-wav", default=Path("realtime_output.wav"), type=Path, help="Output WAV path")
parser.add_argument(
- "--output-text",
+ "--audio_path",
+ type=str,
default=None,
- type=Path,
- help="Optional output text path for final transcription",
+ help="Path to the audio file to transcribe.",
)
- parser.add_argument("--chunk-ms", type=int, default=200, help="Input chunk size in milliseconds")
parser.add_argument(
- "--send-delay-ms",
- type=int,
- default=0,
- help="Delay between chunk sends; set >0 to simulate realtime upload",
+ "--host",
+ type=str,
+ default="localhost",
+ help="vLLM-Omni server host (default: localhost)",
)
parser.add_argument(
- "--delta-dump-dir",
- type=Path,
- default=None,
- help="If set, each response.audio.delta is saved as delta_NNNNNN.wav under this directory",
- )
- parser.add_argument("--num-requests", type=int, default=1, help="Total number of requests to send")
- parser.add_argument(
- "--concurrency",
+ "--port",
type=int,
- default=1,
- help="Maximum number of concurrent websocket requests",
+ default=8000,
+ help="vLLM-Omni server port (default: 8000)",
)
args = parser.parse_args()
-
- if args.num_requests <= 0:
- raise ValueError("--num-requests must be >= 1")
- if args.concurrency <= 0:
- raise ValueError("--concurrency must be >= 1")
- concurrency = min(args.concurrency, args.num_requests)
-
- if args.num_requests == 1:
- asyncio.run(
- run_client(
- url=args.url,
- model=args.model,
- input_wav=args.input_wav,
- output_wav=args.output_wav,
- output_text=args.output_text,
- chunk_ms=args.chunk_ms,
- send_delay_ms=args.send_delay_ms,
- delta_dump_dir=args.delta_dump_dir,
- )
- )
- else:
- asyncio.run(
- run_clients_concurrent(
- url=args.url,
- model=args.model,
- input_wav=args.input_wav,
- output_wav=args.output_wav,
- output_text=args.output_text,
- chunk_ms=args.chunk_ms,
- send_delay_ms=args.send_delay_ms,
- delta_dump_dir=args.delta_dump_dir,
- num_requests=args.num_requests,
- concurrency=concurrency,
- )
- )
-
-
-if __name__ == "__main__":
- main()
+ main(args)
diff --git a/examples/online_serving/qwen3_omni/streaming_video_client.py b/examples/online_serving/qwen3_omni/streaming_video_client.py
deleted file mode 100644
index 58f26d24557..00000000000
--- a/examples/online_serving/qwen3_omni/streaming_video_client.py
+++ /dev/null
@@ -1,208 +0,0 @@
-#!/usr/bin/env python3
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Example WebSocket client for the /v1/video/chat/stream endpoint.
-
-Sends video frames from a local file (or generates synthetic ones), submits a
-query, and prints the streamed text response.
-
-Requirements:
- pip install websockets pillow
-
-Usage:
- # With a video file (requires opencv-python):
- python streaming_video_client.py --video my_clip.mp4 \\
- --query "What is happening in this video?"
-
- # Synthetic frames (no extra deps):
- python streaming_video_client.py \\
- --query "Describe what you see." \\
- --synthetic-frames 10
-
- # With audio (Phase 3):
- python streaming_video_client.py --video my_clip.mp4 \\
- --audio my_audio.pcm \\
- --query "What is the person saying and doing?"
-"""
-
-from __future__ import annotations
-
-import argparse
-import asyncio
-import base64
-import io
-import json
-import sys
-
-try:
- import websockets
-except ImportError:
- print("Please install websockets: pip install websockets")
- sys.exit(1)
-
-from PIL import Image
-
-
-def _generate_synthetic_frame(index: int, width: int = 320, height: int = 240) -> bytes:
- """Generate a simple synthetic JPEG frame with a colour gradient."""
- r = (index * 37) % 256
- g = (index * 73) % 256
- b = (index * 113) % 256
- img = Image.new("RGB", (width, height), (r, g, b))
- buf = io.BytesIO()
- img.save(buf, format="JPEG", quality=80)
- return buf.getvalue()
-
-
-def _load_video_frames(path: str, max_frames: int = 64, fps: int = 2) -> list[bytes]:
- """Extract frames from a video file using OpenCV."""
- try:
- import cv2
- except ImportError:
- print("opencv-python is required to read video files: pip install opencv-python")
- sys.exit(1)
-
- cap = cv2.VideoCapture(path)
- if not cap.isOpened():
- print(f"Cannot open video: {path}")
- sys.exit(1)
-
- video_fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
- frame_interval = max(1, int(video_fps / fps))
-
- frames: list[bytes] = []
- idx = 0
- while len(frames) < max_frames:
- ret, frame = cap.read()
- if not ret:
- break
- if idx % frame_interval == 0:
- _, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
- frames.append(buf.tobytes())
- idx += 1
-
- cap.release()
- print(f"Loaded {len(frames)} frames from {path} (interval={frame_interval})")
- return frames
-
-
-async def run(args: argparse.Namespace) -> None:
- uri = f"ws://{args.host}:{args.port}/v1/video/chat/stream"
-
- # Prepare frames
- if args.video:
- frames = _load_video_frames(args.video, max_frames=args.max_frames, fps=args.fps)
- else:
- frames = [_generate_synthetic_frame(i) for i in range(args.synthetic_frames)]
- print(f"Generated {len(frames)} synthetic frames")
-
- # Prepare audio (optional, Phase 3)
- audio_data: bytes | None = None
- if args.audio:
- with open(args.audio, "rb") as f:
- audio_data = f.read()
- print(f"Loaded audio: {len(audio_data)} bytes")
-
- async with websockets.connect(uri, max_size=16 * 1024 * 1024) as ws:
- # 1. Send session.config
- config = {
- "type": "session.config",
- "model": args.model,
- "modalities": ["text", "audio"] if audio_data else ["text"],
- "max_frames": args.max_frames,
- "num_frames": args.num_sample_frames,
- "enable_frame_filter": args.evs,
- "frame_filter_threshold": args.evs_threshold,
- "use_audio_in_video": bool(audio_data),
- }
- await ws.send(json.dumps(config))
- print(f"Sent session.config: model={args.model} evs={args.evs}")
-
- # 2. Send frames
- for i, frame in enumerate(frames):
- msg = {
- "type": "video.frame",
- "data": base64.b64encode(frame).decode(),
- }
- await ws.send(json.dumps(msg))
- if (i + 1) % 10 == 0:
- print(f" Sent {i + 1}/{len(frames)} frames")
- print(f"Sent all {len(frames)} frames")
-
- # 3. Send audio chunks (Phase 3)
- if audio_data:
- chunk_size = 16000 * 2 # 1 second of 16 kHz 16-bit PCM
- for offset in range(0, len(audio_data), chunk_size):
- chunk = audio_data[offset : offset + chunk_size]
- msg = {
- "type": "audio.chunk",
- "data": base64.b64encode(chunk).decode(),
- }
- await ws.send(json.dumps(msg))
- print(f"Sent audio in {(len(audio_data) + chunk_size - 1) // chunk_size} chunks")
-
- # 4. Send query, then immediately send video.done so the server
- # knows the session is complete (avoids deadlock where client
- # waits for session.done while server waits for video.done).
- await ws.send(json.dumps({"type": "video.query", "text": args.query}))
- print(f"\nQuery: {args.query}")
- print("Response: ", end="", flush=True)
-
- # Signal end of session right after the query. The server will
- # process the query first (it's already queued), then handle
- # video.done and reply with session.done.
- await ws.send(json.dumps({"type": "video.done"}))
-
- # 5. Receive response until session.done
- recv_timeout = 120 # seconds — avoid infinite hang if server stalls
- while True:
- raw = await asyncio.wait_for(ws.recv(), timeout=recv_timeout)
- data = json.loads(raw)
- msg_type = data.get("type")
-
- if msg_type == "response.text.delta":
- print(data.get("delta", ""), end="", flush=True)
- elif msg_type == "response.text.done":
- print() # newline
- elif msg_type == "response.evs_stats":
- retained = data.get("retained_count", 0)
- dropped = data.get("dropped_count", 0)
- rate = data.get("drop_rate", 0)
- print(f"\nEVS stats: retained={retained} dropped={dropped} drop_rate={rate:.1%}")
- elif msg_type == "session.done":
- print("Session complete.")
- break
- elif msg_type == "error":
- print(f"\nError: {data.get('message')}")
- break
- elif msg_type == "response.start":
- pass # expected
- else:
- print(f"\n[unknown message] {data}")
-
-
-def main() -> None:
- parser = argparse.ArgumentParser(description="Streaming video chat client")
- parser.add_argument("--host", default="localhost")
- parser.add_argument("--port", type=int, default=8000)
- parser.add_argument("--model", default="Qwen/Qwen3-Omni-MoE")
- parser.add_argument("--video", help="Path to video file (requires opencv-python)")
- parser.add_argument("--audio", help="Path to raw PCM 16kHz audio file (Phase 3)")
- parser.add_argument("--query", default="What do you see in this video?")
- parser.add_argument(
- "--synthetic-frames", type=int, default=10, help="Number of synthetic frames if --video is not set"
- )
- parser.add_argument("--max-frames", type=int, default=64)
- parser.add_argument("--num-sample-frames", type=int, default=16)
- parser.add_argument("--fps", type=int, default=2, help="Frame extraction rate from video")
- parser.add_argument(
- "--no-evs", dest="evs", action="store_false", help="Disable EVS frame filtering (enabled by default)"
- )
- parser.set_defaults(evs=True)
- parser.add_argument("--evs-threshold", type=float, default=0.95)
- args = parser.parse_args()
- asyncio.run(run(args))
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/online_serving/qwen3_tts/README.md b/examples/online_serving/qwen3_tts/README.md
index 350fcb71cac..5504b5737a8 100644
--- a/examples/online_serving/qwen3_tts/README.md
+++ b/examples/online_serving/qwen3_tts/README.md
@@ -43,7 +43,7 @@ Then open http://localhost:7860 in your browser.
### Launch the Server
-The default deploy config is located at `vllm_omni/deploy/qwen3_tts.yaml` and is loaded automatically by the model registry — no `--deploy-config` flag needed for default use. Platform-specific deltas (NPU, ROCm, XPU) are merged in automatically from the `platforms:` block of the same YAML based on the detected runtime.
+The default stage config is located at `vllm_omni/model_executor/stage_configs/qwen3_tts.yaml`. For other platforms (e.g., NPU), refer to `vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml`.
```bash
# CustomVoice model (predefined speakers)
@@ -70,22 +70,6 @@ vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
--port 8091
```
-#### Sync vs async-chunk mode
-
-Qwen3-TTS supports both **chunked streaming** (default, lower latency) and
-**synchronous end-to-end** modes from the same deploy YAML. The bundled
-`qwen3_tts.yaml` ships with `async_chunk: true`; flip with `--no-async-chunk`
-and the pipeline automatically dispatches to the end-to-end codec processor:
-
-```bash
-vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice --omni --port 8091 \
- --no-async-chunk
-```
-
-No variant YAML or extra flag is needed — the `StagePipelineConfig` on each
-stage declares both processor functions and the runtime picks based on the
-`async_chunk:` bool.
-
Alternatively, use the convenience script:
```bash
./run_server.sh # Default: CustomVoice model
@@ -208,6 +192,14 @@ with open("output.wav", "wb") as f:
f.write(response.content)
```
+### FAQ
+
+If you encounter error about backend of librosa, try to install ffmpeg with command below.
+```
+sudo apt update
+sudo apt install ffmpeg
+```
+
## API Reference
### Voices Endpoint
@@ -394,54 +386,6 @@ Server -> Client:
{"type": "session.done", "total_sentences": 1}
```
-## Choosing an Execution Backend: Uniproc vs Multiprocessing
-
-Qwen3-TTS stage configs support two execution backends controlled by the
-`distributed_executor_backend` engine arg. The performance tradeoff between
-them is **both hardware- and task-dependent**, so there is no single best
-default (see [#2603](https://github.com/vllm-project/vllm-omni/issues/2603),
-[#2604](https://github.com/vllm-project/vllm-omni/pull/2604) for the full
-investigation).
-
-| Backend | Stage config setting | Behaviour |
-| ------- | -------------------- | --------- |
-| **Uniproc** (default, world_size=1) | `distributed_executor_backend` omitted | Both stages run inside the orchestrator process. Avoids IPC serialisation, D2H copies, and msgpack overhead between stages. |
-| **Multiprocessing** | `distributed_executor_backend: "mp"` | Each stage runs in its own subprocess. The Talker can continue decoding while Code2Wav runs the vocoder in parallel, improving pipeline utilisation under concurrency. |
-
-> **Note:** When `distributed_executor_backend` is omitted and `world_size=1`,
-> vLLM [automatically uses the uniproc executor](https://github.com/vllm-project/vllm/blob/main/vllm/config/parallel.py#L825).
-> When `world_size > 1`, it defaults to `mp`.
-
-### When uniproc wins
-
-The uniproc path eliminates inter-process data transfer (D2H copies,
-msgpack serialisation/deserialisation, tensor detaching). This matters most
-when per-request processing is heavy relative to autoregressive decode.
-
-The Base cloning task involves reference-audio encoding on every request, making IPC
-overhead a larger fraction of total cost. Qwen3-Omni shows a similar pattern.
-
-### When multiprocessing (`mp`) wins
-
-For lighter per-request workloads, process-level parallelism between the
-Talker and Code2Wav stages dominates.
-
-CustomVoice is lighter per-request (no reference audio encoding), so the
-process-level parallelism of `mp` outweighs its serialisation cost at
-concurrency ≥ 4.
-
-### How to switch
-
-To use the uniproc executor on a single-GPU setup, pass the
-`qwen3_tts_uniproc.yaml` stage config:
-
-```bash
-vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-Base \
- --omni \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml \
- --port 8091
-```
-
## Limitations
- **Single request**: Batch processing is not yet optimized for online serving.
diff --git a/examples/online_serving/qwen3_tts/batch_speech_client.py b/examples/online_serving/qwen3_tts/batch_speech_client.py
index 47fdc3691c7..7d48e650f88 100644
--- a/examples/online_serving/qwen3_tts/batch_speech_client.py
+++ b/examples/online_serving/qwen3_tts/batch_speech_client.py
@@ -5,13 +5,11 @@
batch level and generate many utterances in the cloned voice without repeating
the reference for each item.
-Start the server (with batch-optimized stage settings for best throughput):
+Start the server (with batch-optimized config for best throughput):
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --omni \
- --trust-remote-code \
- --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2},
- "1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}'
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml \
+ --trust-remote-code
Examples:
# Batch with a predefined voice
diff --git a/examples/online_serving/qwen3_tts/openai_speech_client.py b/examples/online_serving/qwen3_tts/openai_speech_client.py
index 77e13b08ed2..4741a47158c 100644
--- a/examples/online_serving/qwen3_tts/openai_speech_client.py
+++ b/examples/online_serving/qwen3_tts/openai_speech_client.py
@@ -71,7 +71,7 @@ def run_tts_generation(args) -> None:
payload = {
"model": args.model,
"input": args.text,
- "voice": args.speaker,
+ "speaker": args.speaker,
"response_format": args.response_format,
}
diff --git a/examples/online_serving/qwen3_tts/run_gradio_demo.sh b/examples/online_serving/qwen3_tts/run_gradio_demo.sh
index d79be3c2abd..bcc0ddb7cf5 100644
--- a/examples/online_serving/qwen3_tts/run_gradio_demo.sh
+++ b/examples/online_serving/qwen3_tts/run_gradio_demo.sh
@@ -127,7 +127,7 @@ echo "Starting vLLM server..."
LOG_FILE="/tmp/vllm_tts_server_${SERVER_PORT}.log"
vllm-omni serve "$MODEL" \
- --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
--host "$SERVER_HOST" \
--port "$SERVER_PORT" \
--gpu-memory-utilization 0.9 \
diff --git a/examples/online_serving/qwen3_tts/run_server.sh b/examples/online_serving/qwen3_tts/run_server.sh
index 78dd2c305d3..6f4aa83a0b9 100755
--- a/examples/online_serving/qwen3_tts/run_server.sh
+++ b/examples/online_serving/qwen3_tts/run_server.sh
@@ -31,7 +31,7 @@ esac
echo "Starting Qwen3-TTS server with model: $MODEL"
vllm-omni serve "$MODEL" \
- --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
--host 0.0.0.0 \
--port 8091 \
--gpu-memory-utilization 0.9 \
diff --git a/examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py b/examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py
index 7790fa51276..e6786f8869f 100644
--- a/examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py
+++ b/examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py
@@ -5,7 +5,7 @@
using SLERP and sends the result to the /v1/audio/speech API.
Requirements:
- pip install torch soundfile numpy httpx
+ pip install torch librosa soundfile numpy httpx
Examples:
# Extract and save an embedding
@@ -143,18 +143,17 @@ def _load_speaker_encoder_weights(encoder: torch.nn.Module, model_path: str) ->
def compute_mel_spectrogram(audio: np.ndarray, sr: int = 24000) -> torch.Tensor:
"""Compute 128-bin mel spectrogram matching Qwen3-TTS's extraction pipeline."""
- from vllm.multimodal.audio import AudioResampler
+ import librosa
# Resample to 24kHz if needed
if sr != 24000:
- resampler = AudioResampler(target_sr=24000)
- audio = resampler.resample(audio.astype(np.float32), orig_sr=sr)
+ audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=24000)
y = torch.from_numpy(audio).unsqueeze(0).float()
- from vllm_omni.utils.audio import mel_filter_bank
+ from librosa.filters import mel as librosa_mel_fn
- mel_basis = mel_filter_bank(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000)
+ mel_basis = torch.from_numpy(librosa_mel_fn(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000)).float()
n_fft = 1024
hop_size = 256
@@ -181,9 +180,9 @@ def compute_mel_spectrogram(audio: np.ndarray, sr: int = 24000) -> torch.Tensor:
@torch.inference_mode()
def extract_embedding(encoder: torch.nn.Module, audio_path: str, device: str = "cpu") -> np.ndarray:
"""Extract a 1024-dim speaker embedding from an audio file."""
- from vllm.multimodal.media.audio import load_audio
+ import librosa
- audio, sr = load_audio(audio_path, sr=None, mono=True)
+ audio, sr = librosa.load(audio_path, sr=None, mono=True)
mel = compute_mel_spectrogram(audio, sr).to(device)
embedding = encoder(mel.to(next(encoder.parameters()).dtype))[0]
return embedding.float().cpu().numpy()
diff --git a/examples/online_serving/text_to_image/README.md b/examples/online_serving/text_to_image/README.md
index 17d377ea3e2..87b6a56438e 100644
--- a/examples/online_serving/text_to_image/README.md
+++ b/examples/online_serving/text_to_image/README.md
@@ -231,8 +231,6 @@ count, use `size` and `n` rather than `height`, `width`, or
| `seed` | int | None | Random seed (reproducible) |
| `negative_prompt` | str | None | Negative prompt |
| `num_outputs_per_prompt` | int | 1 | Number of images to generate |
-| `use_system_prompt` | str | None | System prompt preset: `en_unified`, `en_vanilla`, `en_recaption`, `en_think_recaption`, `dynamic`, `None`, or custom text string. Only for HunyuanImage-3.0. |
-| `system_prompt` | str | None | Custom system prompt text. Only used when `use_system_prompt` is set to `custom`. Only for HunyuanImage-3.0. |
## Response Format
diff --git a/examples/online_serving/text_to_image/openai_chat_client.py b/examples/online_serving/text_to_image/openai_chat_client.py
index f3c43086a14..828827aba2d 100644
--- a/examples/online_serving/text_to_image/openai_chat_client.py
+++ b/examples/online_serving/text_to_image/openai_chat_client.py
@@ -28,8 +28,6 @@ def generate_image(
lora_name: str | None = None,
lora_scale: float | None = None,
lora_int_id: int | None = None,
- use_system_prompt: str | None = None,
- system_prompt: str | None = None,
) -> bytes | None:
"""Generate an image using the images generation API.
@@ -47,8 +45,6 @@ def generate_image(
lora_name: LoRA name (optional, defaults to path stem)
lora_scale: LoRA scale factor (default: 1.0)
lora_int_id: LoRA integer ID (optional, derived from path if not provided)
- use_system_prompt: System prompt for generation.
- system_prompt: Custom system prompt.
Returns:
Image bytes or None if failed
@@ -74,10 +70,7 @@ def generate_image(
payload["negative_prompt"] = negative_prompt
if seed is not None:
payload["seed"] = seed
- if use_system_prompt is not None:
- payload["use_system_prompt"] = use_system_prompt
- if system_prompt is not None:
- payload["system_prompt"] = system_prompt
+
# Add LoRA if provided
if lora_path:
lora_body: dict = {
@@ -135,21 +128,9 @@ def main():
default=None,
help="LoRA integer id (cache key). If omitted, the server derives a stable id from lora_path.",
)
- parser.add_argument(
- "--use-system-prompt",
- type=str,
- default=None,
- help=(
- "System prompt for generation. Use predefined types: 'en_unified', 'en_vanilla', 'en_recaption', 'en_think_recaption', 'dynamic', or 'None'; Or provide custom text string directly. Recommended en_unified. "
- ),
- )
- parser.add_argument(
- "--system-prompt",
- type=str,
- default=None,
- help=("Custom system prompt. Used when --use-system-prompt is custom. "),
- )
+
args = parser.parse_args()
+
print(f"Generating image for: {args.prompt}")
image_bytes = generate_image(
@@ -165,8 +146,6 @@ def main():
lora_name=args.lora_name,
lora_scale=args.lora_scale if args.lora_path else None,
lora_int_id=args.lora_int_id if args.lora_path else None,
- use_system_prompt=args.use_system_prompt,
- system_prompt=args.system_prompt,
)
if image_bytes:
diff --git a/examples/online_serving/text_to_video/README.md b/examples/online_serving/text_to_video/README.md
index c01e0602ff9..44e676671fe 100644
--- a/examples/online_serving/text_to_video/README.md
+++ b/examples/online_serving/text_to_video/README.md
@@ -1,27 +1,16 @@
# Text-To-Video
-This example demonstrates how to deploy text-to-video models for online video generation using vLLM-Omni.
+This example demonstrates how to deploy the Wan2.2 text-to-video model for online video generation using vLLM-Omni.
-## Supported Models
+## Start Server
-| Model | Model ID |
-|-------|----------|
-| Wan2.1 T2V (1.3B) | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` |
-| Wan2.1 T2V (14B) | `Wan-AI/Wan2.1-T2V-14B-Diffusers` |
-| Wan2.2 T2V | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` |
-| LTX-2 | `Lightricks/LTX-2` |
-
-## Wan2.2 T2V
-
-### Start Server
-
-#### Basic Start
+### Basic Start
```bash
vllm serve Wan-AI/Wan2.2-T2V-A14B-Diffusers --omni --port 8091
```
-#### Start with Parameters
+### Start with Parameters
Or use the startup script:
@@ -241,82 +230,3 @@ while true; do
sleep 2
done
```
-
-## LTX-2
-
-### Start Server
-
-#### Basic Start
-
-```bash
-vllm serve Lightricks/LTX-2 --omni --port 8098 \
- --enforce-eager --flow-shift 1.0 --boundary-ratio 1.0
-```
-
-#### Start with Optimization Presets
-
-Use the LTX-2 startup script with built-in optimization presets:
-
-```bash
-# Baseline (1 GPU, eager)
-bash run_server_ltx2.sh baseline
-
-# 4-GPU Ulysses sequence parallelism (lossless)
-bash run_server_ltx2.sh ulysses4
-
-# Cache-DiT lossy acceleration (1 GPU, ~1.4× speedup)
-bash run_server_ltx2.sh cache-dit
-
-# Best combo: 4-GPU Ulysses SP + Cache-DiT (~2.2× speedup)
-bash run_server_ltx2.sh best-combo
-```
-
-#### Optimization Benchmarks
-
-Benchmarked on H800, online serving (480×768, 41 frames, 20 steps, `seed=42`).
-"Inference" is the server-reported inference time; excludes HTTP/poll overhead.
-
-| Preset | Server Command | Inference (s) | Speedup | Type |
-|--------|---------------|---------------|---------|------|
-| `baseline` | `--enforce-eager` | 10.3 | 1.00× | — |
-| `compile` | *(default, no --enforce-eager)* | ~10.3 (warm) | ~1.00× | Lossless |
-| `ulysses4` | `--enforce-eager --usp 4` | ~10.3 | ~1.00× | Lossless |
-| `cache-dit` | `--enforce-eager --cache-backend cache_dit` | 7.4 avg | ~1.4× | Lossy |
-| `best-combo` | `--enforce-eager --usp 4 --cache-backend cache_dit` | 4.7 avg | **~2.2×** | Lossless + Lossy |
-
-**Observations**:
-- **torch.compile**: On H800, warm-request inference time matches the eager baseline (~10.3s).
- The first request pays ~6s compilation overhead. Benefit depends on model architecture and GPU.
-- **Ulysses SP (4 GPU)**: No measurable speedup alone for 41-frame generation at this resolution.
- Communication overhead outweighs gains at this sequence length.
-- **Cache-DiT**: Inference varies per request (6–10s) due to dynamic caching decisions.
- Average is ~7.4s (~1.4× speedup) with slight quality tradeoff.
-- **Best combo**: 4-GPU Ulysses SP + Cache-DiT synergize well — Cache-DiT reduces per-step
- computation, making the communication overhead of Ulysses SP worthwhile. Average ~4.7s
- (~2.2× speedup).
-- **FP8 quantization**: Reduces VRAM but does not speed up LTX-2 on H800 (compute-bound).
-
-**Deployment Recommendations**:
-- For **production with quality priority**: use `baseline` with `--enforce-eager`
-- For **maximum throughput** (4 GPUs, quality tradeoff): use `best-combo` (~2.2× speedup)
-- For **single-GPU throughput**: use `cache-dit` (~1.4× speedup)
-- `--enforce-eager` is recommended to avoid torch.compile warmup latency on first request
-
-### Send Requests (curl)
-
-```bash
-# Using the provided script
-bash run_curl_ltx2.sh
-
-# Or directly
-curl -sS -X POST http://localhost:8098/v1/videos \
- -H "Accept: application/json" \
- -F "prompt=A serene lakeside sunrise with mist over the water." \
- -F "width=768" \
- -F "height=480" \
- -F "num_frames=41" \
- -F "fps=24" \
- -F "num_inference_steps=20" \
- -F "guidance_scale=3.0" \
- -F "seed=42"
-```
diff --git a/examples/online_serving/text_to_video/run_curl_ltx2.sh b/examples/online_serving/text_to_video/run_curl_ltx2.sh
deleted file mode 100644
index b82f672eaab..00000000000
--- a/examples/online_serving/text_to_video/run_curl_ltx2.sh
+++ /dev/null
@@ -1,66 +0,0 @@
-#!/bin/bash
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-#
-# LTX-2 text-to-video curl example using the async video job API.
-# Start the server first: bash run_server_ltx2.sh best-combo
-
-set -euo pipefail
-
-BASE_URL="${BASE_URL:-http://localhost:8098}"
-OUTPUT_PATH="${OUTPUT_PATH:-ltx2_output.mp4}"
-POLL_INTERVAL="${POLL_INTERVAL:-2}"
-
-PROMPT="${PROMPT:-A serene lakeside sunrise with mist over the water.}"
-
-create_response=$(
- curl -sS -X POST "${BASE_URL}/v1/videos" \
- -H "Accept: application/json" \
- -F "prompt=${PROMPT}" \
- -F "width=768" \
- -F "height=480" \
- -F "num_frames=41" \
- -F "fps=24" \
- -F "num_inference_steps=20" \
- -F "guidance_scale=3.0" \
- -F "seed=42"
-)
-
-video_id="$(echo "${create_response}" | jq -r '.id')"
-if [ -z "${video_id}" ] || [ "${video_id}" = "null" ]; then
- echo "Failed to create video job:"
- echo "${create_response}" | jq .
- exit 1
-fi
-
-echo "Created video job ${video_id}"
-echo "${create_response}" | jq .
-
-while true; do
- status_response="$(curl -sS "${BASE_URL}/v1/videos/${video_id}")"
- status="$(echo "${status_response}" | jq -r '.status')"
-
- case "${status}" in
- queued|in_progress)
- echo "Video job ${video_id} status: ${status}"
- sleep "${POLL_INTERVAL}"
- ;;
- completed)
- echo "${status_response}" | jq .
- break
- ;;
- failed)
- echo "Video generation failed:"
- echo "${status_response}" | jq .
- exit 1
- ;;
- *)
- echo "Unexpected status response:"
- echo "${status_response}" | jq .
- exit 1
- ;;
- esac
-done
-
-curl -sS -L "${BASE_URL}/v1/videos/${video_id}/content" -o "${OUTPUT_PATH}"
-echo "Saved video to ${OUTPUT_PATH}"
diff --git a/examples/online_serving/text_to_video/run_server_ltx2.sh b/examples/online_serving/text_to_video/run_server_ltx2.sh
deleted file mode 100644
index f4597d3cd28..00000000000
--- a/examples/online_serving/text_to_video/run_server_ltx2.sh
+++ /dev/null
@@ -1,84 +0,0 @@
-#!/bin/bash
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-#
-# LTX-2 online serving startup script with optimization presets.
-#
-# Usage:
-# bash run_server_ltx2.sh # baseline (1 GPU, eager)
-# bash run_server_ltx2.sh ulysses4 # 4-GPU Ulysses SP
-# bash run_server_ltx2.sh cache-dit # 1 GPU + Cache-DiT
-# bash run_server_ltx2.sh best-combo # 4-GPU Ulysses SP + Cache-DiT
-#
-# Online serving benchmarks on H800 (480×768, 41 frames, 20 steps):
-# baseline : 10.3s inference (1.00×)
-# compile : ~10.3s warm (~1.00×) first request +6s warmup
-# ulysses4 : ~10.3s (~1.00×) no gain at 41 frames
-# cache-dit : 7.4s avg (~1.4×) lossy, variable per request
-# best-combo : 4.7s avg (~2.2×) 4-GPU ulysses + cache-dit
-
-set -euo pipefail
-
-MODEL="${MODEL:-Lightricks/LTX-2}"
-PORT="${PORT:-8098}"
-FLOW_SHIFT="${FLOW_SHIFT:-1.0}"
-BOUNDARY_RATIO="${BOUNDARY_RATIO:-1.0}"
-
-PRESET="${1:-baseline}"
-
-EXTRA_ARGS=()
-case "$PRESET" in
- baseline)
- echo "=== LTX-2 Preset: baseline (1 GPU, enforce-eager) ==="
- EXTRA_ARGS+=(--enforce-eager)
- ;;
- ulysses2)
- echo "=== LTX-2 Preset: 2-GPU Ulysses SP (lossless) ==="
- EXTRA_ARGS+=(--enforce-eager --usp 2)
- ;;
- ulysses4)
- echo "=== LTX-2 Preset: 4-GPU Ulysses SP (lossless) ==="
- EXTRA_ARGS+=(--enforce-eager --usp 4)
- ;;
- cache-dit)
- echo "=== LTX-2 Preset: Cache-DiT (1 GPU, lossy) ==="
- EXTRA_ARGS+=(--enforce-eager --cache-backend cache_dit)
- ;;
- best-combo)
- echo "=== LTX-2 Preset: 4-GPU Ulysses SP + Cache-DiT (best combo) ==="
- EXTRA_ARGS+=(--enforce-eager --usp 4 --cache-backend cache_dit)
- ;;
- compile)
- echo "=== LTX-2 Preset: torch.compile (1 GPU, lossless) ==="
- # torch.compile is the default (no --enforce-eager)
- ;;
- *)
- echo "Usage: $0 {baseline|ulysses2|ulysses4|cache-dit|best-combo|compile}"
- echo ""
- echo "Presets:"
- echo " baseline - 1 GPU, eager execution (reference)"
- echo " ulysses2 - 2-GPU Ulysses SP (lossless)"
- echo " ulysses4 - 4-GPU Ulysses SP (lossless)"
- echo " cache-dit - 1 GPU + Cache-DiT (lossy, ~1.4× speedup)"
- echo " best-combo - 4-GPU Ulysses SP + Cache-DiT (~2.2× speedup)"
- echo " compile - 1 GPU + torch.compile (slower first request)"
- echo ""
- echo "Environment variables:"
- echo " MODEL - Model path (default: Lightricks/LTX-2)"
- echo " PORT - Server port (default: 8098)"
- echo " FLOW_SHIFT - Scheduler flow shift (default: 1.0)"
- echo " BOUNDARY_RATIO - Boundary ratio (default: 1.0)"
- exit 1
- ;;
-esac
-
-echo "Model: $MODEL"
-echo "Port: $PORT"
-echo "Flow shift: $FLOW_SHIFT"
-echo "Boundary ratio: $BOUNDARY_RATIO"
-
-vllm serve "$MODEL" --omni \
- --port "$PORT" \
- --flow-shift "$FLOW_SHIFT" \
- --boundary-ratio "$BOUNDARY_RATIO" \
- "${EXTRA_ARGS[@]}"
diff --git a/examples/online_serving/voxcpm/README.md b/examples/online_serving/voxcpm/README.md
deleted file mode 100644
index 78e1bf4aaa3..00000000000
--- a/examples/online_serving/voxcpm/README.md
+++ /dev/null
@@ -1,166 +0,0 @@
-# VoxCPM
-
-## Prerequisites
-
-Install VoxCPM in one of these ways:
-
-```bash
-pip install voxcpm
-```
-
-or point vLLM-Omni to a local VoxCPM source tree:
-
-```bash
-export VLLM_OMNI_VOXCPM_CODE_PATH=/path/to/VoxCPM/src
-```
-
-If the native VoxCPM `config.json` lacks HF metadata such as `model_type`,
-prepare a persistent HF-compatible config directory and export:
-
-```bash
-export VLLM_OMNI_VOXCPM_HF_CONFIG_PATH=/tmp/voxcpm_hf_config
-mkdir -p "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH"
-cp "$VOXCPM_MODEL/config.json" "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH/config.json"
-cp "$VOXCPM_MODEL/generation_config.json" "$VLLM_OMNI_VOXCPM_HF_CONFIG_PATH/generation_config.json" 2>/dev/null || true
-python3 -c 'import json, os; p=os.path.join(os.environ["VLLM_OMNI_VOXCPM_HF_CONFIG_PATH"], "config.json"); cfg=json.load(open(p, "r", encoding="utf-8")); cfg["model_type"]="voxcpm"; cfg.setdefault("architectures", ["VoxCPMForConditionalGeneration"]); json.dump(cfg, open(p, "w", encoding="utf-8"), indent=2, ensure_ascii=False)'
-```
-
-The VoxCPM stage configs read `VLLM_OMNI_VOXCPM_HF_CONFIG_PATH` directly. The `python3 -c` form above avoids heredoc/indentation issues in interactive shells.
-
-## Launch the Server
-
-Use the async-chunk stage config by default:
-
-```bash
-export VOXCPM_MODEL=/path/to/voxcpm-model
-cd examples/online_serving/voxcpm
-./run_server.sh
-```
-
-Use the non-streaming stage config:
-
-```bash
-./run_server.sh sync
-```
-
-You can also launch the server directly:
-
-```bash
-vllm serve "$VOXCPM_MODEL" \
- --stage-configs-path vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml \
- --trust-remote-code \
- --enforce-eager \
- --omni \
- --port 8091
-```
-
-## Send Requests
-
-### Basic text-to-speech
-
-```bash
-python openai_speech_client.py \
- --model "$VOXCPM_MODEL" \
- --text "This is a VoxCPM online text-to-speech example."
-```
-
-### Voice cloning
-
-```bash
-python openai_speech_client.py \
- --model "$VOXCPM_MODEL" \
- --text "This sentence is synthesized with a cloned voice." \
- --ref-audio /path/to/reference.wav \
- --ref-text "The exact transcript spoken in reference.wav."
-```
-
-`ref_text` must be the real transcript of the reference audio. Placeholder text or mismatched text will usually degrade quality badly.
-
-### Streaming PCM output
-
-```bash
-python openai_speech_client.py \
- --model "$VOXCPM_MODEL" \
- --text "This is a streaming VoxCPM request." \
- --stream \
- --output voxcpm_stream.pcm
-```
-
-### Using curl
-
-```bash
-curl -X POST http://localhost:8091/v1/audio/speech \
- -H "Content-Type: application/json" \
- -d '{
- "model": "OpenBMB/VoxCPM1.5",
- "input": "Hello from VoxCPM online serving.",
- "response_format": "wav"
- }' --output output.wav
-```
-
-Voice cloning:
-
-```bash
-curl -X POST http://localhost:8091/v1/audio/speech \
- -H "Content-Type: application/json" \
- -d '{
- "model": "OpenBMB/VoxCPM1.5",
- "input": "This sentence uses a cloned voice.",
- "ref_audio": "https://example.com/reference.wav",
- "ref_text": "The exact transcript spoken in the reference audio.",
- "response_format": "wav"
- }' --output cloned.wav
-```
-
-Streaming PCM:
-
-```bash
-curl -X POST http://localhost:8091/v1/audio/speech \
- -H "Content-Type: application/json" \
- -d '{
- "model": "OpenBMB/VoxCPM1.5",
- "input": "This is a streaming VoxCPM request.",
- "stream": true,
- "response_format": "pcm"
- }' --output output.pcm
-```
-
-## Supported Request Shape
-
-VoxCPM online serving currently supports:
-
-- plain text-to-speech
-- voice cloning with `ref_audio` + `ref_text`
-- `stream=true` with `response_format=pcm` or `wav`
-
-VoxCPM online serving does not use these generic TTS fields:
-
-- `voice`
-- `instructions`
-- `language`
-- `speaker_embedding`
-- `x_vector_only_mode`
-
-## Streaming vs Non-Streaming
-
-- `voxcpm_async_chunk.yaml` enables async-chunk streaming and is best for single-request streaming latency.
-- `voxcpm.yaml` performs one-shot latent generation then VAE decode.
-
-Like native VoxCPM, the async streaming path should be treated as single-request. If you need stable throughput benchmarking, prefer `voxcpm.yaml`.
-
-Do not use `voxcpm_async_chunk.yaml` for concurrent online streaming or `/v1/audio/speech/batch`. For multiple requests, prefer `voxcpm.yaml`.
-
-## Benchmark
-
-The serving benchmark reports TTFP and RTF:
-
-```bash
-python benchmarks/voxcpm/vllm_omni/bench_tts_serve.py \
- --host 127.0.0.1 \
- --port 8091 \
- --num-prompts 10 \
- --max-concurrency 1 \
- --result-dir /tmp/voxcpm_bench
-```
-
-For the async-chunk server, keep `--max-concurrency 1`.
diff --git a/examples/online_serving/voxcpm/openai_speech_client.py b/examples/online_serving/voxcpm/openai_speech_client.py
deleted file mode 100644
index c400114e8be..00000000000
--- a/examples/online_serving/voxcpm/openai_speech_client.py
+++ /dev/null
@@ -1,155 +0,0 @@
-"""OpenAI-compatible client for VoxCPM via /v1/audio/speech.
-
-Examples:
- # Basic text-to-speech
- python openai_speech_client.py --text "Hello from VoxCPM"
-
- # Voice cloning
- python openai_speech_client.py \
- --text "This sentence uses the cloned voice." \
- --ref-audio /path/to/reference.wav \
- --ref-text "The exact transcript spoken in the reference audio."
-
- # Streaming PCM output
- python openai_speech_client.py \
- --text "This is a streaming VoxCPM request." \
- --stream \
- --output output.pcm
-"""
-
-import argparse
-import base64
-import os
-
-import httpx
-
-DEFAULT_API_BASE = "http://localhost:8091"
-DEFAULT_API_KEY = "EMPTY"
-DEFAULT_MODEL = "OpenBMB/VoxCPM1.5"
-
-
-def encode_audio_to_base64(audio_path: str) -> str:
- """Encode a local audio file to base64 data URL."""
- if not os.path.exists(audio_path):
- raise FileNotFoundError(f"Audio file not found: {audio_path}")
-
- ext = audio_path.lower().rsplit(".", 1)[-1]
- mime_map = {
- "wav": "audio/wav",
- "mp3": "audio/mpeg",
- "flac": "audio/flac",
- "ogg": "audio/ogg",
- }
- mime_type = mime_map.get(ext, "audio/wav")
-
- with open(audio_path, "rb") as f:
- audio_b64 = base64.b64encode(f.read()).decode("utf-8")
- return f"data:{mime_type};base64,{audio_b64}"
-
-
-def build_payload(args) -> dict[str, object]:
- payload: dict[str, object] = {
- "model": args.model,
- "input": args.text,
- "response_format": "pcm" if args.stream else args.response_format,
- }
-
- if args.ref_audio:
- if args.ref_audio.startswith(("http://", "https://", "data:")):
- payload["ref_audio"] = args.ref_audio
- else:
- payload["ref_audio"] = encode_audio_to_base64(args.ref_audio)
- if args.ref_text:
- payload["ref_text"] = args.ref_text
- if args.max_new_tokens is not None:
- payload["max_new_tokens"] = args.max_new_tokens
- if args.stream:
- payload["stream"] = True
-
- return payload
-
-
-def run_tts(args) -> None:
- payload = build_payload(args)
- api_url = f"{args.api_base}/v1/audio/speech"
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {args.api_key}",
- }
-
- print(f"Model: {args.model}")
- print(f"Text: {args.text}")
- if args.ref_audio:
- print("Mode: voice cloning")
- print(f"Reference audio: {args.ref_audio}")
- else:
- print("Mode: text-to-speech")
-
- if args.stream:
- output_path = args.output or "voxcpm_output.pcm"
- with httpx.Client(timeout=300.0) as client:
- with client.stream("POST", api_url, json=payload, headers=headers) as response:
- if response.status_code != 200:
- print(f"Error: {response.status_code}")
- print(response.read().decode("utf-8", errors="ignore"))
- return
-
- total_bytes = 0
- with open(output_path, "wb") as f:
- for chunk in response.iter_bytes():
- if not chunk:
- continue
- f.write(chunk)
- total_bytes += len(chunk)
- print(f"Streamed {total_bytes} bytes to: {output_path}")
- return
-
- with httpx.Client(timeout=300.0) as client:
- response = client.post(api_url, json=payload, headers=headers)
-
- if response.status_code != 200:
- print(f"Error: {response.status_code}")
- print(response.text)
- return
-
- try:
- text = response.content.decode("utf-8")
- if text.startswith('{"error"'):
- print(f"Error: {text}")
- return
- except UnicodeDecodeError:
- pass
-
- output_path = args.output or "voxcpm_output.wav"
- with open(output_path, "wb") as f:
- f.write(response.content)
- print(f"Audio saved to: {output_path}")
-
-
-def main():
- parser = argparse.ArgumentParser(description="VoxCPM OpenAI-compatible speech client")
- parser.add_argument("--api-base", default=DEFAULT_API_BASE, help="API base URL")
- parser.add_argument("--api-key", default=DEFAULT_API_KEY, help="API key")
- parser.add_argument("--model", "-m", default=DEFAULT_MODEL, help="Model name or path")
- parser.add_argument("--text", required=True, help="Text to synthesize")
- parser.add_argument("--ref-audio", default=None, help="Reference audio path, URL, or data URL")
- parser.add_argument(
- "--ref-text",
- default=None,
- help="The exact transcript spoken in the reference audio",
- )
- parser.add_argument("--stream", action="store_true", help="Enable streaming PCM output")
- parser.add_argument(
- "--response-format",
- default="wav",
- choices=["wav", "pcm", "flac", "mp3", "aac", "opus"],
- help="Audio format for non-streaming mode (default: wav)",
- )
- parser.add_argument("--max-new-tokens", type=int, default=None, help="Maximum tokens to generate")
- parser.add_argument("--output", "-o", default=None, help="Output file path")
- args = parser.parse_args()
- run_tts(args)
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/online_serving/voxcpm/run_server.sh b/examples/online_serving/voxcpm/run_server.sh
deleted file mode 100755
index ab4b6fe854e..00000000000
--- a/examples/online_serving/voxcpm/run_server.sh
+++ /dev/null
@@ -1,38 +0,0 @@
-#!/bin/bash
-# Launch vLLM-Omni server for VoxCPM online speech serving.
-#
-# Usage:
-# ./run_server.sh # default: async_chunk stage config
-# ./run_server.sh async # async_chunk stage config
-# ./run_server.sh sync # no-async-chunk stage config
-# VOXCPM_MODEL=/path/to/model ./run_server.sh
-
-set -e
-
-MODE="${1:-async}"
-MODEL="${VOXCPM_MODEL:-OpenBMB/VoxCPM1.5}"
-
-case "$MODE" in
- async)
- STAGE_CONFIG="vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml"
- ;;
- sync)
- STAGE_CONFIG="vllm_omni/model_executor/stage_configs/voxcpm.yaml"
- ;;
- *)
- echo "Unknown mode: $MODE"
- echo "Supported: async, sync"
- exit 1
- ;;
-esac
-
-echo "Starting VoxCPM server with model: $MODEL"
-echo "Stage config: $STAGE_CONFIG"
-
-vllm serve "$MODEL" \
- --stage-configs-path "$STAGE_CONFIG" \
- --host 0.0.0.0 \
- --port 8091 \
- --trust-remote-code \
- --enforce-eager \
- --omni
diff --git a/examples/online_serving/voxcpm2/README.md b/examples/online_serving/voxcpm2/README.md
deleted file mode 100644
index 9ca2ae708a3..00000000000
--- a/examples/online_serving/voxcpm2/README.md
+++ /dev/null
@@ -1,43 +0,0 @@
-# VoxCPM2 Online Serving
-
-Serve VoxCPM2 TTS via the OpenAI-compatible `/v1/audio/speech` endpoint.
-
-## Start the Server
-
-```bash
-vllm serve openbmb/VoxCPM2 --omni --host 0.0.0.0 --port 8000
-```
-
-The deploy config is auto-loaded from `vllm_omni/deploy/voxcpm2.yaml`. Pass
-`--deploy-config ` to override, or `--stage-N- ` (e.g.
-`--stage-0-max-num-seqs 8`) for per-stage runtime overrides.
-
-## Zero-shot Synthesis
-
-```bash
-python openai_speech_client.py --text "Hello, this is VoxCPM2."
-```
-
-Or with curl:
-
-```bash
-curl -X POST http://localhost:8000/v1/audio/speech \
- -H "Content-Type: application/json" \
- -d '{"model": "voxcpm2", "input": "Hello, this is VoxCPM2.", "voice": "default"}' \
- --output output.wav
-```
-
-## Voice Cloning
-
-Clone a speaker's voice using a reference audio file:
-
-```bash
-python openai_speech_client.py \
- --text "This should sound like the reference speaker." \
- --ref-audio /path/to/reference.wav
-```
-
-The `--ref-audio` parameter accepts:
-- Local file path (auto-encoded to base64)
-- URL (`https://...`)
-- Base64 data URI (`data:audio/wav;base64,...`)
diff --git a/examples/online_serving/voxcpm2/gradio_demo.py b/examples/online_serving/voxcpm2/gradio_demo.py
deleted file mode 100644
index c6706198ae4..00000000000
--- a/examples/online_serving/voxcpm2/gradio_demo.py
+++ /dev/null
@@ -1,599 +0,0 @@
-"""Gradio demo for VoxCPM2 TTS with gapless streaming audio playback.
-
-Uses a custom AudioWorklet-based player for gap-free streaming
-(adapted from the Qwen3-TTS demo). Audio is streamed from the vLLM
-server through a same-origin proxy and played via the Web Audio API's
-AudioWorklet, which maintains a FIFO buffer queue and plays samples at
-the audio clock rate.
-
-Usage:
- # Start the vLLM server first:
- vllm serve openbmb/VoxCPM2 --omni --host 0.0.0.0 --port 8000
-
- # Then launch the demo:
- python gradio_demo.py --api-base http://localhost:8000
-"""
-
-from __future__ import annotations
-
-import argparse
-import base64
-import io
-import json
-import logging
-
-import gradio as gr
-import httpx
-import numpy as np
-import soundfile as sf
-from fastapi import FastAPI, Request
-from fastapi.responses import Response, StreamingResponse
-
-logger = logging.getLogger(__name__)
-
-SAMPLE_RATE = 48000
-
-# ── AudioWorklet processor (loaded in browser via Blob URL) ──────────
-WORKLET_JS = r"""
-class TTSPlaybackProcessor extends AudioWorkletProcessor {
- constructor() {
- super();
- this.queue = [];
- this.buf = null;
- this.pos = 0;
- this.playing = false;
- this.played = 0;
- this.port.onmessage = (e) => {
- if (e.data && e.data.type === 'clear') {
- this.queue = []; this.buf = null; this.pos = 0; this.played = 0;
- if (this.playing) { this.playing = false; this.port.postMessage({type:'stopped'}); }
- return;
- }
- this.queue.push(e.data);
- };
- }
- process(inputs, outputs) {
- const out = outputs[0][0];
- for (let i = 0; i < out.length; i++) {
- if (!this.buf || this.pos >= this.buf.length) {
- if (this.queue.length > 0) {
- this.buf = this.queue.shift(); this.pos = 0;
- } else {
- for (let j = i; j < out.length; j++) out[j] = 0;
- if (this.playing) { this.playing = false; this.port.postMessage({type:'stopped', played:this.played}); }
- return true;
- }
- }
- out[i] = this.buf[this.pos++] / 32768;
- this.played++;
- }
- if (!this.playing) { this.playing = true; this.port.postMessage({type:'started'}); }
- return true;
- }
-}
-registerProcessor('tts-playback-processor', TTSPlaybackProcessor);
-"""
-
-PLAYER_HTML = """
-
-"""
-
-
-def _build_player_js() -> str:
- return f"""
-
-"""
-
-
-def _encode_audio(audio_data: tuple) -> str:
- sr, audio_np = audio_data
- if audio_np.dtype in (np.float32, np.float64):
- audio_np = np.clip(audio_np, -1.0, 1.0)
- audio_np = (audio_np * 32767).astype(np.int16)
- elif audio_np.dtype != np.int16:
- audio_np = audio_np.astype(np.int16)
- buf = io.BytesIO()
- sf.write(buf, audio_np, sr, format="WAV")
- return f"data:audio/wav;base64,{base64.b64encode(buf.getvalue()).decode()}"
-
-
-def create_app(api_base: str):
- app = FastAPI()
- _pending: dict[str, dict] = {}
-
- @app.post("/proxy/v1/audio/speech")
- async def proxy_speech(request: Request):
- body = await request.json()
- req_id = body.get("_req_id")
- if req_id and req_id in _pending:
- body = _pending.pop(req_id)
- logger.info("Proxy: %s", {k: (f"<{len(str(v))} chars>" if k == "ref_audio" else v) for k, v in body.items()})
- try:
- client = httpx.AsyncClient(timeout=300)
- resp = await client.send(
- client.build_request(
- "POST",
- f"{api_base}/v1/audio/speech",
- json=body,
- headers={"Authorization": "Bearer EMPTY", "Content-Type": "application/json"},
- ),
- stream=True,
- )
- except Exception as exc:
- logger.exception("Proxy connection error")
- await client.aclose()
- return Response(content=str(exc), status_code=502)
- if resp.status_code != 200:
- content = await resp.aread()
- await resp.aclose()
- await client.aclose()
- return Response(content=content, status_code=resp.status_code)
-
- async def relay():
- try:
- async for chunk in resp.aiter_bytes():
- yield chunk
- finally:
- await resp.aclose()
- await client.aclose()
-
- return StreamingResponse(relay(), media_type="application/octet-stream")
-
- css = """
- #generate-btn button { width: 100%; }
- #streaming-player { border: 1px solid var(--border-color-primary) !important; border-radius: var(--block-radius) !important; padding: var(--block-padding) !important; }
- """
- theme = gr.themes.Default(
- primary_hue=gr.themes.Color(
- c50="#f0f5ff",
- c100="#dce6f9",
- c200="#b8cef3",
- c300="#8eb2eb",
- c400="#6496e0",
- c500="#4A90D9",
- c600="#3a7bc8",
- c700="#2d66b0",
- c800="#1f4f8f",
- c900="#163a6e",
- c950="#0e2650",
- ),
- )
-
- with gr.Blocks(title="VoxCPM2 TTS Demo") as demo:
- gr.HTML(f"""
-
-
-
-
VoxCPM2 Streaming Demo
-
- Served by vLLM-Omni
- · {api_base}
- · 48 kHz
-
-
-
- """)
-
- gr.Markdown(
- "**Three modes:** "
- "**Voice Design** (control instruction only) · "
- "**Controllable Cloning** (ref audio + optional style control) · "
- "**Ultimate Cloning** (ref audio + transcript for audio continuation)"
- )
-
- with gr.Row():
- with gr.Column(scale=3):
- text_input = gr.Textbox(
- label="Target Text",
- placeholder="Enter text to synthesize...",
- lines=4,
- )
- control_instruction = gr.Textbox(
- label="Control Instruction (optional)",
- placeholder="e.g. A warm young woman / Excited and fast-paced",
- lines=2,
- info="Describe voice style, emotion, pace. Works for both Voice Design and Controllable Cloning.",
- )
-
- with gr.Accordion("Voice Cloning", open=False):
- ref_audio = gr.Audio(
- label="Reference Audio (upload for cloning)",
- type="numpy",
- sources=["upload", "microphone"],
- )
- ref_audio_url = gr.Textbox(
- label="or Reference Audio URL",
- placeholder="https://example.com/reference.wav",
- )
- ultimate_clone = gr.Checkbox(
- label="Ultimate Cloning Mode",
- value=False,
- info="Provide transcript of ref audio for audio continuation (disables control instruction)",
- )
- prompt_text = gr.Textbox(
- label="Reference Audio Transcript",
- placeholder="Transcript of your reference audio (for ultimate cloning)",
- lines=2,
- visible=False,
- )
-
- with gr.Row():
- stream_checkbox = gr.Checkbox(
- label="Stream (gapless)",
- value=True,
- info="AudioWorklet streaming",
- )
- with gr.Row():
- generate_btn = gr.Button(
- "Generate Speech",
- variant="primary",
- size="lg",
- elem_id="generate-btn",
- scale=3,
- )
- reset_btn = gr.Button("Reset", variant="secondary", size="lg", scale=1)
-
- with gr.Column(scale=2):
- player_html = gr.HTML(
- value=PLAYER_HTML,
- visible=True,
- label="streaming player",
- elem_id="streaming-player",
- )
- audio_output = gr.Audio(
- label="generated audio",
- interactive=False,
- autoplay=True,
- visible=False,
- )
- gr.Examples(
- examples=[
- ["Hello, this is a VoxCPM2 demo running on vLLM-Omni.", ""],
- [
- "I have a dream that my four little children will one day live in a nation "
- "where they will not be judged by the color of their skin but by the content "
- "of their character.",
- "",
- ],
- [
- "I never asked you to stay. It's not like I care or anything. "
- "But why does it still hurt so much now that you're gone?",
- "A young girl with a soft, sweet voice. Speaks slowly with a melancholic tone.",
- ],
- ],
- inputs=[text_input, control_instruction],
- label="examples",
- )
- gr.HTML("""
-
- """)
-
- hidden_payload = gr.Textbox(visible=False, elem_id="tts-payload")
-
- def on_ultimate_toggle(checked):
- return (
- gr.update(visible=checked), # prompt_text
- gr.update(interactive=not checked), # control_instruction
- )
-
- ultimate_clone.change(
- fn=on_ultimate_toggle,
- inputs=[ultimate_clone],
- outputs=[prompt_text, control_instruction],
- )
-
- def on_stream_change(stream: bool):
- if stream:
- return gr.update(visible=True), gr.update(visible=False)
- return gr.update(visible=False), gr.update(visible=True)
-
- stream_checkbox.change(
- fn=on_stream_change,
- inputs=[stream_checkbox],
- outputs=[player_html, audio_output],
- )
-
- def on_reset():
- return "", "", None, "", False, "", PLAYER_HTML
-
- reset_btn.click(
- fn=on_reset,
- outputs=[
- text_input,
- control_instruction,
- audio_output,
- hidden_payload,
- ultimate_clone,
- prompt_text,
- player_html,
- ],
- js="() => { if (window.ttsStop) window.ttsStop(); }",
- )
-
- def on_generate(stream_enabled, text, ctrl_instr, ref_a, ref_url, ult_clone, p_text):
- import time as _time
-
- if not text or not text.strip():
- raise gr.Error("Please enter text to synthesize.")
-
- # VoxCPM2 uses "(instruction)text" format for control
- ctrl = ctrl_instr.strip() if ctrl_instr and not ult_clone else ""
- final_text = f"({ctrl}){text.strip()}" if ctrl else text.strip()
-
- payload: dict = {
- "input": final_text,
- "voice": "default",
- "response_format": "pcm" if stream_enabled else "wav",
- "stream": stream_enabled,
- }
-
- # Reference audio for cloning
- ref_url_s = ref_url.strip() if ref_url else ""
- if ref_url_s:
- payload["ref_audio"] = ref_url_s
- elif ref_a is not None:
- payload["ref_audio"] = _encode_audio(ref_a)
-
- # Ultimate cloning: prompt_audio + prompt_text for continuation
- if ult_clone and p_text and p_text.strip():
- if ref_url_s:
- payload["prompt_audio"] = ref_url_s
- elif ref_a is not None:
- payload["prompt_audio"] = payload.get("ref_audio", "")
- payload["prompt_text"] = p_text.strip()
-
- if stream_enabled:
- if ref_a is not None and not ref_url_s:
- req_id = f"req-{int(_time.time() * 1000)}"
- _pending[req_id] = payload
- browser_payload = {"_req_id": req_id, "_nonce": int(_time.time() * 1000)}
- return json.dumps(browser_payload), gr.update()
- payload["_nonce"] = int(_time.time() * 1000)
- return json.dumps(payload), gr.update()
- else:
- try:
- with httpx.Client(timeout=300.0) as client:
- resp = client.post(
- f"{api_base}/v1/audio/speech",
- json=payload,
- headers={"Content-Type": "application/json", "Authorization": "Bearer EMPTY"},
- )
- except httpx.ConnectError:
- raise gr.Error(f"Cannot connect to server at {api_base}.")
- if resp.status_code != 200:
- raise gr.Error(f"Server error ({resp.status_code}): {resp.text[:200]}")
- audio_np, sr = sf.read(io.BytesIO(resp.content))
- if audio_np.ndim > 1:
- audio_np = audio_np[:, 0]
- return "", (sr, audio_np.astype(np.float32))
-
- generate_btn.click(
- fn=on_generate,
- inputs=[
- stream_checkbox,
- text_input,
- control_instruction,
- ref_audio,
- ref_audio_url,
- ultimate_clone,
- prompt_text,
- ],
- outputs=[hidden_payload, audio_output],
- ).then(
- fn=lambda p: p,
- inputs=[hidden_payload],
- outputs=[hidden_payload],
- js="(p) => { if (p && p.trim()) { const d = JSON.parse(p); delete d._nonce; window.ttsGenerate(d); } return p; }",
- )
-
- demo.queue()
-
- return gr.mount_gradio_app(app, demo, path="/", css=css, theme=theme, head=_build_player_js())
-
-
-def main():
- parser = argparse.ArgumentParser(description="VoxCPM2 streaming Gradio demo")
- parser.add_argument("--api-base", default="http://localhost:8000", help="vLLM API server URL")
- parser.add_argument("--host", default="0.0.0.0", help="Gradio server host")
- parser.add_argument("--port", type=int, default=7860, help="Gradio server port")
- args = parser.parse_args()
-
- logging.basicConfig(level=logging.INFO)
- print(f"Connecting to vLLM server at: {args.api_base}")
-
- import uvicorn
-
- uvicorn.run(create_app(args.api_base), host=args.host, port=args.port)
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/online_serving/voxcpm2/openai_speech_client.py b/examples/online_serving/voxcpm2/openai_speech_client.py
deleted file mode 100644
index 127b8cebb09..00000000000
--- a/examples/online_serving/voxcpm2/openai_speech_client.py
+++ /dev/null
@@ -1,105 +0,0 @@
-"""OpenAI-compatible client for VoxCPM2 TTS via /v1/audio/speech endpoint.
-
-Examples:
- # Zero-shot synthesis
- python openai_speech_client.py --text "Hello, this is VoxCPM2."
-
- # Voice cloning with a local reference audio file
- python openai_speech_client.py --text "Hello world" \
- --ref-audio /path/to/reference.wav
-
- # Voice cloning with a URL
- python openai_speech_client.py --text "Hello world" \
- --ref-audio "https://example.com/reference.wav"
-
-Server setup:
- vllm serve openbmb/VoxCPM2 --omni --host 0.0.0.0 --port 8000
-"""
-
-from __future__ import annotations
-
-import argparse
-import base64
-import os
-
-import httpx
-
-DEFAULT_API_BASE = "http://localhost:8000"
-DEFAULT_API_KEY = "sk-empty"
-
-
-def encode_audio_to_base64(audio_path: str) -> str:
- """Encode a local audio file to a base64 data URL."""
- if not os.path.exists(audio_path):
- raise FileNotFoundError(f"Audio file not found: {audio_path}")
-
- ext = audio_path.lower().rsplit(".", 1)[-1]
- mime = {
- "wav": "audio/wav",
- "mp3": "audio/mpeg",
- "flac": "audio/flac",
- "ogg": "audio/ogg",
- }.get(ext, "audio/wav")
-
- with open(audio_path, "rb") as f:
- b64 = base64.b64encode(f.read()).decode("utf-8")
- return f"data:{mime};base64,{b64}"
-
-
-def main() -> None:
- parser = argparse.ArgumentParser(description="VoxCPM2 OpenAI speech client")
- parser.add_argument("--text", type=str, required=True, help="Text to synthesize")
- parser.add_argument(
- "--ref-audio",
- type=str,
- default=None,
- help="Reference audio for voice cloning (local path, URL, or data: URI)",
- )
- parser.add_argument("--model", type=str, default="voxcpm2")
- parser.add_argument("--output", type=str, default="output.wav")
- parser.add_argument("--api-base", type=str, default=DEFAULT_API_BASE)
- parser.add_argument("--api-key", type=str, default=DEFAULT_API_KEY)
- parser.add_argument("--response-format", type=str, default="wav")
- args = parser.parse_args()
-
- # VoxCPM2 has no predefined voices. The "voice" field is required by
- # the OpenAI API schema but ignored by VoxCPM2 — use any placeholder.
- # For voice cloning, pass --ref-audio instead.
- payload: dict = {
- "model": args.model,
- "input": args.text,
- "voice": "default",
- "response_format": args.response_format,
- }
-
- if args.ref_audio:
- ref = args.ref_audio
- if ref.startswith(("http://", "https://", "data:")):
- payload["ref_audio"] = ref
- else:
- payload["ref_audio"] = encode_audio_to_base64(ref)
-
- url = f"{args.api_base}/v1/audio/speech"
- print(f"POST {url}")
- print(f" text: {args.text}")
- if args.ref_audio:
- print(f" ref_audio: {args.ref_audio[:80]}...")
-
- with httpx.Client(timeout=300) as client:
- resp = client.post(
- url,
- json=payload,
- headers={"Authorization": f"Bearer {args.api_key}"},
- )
-
- if resp.status_code != 200:
- print(f"Error {resp.status_code}: {resp.text[:500]}")
- return
-
- with open(args.output, "wb") as f:
- f.write(resp.content)
- print(f"Saved: {args.output} ({len(resp.content):,} bytes)")
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/online_serving/voxtral_tts/gradio_demo.py b/examples/online_serving/voxtral_tts/gradio_demo.py
index 7905c62618c..35d6b590c97 100644
--- a/examples/online_serving/voxtral_tts/gradio_demo.py
+++ b/examples/online_serving/voxtral_tts/gradio_demo.py
@@ -216,7 +216,6 @@ def update_voice_dropdown(language: str) -> gr.Dropdown:
def run_inference(
voice_name: str,
text_prompt: str,
- cfg_alpha: float,
base_url: str,
model: str,
) -> tuple[int, np.ndarray]:
@@ -234,7 +233,6 @@ def run_inference(
"model": model,
"response_format": "wav",
"voice": voice_name,
- "extra_params": {"cfg_alpha": cfg_alpha},
}
response = httpx.post(
@@ -379,14 +377,6 @@ def main(
placeholder="Enter the text you want to synthesize...",
lines=4,
)
- cfg_alpha_slider = gr.Slider(
- minimum=1.0,
- maximum=2.0,
- step=0.1,
- value=1.2,
- label="CFG Alpha",
- info="Flow-matching guidance strength (default: 1.2)",
- )
with gr.Row():
reset_btn = gr.Button("Clear")
submit_btn = gr.Button("Generate audio", interactive=False)
@@ -425,9 +415,9 @@ def _toggle_submit(text: str):
)
# --- Wiring inference + persistence to the button ---
- def _on_submit(voice: str, text: str, cfg_alpha: float):
+ def _on_submit(voice: str, text: str):
assert text.strip() != ""
- sr, audio_array = run_inference(voice, text, cfg_alpha, base_url, model)
+ sr, audio_array = run_inference(voice, text, base_url, model)
if outputs_dir is not None:
share_id, saved_audio_path = _save_example(
outputs_dir,
@@ -442,7 +432,7 @@ def _on_submit(voice: str, text: str, cfg_alpha: float):
submit_btn.click(
fn=_on_submit,
- inputs=[voice_name, text_prompt, cfg_alpha_slider],
+ inputs=[voice_name, text_prompt],
outputs=[output_audio, share_link_box],
)
@@ -456,7 +446,6 @@ def _on_reset():
language, # language_dropdown
voice, # voice_name
"", # text_prompt
- 1.2, # cfg_alpha_slider
None, # output_audio
gr.update(interactive=False), # submit_btn
"", # share_link_box
@@ -467,15 +456,7 @@ def _on_reset():
reset_btn.click(
fn=make_on_reset(languages, language_voices),
inputs=[],
- outputs=[
- language_dropdown,
- voice_name,
- text_prompt,
- cfg_alpha_slider,
- output_audio,
- submit_btn,
- share_link_box,
- ],
+ outputs=[language_dropdown, voice_name, text_prompt, output_audio, submit_btn, share_link_box],
)
def make_load_from_share(outputs_dir: Path | None, languages: list[str], language_voices: dict[str, list[str]]):
diff --git a/mkdocs.yml b/mkdocs.yml
index 1e184439bd1..6461c65f220 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -97,7 +97,6 @@ plugins:
exclude:
- "re:vllm_omni\\._.*" # Internal modules
- "vllm_omni.diffusion.models.qwen_image" # avoid importing vllm in mkdocs building
- - "vllm_omni.diffusion.models.dreamid_omni.wan2_2" # docstring signature warnings break strict docs
- "vllm_omni.diffusion.quantization" # avoid importing vllm in mkdocs building
- "vllm_omni.quantization" # avoid importing vllm in mkdocs building
- "vllm_omni.entrypoints.async_diffusion" # avoid importing vllm in mkdocs building
diff --git a/pyproject.toml b/pyproject.toml
index 8346693f129..e49aa6e3251 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -55,12 +55,6 @@ dev = [
"pyttsx3>=2.99",
"opencc>=1.2.0",
"mistune>=3.2.0", # for example tests
- "torchmetrics>=1.4.0", # for accuracy similarity metrics
- "jiwer>=3.0.0",
- "zhon>=2.0.0",
- "zhconv>=1.4.2",
- "scipy>=1.10.0",
- "funasr>=1.0.0",
]
demo = [
@@ -121,13 +115,12 @@ exclude = [
[tool.ruff.lint]
select = [
- "E", # pycodestyle errors
- "W", # pycodestyle warnings
- "F", # pyflakes
- "I", # isort (handled separately, but included for compatibility)
- "N", # pep8-naming
- "UP", # pyupgrade
- "TID251", # flake8-tidy-imports.banned-api
+ "E", # pycodestyle errors
+ "W", # pycodestyle warnings
+ "F", # pyflakes
+ "I", # isort (handled separately, but included for compatibility)
+ "N", # pep8-naming
+ "UP", # pyupgrade
]
ignore = [
"E203", # whitespace before ':' (conflicts with black)
@@ -142,9 +135,6 @@ ignore = [
"examples/**" = ["E501"] # Allow long lines in examples
"tests/**" = ["E501"] # Allow long lines in tests
-[tool.ruff.lint.flake8-tidy-imports.banned-api]
-"librosa".msg = "The librosa module is banned, use vllm.multimodal helpers instead"
-
[tool.mypy]
python_version = "3.12, 3.13"
warn_return_any = true
@@ -173,8 +163,7 @@ addopts = [
markers = [
# ci/cd required
"core_model: L1&L2 tests (run in each PR)",
- "advanced_model: L3 level tests (run on each merge)",
- "full_model: L4 level tests (run nightly)",
+ "advanced_model: L3&L4 level tests (run in each merge or nightly)",
# function module markers
"diffusion: Diffusion model tests",
"omni: Omni model tests",
@@ -193,7 +182,6 @@ markers = [
"H100: Tests that require H100 GPU",
"L4: Tests that require L4 GPU",
"MI325: Tests that require MI325 GPU (AMD/ROCm)",
- "B60: Tests that require Intel Arc Pro B60 XPU",
"S5000: Tests that require S5000 GPU (Moore Threads/MUSA)",
"A2: Tests that require A2 NPU",
"A3: Tests that require A3 NPU",
@@ -236,5 +224,3 @@ ue = "ue"
semantics = "semantics"
fullset = "fullset"
Vai = "Vai"
-tockens = "tockens"
-CANN = "CANN"
diff --git a/recipes/LTX/LTX-2.3.md b/recipes/LTX/LTX-2.3.md
deleted file mode 100644
index 8d92562fb04..00000000000
--- a/recipes/LTX/LTX-2.3.md
+++ /dev/null
@@ -1,112 +0,0 @@
-# LTX-2.3 Text-to-Video with Audio on 1x GPU (96GB VRAM)
-
-> 22B parameter text-to-video + audio generation model served via vLLM-Omni
-
-## Summary
-
-- Vendor: Lightricks
-- Model: `dg845/LTX-2.3-Diffusers`
-- Task: Text-to-video with synchronized audio generation
-- Mode: Online serving (pure diffusion)
-- Maintainer: @oglok
-
-## When to use this recipe
-
-Use this recipe when you want to serve LTX-2.3 for text-to-video generation
-with audio. The model generates videos up to 20+ seconds at 768x512 resolution
-with 48kHz audio, all from a single text prompt. Requires a GPU with at least
-96GB VRAM due to the 22B parameter transformer (~44GB weights) plus text
-encoder, VAE, and vocoder components.
-
-## References
-
-- Model:
-- Requires `diffusers >= 0.38.0` (install from git: `pip install git+https://github.com/huggingface/diffusers.git`)
-
-## Serving
-
-### Command
-
-```bash
-vllm serve dg845/LTX-2.3-Diffusers \
- --omni \
- --model-class-name LTX23Pipeline \
- --stage-init-timeout 600
-```
-
-### Verification
-
-```bash
-# Health check
-curl http://localhost:8000/health
-
-# Generate a 3-second video (81 frames at 24fps)
-curl -X POST http://localhost:8000/v1/videos \
- -F "prompt=A majestic bald eagle soaring over a misty mountain valley at dawn, golden sunlight breaking through clouds" \
- -F "negative_prompt=blurry, low quality, distorted, watermark" \
- -F "model=dg845/LTX-2.3-Diffusers" \
- -F "num_frames=81" \
- -F "fps=24" \
- -F "size=768x512" \
- -F "num_inference_steps=30" \
- -F "guidance_scale=4.0" \
- -F "seed=42"
-
-# Generate a 10-second video (241 frames)
-curl -X POST http://localhost:8000/v1/videos \
- -F "prompt=A cozy Japanese ramen shop at night in the rain, steam rising from bowls, neon signs reflecting on wet cobblestone streets" \
- -F "model=dg845/LTX-2.3-Diffusers" \
- -F "num_frames=241" \
- -F "fps=24" \
- -F "size=768x512" \
- -F "num_inference_steps=30" \
- -F "guidance_scale=4.0"
-
-# Generate a 20-second video (481 frames)
-curl -X POST http://localhost:8000/v1/videos \
- -F "prompt=An underwater coral reef teeming with tropical fish, sea turtles gliding gracefully, National Geographic documentary style" \
- -F "model=dg845/LTX-2.3-Diffusers" \
- -F "num_frames=481" \
- -F "fps=24" \
- -F "size=768x512" \
- -F "num_inference_steps=30" \
- -F "guidance_scale=4.0"
-```
-
-### Notes
-
-- Memory usage: Model loads at ~36 GiB, peaks at ~62 GiB during inference
-- Key flags:
- - `--stage-init-timeout 600`: Required for the initial `torch.compile` warmup (~90-140 seconds on first request)
- - `--model-class-name LTX23Pipeline`: Selects the LTX-2.3 pipeline (not LTX-2)
-- Audio: 48kHz AAC via BWE vocoder, automatically synced with video
-- CPU offloading: Text encoder (Gemma-3-12B), connectors, VAE, audio VAE, and vocoder stay on CPU and are moved to GPU only when needed
-- Supported resolutions: 768x512, 512x384 (must be divisible by 32)
-- Frame rate: 24 fps
-- Duration: Controlled by `num_frames` (frames = duration_seconds * 24 + 1)
-- Known limitations:
- - No image-to-video support yet (LTX23ImageToVideoPipeline is a placeholder)
- - No CFG-parallel support (single-GPU only)
- - Requires `diffusers >= 0.38.0` (not yet on PyPI, install from git)
-
-## Hardware Support
-
-## GPU
-
-### 1x NVIDIA RTX PRO 6000 Blackwell (96GB)
-
-#### Environment
-
-- OS: Ubuntu 22.04
-- Python: 3.10+
-- Driver / runtime: CUDA 13.0, Driver 580.126.09
-- vLLM version: 0.19.x
-- vLLM-Omni version: 0.19.x
-
-### Validated configurations
-
-| Duration | Frames | Resolution | Steps | Guidance | Inference Time | Peak VRAM |
-|----------|--------|------------|-------|----------|----------------|-----------|
-| 3s | 81 | 768x512 | 30 | 4.0 | ~110s | ~62 GB |
-| 10s | 241 | 768x512 | 30 | 4.0 | ~130s | ~62 GB |
-| 20s | 481 | 768x512 | 30 | 4.0 | ~420s | ~62 GB |
diff --git a/recipes/Qwen/Qwen3-Omni.md b/recipes/Qwen/Qwen3-Omni.md
deleted file mode 100644
index f78e4dda2aa..00000000000
--- a/recipes/Qwen/Qwen3-Omni.md
+++ /dev/null
@@ -1,99 +0,0 @@
-# Qwen3-Omni for multimodal chat on 1x A100 80GB
-
-## Summary
-
-- Vendor: Qwen
-- Model: `Qwen/Qwen3-Omni-30B-A3B-Instruct`
-- Task: Multimodal chat with text, image, audio, or video input
-- Mode: Online serving with the OpenAI-compatible API
-- Maintainer: Community
-
-## When to use this recipe
-
-Use this recipe when you want a known-good starting point for serving
-`Qwen/Qwen3-Omni-30B-A3B-Instruct` with vLLM-Omni on a single 80 GB A100 and
-validate the deployment with the existing multimodal client examples in this
-repository.
-
-## References
-
-- Upstream or canonical docs:
- [`docs/user_guide/examples/online_serving/qwen3_omni.md`](../../docs/user_guide/examples/online_serving/qwen3_omni.md)
-- Related example under `examples/`:
- [`examples/online_serving/qwen3_omni/README.md`](../../examples/online_serving/qwen3_omni/README.md)
-- Related issue or discussion:
- [RFC: add recipes folder](https://github.com/vllm-project/vllm-omni/issues/2645)
-
-## Hardware Support
-
-This recipe currently documents one tested-style reference configuration for
-CUDA GPU serving. Add more sections for other hardware as community validation
-lands.
-
-## GPU
-
-### 1x A100 80GB
-
-#### Environment
-
-- OS: Linux
-- Python: 3.10+
-- Driver / runtime: NVIDIA CUDA environment with an A100 80 GB GPU
-- vLLM version: Match the repository requirements for your checkout
-- vLLM-Omni version or commit: Use the commit you are deploying from
-
-#### Command
-
-Start the server from the repository root:
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
-```
-
-Async chunking is enabled by default in the bundled deployment config. For
-common runtime tuning, prefer CLI overrides instead of editing or passing a
-custom YAML file:
-
-```bash
-# Disable async chunking for /v1/realtime sessions
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
- --no-async-chunk
-```
-
-Use a custom deploy config only for advanced cases such as custom topology,
-connector wiring, or a larger overlay of stage defaults:
-
-```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
- --deploy-config /path/to/your_qwen3_omni_overrides.yaml
-```
-
-#### Verification
-
-Run one of the existing example clients after the server is ready:
-
-```bash
-python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py \
- --model Qwen/Qwen3-Omni-30B-A3B-Instruct \
- --query-type use_image \
- --port 8091 \
- --host localhost
-```
-
-For a quick API smoke test, request text-only output:
-
-```bash
-curl http://localhost:8091/v1/chat/completions \
- -H "Content-Type: application/json" \
- -d '{
- "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
- "messages": [{"role": "user", "content": "Describe vLLM in brief."}],
- "modalities": ["text"]
- }'
-```
-
-#### Notes
-
-- Memory usage: Size depends on runtime options and output modalities; leave headroom for multimodal workloads. Prefer CLI overrides such as `--gpu-memory-utilization` for routine tuning.
-- Key flags: `--omni` is required; async chunking is enabled by default; use `--no-async-chunk` for realtime sessions and `--deploy-config` only for advanced custom deployments.
-- Known limitations: The `/v1/realtime` WebSocket flow is currently unsupported while async chunking is enabled. This starter recipe is intentionally narrow and focuses on the single-GPU online-serving path already documented in the repo examples.
diff --git a/recipes/README.md b/recipes/README.md
deleted file mode 100644
index 69ce4d7504d..00000000000
--- a/recipes/README.md
+++ /dev/null
@@ -1,39 +0,0 @@
-# Community Recipes
-
-This directory contains community-maintained recipes for answering a
-practical user question:
-
-> How do I run model X on hardware Y for task Z?
-
-Add recipes for this repository under this in-repo `recipes/` directory. To
-keep naming and layout consistent, organize recipes by model vendor in a way
-that is aligned with
-[`vllm-project/recipes`](https://github.com/vllm-project/recipes), but treat
-that external repository as a reference for structure rather than the place to
-add files for this repo. Use one Markdown file per model family by default.
-
-Example layout:
-
-```text
-recipes/
- Qwen/
- Qwen3-Omni.md
- Qwen3-TTS.md
- Tencent-Hunyuan/
- HunyuanVideo.md
-```
-
-## Available Recipes
-
-- [`Qwen/Qwen3-Omni.md`](./Qwen/Qwen3-Omni.md): online serving recipe for
- multimodal chat on `1x A100 80GB`
-- [`Wan-AI/Wan2.2-I2V.md`](./Wan-AI/Wan2.2-I2V.md): image-to-video serving
- recipe for Wan2.2 14B on `8x Ascend NPU (A2/A3)`
-- [`inclusionAI/Ming-flash-omni-2.0.md`](./inclusionAI/Ming-flash-omni-2.0.md):
- online serving recipe for multimodal chat (`4x H100 80GB`) and standalone TTS (`1x H100 80GB`)
-
-Within a single recipe file, include different hardware support sections such
-as `GPU`, `ROCm`, and `NPU`, and add concrete tested configurations like
-`1x A100 80GB` or `2x L40S` inside those sections when applicable.
-
-See [TEMPLATE.md](./TEMPLATE.md) for the recommended format.
diff --git a/recipes/TEMPLATE.md b/recipes/TEMPLATE.md
deleted file mode 100644
index 9bf8cb9c759..00000000000
--- a/recipes/TEMPLATE.md
+++ /dev/null
@@ -1,82 +0,0 @@
-# Recipe Title
-
-> Example: Qwen3-Omni for speech chat on 1x A100 80GB
-
-## Summary
-
-- Vendor:
-- Model:
-- Task:
-- Mode:
-- Maintainer:
-
-## When to use this recipe
-
-Briefly describe the concrete scenario this recipe covers.
-
-## References
-
-- Upstream or canonical docs:
-- Related example under `examples/`:
-- Related issue or discussion:
-
-## Hardware Support
-
-Add one section per platform, such as `GPU`, `ROCm`, or `NPU`. Under each
-platform section, document one or more tested hardware configurations.
-
-## GPU
-
-### 1x A100 80GB
-
-#### Environment
-
-- OS:
-- Python:
-- Driver / runtime:
-- vLLM version:
-- vLLM-Omni version or commit:
-
-#### Command
-
-```bash
-# Add the exact command(s) here
-```
-
-#### Verification
-
-```bash
-# Add a quick validation command or expected output here
-```
-
-#### Notes
-
-- Memory usage:
-- Key flags:
-- Known limitations:
-
-### 2x L40S
-
-Repeat the same structure for other hardware setups as needed.
-
-## ROCm
-
-### Example hardware configuration
-
-Repeat the same nested structure for ROCm setups as needed:
-
-- `#### Environment`
-- `#### Command`
-- `#### Verification`
-- `#### Notes`
-
-## NPU
-
-### Example hardware configuration
-
-Repeat the same nested structure for NPU setups as needed:
-
-- `#### Environment`
-- `#### Command`
-- `#### Verification`
-- `#### Notes`
diff --git a/recipes/Wan-AI/Wan2.2-I2V.md b/recipes/Wan-AI/Wan2.2-I2V.md
deleted file mode 100644
index 99ceac3cebe..00000000000
--- a/recipes/Wan-AI/Wan2.2-I2V.md
+++ /dev/null
@@ -1,136 +0,0 @@
-# Wan2.2 Image To Video
-
-## Summary
-
-- Vendor: Wan-AI
-- Model: `Wan-AI/Wan2.2-I2V-A14B-Diffusers`
-- Task: Image-to-video generation
-- Mode: Online serving with the OpenAI-compatible API
-- Maintainer: Community
-
-## When to use this recipe
-
-Use this recipe when you want to deploy the Wan2.2 14B image-to-video model
-with vLLM-Omni using multi-card parallelism. Two configurations are provided:
-
-1. **Distilled model (no negative-prompt / CFG computation)** — higher
- throughput, recommended when using a distilled checkpoint that does not
- require classifier-free guidance.
-2. **Official open-source model (with CFG)** — uses `--cfg 2` to run negative
- and positive samples in parallel for the original released weights.
-
-## References
-
-- Upstream model card:
-
-## Hardware Support
-
-## NPU
-
-### 8x Ascend A2 / A3
-
-#### Environment
-
-- OS: Linux
-- Python: 3.10+
-- Driver / runtime: Ascend NPU driver with CANN toolkit
-- Recommended operator library: **mindie-sd** (Ascend high-performance fused
- operators — enables `adalayernorm` and other fused kernels automatically upon
- installation)
-- vLLM version: Match the repository requirements for your checkout
-- vLLM-Omni version or commit: Use the commit you are deploying from
-
-#### Prerequisites
-
-Install the **mindie-sd** operator library to enable Ascend-optimized fused
-operators (`adalayernorm`, etc.):
-
-```bash
-git clone https://gitcode.com/Ascend/MindIE-SD.git && cd MindIE-SD
-
-# Comment out the tik_ops build step (not needed for this use case)
-sed -i 's|^\(\s*\)source ${current_script_dir}/build_tik_ops.sh|\1# source ${current_script_dir}/build_tik_ops.sh|' build/build_ops.sh
-
-python setup.py bdist_wheel
-cd dist
-pip install mindiesd-*.whl
-```
-
-After installation, enable the Laser Attention kernel for significant
-long-sequence speedups (up to ~40% at 720p in tested workloads):
-
-```bash
-export MINDIE_SD_FA_TYPE=ascend_laser_attention
-```
-
-When using HSDP with FSDP2, set the following environment variable to work
-around a PyTorch NPU multi-stream memory reuse issue
-([pytorch/pytorch#147168](https://github.com/pytorch/pytorch/issues/147168)).
-This issue has been fixed on CUDA but still applies to NPU:
-
-```bash
-export MULTI_STREAM_MEMORY_REUSE=2
-```
-
-#### Command
-
-**Distilled model (no CFG, recommended for distilled checkpoints):**
-
-```bash
-export MINDIE_SD_FA_TYPE=ascend_laser_attention
-export MULTI_STREAM_MEMORY_REUSE=2
-
-vllm serve \
- --omni Wan-AI/Wan2.2-I2V-A14B-Diffusers \
- --use-hsdp \
- --usp 8 \
- --vae-patch-parallel-size 8 \
- --vae-use-tiling
-```
-
-**Official open-source model (with CFG):**
-
-```bash
-export MINDIE_SD_FA_TYPE=ascend_laser_attention
-export MULTI_STREAM_MEMORY_REUSE=2
-
-vllm serve \
- --omni Wan-AI/Wan2.2-I2V-A14B-Diffusers \
- --use-hsdp \
- --usp 4 \
- --cfg 2 \
- --vae-patch-parallel-size 8 \
- --vae-use-tiling
-```
-
-> **Why the difference?** With `--cfg 2`, two copies of the input (positive and
-> negative prompts) are processed in parallel, effectively doubling the compute
-> for the DiT backbone. USP is therefore halved from 8 to 4 so that the total
-> parallelism across the 8 cards remains balanced (`usp * cfg = 8`).
-
-#### Verification
-
-After the server is ready, see
-[`examples/online_serving/image_to_video/README.md`](../../examples/online_serving/image_to_video/README.md)
-for complete client examples and request formats.
-
-#### Notes
-
-- **Key flags:**
- - `--omni` — enables vLLM-Omni diffusion serving.
- - `--use-hsdp` — enables Hybrid Sharded Data Parallelism for the DiT model
- weights.
- - `--usp ` — Unified Sequence Parallelism degree.
- - `--cfg ` — Classifier-Free Guidance parallelism; set to 2 for models
- that require negative-prompt computation, omit for distilled models.
- - `--vae-patch-parallel-size 8` — parallelizes VAE decoding across all 8
- cards.
- - `--vae-use-tiling` — enables tiled VAE decoding to reduce peak memory.
-- **Performance tips:**
- - Installing mindie-sd and enabling Laser Attention
- (`MINDIE_SD_FA_TYPE=ascend_laser_attention`) provides up to ~40%
- performance improvement at 720p resolution due to long-sequence attention
- optimization.
-- **Known limitations:**
- - `MULTI_STREAM_MEMORY_REUSE=2` is required on NPU when using HSDP/FSDP2
- due to a multi-stream memory reuse bug. This is not needed on CUDA.
diff --git a/recipes/inclusionAI/Ming-flash-omni-2.0.md b/recipes/inclusionAI/Ming-flash-omni-2.0.md
deleted file mode 100644
index 873158c8adc..00000000000
--- a/recipes/inclusionAI/Ming-flash-omni-2.0.md
+++ /dev/null
@@ -1,210 +0,0 @@
-# Ming-flash-omni 2.0 for omni-speech chat and standalone TTS
-
-## Summary
-
-- Vendor: inclusionAI
-- Model: `Jonathan1909/Ming-flash-omni-2.0`
-- Task: Multimodal chat with text, image, audio, or video input; standalone text-to-speech (TTS);
-and image generation
-- Mode: Online serving with the OpenAI-compatible API
-- Maintainer: Community
-
-## When to use this recipe
-
-Use this recipe when you want a known-good starting point for serving
-`Jonathan1909/Ming-flash-omni-2.0` with vLLM-Omni in one of three modes:
-
-- **Thinker only** — multimodal understanding with text output.
-- **Thinker + Talker (omni-speech)** — multimodal understanding with text and spoken output.
-- **Talker only (TTS)** — standalone text-to-speech via the OpenAI `/v1/audio/speech` endpoint.
-
-## References
-
-- Upstream model:
- [`inclusionAI/Ming`](https://github.com/inclusionAI/Ming)
-- For offline inference and additional client variants, see
- `examples/offline_inference/ming_flash_omni{,_tts}/` and
- `examples/online_serving/ming_flash_omni{,_tts}/`.
-
-
-## Hardware Support
-
-This recipe documents reference GPU configurations for the two-stage
-omni-speech deployment and the standalone TTS deployment.
-Other hardware and configurations are welcome as community validation lands.
-
-## GPU
-
-### 4x H100 80GB — omni-speech/chat (thinker + talker)
-
-The bundled `ming_flash_omni.yaml` runs the thinker with tensor parallel size
-4 on GPUs 0–3 and the talker on GPU 3.
-Adjust `devices` in the YAML to match your hardware.
-
-#### Environment
-
-- OS: Linux
-- Python: 3.10+
-- CUDA Driver Version: 590.48.01
-- CUDA 12.5
-- vLLM version: 0.19.0
-- vLLM-Omni version or commit: 0.19.0rc1
-
-#### Command
-
-Thinker only (text output):
-
-```bash
-vllm serve Jonathan1909/Ming-flash-omni-2.0 --omni --port 8091
-```
-
-Thinker + talker (text and/or audio output):
-
-```bash
-vllm serve Jonathan1909/Ming-flash-omni-2.0 \
- --omni \
- --port 8091 \
- --stage-configs-path vllm_omni/model_executor/stage_configs/ming_flash_omni.yaml \
- --log-stats
-```
-
-`--log-stats` is optional but recommended while validating the deployment.
-
-#### Verification
-
-Text output from a multimodal (image) input:
-
-```bash
-curl http://localhost:8091/v1/chat/completions \
- -H "Content-Type: application/json" \
- -d '{
- "model": "Jonathan1909/Ming-flash-omni-2.0",
- "messages": [
- {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]},
- {"role": "user", "content": [
- {"type": "image_url", "image_url": {"url": "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg"}},
- {"type": "text", "text": "Describe this image in detail."}
- ]}
- ],
- "modalities": ["text"]
- }'
-```
-
-Spoken response from a text query (save the WAV bytes):
-
-```bash
-curl http://localhost:8091/v1/chat/completions \
- -H "Content-Type: application/json" \
- -d '{
- "model": "Jonathan1909/Ming-flash-omni-2.0",
- "messages": [
- {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]},
- {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"}
- ],
- "modalities": ["audio"]
- }' | jq -r '.choices[0].message.audio.data' | base64 -d > ming_omni_parrot.wav
-```
-
-Text + audio output from an audio input (swap `audio_url` for `video_url`
-or `image_url` to exercise the other multimodal input paths):
-
-```bash
-curl http://localhost:8091/v1/chat/completions \
- -H "Content-Type: application/json" \
- -d '{
- "model": "Jonathan1909/Ming-flash-omni-2.0",
- "messages": [
- {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]},
- {"role": "user", "content": [
- {"type": "audio_url", "audio_url": {"url": "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/mary_had_lamb.ogg"}},
- {"type": "text", "text": "Please recognize the language of this speech and transcribe it. Format: oral."}
- ]}
- ],
- "modalities": ["text", "audio"]
- }' | jq -r '.choices[0].message.content'
-```
-
-Streaming text output via SSE (set `"stream": true`):
-
-```bash
-curl -N http://localhost:8091/v1/chat/completions \
- -H "Content-Type: application/json" \
- -d '{
- "model": "Jonathan1909/Ming-flash-omni-2.0",
- "messages": [
- {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]},
- {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"}
- ],
- "modalities": ["text"],
- "stream": true
- }'
-```
-
-Each SSE event carries a `data:` line with a chat-completion chunk; text
-deltas appear at `choices[0].delta.content`.
-
-#### Notes
-
-- Output modality is selected by the request body: `"modalities": ["text"]`,
- `["audio"]`, or `["text", "audio"]`. The two-stage omni-speech server must be launched
- for any request containing `audio`.
-- Reasoning mode: flip the system prompt suffix from `detailed thinking off`
- to `detailed thinking on` in any request above.
-- Memory usage: size depends on output modalities and multimodal input; leave
- headroom for video frames and audio caches.
-
-### 1x H100 80GB — standalone TTS (talker only)
-
-The bundled `ming_flash_omni_tts.yaml` runs the talker on a single GPU and exposes the OpenAI `/v1/audio/speech` endpoint.
-
-#### Environment
-
-- OS: Linux
-- Python: 3.10+
-- CUDA Driver Version: 590.48.01
-- CUDA 12.5
-- vLLM version: 0.19.0
-- vLLM-Omni version or commit: 0.19.0rc1
-
-#### Command
-
-```bash
-vllm serve Jonathan1909/Ming-flash-omni-2.0 \
- --omni \
- --stage-configs-path vllm_omni/model_executor/stage_configs/ming_flash_omni_tts.yaml \
- --port 8091 \
- --log-stats
-```
-
-`--log-stats` is optional but recommended while validating the deployment.
-
-#### Verification
-
-Basic curl:
-
-```bash
-curl -X POST http://localhost:8091/v1/audio/speech \
- -H "Content-Type: application/json" \
- -d '{
- "model": "Jonathan1909/Ming-flash-omni-2.0",
- "input": "我会一直在这里陪着你。",
- "response_format": "wav"
- }' --output ming_online.wav
-```
-
-Speaker selection (e.g. `lingguang`):
-
-```bash
-curl -X POST http://localhost:8091/v1/audio/speech \
- -H "Content-Type: application/json" \
- -d '{
- "model": "Jonathan1909/Ming-flash-omni-2.0",
- "input": "春天来了,万物复苏,大地一片生机盎然。田野里的油菜花开得金灿灿的,蜜蜂在花丛中忙碌地采蜜。远处的山坡上,桃花和杏花竞相绽放,粉的白的交织在一起,美不胜收。清晨的微风带着泥土的芬芳,轻轻拂过脸颊,让人感到无比惬意。孩子们在田间小路上追逐嬉戏,老人们坐在门前晒太阳,享受着这份宁静与美好。",
- "speaker": "lingguang",
- "response_format": "wav"
- }' --output ming_online_lingguang.wav
-```
-
-#### Notes
-
-- The OpenAI `instructions` field is forwarded to the talker as the caption JSON — pass a raw string for `风格` (style) only, or a JSON-encoded object for multiple entries such as `方言` (dialect) and `情感` (emotion).
diff --git a/requirements/common.txt b/requirements/common.txt
index 63e16d580ff..138a61ed222 100644
--- a/requirements/common.txt
+++ b/requirements/common.txt
@@ -1,6 +1,7 @@
# Common dependencies for all platforms
-av>=14.0.0
omegaconf>=2.3.0
+librosa>=0.11.0
+resampy>=0.4.3
diffusers>=0.36.0
accelerate==1.12.0
soundfile>=0.13.1
@@ -9,6 +10,7 @@ tqdm>=4.66.0
torchsde>=0.2.6
openai-whisper>=20250625
imageio[ffmpeg]>=2.37.2
+sox>=1.5.0
x-transformers>=2.12.2
einops>=0.8.1
prettytable>=3.8.0
diff --git a/requirements/musa.txt b/requirements/musa.txt
index c100c70cf05..112f3260465 100644
--- a/requirements/musa.txt
+++ b/requirements/musa.txt
@@ -1,6 +1,4 @@
-r common.txt
# MUSA platform dependencies
-torchada>=0.1.50
+torchada>=0.1.46
onnxruntime>=1.23.2
-mate>=0.2.0
-flash_attn_3>=0.1.4
diff --git a/tests/benchmarks/conftest.py b/tests/benchmarks/conftest.py
deleted file mode 100644
index 7af6c3f8cb8..00000000000
--- a/tests/benchmarks/conftest.py
+++ /dev/null
@@ -1,103 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""conftest.py for benchmarks unit tests.
-
-Installs lightweight mock stubs for ``vllm`` (and sub-packages) so the
-data-module unit tests can run without a full vLLM installation. Only the
-symbols actually imported by
-``vllm_omni.benchmarks.data_modules.seed_tts_dataset`` are emulated.
-"""
-
-from __future__ import annotations
-
-import sys
-import types
-from dataclasses import dataclass
-from typing import Any
-
-
-def _install_vllm_stubs() -> None:
- """Register minimal vllm stubs in sys.modules.
-
- Only installs when real vllm is unavailable. We actively probe the
- import because an empty or partial vllm may not yet have imported
- the submodules we rely on, and unconditionally registering stubs
- would shadow the real package for sibling tests (e.g.
- ``tests/benchmarks/metrics/test_metrics.py`` needs the real
- ``vllm.benchmarks.serve``).
- """
- try:
- import vllm.benchmarks.datasets # noqa: F401
- import vllm.tokenizers # noqa: F401
- except ImportError:
- pass
- else:
- return # real vllm available — do not shadow it
- if "vllm.benchmarks.datasets" in sys.modules:
- return
-
- # ------------------------------------------------------------------ #
- # vllm.benchmarks.datasets #
- # ------------------------------------------------------------------ #
- @dataclass
- class SampleRequest:
- prompt: str = ""
- prompt_len: int = 0
- expected_output_len: int = 0
- multi_modal_data: Any = None
- request_id: str = ""
-
- class BenchmarkDataset:
- def __init__(
- self,
- dataset_path: str = "",
- random_seed: int = 0,
- disable_shuffle: bool = False,
- **kwargs: Any,
- ) -> None:
- self.dataset_path = dataset_path
- self.random_seed = random_seed
- self.disable_shuffle = disable_shuffle
-
- def maybe_oversample_requests(
- self,
- out: list,
- num_requests: int,
- request_id_prefix: str,
- no_oversample: bool,
- ) -> None:
- pass
-
- # ------------------------------------------------------------------ #
- # vllm.tokenizers / vllm.tokenizers.hf #
- # ------------------------------------------------------------------ #
- class TokenizerLike:
- pass
-
- def get_cached_tokenizer(t: Any) -> Any:
- return t
-
- # ------------------------------------------------------------------ #
- # Wire up sys.modules #
- # ------------------------------------------------------------------ #
- vllm_mod = types.ModuleType("vllm")
- vllm_benchmarks = types.ModuleType("vllm.benchmarks")
- vllm_benchmarks_datasets = types.ModuleType("vllm.benchmarks.datasets")
- vllm_tokenizers = types.ModuleType("vllm.tokenizers")
- vllm_tokenizers_hf = types.ModuleType("vllm.tokenizers.hf")
-
- vllm_benchmarks_datasets.BenchmarkDataset = BenchmarkDataset # type: ignore[attr-defined]
- vllm_benchmarks_datasets.SampleRequest = SampleRequest # type: ignore[attr-defined]
- vllm_tokenizers.TokenizerLike = TokenizerLike # type: ignore[attr-defined]
- vllm_tokenizers_hf.get_cached_tokenizer = get_cached_tokenizer # type: ignore[attr-defined]
-
- sys.modules["vllm"] = vllm_mod
- sys.modules["vllm.benchmarks"] = vllm_benchmarks
- sys.modules["vllm.benchmarks.datasets"] = vllm_benchmarks_datasets
- sys.modules["vllm.tokenizers"] = vllm_tokenizers
- sys.modules["vllm.tokenizers.hf"] = vllm_tokenizers_hf
-
-
-# Install stubs immediately at collection time (before any test import).
-_install_vllm_stubs()
diff --git a/tests/benchmarks/metrics/test_metrics.py b/tests/benchmarks/metrics/test_metrics.py
deleted file mode 100644
index f531a5026a3..00000000000
--- a/tests/benchmarks/metrics/test_metrics.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""
-Unit tests for metrics.py
-"""
-
-import pytest
-from vllm.benchmarks.serve import TaskType
-
-from vllm_omni.benchmarks.metrics.metrics import calculate_metrics
-from vllm_omni.benchmarks.patch.patch import MixRequestFuncOutput
-
-pytestmark = [pytest.mark.core_model, pytest.mark.benchmark, pytest.mark.cpu]
-
-
-def _make_output(prompt_len: int, output_tokens: int = 10) -> MixRequestFuncOutput:
- """Build a minimal successful MixRequestFuncOutput for metrics aggregation."""
- output = MixRequestFuncOutput()
- output.success = True
- output.prompt_len = prompt_len
- output.output_tokens = output_tokens
- output.generated_text = "x" * output_tokens
- output.ttft = 0.1
- output.text_latency = 1.0
- output.latency = 1.0
- output.start_time = 0.0
- output.itl = [0.1] * max(output_tokens - 1, 0)
- output.audio_ttfp = 0.0
- output.audio_rtf = 0.0
- output.audio_duration = 0.0
- output.audio_frames = 0
- output.input_audio_duration = 0.0
- output.error = ""
- return output
-
-
-# ============================================================================
-# total_input Tests
-# ============================================================================
-
-
-def test_total_input_aggregated_from_output_prompt_len():
- """Test that total_input sums outputs[i].prompt_len, not input_requests[i].prompt_len."""
- outputs = [_make_output(4992), _make_output(3000)]
-
- metrics, _ = calculate_metrics(
- input_requests=[],
- outputs=outputs,
- dur_s=10.0,
- tokenizer=None,
- selected_percentiles=[99.0],
- goodput_config_dict={},
- task_type=TaskType.GENERATION,
- selected_percentile_metrics=[],
- max_concurrency=None,
- request_rate=float("inf"),
- benchmark_duration=10.0,
- )
-
- assert metrics.total_input == 7992, (
- "total_input should aggregate from outputs[i].prompt_len to reflect the true multimodal input token count"
- )
-
-
-if __name__ == "__main__":
- pytest.main([__file__, "-v", "-s"])
diff --git a/tests/benchmarks/patch/test_patch.py b/tests/benchmarks/patch/test_patch.py
index 35a18aea33c..39b7f84fb49 100644
--- a/tests/benchmarks/patch/test_patch.py
+++ b/tests/benchmarks/patch/test_patch.py
@@ -574,59 +574,5 @@ async def test_text_latency_value_consistency(self, mocker: MockerFixture):
)
-# ============================================================================
-# prompt_len Tests
-# ============================================================================
-
-
-@pytest.mark.asyncio
-async def test_prompt_len_assigned_from_usage(mocker: MockerFixture):
- # Arrange: request claims prompt_len=100, but server reports 4992 (multimodal).
- request_input = RequestFuncInput(
- model="test-model",
- model_name="test-model",
- prompt="test prompt",
- api_url="http://test.com/v1/chat/completions",
- prompt_len=100,
- output_len=20,
- )
-
- chunks = [
- create_sse_chunk(
- {
- "choices": [{"delta": {"content": "Hello"}}],
- "modality": "text",
- }
- ),
- create_sse_chunk(
- {
- "choices": [{"delta": {"content": " world"}}],
- "modality": "text",
- }
- ),
- # Final usage chunk emitted because stream_options.include_usage=True.
- create_sse_chunk(
- {
- "choices": [],
- "usage": {"prompt_tokens": 4992, "completion_tokens": 2, "total_tokens": 4994},
- }
- ),
- b"data: [DONE]\n\n",
- ]
-
- mock_response = MockResponse(200, chunks)
- mock_session = mocker.AsyncMock()
- mock_session.post = mocker.MagicMock(return_value=mock_response)
-
- # Act
- output = await async_request_openai_chat_omni_completions(request_input, mock_session)
-
- # Assert
- assert output.success is True
- assert output.prompt_len == 4992, (
- "prompt_len should be overridden by usage.prompt_tokens to reflect the true multimodal input token count"
- )
-
-
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
diff --git a/tests/benchmarks/test_accuracy_bench_utils.py b/tests/benchmarks/test_accuracy_bench_utils.py
index 6ceebb11b79..a0479fb1bad 100644
--- a/tests/benchmarks/test_accuracy_bench_utils.py
+++ b/tests/benchmarks/test_accuracy_bench_utils.py
@@ -1,17 +1,11 @@
# ruff: noqa: E402, I001
-import argparse
import math
-import os
import sys
-import types
from pathlib import Path
import pytest
from PIL import Image
-pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]
-
-
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
@@ -42,64 +36,8 @@
summarize_generated_records as summarize_gebench_generated_records,
summarize_gebench_results,
)
-from tests.e2e.accuracy.qwen3_omni.qwen3_omni_acc_bench_core import seed_tts_bench_argv
-from tests.e2e.accuracy.qwen3_omni.run_qwen_omni_acc_benchmark import sync_dataset_env_from_ns
-from vllm_omni.benchmarks.data_modules.seed_tts_dataset import resolve_seed_tts_root
-
-
-def test_seed_tts_bench_argv_preserves_hf_repo_id_from_env(monkeypatch):
- monkeypatch.setenv("VLLM_SEED_TTS_DATASET_PATH", "zhaochenyang20/seed-tts-eval")
- monkeypatch.delenv("VLLM_SEED_TTS_REPO", raising=False)
-
- argv = seed_tts_bench_argv(locale="en")
-
- dataset_idx = argv.index("--dataset-path")
- assert argv[dataset_idx + 1] == "zhaochenyang20/seed-tts-eval"
-
-
-def test_sync_dataset_env_preserves_seed_tts_hf_repo_id(monkeypatch):
- ns = argparse.Namespace(
- daily_omni_repo=None,
- daily_omni_qa_json=None,
- daily_omni_video_dir=None,
- seed_tts_dataset_path="zhaochenyang20/seed-tts-eval",
- seed_tts_root=None,
- )
- monkeypatch.delenv("VLLM_SEED_TTS_DATASET_PATH", raising=False)
- sync_dataset_env_from_ns(ns)
-
- assert os.environ["VLLM_SEED_TTS_DATASET_PATH"] == "zhaochenyang20/seed-tts-eval"
-
-
-def test_resolve_seed_tts_root_downloads_only_requested_locale(monkeypatch, tmp_path: Path):
- downloaded_root = tmp_path / "seed_tts_cache"
- (downloaded_root / "zh" / "prompt-wavs").mkdir(parents=True)
- (downloaded_root / "zh" / "meta.lst").write_text("", encoding="utf-8")
- captured: dict[str, object] = {}
-
- def fake_snapshot_download(*, repo_id, repo_type, allow_patterns):
- captured["repo_id"] = repo_id
- captured["repo_type"] = repo_type
- captured["allow_patterns"] = allow_patterns
- return str(downloaded_root)
-
- monkeypatch.setitem(
- sys.modules,
- "huggingface_hub",
- types.SimpleNamespace(snapshot_download=fake_snapshot_download),
- )
-
- resolved = resolve_seed_tts_root(
- "zhaochenyang20/seed-tts-eval",
- explicit_root=None,
- locale="zh",
- )
-
- assert resolved == downloaded_root.resolve()
- assert captured["repo_id"] == "zhaochenyang20/seed-tts-eval"
- assert captured["repo_type"] == "dataset"
- assert captured["allow_patterns"] == ["zh/**"]
+pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]
def test_summarize_gebench_generated_records_groups_by_type():
diff --git a/tests/benchmarks/test_bench_tts_cli.py b/tests/benchmarks/test_bench_tts_cli.py
deleted file mode 100644
index b8a487f80c6..00000000000
--- a/tests/benchmarks/test_bench_tts_cli.py
+++ /dev/null
@@ -1,139 +0,0 @@
-"""Tests for the universal benchmarks/tts/bench_tts.py CLI."""
-
-from __future__ import annotations
-
-import json
-import sys
-from pathlib import Path
-
-import pytest
-import yaml
-
-# Add benchmarks/tts to path for import
-sys.path.insert(0, str(Path(__file__).parent.parent.parent / "benchmarks" / "tts"))
-import bench_tts
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-@pytest.fixture()
-def model_configs_path(tmp_path: Path) -> Path:
- cfg = {
- "models": {
- "test/ModelA": {
- "stage_config": "model_a.yaml",
- "supported_tasks": ["voice_clone", "default_voice"],
- "backend": "openai-audio-speech",
- "endpoint": "/v1/audio/speech",
- "task_extra_body": {
- "voice_clone": {"task_type": "Base"},
- "default_voice": {"voice": "Vivian", "task_type": "CustomVoice"},
- },
- },
- "test/ModelB": {
- "stage_config": "model_b.yaml",
- "supported_tasks": ["voice_clone"],
- "backend": "openai-audio-speech",
- "endpoint": "/v1/audio/speech",
- "task_extra_body": {"voice_clone": {}},
- },
- }
- }
- p = tmp_path / "model_configs.yaml"
- p.write_text(yaml.dump(cfg), encoding="utf-8")
- return p
-
-
-def test_load_model_configs(model_configs_path: Path) -> None:
- configs = bench_tts.load_model_configs(model_configs_path)
- assert "test/ModelA" in configs
- assert "test/ModelB" in configs
- assert configs["test/ModelA"]["supported_tasks"] == ["voice_clone", "default_voice"]
-
-
-def test_build_bench_args_voice_clone(model_configs_path: Path) -> None:
- configs = bench_tts.load_model_configs(model_configs_path)
- cmd = bench_tts.build_bench_args(
- host="localhost",
- port=8000,
- model="test/ModelA",
- task="voice_clone",
- model_cfg=configs["test/ModelA"],
- locale="en",
- num_prompts=10,
- concurrency=1,
- dataset_path="/data/seed-tts",
- wer_eval=False,
- output_dir=None,
- result_filename=None,
- extra_cli_args=[],
- )
- assert "--dataset-name" in cmd
- idx = cmd.index("--dataset-name")
- assert cmd[idx + 1] == "seed-tts"
- assert "--max-concurrency" in cmd
- assert "--extra-body" in cmd
- extra_body = json.loads(cmd[cmd.index("--extra-body") + 1])
- assert extra_body.get("task_type") == "Base"
-
-
-def test_build_bench_args_default_voice_has_voice_param(model_configs_path: Path) -> None:
- configs = bench_tts.load_model_configs(model_configs_path)
- cmd = bench_tts.build_bench_args(
- host="localhost",
- port=8000,
- model="test/ModelA",
- task="default_voice",
- model_cfg=configs["test/ModelA"],
- locale="en",
- num_prompts=10,
- concurrency=1,
- dataset_path="/data/seed-tts",
- wer_eval=False,
- output_dir=None,
- result_filename=None,
- extra_cli_args=[],
- )
- idx = cmd.index("--dataset-name")
- assert cmd[idx + 1] == "seed-tts-text"
- extra_body = json.loads(cmd[cmd.index("--extra-body") + 1])
- assert extra_body.get("voice") == "Vivian"
-
-
-def test_build_bench_args_wer_eval_adds_flag(model_configs_path: Path) -> None:
- configs = bench_tts.load_model_configs(model_configs_path)
- cmd = bench_tts.build_bench_args(
- host="localhost",
- port=8000,
- model="test/ModelA",
- task="voice_clone",
- model_cfg=configs["test/ModelA"],
- locale="en",
- num_prompts=10,
- concurrency=1,
- dataset_path="/data/seed-tts",
- wer_eval=True,
- output_dir=None,
- result_filename=None,
- extra_cli_args=[],
- )
- assert "--seed-tts-wer-eval" in cmd
-
-
-def test_unsupported_task_exits(model_configs_path: Path, capsys: pytest.CaptureFixture, mocker) -> None:
- # ModelB does not support voice_design
- mocker.patch.object(
- sys,
- "argv",
- [
- "bench_tts.py",
- "--model",
- "test/ModelB",
- "--task",
- "voice_design",
- "--model-configs",
- str(model_configs_path),
- ],
- )
- with pytest.raises(SystemExit):
- bench_tts.main()
diff --git a/tests/benchmarks/test_diffusion_backends_metrics.py b/tests/benchmarks/test_diffusion_backends_metrics.py
deleted file mode 100644
index 2d51d0f1d38..00000000000
--- a/tests/benchmarks/test_diffusion_backends_metrics.py
+++ /dev/null
@@ -1,107 +0,0 @@
-import pytest
-
-from benchmarks.diffusion.backends import RequestFuncInput, async_request_chat_completions
-
-
-class _MockResponse:
- def __init__(self, payload: dict, status: int = 200):
- self._payload = payload
- self.status = status
-
- async def __aenter__(self):
- return self
-
- async def __aexit__(self, exc_type, exc, tb):
- return False
-
- async def json(self):
- return self._payload
-
- async def text(self):
- return str(self._payload)
-
-
-class _MockSession:
- def __init__(self, payload: dict):
- self._payload = payload
-
- def post(self, *args, **kwargs):
- return _MockResponse(self._payload)
-
-
-@pytest.mark.core_model
-@pytest.mark.benchmark
-@pytest.mark.cpu
-@pytest.mark.asyncio
-async def test_chat_completions_metrics_fallback_to_top_level():
- payload = {
- "choices": [
- {
- "message": {
- "content": [
- {
- "type": "image_url",
- "image_url": {"url": "data:image/png;base64,abc"},
- }
- ]
- }
- }
- ],
- "metrics": {
- "stage_durations": {"diffusion": 1.25},
- "peak_memory_mb": 4096.0,
- },
- }
-
- output = await async_request_chat_completions(
- RequestFuncInput(
- prompt="draw a cat",
- api_url="http://test.local/v1/chat/completions",
- model="ByteDance-Seed/BAGEL-7B-MoT",
- ),
- session=_MockSession(payload),
- )
-
- assert output.success is True
- assert output.stage_durations == {"diffusion": 1.25}
- assert output.peak_memory_mb == 4096.0
-
-
-@pytest.mark.core_model
-@pytest.mark.benchmark
-@pytest.mark.cpu
-@pytest.mark.asyncio
-async def test_chat_completions_metrics_message_level_takes_precedence():
- payload = {
- "choices": [
- {
- "message": {
- "content": [
- {
- "type": "image_url",
- "image_url": {"url": "data:image/png;base64,abc"},
- "stage_durations": {"message_stage": 0.7},
- "peak_memory_mb": 1234.0,
- }
- ]
- }
- }
- ],
- "metrics": {
- "stage_durations": {"top_level_stage": 9.9},
- "peak_memory_mb": 9999.0,
- },
- }
-
- output = await async_request_chat_completions(
- RequestFuncInput(
- prompt="draw a dog",
- api_url="http://test.local/v1/chat/completions",
- model="ByteDance-Seed/BAGEL-7B-MoT",
- ),
- session=_MockSession(payload),
- )
-
- assert output.success is True
- assert output.stage_durations == {"message_stage": 0.7}
- assert output.peak_memory_mb == 1234.0
diff --git a/tests/benchmarks/test_seed_tts_dataset_variants.py b/tests/benchmarks/test_seed_tts_dataset_variants.py
deleted file mode 100644
index 7fa5747bdfd..00000000000
--- a/tests/benchmarks/test_seed_tts_dataset_variants.py
+++ /dev/null
@@ -1,168 +0,0 @@
-"""Tests for SeedTTSTextDataset, SeedTTSTextSampleRequest, SeedTTSDesignDataset,
-and SeedTTSDesignSampleRequest.
-
-vllm stubs are installed by tests/benchmarks/conftest.py before collection.
-"""
-
-from __future__ import annotations
-
-import importlib.util
-import sys
-from pathlib import Path
-
-import pytest
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-# Load the data module directly (bypasses vllm_omni.__init__ heavy imports).
-_REPO_ROOT = Path(__file__).resolve().parents[2]
-_MODULE_PATH = _REPO_ROOT / "vllm_omni" / "benchmarks" / "data_modules" / "seed_tts_dataset.py"
-_MODULE_NAME = "vllm_omni.benchmarks.data_modules.seed_tts_dataset"
-
-if _MODULE_NAME not in sys.modules:
- _spec = importlib.util.spec_from_file_location(_MODULE_NAME, _MODULE_PATH)
- _mod = importlib.util.module_from_spec(_spec)
- sys.modules[_MODULE_NAME] = _mod
- _spec.loader.exec_module(_mod)
-
-from vllm_omni.benchmarks.data_modules.seed_tts_dataset import ( # noqa: E402
- SeedTTSDesignDataset,
- SeedTTSDesignSampleRequest,
- SeedTTSTextDataset,
- SeedTTSTextSampleRequest,
-)
-
-# ---------------------------------------------------------------------------
-# Fixtures
-# ---------------------------------------------------------------------------
-
-
-@pytest.fixture()
-def seed_tts_root(tmp_path: Path) -> Path:
- """Minimal seed-tts-style directory with 5 entries."""
- locale_dir = tmp_path / "en"
- locale_dir.mkdir()
- wav_dir = locale_dir / "prompt-wavs"
- wav_dir.mkdir()
- for i in range(5):
- (wav_dir / f"utt{i:03d}.wav").write_bytes(b"RIFF\x00\x00\x00\x00WAVE")
- meta = "\n".join(f"utt{i:03d}|ref text {i}|prompt-wavs/utt{i:03d}.wav|target text {i}" for i in range(5))
- (locale_dir / "meta.lst").write_text(meta, encoding="utf-8")
- return tmp_path
-
-
-@pytest.fixture()
-def mock_tokenizer(mocker):
- tokenizer = mocker.MagicMock()
- tokenizer.encode = lambda text, **kw: [0] * len(text.split())
- tokenizer.get_vocab.return_value = {"": 0}
- tokenizer.all_special_ids = []
- tokenizer.all_special_tokens = []
- tokenizer.vocab_size = 1
- return tokenizer
-
-
-# ---------------------------------------------------------------------------
-# Tests
-# ---------------------------------------------------------------------------
-
-
-def test_seed_tts_text_dataset_omits_ref_audio(seed_tts_root, mock_tokenizer):
- ds = SeedTTSTextDataset(
- dataset_path=str(seed_tts_root),
- random_seed=0,
- locale="en",
- disable_shuffle=True,
- )
- requests = ds.sample(mock_tokenizer, num_requests=3)
- assert len(requests) == 3
- for req in requests:
- assert isinstance(req, SeedTTSTextSampleRequest)
- assert req.seed_tts_speech_extra is None or "ref_audio" not in (req.seed_tts_speech_extra or {})
- assert req.seed_tts_ref_wav_path == ""
- assert "target text" in req.prompt
-
-
-# ---------------------------------------------------------------------------
-# SeedTTSDesignDataset tests
-# ---------------------------------------------------------------------------
-
-
-@pytest.fixture()
-def seed_tts_design_root(tmp_path: Path) -> Path:
- """seed-tts-design directory with 5-field meta.lst entries."""
- locale_dir = tmp_path / "en"
- locale_dir.mkdir()
- meta = "\n".join(
- f"des{i:03d}|||target text {i}|A warm {['female', 'male'][i % 2]} voice with neutral accent." for i in range(5)
- )
- (locale_dir / "meta.lst").write_text(meta, encoding="utf-8")
- return tmp_path
-
-
-def test_seed_tts_design_dataset_has_instructions(seed_tts_design_root, mock_tokenizer):
- ds = SeedTTSDesignDataset(
- dataset_path=str(seed_tts_design_root),
- random_seed=0,
- locale="en",
- disable_shuffle=True,
- )
- requests = ds.sample(mock_tokenizer, num_requests=3)
- assert len(requests) == 3
- for req in requests:
- assert isinstance(req, SeedTTSDesignSampleRequest)
- extra = req.seed_tts_speech_extra or {}
- assert "instructions" in extra
- assert extra["instructions"], "instructions must be non-empty"
- assert extra.get("task_type") == "VoiceDesign"
- assert "ref_audio" not in extra
- assert req.seed_tts_ref_wav_path == ""
-
-
-def test_seed_tts_design_dataset_rejects_missing_description(seed_tts_design_root, mock_tokenizer):
- """Lines without a voice_description should be skipped."""
- locale_dir = seed_tts_design_root / "en"
- # The bad line has 4 fields, not 5, so will be filtered
- meta = "bad|||target text without description\n" + "\n".join(
- f"ok|||target text {i}|A clear female voice." for i in range(9)
- )
- (locale_dir / "meta.lst").write_text(meta, encoding="utf-8")
- ds = SeedTTSDesignDataset(
- dataset_path=str(seed_tts_design_root),
- random_seed=0,
- locale="en",
- disable_shuffle=True,
- )
- requests = ds.sample(mock_tokenizer, num_requests=10, no_oversample=True)
- assert len(requests) == 9 # since we filter the bad row out and don't oversample
- for req in requests:
- assert isinstance(req, SeedTTSDesignSampleRequest)
- assert req.seed_tts_utterance_id == "ok"
-
-
-def test_attach_sets_seed_tts_row_even_without_extra_body():
- """seed_tts_row=True must be set for SeedTTSTextSampleRequest (no extra body)."""
- from vllm_omni.benchmarks.data_modules.seed_tts_dataset import SeedTTSTextSampleRequest
-
- req = SeedTTSTextSampleRequest(
- prompt="hello world",
- prompt_len=2,
- expected_output_len=100,
- multi_modal_data=None,
- request_id="test-0",
- seed_tts_speech_extra=None,
- seed_tts_ref_wav_path="",
- )
- assert req.seed_tts_speech_extra is None
- assert req.seed_tts_ref_wav_path == ""
- # The fix ensures that even with speech_extra=None, the function
- # sets seed_tts_row=True. We verify the source code has the fix.
- import inspect
-
- import vllm_omni.benchmarks.patch.patch as patch_mod
-
- src = inspect.getsource(patch_mod._attach_seed_tts_to_request_func_input)
- # seed_tts_row must be set BEFORE the 'if not ex: return' check
- row_pos = src.index("seed_tts_row")
- not_ex_pos = src.index("if not ex:")
- assert row_pos < not_ex_pos, "seed_tts_row must be set before 'if not ex: return'"
diff --git a/tests/comfyui/conftest.py b/tests/comfyui/conftest.py
index 4280d3506ff..0b4565e9465 100644
--- a/tests/comfyui/conftest.py
+++ b/tests/comfyui/conftest.py
@@ -9,8 +9,8 @@
import os
import sys
-from types import ModuleType, SimpleNamespace
from typing import BinaryIO, TypedDict
+from unittest.mock import MagicMock
def pytest_configure(config):
@@ -58,15 +58,15 @@ def save_to(self, file: str | BinaryIO):
else:
file.write(self._data)
- mock_comfy_api = ModuleType("comfy_api")
- mock_comfy_api_input = ModuleType("comfy_api.input")
+ mock_comfy_api = MagicMock()
+ mock_comfy_api_input = MagicMock()
mock_comfy_api_input.AudioInput = AudioInput
mock_comfy_api_input.VideoInput = VideoInput
mock_comfy_api.input = mock_comfy_api_input
- mock_comfy_api_latest = ModuleType("comfy_api.latest")
- mock_comfy_api_latest.Types = SimpleNamespace(VideoComponents=lambda **kwargs: kwargs)
- mock_comfy_api_latest.InputImpl = SimpleNamespace(
- VideoFromComponents=lambda _: VideoInput(b"mock_video_from_components")
+ mock_comfy_api_latest = MagicMock()
+ mock_comfy_api_latest.Types.VideoComponents = MagicMock(side_effect=lambda **kwargs: kwargs)
+ mock_comfy_api_latest.InputImpl.VideoFromComponents = MagicMock(
+ side_effect=lambda _: VideoInput(b"mock_video_from_components")
)
mock_comfy_api.latest = mock_comfy_api_latest
@@ -76,8 +76,8 @@ def mock_load(_: str | BinaryIO):
sample_rate = 24000
return waveform, sample_rate
- mock_comfy_extras = ModuleType("comfy_extras")
- mock_nodes_audio = ModuleType("comfy_extras.nodes_audio")
+ mock_comfy_extras = MagicMock()
+ mock_nodes_audio = MagicMock()
mock_nodes_audio.load = mock_load
mock_comfy_extras.nodes_audio = mock_nodes_audio
diff --git a/tests/comfyui/test_comfyui_integration.py b/tests/comfyui/test_comfyui_integration.py
index 80e86d82412..f6ce82f9b28 100644
--- a/tests/comfyui/test_comfyui_integration.py
+++ b/tests/comfyui/test_comfyui_integration.py
@@ -13,6 +13,7 @@
from enum import StrEnum, auto
from types import SimpleNamespace
from typing import Any, NamedTuple
+from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import requests
@@ -27,7 +28,6 @@
)
from comfyui_vllm_omni.utils.types import AutoregressionSamplingParams, DiffusionSamplingParams, WanModelSpecificParams
from PIL import Image
-from pytest_mock import MockerFixture
from vllm import SamplingParams
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.utils.argparse_utils import FlexibleArgumentParser
@@ -217,10 +217,9 @@ def _build_diffusion_video_output() -> OmniRequestOutput:
def _build_diffusion_image_output_for_chat_endpoint() -> OmniRequestOutput:
- request_output = SimpleNamespace(
- images=[_build_image_output(color="blue")],
- finished=True,
- )
+ request_output = MagicMock()
+ request_output.images = [_build_image_output(color="blue")]
+ request_output.finished = True
return OmniRequestOutput(
request_id="test_req_img_chat",
finished=True,
@@ -390,55 +389,51 @@ def sampling_case(request) -> SamplingCase:
@pytest.fixture
-def mock_async_omni(
- server_case: ServerCase,
- sampling_case: SamplingCase,
- monkeypatch: pytest.MonkeyPatch,
- mocker: MockerFixture,
-):
+def mock_async_omni(server_case: ServerCase, sampling_case: SamplingCase):
async def _mock_preprocess_chat(self, *args, **kwargs):
return ([{"role": "user", "content": "test"}], [{"prompt": "test prompt"}])
# Need to mock AsyncOmni itself (not only its generate method) because
# 1. The API layer uses its stage_list and stage_configs attributes
# 2. Its __init__ method has slow side effects (model & config loading).
- mock_async_omni_cls = mocker.patch("vllm_omni.entrypoints.openai.api_server.AsyncOmni")
- monkeypatch.setattr(
- "vllm_omni.entrypoints.openai.serving_chat.OmniOpenAIServingChat._preprocess_chat",
- _mock_preprocess_chat,
- )
-
- mock_instance = mocker.AsyncMock(spec=RealAsyncOmni)
- mock_instance.generate = _build_mock_outputs(server_case.outputs, sampling_case, server_case)
-
- mock_instance.stage_list = server_case.stage_list
- mock_instance.stage_configs = server_case.stage_configs
- mock_instance.output_modalities = _build_output_modalities(server_case.stage_configs)
- mock_instance.default_sampling_params_list = [
- SamplingParams() if _stage_type(stage) != "diffusion" else mocker.MagicMock()
- for stage in server_case.stage_configs
- ]
- mock_instance.errored = False
- mock_instance.dead_error = RuntimeError("Mock engine error")
- mock_instance.model_config = mocker.MagicMock(
- max_model_len=4096,
- io_processor_plugin=None,
- allowed_local_media_path=None,
- allowed_media_domains=None,
- )
- # Mimic Qwen3-TTS talker speaker config so CustomVoice validation passes.
- mock_instance.model_config.hf_config = mocker.MagicMock()
- mock_instance.model_config.hf_config.talker_config = mocker.MagicMock()
- mock_instance.model_config.hf_config.talker_config.speaker_id = {"Vivian": 0}
- mock_instance.io_processor = mocker.MagicMock()
- mock_instance.input_processor = mocker.MagicMock()
- mock_instance.shutdown = mocker.MagicMock()
- mock_instance.get_vllm_config = mocker.AsyncMock(return_value=None)
- mock_instance.get_supported_tasks = mocker.AsyncMock(return_value=["generate"])
- mock_instance.get_tokenizer = mocker.AsyncMock(return_value=None)
+ with (
+ patch("vllm_omni.entrypoints.openai.api_server.AsyncOmni") as MockAsyncOmni,
+ patch(
+ "vllm_omni.entrypoints.openai.serving_chat.OmniOpenAIServingChat._preprocess_chat",
+ new=_mock_preprocess_chat,
+ ),
+ ):
+ mock_instance = AsyncMock(spec=RealAsyncOmni)
+ mock_instance.generate = _build_mock_outputs(server_case.outputs, sampling_case, server_case)
+
+ mock_instance.stage_list = server_case.stage_list
+ mock_instance.stage_configs = server_case.stage_configs
+ mock_instance.output_modalities = _build_output_modalities(server_case.stage_configs)
+ mock_instance.default_sampling_params_list = [
+ SamplingParams() if _stage_type(stage) != "diffusion" else MagicMock()
+ for stage in server_case.stage_configs
+ ]
+ mock_instance.errored = False
+ mock_instance.dead_error = RuntimeError("Mock engine error")
+ mock_instance.model_config = MagicMock(
+ max_model_len=4096,
+ io_processor_plugin=None,
+ allowed_local_media_path=None,
+ allowed_media_domains=None,
+ )
+ # Mimic Qwen3-TTS talker speaker config so CustomVoice validation passes.
+ mock_instance.model_config.hf_config = MagicMock()
+ mock_instance.model_config.hf_config.talker_config = MagicMock()
+ mock_instance.model_config.hf_config.talker_config.speaker_id = {"Vivian": 0}
+ mock_instance.io_processor = MagicMock()
+ mock_instance.input_processor = MagicMock()
+ mock_instance.shutdown = MagicMock()
+ mock_instance.get_vllm_config = AsyncMock(return_value=None)
+ mock_instance.get_supported_tasks = AsyncMock(return_value=["generate"])
+ mock_instance.get_tokenizer = AsyncMock(return_value=None)
- mock_async_omni_cls.return_value = mock_instance
- yield mock_async_omni_cls
+ MockAsyncOmni.return_value = mock_instance
+ yield MockAsyncOmni
@pytest.fixture
@@ -588,9 +583,9 @@ async def test_image_generation_node(api_server: str, model: str, image_input: b
ServerCase(
served_model="Qwen/Qwen2.5-Omni-7B",
stage_list=[
- SimpleNamespace(is_comprehension=True, model_stage="llm"),
- SimpleNamespace(is_comprehension=False, model_stage="llm"),
- SimpleNamespace(is_comprehension=False, model_stage="llm"),
+ MagicMock(is_comprehension=True, model_stage="llm"),
+ MagicMock(is_comprehension=False, model_stage="llm"),
+ MagicMock(is_comprehension=False, model_stage="llm"),
],
stage_configs=[
_make_stage_config("llm", is_comprehension=True, model_stage="thinker"),
diff --git a/tests/config/__init__.py b/tests/config/__init__.py
deleted file mode 100644
index e69de29bb2d..00000000000
diff --git a/tests/config/test_pipeline_registry.py b/tests/config/test_pipeline_registry.py
deleted file mode 100644
index 6cc7c9258ed..00000000000
--- a/tests/config/test_pipeline_registry.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Tests for the central pipeline registry (2.5/N)."""
-
-from __future__ import annotations
-
-import pytest
-
-from vllm_omni.config.pipeline_registry import _OMNI_PIPELINES
-from vllm_omni.config.stage_config import (
- _PIPELINE_REGISTRY,
- PipelineConfig,
- StageExecutionType,
- StagePipelineConfig,
- register_pipeline,
-)
-
-
-class TestCentralRegistryDeclarations:
- """Every in-tree pipeline must be declared exactly once in the central registry."""
-
- def test_omni_entries_visible_in_registry(self):
- for key in _OMNI_PIPELINES:
- assert key in _PIPELINE_REGISTRY
-
- def test_expected_omni_pipelines_present(self):
- # Guard against accidental removal during future refactors.
- assert "qwen2_5_omni" in _OMNI_PIPELINES
- assert "qwen2_5_omni_thinker_only" in _OMNI_PIPELINES
- assert "qwen3_omni_moe" in _OMNI_PIPELINES
- assert "qwen3_tts" in _OMNI_PIPELINES
-
-
-class TestLazyLoading:
- """Pipelines are imported only on first access."""
-
- def test_contains_without_import(self):
- # ``in`` hits the lazy map, not the loaded cache.
- assert "qwen3_omni_moe" in _PIPELINE_REGISTRY
-
- def test_getitem_loads_correct_pipeline(self):
- pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
- assert pipeline.model_type == "qwen3_omni_moe"
- assert pipeline.model_arch == "Qwen3OmniMoeForConditionalGeneration"
-
- def test_unknown_model_type_returns_none_via_get(self):
- assert _PIPELINE_REGISTRY.get("not_a_real_pipeline") is None
-
- def test_unknown_model_type_raises_keyerror_via_getitem(self):
- with pytest.raises(KeyError):
- _PIPELINE_REGISTRY["not_a_real_pipeline"]
-
- def test_iteration_yields_registered_pipelines(self):
- keys = set(_PIPELINE_REGISTRY)
- assert "qwen2_5_omni" in keys
- assert "qwen3_omni_moe" in keys
-
-
-class TestDynamicRegistration:
- """``register_pipeline()`` still works for plugins and tests."""
-
- def test_register_adds_to_registry(self):
- custom = PipelineConfig(
- model_type="_test_dynamic_registration",
- model_arch="DynamicTestModel",
- stages=(
- StagePipelineConfig(
- stage_id=0,
- model_stage="test",
- execution_type=StageExecutionType.LLM_AR,
- input_sources=(),
- final_output=True,
- ),
- ),
- )
- register_pipeline(custom)
- try:
- assert "_test_dynamic_registration" in _PIPELINE_REGISTRY
- assert _PIPELINE_REGISTRY["_test_dynamic_registration"] is custom
- finally:
- # Don't leak the test registration into other tests.
- if "_test_dynamic_registration" in _PIPELINE_REGISTRY:
- del _PIPELINE_REGISTRY["_test_dynamic_registration"]
-
- def test_dynamic_registration_overrides_lazy_entry(self):
- # Build a substitute for qwen3_omni_moe that we can distinguish.
- original = _PIPELINE_REGISTRY["qwen3_omni_moe"]
- override = PipelineConfig(
- model_type="qwen3_omni_moe",
- model_arch="OverriddenArch",
- stages=original.stages,
- )
- register_pipeline(override)
- try:
- assert _PIPELINE_REGISTRY["qwen3_omni_moe"].model_arch == "OverriddenArch"
- finally:
- # Remove the dynamic override so later tests see the original.
- if "qwen3_omni_moe" in _PIPELINE_REGISTRY._loaded:
- del _PIPELINE_REGISTRY["qwen3_omni_moe"]
diff --git a/tests/conftest.py b/tests/conftest.py
index 77075f9525a..8e9a7bf9280 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,62 +1,3180 @@
-"""
-Root pytest entrypoint for the vLLM-Omni test suite.
-
-- `tests/conftest.py` stays thin: plugin registration + compatibility re-exports.
-- Importable utilities live under `tests/helpers/`.
-- Fixtures live under `tests/helpers/fixtures/` and are loaded via `pytest_plugins`.
-"""
-
-from __future__ import annotations
-
-pytest_plugins = (
- "tests.helpers.fixtures.env",
- "tests.helpers.fixtures.log",
- "tests.helpers.fixtures.run_args",
- "tests.helpers.fixtures.runtime",
-)
-
-
-def pytest_terminal_summary(terminalreporter, exitstatus, config):
- # Marker for Buildkite log folding before pytest summary lines.
- terminalreporter.write_line("--- Running Summary")
-
-
-# Backward-compatible lazy re-exports.
-# (Many tests still import from `tests.conftest`; migrate these imports to `tests.helpers.*` over time.)
-# Keep these lazy so conftest import does not trigger heavy helper dependencies.
-_ASSERTION_EXPORT_NAMES = (
- "assert_audio_speech_response",
- "assert_diffusion_response",
- "assert_image_diffusion_response",
- "assert_image_valid",
- "assert_omni_response",
- "assert_video_diffusion_response",
- "assert_video_valid",
-)
-_MEDIA_EXPORT_NAMES = (
- "convert_audio_bytes_to_text",
- "convert_audio_file_to_text",
- "cosine_similarity_text",
- "decode_b64_image",
- "generate_synthetic_audio",
- "generate_synthetic_image",
- "generate_synthetic_video",
-)
-_STAGE_CONFIG_EXPORT_NAMES = ("modify_stage_config",)
-_RUNTIME_EXPORT_NAMES = (
- "DiffusionResponse",
- "OmniResponse",
- "OmniRunner",
- "OmniRunnerHandler",
- "OmniServer",
- "OmniServerParams",
- "OmniServerStageCli",
- "OpenAIClientHandler",
- "dummy_messages_from_mix_data",
-)
-_LAZY_EXPORT_MODULES = {
- **{name: "tests.helpers.assertions" for name in _ASSERTION_EXPORT_NAMES},
- **{name: "tests.helpers.media" for name in _MEDIA_EXPORT_NAMES},
- **{name: "tests.helpers.stage_config" for name in _STAGE_CONFIG_EXPORT_NAMES},
- **{name: "tests.helpers.runtime" for name in _RUNTIME_EXPORT_NAMES},
+import base64
+import datetime
+import io
+import json
+import math
+import os
+import random
+import re
+import tempfile
+
+import requests
+
+os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+# Set CPU device for CI environments without GPU
+if "VLLM_TARGET_DEVICE" not in os.environ:
+ os.environ["VLLM_TARGET_DEVICE"] = "cpu"
+
+import concurrent.futures
+import gc
+import multiprocessing
+import socket
+import subprocess
+import sys
+import threading
+import time
+import uuid
+from collections.abc import Generator
+from dataclasses import dataclass
+from io import BytesIO
+from pathlib import Path
+from typing import Any, NamedTuple
+
+import cv2
+import numpy as np
+import psutil
+import pytest
+import soundfile as sf
+import torch
+import yaml
+from openai import OpenAI, omit
+from PIL import Image
+from transformers import pipeline
+from vllm import TextPrompt
+from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
+from vllm.logger import init_logger
+from vllm.utils.network_utils import get_open_port
+
+from vllm_omni.entrypoints.omni import Omni
+from vllm_omni.inputs.data import OmniSamplingParams
+from vllm_omni.outputs import OmniRequestOutput
+from vllm_omni.platforms import current_omni_platform
+
+logger = init_logger(__name__)
+
+PromptAudioInput = list[tuple[Any, int]] | tuple[Any, int] | None
+PromptImageInput = list[Any] | Any | None
+PromptVideoInput = list[Any] | Any | None
+
+_GENDER_PIPELINE = None
+# transformers.Pipeline is not thread-safe; concurrent e2e requests must serialize inference.
+_GENDER_PIPELINE_LOCK = threading.Lock()
+
+# int16 mono PCM from /v1/audio/speech when response_format=pcm (Qwen3-TTS code2wav output rate).
+_PCM_SPEECH_SAMPLE_RATE_HZ = 24_000
+
+
+class OmniServerParams(NamedTuple):
+ model: str
+ port: int | None = None
+ stage_config_path: str | None = None
+ server_args: list[str] | None = None
+ env_dict: dict[str, str] | None = None
+ use_omni: bool = True
+
+
+def assert_image_diffusion_response(
+ response,
+ request_config: dict[str, Any],
+ run_level: str = None,
+) -> None:
+ """
+ Validate image diffusion response.
+
+ Expected request_config schema:
+ {
+ "request_type": "image",
+ "extra_body": {
+ "num_outputs_per_prompt": 1,
+ "width": ...,
+ "height": ...,
+ ...
+ }
+ }
+ """
+ assert response.images is not None, "Image response is None"
+ assert len(response.images) > 0, "No images in response"
+
+ extra_body = request_config.get("extra_body") or {}
+
+ num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt")
+ if num_outputs_per_prompt is not None:
+ assert len(response.images) == num_outputs_per_prompt, (
+ f"Expected {num_outputs_per_prompt} images, got {len(response.images)}"
+ )
+
+ if run_level == "advanced_model":
+ width = extra_body.get("width")
+ height = extra_body.get("height")
+
+ if width is not None or height is not None:
+ for img in response.images:
+ assert_image_valid(img, width=width, height=height)
+
+
+def assert_video_diffusion_response(
+ response,
+ request_config: dict[str, Any],
+ run_level: str = None,
+) -> None:
+ """
+ Validate video diffusion response.
+
+ Expected request_config schema:
+ {
+ "request_type": "video",
+ "form_data": {
+ "prompt": "...",
+ "num_frames": ...,
+ "width": ...,
+ "height": ...,
+ "fps": ...,
+ ...
+ }
+ }
+ """
+ form_data = request_config.get("form_data", {})
+
+ assert response.videos is not None, "Video response is None"
+ assert len(response.videos) > 0, "No videos in response"
+
+ expected_frames = _maybe_int(form_data.get("num_frames"))
+ expected_width = _maybe_int(form_data.get("width"))
+ expected_height = _maybe_int(form_data.get("height"))
+ expected_fps = _maybe_int(form_data.get("fps"))
+
+ for vid_bytes in response.videos:
+ assert_video_valid(
+ vid_bytes,
+ num_frames=expected_frames,
+ width=expected_width,
+ height=expected_height,
+ fps=expected_fps,
+ )
+
+
+def assert_audio_diffusion_response(
+ response,
+ request_config: dict[str, Any],
+ run_level: str = None,
+) -> None:
+ """
+ Validate audio diffusion response.
+ """
+ raise NotImplementedError("Audio validation is not implemented yet")
+ # consider using assert_audio_valid defined above
+
+
+def _maybe_int(value: Any) -> int | None:
+ if value is None:
+ return None
+ return int(value)
+
+
+def assert_image_valid(image: Path | Image.Image, *, width: int | None = None, height: int | None = None):
+ """Assert the file is a loadable image with optional exact dimensions."""
+ if isinstance(image, Path):
+ assert image.exists(), f"Image not found: {image}"
+ image = Image.open(image)
+ image.load()
+ assert image.width > 0 and image.height > 0
+ if width is not None:
+ assert image.width == width, f"Expected width={width}, got {image.width}"
+ if height is not None:
+ assert image.height == height, f"Expected height={height}, got {image.height}"
+ return image
+
+
+def assert_video_valid(
+ video: Path | bytes | BytesIO,
+ *,
+ num_frames: int | None = None,
+ width: int | None = None,
+ height: int | None = None,
+ fps: float | None = None,
+) -> dict[str, int | float]:
+ """Assert the MP4 has the expected resolution and exact frame count."""
+ temp_path = None
+ cap = None
+ try:
+ # Normalize input to file path
+ if isinstance(video, Path):
+ if not video.exists():
+ raise AssertionError(f"Video file not found: {video}")
+ video_path = str(video)
+ else:
+ # Create temp file for bytes/BytesIO
+ suffix = ".mp4"
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix, mode="wb") as tmp:
+ if isinstance(video, bytes):
+ tmp.write(video)
+ elif isinstance(video, BytesIO):
+ tmp.write(video.getvalue())
+ else:
+ raise TypeError(f"Unsupported video type: {type(video)}")
+ temp_path = Path(tmp.name)
+ video_path = str(temp_path)
+
+ # Open video capture
+ cap = cv2.VideoCapture(video_path)
+ if not cap.isOpened():
+ raise AssertionError(f"Failed to open video: {video_path}")
+
+ # Extract properties
+ actual_num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ actual_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ actual_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ actual_fps = cap.get(cv2.CAP_PROP_FPS)
+
+ actual_num_frames = 0
+ while True:
+ ok, _frame = cap.read()
+ if not ok:
+ break
+ actual_num_frames += 1
+
+ # Basic validity checks
+ if actual_num_frames <= 0:
+ raise AssertionError(f"Invalid frame count: {actual_num_frames} (must be > 0)")
+ if actual_width <= 0 or actual_height <= 0:
+ raise AssertionError(f"Invalid dimensions: {actual_width}x{actual_height} (must be > 0)")
+ if actual_fps <= 0:
+ raise AssertionError(f"Invalid FPS: {actual_fps} (must be > 0)")
+
+ # Validate against expectations
+ if num_frames is not None:
+ expected_num_frames = (num_frames // 4) * 4 + 1
+ assert actual_num_frames == expected_num_frames, (
+ f"Frame count mismatch: expected {num_frames}, got {actual_num_frames}"
+ )
+ if width is not None:
+ assert actual_width == width, f"Width mismatch: expected {width}px, got {actual_width}px"
+ if height is not None:
+ assert actual_height == height, f"Height mismatch: expected {height}px, got {actual_height}px"
+ if fps is not None:
+ # Use tolerance for float comparison (codec rounding)
+ assert abs(actual_fps - fps) < 0.5, f"FPS mismatch: expected {fps}, got {actual_fps:.2f}"
+
+ return {"num_frames": actual_num_frames, "width": actual_width, "height": actual_height, "fps": actual_fps}
+
+ except Exception as e:
+ print(f"ERROR: {type(e).__name__}: {e}", flush=True)
+ raise
+
+ finally:
+ # Cleanup resources
+ if cap is not None:
+ cap.release()
+ if temp_path and temp_path.exists():
+ try:
+ temp_path.unlink()
+ except OSError:
+ pass
+
+
+def assert_audio_valid(path: Path, *, sample_rate: int, channels: int, duration_s: float) -> None:
+ """Assert the WAV has the expected sample rate, channel count, and duration."""
+ assert path.exists(), f"Audio not found: {path}"
+ info = sf.info(str(path))
+ assert info.samplerate == sample_rate, f"Expected sample_rate={sample_rate}, got {info.samplerate}"
+ assert info.channels == channels, f"Expected {channels} channel(s), got {info.channels}"
+ expected_frames = int(duration_s * sample_rate)
+ assert info.frames == expected_frames, (
+ f"Expected {expected_frames} frames ({duration_s}s @ {sample_rate} Hz), got {info.frames}"
+ )
+
+
+def decode_b64_image(b64: str):
+ img = Image.open(BytesIO(base64.b64decode(b64)))
+ img.load()
+ return img
+
+
+@pytest.fixture(scope="session")
+def model_prefix() -> str:
+ """Optional model-path prefix from MODEL_PREFIX env var.
+ Useful if models are downloaded to non-default local directories.
+ """
+ prefix = os.environ.get("MODEL_PREFIX", "")
+ return f"{prefix.rstrip('/')}/" if prefix else ""
+
+
+@pytest.fixture(autouse=True)
+def default_vllm_config():
+ """Set a default VllmConfig for all tests.
+
+ This fixture is auto-used for all tests to ensure that any test
+ that directly instantiates vLLM CustomOps (e.g., RMSNorm, LayerNorm)
+ or model components has the required VllmConfig context.
+
+ This fixture is required for vLLM 0.14.0+ where CustomOp initialization
+ requires a VllmConfig context set via set_current_vllm_config().
+ """
+ from vllm.config import DeviceConfig, VllmConfig, set_current_vllm_config
+
+ # Use CPU device if no GPU is available (e.g., in CI environments)
+ has_gpu = torch.cuda.is_available() and torch.cuda.device_count() > 0
+ device = "cuda" if has_gpu else "cpu"
+ device_config = DeviceConfig(device=device)
+
+ with set_current_vllm_config(VllmConfig(device_config=device_config)):
+ yield
+
+
+@pytest.fixture(autouse=True)
+def clean_gpu_memory_between_tests():
+ print("\n=== PRE-TEST GPU CLEANUP ===")
+ _run_pre_test_cleanup()
+ yield
+ _run_post_test_cleanup()
+
+
+@pytest.fixture(autouse=True)
+def log_test_name_before_test(request):
+ print(f"--- Running test: {request.node.name}")
+ yield
+
+
+def _run_pre_test_cleanup(enable_force=False):
+ if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force:
+ print("GPU cleanup disabled")
+ return
+
+ print("Pre-test GPU status:")
+
+ num_gpus = torch.cuda.device_count()
+ if num_gpus > 0:
+ try:
+ from tests.utils import wait_for_gpu_memory_to_clear
+
+ wait_for_gpu_memory_to_clear(
+ devices=list(range(num_gpus)),
+ threshold_ratio=0.05,
+ )
+ except Exception as e:
+ print(f"Pre-test cleanup note: {e}")
+
+
+def _run_post_test_cleanup(enable_force=False):
+ if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force:
+ print("GPU cleanup disabled")
+ return
+
+ if torch.cuda.is_available():
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ print("Post-test GPU status:")
+ _print_gpu_processes()
+
+
+def _print_gpu_processes():
+ """Print GPU information including nvidia-smi and system processes"""
+
+ print("\n" + "=" * 80)
+ print("NVIDIA GPU Information (nvidia-smi)")
+ print("=" * 80)
+
+ try:
+ nvidia_result = subprocess.run(
+ ["nvidia-smi"],
+ capture_output=True,
+ text=True,
+ timeout=5,
+ )
+
+ if nvidia_result.returncode == 0:
+ lines = nvidia_result.stdout.strip().split("\n")
+ for line in lines[:20]:
+ print(line)
+
+ if len(lines) > 20:
+ print(f"... (showing first 20 of {len(lines)} lines)")
+ else:
+ print("nvidia-smi command failed")
+
+ except (subprocess.TimeoutExpired, FileNotFoundError):
+ print("nvidia-smi not available or timed out")
+ except Exception as e:
+ print(f"Error running nvidia-smi: {e}")
+
+ print("\n" + "=" * 80)
+ print("Detailed GPU Processes (nvidia-smi pmon)")
+ print("=" * 80)
+
+ try:
+ pmon_result = subprocess.run(
+ ["nvidia-smi", "pmon", "-c", "1"],
+ capture_output=True,
+ text=True,
+ timeout=3,
+ )
+
+ if pmon_result.returncode == 0 and pmon_result.stdout.strip():
+ print(pmon_result.stdout)
+ else:
+ print("No active GPU processes found via nvidia-smi pmon")
+
+ except Exception:
+ print("nvidia-smi pmon not available")
+
+ print("\n" + "=" * 80)
+ print("System Processes with GPU keywords")
+ print("=" * 80)
+
+
+def dummy_messages_from_mix_data(
+ system_prompt: dict[str, Any] = None,
+ video_data_url: Any = None,
+ audio_data_url: Any = None,
+ image_data_url: Any = None,
+ content_text: str = None,
+):
+ """Create messages with video、image、audio data URL for OpenAI API."""
+
+ if content_text is not None:
+ content = [{"type": "text", "text": content_text}]
+ else:
+ content = []
+
+ media_items = []
+ if isinstance(video_data_url, list):
+ for video_url in video_data_url:
+ media_items.append((video_url, "video"))
+ else:
+ media_items.append((video_data_url, "video"))
+
+ if isinstance(image_data_url, list):
+ for url in image_data_url:
+ media_items.append((url, "image"))
+ else:
+ media_items.append((image_data_url, "image"))
+
+ if isinstance(audio_data_url, list):
+ for url in audio_data_url:
+ media_items.append((url, "audio"))
+ else:
+ media_items.append((audio_data_url, "audio"))
+
+ content.extend(
+ {"type": f"{media_type}_url", f"{media_type}_url": {"url": url}}
+ for url, media_type in media_items
+ if url is not None
+ )
+ messages = [{"role": "user", "content": content}]
+ if system_prompt is not None:
+ messages = [system_prompt] + messages
+ return messages
+
+
+def generate_synthetic_audio(
+ duration: int, # seconds
+ num_channels: int, # 1:Mono,2:Stereo 5:5.1 surround sound
+ sample_rate: int = 48000, # Default use 48000Hz.
+ save_to_file: bool = False,
+) -> dict[str, Any]:
+ """
+ Generate TTS speech with pyttsx3 and return base64 string.
+ """
+
+ import pyttsx3
+ import soundfile as sf
+
+ def _pick_voice(engine: pyttsx3.Engine) -> str | None:
+ voices = engine.getProperty("voices")
+ if not voices:
+ return None
+
+ preferred_tokens = (
+ "natural",
+ "jenny",
+ "sonia",
+ "susan",
+ "zira",
+ "aria",
+ "hazel",
+ "samantha",
+ "ava",
+ "allison",
+ "female",
+ "woman",
+ "english-us",
+ "en-us",
+ "english",
+ )
+ discouraged_tokens = (
+ "espeak",
+ "robot",
+ "mbrola",
+ "microsoft david",
+ "male",
+ "man",
+ )
+
+ best_voice = voices[0]
+ best_score = float("-inf")
+ for voice in voices:
+ voice_text = f"{getattr(voice, 'id', '')} {getattr(voice, 'name', '')}".lower()
+ voice_languages = " ".join(
+ lang.decode(errors="ignore") if isinstance(lang, bytes) else str(lang)
+ for lang in getattr(voice, "languages", [])
+ ).lower()
+ combined_text = f"{voice_text} {voice_languages}"
+ score = 0
+ for idx, token in enumerate(preferred_tokens):
+ if token in combined_text:
+ score += 20 - idx
+ for token in discouraged_tokens:
+ if token in combined_text:
+ score -= 10
+ if "english" in combined_text or "en_" in combined_text or "en-" in combined_text:
+ score += 4
+ if "en-us" in combined_text or "english-us" in combined_text:
+ score += 4
+ if score > best_score:
+ best_score = score
+ best_voice = voice
+
+ return best_voice.id
+
+ def _resample_audio(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
+ if src_sr == dst_sr or len(audio) == 0:
+ return audio.astype(np.float32)
+
+ src_len = audio.shape[0]
+ dst_len = max(1, int(round(src_len * float(dst_sr) / float(src_sr))))
+ src_idx = np.arange(src_len, dtype=np.float32)
+ dst_idx = np.linspace(0, src_len - 1, dst_len, dtype=np.float32)
+
+ resampled_channels: list[np.ndarray] = []
+ for ch in range(audio.shape[1]):
+ resampled_channels.append(np.interp(dst_idx, src_idx, audio[:, ch]).astype(np.float32))
+ return np.stack(resampled_channels, axis=1)
+
+ def _match_channels(audio: np.ndarray, target_channels: int) -> np.ndarray:
+ current_channels = audio.shape[1]
+ if current_channels == target_channels:
+ return audio.astype(np.float32)
+ if target_channels == 1:
+ return np.mean(audio, axis=1, keepdims=True, dtype=np.float32)
+ if current_channels == 1:
+ return np.repeat(audio, target_channels, axis=1).astype(np.float32)
+
+ collapsed = np.mean(audio, axis=1, keepdims=True, dtype=np.float32)
+ return np.repeat(collapsed, target_channels, axis=1).astype(np.float32)
+
+ def _trim_silence(audio: np.ndarray, threshold: float = 0.01) -> np.ndarray:
+ if len(audio) == 0:
+ return audio
+ energy = np.max(np.abs(audio), axis=1)
+ voiced = np.where(energy > threshold)[0]
+ if len(voiced) == 0:
+ return audio
+ start = max(0, int(voiced[0]) - int(sample_rate * 0.02))
+ end = min(len(audio), int(voiced[-1]) + int(sample_rate * 0.04) + 1)
+ return audio[start:end]
+
+ def _enhance_speech(audio: np.ndarray) -> np.ndarray:
+ if len(audio) == 0:
+ return audio.astype(np.float32)
+ enhanced = audio.astype(np.float32).copy()
+ enhanced -= np.mean(enhanced, axis=0, keepdims=True, dtype=np.float32)
+ if len(enhanced) > 1:
+ preemphasis = enhanced.copy()
+ preemphasis[1:] = enhanced[1:] - 0.94 * enhanced[:-1]
+ enhanced = 0.7 * enhanced + 0.3 * preemphasis
+ # Mild dynamic-range compression for ASR/TTS robustness.
+ enhanced = np.sign(enhanced) * np.sqrt(np.abs(enhanced))
+ # Light fade to avoid clicks after trimming/repeating.
+ fade = min(len(enhanced) // 4, max(1, int(sample_rate * 0.01)))
+ if fade > 1:
+ ramp_in = np.linspace(0.0, 1.0, fade, dtype=np.float32)
+ ramp_out = np.linspace(1.0, 0.0, fade, dtype=np.float32)
+ enhanced[:fade] *= ramp_in[:, None]
+ enhanced[-fade:] *= ramp_out[:, None]
+ peak = float(np.max(np.abs(enhanced)))
+ if peak > 1e-8:
+ enhanced = enhanced / peak * 0.95
+ return enhanced.astype(np.float32)
+
+ phrase_text = "test"
+ num_samples = int(sample_rate * max(1, duration))
+ audio_data = np.zeros((num_samples, num_channels), dtype=np.float32)
+
+ engine = pyttsx3.init()
+ engine.setProperty("rate", 112)
+ engine.setProperty("volume", 1.0)
+ selected_voice = _pick_voice(engine)
+ if selected_voice is not None:
+ engine.setProperty("voice", selected_voice)
+
+ temp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
+ temp_wav.close()
+
+ try:
+ engine.save_to_file(phrase_text, temp_wav.name)
+ engine.runAndWait()
+ engine.stop()
+
+ ready = False
+ for _ in range(50):
+ if os.path.exists(temp_wav.name) and os.path.getsize(temp_wav.name) > 44:
+ ready = True
+ break
+ time.sleep(0.1)
+
+ if not ready:
+ raise RuntimeError("pyttsx3 did not produce a WAV file in time.")
+
+ tts_audio, tts_sr = sf.read(temp_wav.name, dtype="float32", always_2d=True)
+ finally:
+ if os.path.exists(temp_wav.name):
+ os.unlink(temp_wav.name)
+
+ if len(tts_audio) == 0:
+ raise RuntimeError("pyttsx3 produced an empty WAV file.")
+
+ tts_audio = _resample_audio(tts_audio, tts_sr, sample_rate)
+ tts_audio = _match_channels(tts_audio, num_channels)
+ tts_audio = _trim_silence(tts_audio, threshold=0.012)
+ tts_audio = _enhance_speech(tts_audio)
+
+ lead_silence = min(int(sample_rate * 0.02), num_samples // 8)
+ pause_samples = int(sample_rate * 0.18)
+ start = lead_silence
+ phrase_len = tts_audio.shape[0]
+
+ while start < num_samples:
+ take = min(phrase_len, num_samples - start)
+ audio_data[start : start + take] = tts_audio[:take]
+ start += phrase_len + pause_samples
+
+ max_amp = float(np.max(np.abs(audio_data)))
+ if max_amp > 0:
+ audio_data = audio_data / max_amp * 0.95
+
+ audio_bytes: bytes | None = None
+ output_path: str | None = None
+ result: dict[str, Any] = {
+ "np_array": audio_data.copy(),
+ }
+
+ if save_to_file:
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
+ output_path = f"audio_{num_channels}ch_{timestamp}.wav"
+
+ try:
+ sf.write(output_path, audio_data, sample_rate, format="WAV", subtype="PCM_16")
+ print(f"Audio saved: {output_path}")
+
+ with open(output_path, "rb") as f:
+ audio_bytes = f.read()
+ except Exception as e:
+ print(f"Save failed: {e}")
+ save_to_file = False
+
+ # If not saving or save failed, create in memory
+ if not save_to_file or audio_bytes is None:
+ buffer = io.BytesIO()
+ sf.write(buffer, audio_data, sample_rate, format="WAV", subtype="PCM_16")
+ buffer.seek(0)
+ audio_bytes = buffer.read()
+
+ # Return result
+ base64_audio = base64.b64encode(audio_bytes).decode("utf-8")
+ result["base64"] = base64_audio
+ # Always include file_path to avoid KeyError in callers.
+ result["file_path"] = output_path if save_to_file and output_path else None
+
+ return result
+
+
+def _mux_mp4_bytes_with_synthetic_audio(
+ video_mp4_bytes: bytes,
+ *,
+ num_frames: int,
+ fps: float = 30.0,
+ sample_rate: int = 48000,
+) -> bytes:
+ """
+ Mux a video-only MP4 with mono TTS audio from :func:`generate_synthetic_audio` (AAC).
+
+ Audio length is at least the video duration in whole seconds (rounded up); ffmpeg
+ ``-shortest`` trims to the video when the WAV is longer.
+
+ Uses ffmpeg from ``imageio_ffmpeg`` when available, else ``ffmpeg`` on PATH.
+ If TTS or mux fails, returns ``video_mp4_bytes`` unchanged.
+
+ Mux subprocess does **not** use ``capture_output=True``: ffmpeg can block writing
+ to a full stderr pipe while :func:`subprocess.run` waits for exit (classic deadlock).
+ """
+ duration_sec = num_frames / fps if fps > 0 else 0.0
+ # generate_synthetic_audio(duration=int) uses at least 1s of buffer internally
+ duration_int = max(1, int(math.ceil(duration_sec)))
+
+ try:
+ audio_result = generate_synthetic_audio(
+ duration=duration_int,
+ num_channels=1,
+ sample_rate=sample_rate,
+ save_to_file=False,
+ )
+ audio_pcm = audio_result["np_array"]
+ except Exception as e:
+ logger.warning("Synthetic video: generate_synthetic_audio failed (%s); using video-only MP4.", e)
+ return video_mp4_bytes
+
+ try:
+ import imageio_ffmpeg
+
+ ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe()
+ except Exception:
+ ffmpeg_exe = "ffmpeg"
+
+ import tempfile
+
+ try:
+ with tempfile.TemporaryDirectory(prefix="syn_vid_mux_") as tmp:
+ vid_path = os.path.join(tmp, "video.mp4")
+ wav_path = os.path.join(tmp, "audio.wav")
+ out_path = os.path.join(tmp, "out.mp4")
+ with open(vid_path, "wb") as f:
+ f.write(video_mp4_bytes)
+ sf.write(wav_path, audio_pcm, sample_rate, format="WAV", subtype="PCM_16")
+ cmd = [
+ ffmpeg_exe,
+ "-y",
+ "-nostdin",
+ "-hide_banner",
+ "-loglevel",
+ "error",
+ "-i",
+ vid_path,
+ "-i",
+ wav_path,
+ "-c:v",
+ "copy",
+ "-c:a",
+ "aac",
+ "-b:a",
+ "128k",
+ "-shortest",
+ "-movflags",
+ "+faststart",
+ out_path,
+ ]
+ subprocess.run(
+ cmd,
+ check=True,
+ stdin=subprocess.DEVNULL,
+ timeout=300,
+ )
+ with open(out_path, "rb") as f:
+ return f.read()
+ except (
+ FileNotFoundError,
+ subprocess.CalledProcessError,
+ subprocess.TimeoutExpired,
+ OSError,
+ ) as e:
+ logger.warning("Synthetic video: audio mux failed (%s); using video-only MP4.", e)
+ return video_mp4_bytes
+
+
+def generate_synthetic_video(
+ width: int,
+ height: int,
+ num_frames: int,
+ save_to_file: bool = False,
+ *,
+ embed_audio: bool = False,
+) -> dict[str, Any]:
+ """Generate synthetic video with bouncing balls and base64 MP4.
+
+ When ``embed_audio`` is True, muxes mono AAC from :func:`generate_synthetic_audio`
+ (TTS + ffmpeg) into the MP4; otherwise returns video-only MP4 (faster when tests do
+ not need an audio track).
+ """
+
+ import cv2
+ import imageio
+
+ # Create random balls
+ num_balls = random.randint(3, 8)
+ balls = []
+
+ for _ in range(num_balls):
+ radius = min(width, height) // 8
+ if radius < 1:
+ raise ValueError(f"Video dimensions ({width}x{height}) are too small for synthetic video generation")
+ x = random.randint(radius, width - radius)
+ y = random.randint(radius, height - radius)
+
+ speed = random.uniform(3.0, 8.0)
+ angle = random.uniform(0, 2 * math.pi)
+ vx = speed * math.cos(angle)
+ vy = speed * math.sin(angle)
+
+ # OpenCV uses BGR format, but imageio expects RGB
+ # We'll create in BGR first, then convert to RGB later
+ color_bgr = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255))
+
+ balls.append({"x": x, "y": y, "vx": vx, "vy": vy, "radius": radius, "color_bgr": color_bgr})
+
+ # Generate video frames
+ video_frames = []
+
+ for frame_idx in range(num_frames):
+ # Create black background (BGR format)
+ frame_bgr = np.zeros((height, width, 3), dtype=np.uint8)
+
+ for ball in balls:
+ # Update position
+ ball["x"] += ball["vx"]
+ ball["y"] += ball["vy"]
+
+ # Boundary collision detection
+ if ball["x"] - ball["radius"] <= 0 or ball["x"] + ball["radius"] >= width:
+ ball["vx"] = -ball["vx"]
+ ball["x"] = max(ball["radius"], min(width - ball["radius"], ball["x"]))
+
+ if ball["y"] - ball["radius"] <= 0 or ball["y"] + ball["radius"] >= height:
+ ball["vy"] = -ball["vy"]
+ ball["y"] = max(ball["radius"], min(height - ball["radius"], ball["y"]))
+
+ # Use cv2 to draw circle
+ x, y = int(ball["x"]), int(ball["y"])
+ radius = ball["radius"]
+
+ # Draw solid circle (main circle)
+ cv2.circle(frame_bgr, (x, y), radius, ball["color_bgr"], -1)
+
+ # Add simple 3D effect: draw a brighter center
+ if radius > 3: # Only add highlight when radius is large enough
+ highlight_radius = max(1, radius // 2)
+ highlight_x = max(highlight_radius, min(x - radius // 4, width - highlight_radius))
+ highlight_y = max(highlight_radius, min(y - radius // 4, height - highlight_radius))
+
+ # Create highlight color (brighter)
+ highlight_color = tuple(min(c + 40, 255) for c in ball["color_bgr"])
+ cv2.circle(frame_bgr, (highlight_x, highlight_y), highlight_radius, highlight_color, -1)
+
+ # Convert BGR to RGB for imageio
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
+ video_frames.append(frame_rgb)
+
+ video_array = np.array(video_frames)
+ result = {
+ "np_array": video_array,
+ }
+ saved_file_path = None
+
+ fps = 30
+ buffer = io.BytesIO()
+ writer_kwargs = {
+ "format": "mp4",
+ "fps": fps,
+ "codec": "libx264",
+ "quality": 7,
+ "pixelformat": "yuv420p",
+ "macro_block_size": 16,
+ "ffmpeg_params": [
+ "-preset",
+ "medium",
+ "-crf",
+ "23",
+ "-movflags",
+ "+faststart",
+ "-pix_fmt",
+ "yuv420p",
+ "-vf",
+ f"scale={width}:{height}",
+ ],
+ }
+
+ try:
+ with imageio.get_writer(buffer, **writer_kwargs) as writer:
+ for frame in video_frames:
+ writer.append_data(frame)
+ buffer.seek(0)
+ video_only_bytes = buffer.read()
+ except Exception as e:
+ print(f"Warning: Failed to encode synthetic video: {e}")
+ raise
+
+ if embed_audio:
+ video_bytes = _mux_mp4_bytes_with_synthetic_audio(video_only_bytes, num_frames=num_frames, fps=float(fps))
+ else:
+ video_bytes = video_only_bytes
+
+ if save_to_file:
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
+ output_path = f"video_{width}x{height}_{timestamp}.mp4"
+ try:
+ with open(output_path, "wb") as f:
+ f.write(video_bytes)
+ saved_file_path = output_path
+ print(f"Video saved to: {saved_file_path}")
+ except Exception as e:
+ print(f"Warning: Failed to save video to file {output_path}: {e}")
+
+ base64_video = base64.b64encode(video_bytes).decode("utf-8")
+
+ result["base64"] = base64_video
+ if save_to_file and saved_file_path:
+ result["file_path"] = saved_file_path
+
+ return result
+
+
+def generate_synthetic_image(width: int, height: int, save_to_file: bool = False) -> dict[str, Any]:
+ """Generate synthetic image with randomly colored squares and return base64 string."""
+ from PIL import Image, ImageDraw
+
+ # Create white background
+ image = Image.new("RGB", (width, height), (255, 255, 255))
+ draw = ImageDraw.Draw(image)
+
+ # Generate random number of squares
+ num_squares = random.randint(3, 8)
+
+ for _ in range(num_squares):
+ # Random square size
+ square_size = random.randint(min(width, height) // 8, min(width, height) // 4)
+
+ # Random position
+ x = random.randint(0, width - square_size - 1)
+ y = random.randint(0, height - square_size - 1)
+
+ # Random color
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
+
+ # Random border width
+ border_width = random.randint(1, 5)
+
+ # Draw square
+ draw.rectangle([x, y, x + square_size, y + square_size], fill=color, outline=(0, 0, 0), width=border_width)
+
+ image_array = np.array(image)
+ result = {"np_array": image_array.copy()}
+
+ # Handle file saving
+ image_bytes = None
+ saved_file_path = None
+
+ if save_to_file:
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
+ output_path = f"image_{width}x{height}_{timestamp}.jpg"
+
+ try:
+ # Save image to file
+ image.save(output_path, format="JPEG", quality=85, optimize=True)
+ saved_file_path = output_path
+ print(f"Image saved to: {saved_file_path}")
+
+ # Read file for base64 encoding
+ with open(output_path, "rb") as f:
+ image_bytes = f.read()
+
+ except Exception as e:
+ print(f"Warning: Failed to save image to file {output_path}: {e}")
+ save_to_file = False
+
+ # If not saving or save failed, create in memory
+ if not save_to_file or image_bytes is None:
+ buffer = io.BytesIO()
+ image.save(buffer, format="JPEG", quality=85, optimize=True)
+ buffer.seek(0)
+ image_bytes = buffer.read()
+
+ # Generate base64
+ base64_image = base64.b64encode(image_bytes).decode("utf-8")
+
+ # Return result
+ result["base64"] = base64_image
+ if save_to_file and saved_file_path:
+ result["file_path"] = saved_file_path
+
+ return result
+
+
+def preprocess_text(text):
+ import opencc
+
+ word_to_num = {
+ "zero": "0",
+ "one": "1",
+ "two": "2",
+ "three": "3",
+ "four": "4",
+ "five": "5",
+ "six": "6",
+ "seven": "7",
+ "eight": "8",
+ "nine": "9",
+ "ten": "10",
+ }
+
+ for word, num in word_to_num.items():
+ pattern = r"\b" + re.escape(word) + r"\b"
+ text = re.sub(pattern, num, text, flags=re.IGNORECASE)
+
+ text = re.sub(r"[^\w\s]", "", text)
+ text = re.sub(r"\s+", " ", text)
+ cc = opencc.OpenCC("t2s")
+ text = cc.convert(text)
+
+ # Special handling for spaces between Chinese characters:
+ # - Keep single spaces between English words/numbers
+ # - Remove spaces only when surrounded by Chinese characters on both sides to prevent incorrect word segmentation
+ text = re.sub(r"(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])", "", text)
+
+ return text.lower().strip()
+
+
+def cosine_similarity_text(text1, text2, n: int = 3):
+ from collections import Counter
+
+ if not text1 or not text2:
+ return 0.0
+
+ text1 = preprocess_text(text1)
+ text2 = preprocess_text(text2)
+ print(f"cosine similarity text1 is: {text1}, text2 is: {text2}")
+
+ ngrams1 = [text1[i : i + n] for i in range(len(text1) - n + 1)]
+ ngrams2 = [text2[i : i + n] for i in range(len(text2) - n + 1)]
+
+ counter1 = Counter(ngrams1)
+ counter2 = Counter(ngrams2)
+
+ all_ngrams = set(counter1.keys()) | set(counter2.keys())
+ vec1 = [counter1.get(ng, 0) for ng in all_ngrams]
+ vec2 = [counter2.get(ng, 0) for ng in all_ngrams]
+
+ dot_product = sum(a * b for a, b in zip(vec1, vec2))
+ norm1 = sum(a * a for a in vec1) ** 0.5
+ norm2 = sum(b * b for b in vec2) ** 0.5
+
+ if norm1 == 0 or norm2 == 0:
+ return 0.0
+ return dot_product / (norm1 * norm2)
+
+
+def convert_audio_to_text(audio_data):
+ """
+ Convert base64 encoded audio data to text using speech recognition.
+ """
+ audio_data = base64.b64decode(audio_data)
+ output_path = f"./test_{uuid.uuid4().hex}.wav"
+ with open(output_path, "wb") as audio_file:
+ audio_file.write(audio_data)
+
+ print(f"audio data is saved: {output_path}")
+ text = convert_audio_file_to_text(output_path=output_path)
+ return text
+
+
+def _merge_base64_audio_to_segment(base64_list: list[str]):
+ """Merge a list of base64-encoded audio chunks into one pydub AudioSegment."""
+ from pydub import AudioSegment
+
+ merged = None
+ for b64 in base64_list:
+ raw = base64.b64decode(b64.split(",", 1)[-1])
+ seg = AudioSegment.from_file(io.BytesIO(raw))
+ merged = seg if merged is None else merged + seg
+ return merged
+
+
+def _whisper_transcribe_in_current_process(output_path: str) -> str:
+ import whisper
+
+ # Multi-GPU: use last visible device to avoid colliding with default device 0; single device uses 0.
+ device_index = None
+ if current_omni_platform.is_available():
+ n = current_omni_platform.get_device_count()
+ if n == 1:
+ device_index = 0
+ elif n > 1:
+ device_index = n - 1
+
+ if device_index is not None:
+ torch_device = current_omni_platform.get_torch_device(device_index)
+ current_omni_platform.set_device(torch_device)
+ device = str(torch_device)
+ use_accelerator = True
+ else:
+ use_accelerator = False
+ device = "cpu"
+ model = whisper.load_model("small", device=device)
+ try:
+ text = model.transcribe(
+ output_path,
+ temperature=0.0,
+ word_timestamps=True,
+ condition_on_previous_text=False,
+ )["text"]
+ finally:
+ del model
+ gc.collect()
+ if use_accelerator:
+ current_omni_platform.synchronize()
+ current_omni_platform.empty_cache()
+
+ return text or ""
+
+
+def convert_audio_file_to_text(output_path: str) -> str:
+ """Convert an audio file to text in an isolated subprocess."""
+ # Import locally to avoid impacting test module import time.
+ ctx = multiprocessing.get_context("spawn")
+ with concurrent.futures.ProcessPoolExecutor(max_workers=1, mp_context=ctx) as executor:
+ future = executor.submit(_whisper_transcribe_in_current_process, output_path)
+ return future.result()
+
+
+def convert_audio_bytes_to_text(raw_bytes: bytes) -> str:
+ """
+ Write container audio bytes (WAV, etc.) to a temp WAV file suitable for Whisper/ffmpeg.
+ Normalizes with soundfile to PCM_16 WAV when possible to avoid codec issues.
+ """
+ output_path = f"./test_{uuid.uuid4().hex}.wav"
+ data, samplerate = sf.read(io.BytesIO(raw_bytes))
+ sf.write(output_path, data, samplerate, format="WAV", subtype="PCM_16")
+ text = convert_audio_file_to_text(output_path)
+ return text
+
+
+def modify_stage_config(
+ yaml_path: str,
+ updates: dict[str, Any] = None,
+ deletes: dict[str, Any] = None,
+) -> str:
+ """
+ Modify configurations in a YAML file, supporting both top-level and stage-specific modifications,
+ including addition, modification, and deletion of configurations.
+
+ Args:
+ yaml_path: Path to the YAML configuration file.
+ updates: Dictionary containing both top-level and stage-specific modifications to add or update.
+ Format: {
+ 'async_chunk': True,
+ 'stage_args': {
+ 0: {'engine_args.max_model_len': 5800},
+ 1: {'engine_args.max_num_seqs': 2}
+ }
+ }
+ deletes: Dictionary containing configurations to delete.
+ Format: {
+ 'old_config': None, # Delete entire key
+ 'stage_args': {
+ 0: ['engine_args.old_param'],
+ 1: ['runtime.unused_setting']
+ }
+ }
+
+ Returns:
+ str: Path to the newly created modified YAML file with timestamp suffix.
+ """
+ path = Path(yaml_path)
+ if not path.exists():
+ raise FileNotFoundError(f"yaml does not exist: {path}")
+
+ try:
+ with open(yaml_path, encoding="utf-8") as f:
+ config = yaml.safe_load(f) or {}
+ except Exception as e:
+ raise ValueError(f"Cannot parse YAML file: {e}")
+
+ # Helper function to apply update
+ def apply_update(config_dict: dict, key_path: str, value: Any) -> None:
+ """Apply update to dictionary using dot-separated path."""
+ # Handle direct list assignment (e.g., engine_input_source: [1, 2])
+ if "." not in key_path:
+ # Simple key, set directly
+ config_dict[key_path] = value
+ return
+
+ current = config_dict
+ keys = key_path.split(".")
+
+ for i in range(len(keys) - 1):
+ key = keys[i]
+
+ # Handle list indices
+ if key.isdigit() and isinstance(current, list):
+ index = int(key)
+ if index < 0:
+ raise ValueError(f"Negative list index not allowed: {index}")
+ if index >= len(current):
+ # Expand list if needed
+ while len(current) <= index:
+ # If we need to go deeper (more keys after this), create a dict
+ # Otherwise, create None placeholder
+ current.append({} if i < len(keys) - 2 else None)
+ current = current[index]
+ elif isinstance(current, dict):
+ # Handle dictionary keys
+ if key not in current:
+ # If there are more keys after this, create appropriate structure
+ if i < len(keys) - 1:
+ # Check if next key is a digit (list index) or string (dict key)
+ if keys[i + 1].isdigit():
+ current[key] = []
+ else:
+ current[key] = {}
+ else:
+ # This is the last key, create based on value type
+ current[key] = [] if isinstance(value, list) else {}
+ elif not isinstance(current[key], (dict, list)) and i < len(keys) - 1:
+ # If current value is not dict/list but we need to go deeper, replace it
+ if keys[i + 1].isdigit():
+ current[key] = []
+ else:
+ current[key] = {}
+ current = current[key]
+ else:
+ # Current is not a dict or list, cannot traverse further
+ raise TypeError(
+ f"Cannot access {'.'.join(keys[: i + 1])} as a dict/list. It's a {type(current).__name__}"
+ )
+
+ # Set the final value
+ last_key = keys[-1]
+ if isinstance(current, list) and last_key.isdigit():
+ # Setting a value in a list by index
+ index = int(last_key)
+ if index < 0:
+ raise ValueError(f"Negative list index not allowed: {index}")
+ if index >= len(current):
+ # Expand list if needed
+ while len(current) <= index:
+ current.append(None)
+ current[index] = value
+ elif isinstance(current, dict):
+ # Special case: if the value is a list and we're setting a top-level key
+ # Example: updating engine_input_source with [1, 2]
+ current[last_key] = value
+ else:
+ # Current is not a dict, cannot set key
+ raise TypeError(f"Cannot set value at {key_path}. Current type is {type(current).__name__}, expected dict.")
+
+ # Helper function to delete by path
+ def delete_by_path(config_dict: dict, path: str) -> None:
+ """Delete configuration by dot-separated path."""
+ if not path:
+ return
+
+ current = config_dict
+ keys = path.split(".")
+
+ # Traverse to the parent
+ for i in range(len(keys) - 1):
+ key = keys[i]
+
+ # Handle list indices
+ if key.isdigit() and isinstance(current, list):
+ index = int(key)
+ if index < 0 or index >= len(current):
+ raise KeyError(f"List index {index} out of bounds")
+ current = current[index]
+ elif isinstance(current, dict):
+ if key not in current:
+ raise KeyError(f"Path {'.'.join(keys[: i + 1])} does not exist")
+ current = current[key]
+ else:
+ raise TypeError(
+ f"Cannot access {'.'.join(keys[: i + 1])} as a dict/list. It's a {type(current).__name__}"
+ )
+
+ # Delete the item
+ last_key = keys[-1]
+
+ if isinstance(current, list) and last_key.isdigit():
+ index = int(last_key)
+ if index < 0 or index >= len(current):
+ raise KeyError(f"List index {index} out of bounds")
+ del current[index]
+ elif isinstance(current, dict) and last_key in current:
+ del current[last_key]
+ else:
+ print(f"Path {path} does not exist")
+
+ # Apply deletions first
+ if deletes:
+ for key, value in deletes.items():
+ if key == "stage_args":
+ if value and isinstance(value, dict):
+ stage_args = config.get("stage_args", [])
+ if not stage_args:
+ raise ValueError("stage_args does not exist in config")
+
+ for stage_id, delete_paths in value.items():
+ if not delete_paths:
+ continue
+
+ # Find stage by ID
+ target_stage = None
+ for stage in stage_args:
+ if stage.get("stage_id") == int(stage_id):
+ target_stage = stage
+ break
+
+ if target_stage is None:
+ continue
+
+ # Delete specified paths in this stage
+ for path in delete_paths:
+ if path: # Skip empty paths
+ delete_by_path(target_stage, path)
+ elif "." in key:
+ # Delete using dot-separated path
+ delete_by_path(config, key)
+ elif value is None and key in config:
+ # Delete entire key
+ del config[key]
+
+ # Apply updates
+ if updates:
+ for key, value in updates.items():
+ if key == "stage_args":
+ if value and isinstance(value, dict):
+ stage_args = config.get("stage_args", [])
+ if not stage_args:
+ raise ValueError("stage_args does not exist in config")
+
+ for stage_id, stage_updates in value.items():
+ # Find stage by ID
+ target_stage = None
+ for stage in stage_args:
+ if stage.get("stage_id") == int(stage_id):
+ target_stage = stage
+ break
+
+ if target_stage is None:
+ available_ids = [s.get("stage_id") for s in stage_args if "stage_id" in s]
+ raise KeyError(f"Stage ID {stage_id} not found, available: {available_ids}")
+
+ # Apply updates to this stage
+ for path, val in stage_updates.items():
+ # Check if this is a simple key (not dot-separated)
+ # Example: 'engine_input_source' vs 'engine_args.max_model_len'
+ if "." not in path:
+ # Direct key assignment (e.g., updating a list value)
+ target_stage[path] = val
+ else:
+ # Dot-separated path (e.g., nested dict access)
+ apply_update(target_stage, path, val)
+ elif "." in key:
+ # Apply using dot-separated path
+ apply_update(config, key, value)
+ else:
+ # Direct top-level key
+ config[key] = value
+
+ # Unique suffix: multiple modify_stage_config calls in one process often run
+ # within the same second (e.g. test_qwen3_omni_expansion imports both
+ # get_chunk_config and get_batch_token_config). int(time.time()) would collide
+ # and the later write would overwrite the earlier YAML on disk.
+ base_name = yaml_path.rsplit(".", 1)[0] if "." in yaml_path else yaml_path
+ output_path = f"{base_name}_{time.time_ns()}.yaml"
+
+ with open(output_path, "w", encoding="utf-8") as f:
+ yaml.dump(config, f, default_flow_style=None, sort_keys=False, allow_unicode=True, indent=2)
+
+ return output_path
+
+
+class OmniServer:
+ """Omniserver for vLLM-Omni tests."""
+
+ def __init__(
+ self,
+ model: str,
+ serve_args: list[str],
+ *,
+ port: int | None = None,
+ env_dict: dict[str, str] | None = None,
+ use_omni: bool = True,
+ ) -> None:
+ _run_pre_test_cleanup(enable_force=True)
+ _run_post_test_cleanup(enable_force=True)
+ cleanup_dist_env_and_memory()
+ self.model = model
+ self.serve_args = serve_args
+ self.env_dict = env_dict
+ self.use_omni = use_omni
+ self.proc: subprocess.Popen | None = None
+ self.host = "127.0.0.1"
+ if port is None:
+ self.port = get_open_port()
+ else:
+ self.port = port
+
+ def _start_server(self) -> None:
+ """Start the vLLM-Omni server subprocess."""
+ env = os.environ.copy()
+ env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+ if self.env_dict is not None:
+ env.update(self.env_dict)
+
+ cmd = [
+ sys.executable,
+ "-m",
+ "vllm_omni.entrypoints.cli.main",
+ "serve",
+ self.model,
+ "--host",
+ self.host,
+ "--port",
+ str(self.port),
+ ]
+ if self.use_omni:
+ cmd.append("--omni")
+ cmd += self.serve_args
+
+ print(f"Launching OmniServer with: {' '.join(cmd)}")
+ self.proc = subprocess.Popen(
+ cmd,
+ env=env,
+ cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # Set working directory to vllm-omni root
+ )
+
+ # Wait for server to be ready
+ max_wait = 1200 # 20 minutes
+ start_time = time.time()
+ while time.time() - start_time < max_wait:
+ # Check for process status
+ ret = self.proc.poll()
+ if ret is not None:
+ raise RuntimeError(f"Server processes exited with code {ret} before becoming ready.")
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ sock.settimeout(1)
+ result = sock.connect_ex((self.host, self.port))
+ if result == 0:
+ print(f"Server ready on {self.host}:{self.port}")
+ return
+ time.sleep(2)
+
+ raise RuntimeError(f"Server failed to start within {max_wait} seconds")
+
+ def _kill_process_tree(self, pid):
+ """kill process and its children with verification"""
+ try:
+ parent = psutil.Process(pid)
+ children = parent.children(recursive=True)
+
+ # Get all PIDs first
+ all_pids = [pid] + [child.pid for child in children]
+
+ # Terminate children
+ for child in children:
+ try:
+ child.terminate()
+ except psutil.NoSuchProcess:
+ pass
+
+ # Wait for children
+ gone, still_alive = psutil.wait_procs(children, timeout=10)
+
+ # Kill remaining children
+ for child in still_alive:
+ try:
+ child.kill()
+ except psutil.NoSuchProcess:
+ pass
+
+ # Terminate parent
+ try:
+ parent.terminate()
+ parent.wait(timeout=10)
+ except (psutil.NoSuchProcess, psutil.TimeoutExpired):
+ try:
+ parent.kill()
+ except psutil.NoSuchProcess:
+ pass
+
+ # VERIFICATION: Check if all processes are gone
+ time.sleep(1) # Give system time
+ alive_processes = []
+ for check_pid in all_pids:
+ if psutil.pid_exists(check_pid):
+ alive_processes.append(check_pid)
+
+ if alive_processes:
+ print(f"Warning: Processes still alive: {alive_processes}")
+ # Optional: Try system kill
+ import subprocess
+
+ for alive_pid in alive_processes:
+ try:
+ subprocess.run(["kill", "-9", str(alive_pid)], timeout=2)
+ except Exception as e:
+ print(f"Cleanup failed: {e}")
+
+ except psutil.NoSuchProcess:
+ pass
+
+ def __enter__(self):
+ self._start_server()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self.proc:
+ self._kill_process_tree(self.proc.pid)
+ _run_pre_test_cleanup(enable_force=True)
+ _run_post_test_cleanup(enable_force=True)
+ cleanup_dist_env_and_memory()
+
+
+def pytest_addoption(parser):
+ parser.addoption(
+ "--run-level",
+ action="store",
+ default="core_model",
+ choices=["core_model", "advanced_model"],
+ help="Test level to run: L2, L3",
+ )
+
+
+@pytest.fixture(scope="session")
+def run_level(request) -> str:
+ """A command-line argument that specifies the level of tests to run in this session.
+ See https://docs.vllm.ai/projects/vllm-omni/en/latest/contributing/ci/CI_5levels/"""
+ return request.config.getoption("--run-level")
+
+
+_omni_server_lock = threading.Lock()
+
+
+@pytest.fixture(scope="module")
+def omni_server(request: pytest.FixtureRequest, run_level: str, model_prefix: str) -> Generator[OmniServer, Any, None]:
+ """Start vLLM-Omni server as a subprocess with actual model weights.
+ Uses session scope so the server starts only once for the entire test session.
+ Multi-stage initialization can take 10-20+ minutes.
+ """
+ with _omni_server_lock:
+ params: OmniServerParams = request.param
+ model = model_prefix + params.model
+ port = params.port
+ stage_config_path = params.stage_config_path
+ if run_level == "advanced_model" and stage_config_path is not None:
+ with open(stage_config_path, encoding="utf-8") as f:
+ cfg = yaml.safe_load(f) or {}
+ stage_ids = [stage["stage_id"] for stage in cfg.get("stage_args", []) if "stage_id" in stage]
+ stage_config_path = modify_stage_config(
+ stage_config_path,
+ deletes={"stage_args": {stage_id: ["engine_args.load_format"] for stage_id in stage_ids}},
+ )
+
+ server_args = params.server_args or []
+ if params.use_omni:
+ server_args = ["--stage-init-timeout", "120", *server_args]
+ if stage_config_path is not None:
+ server_args += ["--stage-configs-path", stage_config_path]
+
+ with (
+ OmniServer(
+ model,
+ server_args,
+ port=port,
+ env_dict=params.env_dict,
+ use_omni=params.use_omni,
+ )
+ if port
+ else OmniServer(
+ model,
+ server_args,
+ env_dict=params.env_dict,
+ use_omni=params.use_omni,
+ )
+ ) as server:
+ print("OmniServer started successfully")
+ yield server
+ print("OmniServer stopping...")
+
+ print("OmniServer stopped")
+
+
+@dataclass
+class OmniResponse:
+ text_content: str | None = None
+ audio_data: list[str] | None = None
+ audio_content: str | None = None
+ audio_format: str | None = None
+ audio_bytes: bytes | None = None
+ similarity: float | None = None
+ e2e_latency: float | None = None
+ success: bool = False
+ error_message: str | None = None
+
+
+@dataclass
+class DiffusionResponse:
+ text_content: str | None = None
+ images: list[Image.Image] | None = None
+ audios: list[Any] | None = None
+ videos: list[Any] | None = None
+ e2e_latency: float | None = None
+ success: bool = False
+ error_message: str | None = None
+
+
+def _load_gender_pipeline():
+ """
+ Lazy-load a cached audio-classification pipeline for gender.
+
+ We prefer the pipeline wrapper because it encapsulates processor/model loading
+ and avoids direct AutoProcessor.from_pretrained call sites in this file.
+ """
+ global _GENDER_PIPELINE
+ if _GENDER_PIPELINE is not None:
+ return _GENDER_PIPELINE
+
+ model_name = "7wolf/wav2vec2-base-gender-classification"
+ try:
+ # device=-1 forces CPU for pipeline.
+ _GENDER_PIPELINE = pipeline(
+ task="audio-classification",
+ model=model_name,
+ device=-1,
+ )
+ return _GENDER_PIPELINE
+ except Exception as exc: # pragma: no cover - best-effort fallback
+ print(f"Warning: failed to create gender pipeline '{model_name}': {exc}")
+ _GENDER_PIPELINE = None
+ return None
+
+
+def _median_pitch_hz_from_autocorr(mono: np.ndarray, sr: int) -> float | None:
+ """
+ Rough median F0 (Hz) over short-time frames. Used to debias wav2vec2 gender head on TTS,
+ which often labels lower-pitched synthetic speech as female under load or on clean signals.
+ Returns None if the clip is too short or mostly unvoiced.
+ """
+ x = np.asarray(mono, dtype=np.float64)
+ x = x - np.mean(x)
+ if x.size < int(0.15 * sr):
+ return None
+ frame_len = int(0.04 * sr)
+ hop = max(frame_len // 2, 1)
+ f0_min_hz, f0_max_hz = 70.0, 400.0
+ lag_min = max(1, int(sr / f0_max_hz))
+ lag_max = min(frame_len - 2, int(sr / f0_min_hz))
+ if lag_max <= lag_min:
+ return None
+ win = np.hamming(frame_len)
+ pitches: list[float] = []
+ for start in range(0, int(x.shape[0]) - frame_len, hop):
+ frame = x[start : start + frame_len] * win
+ frame = frame - np.mean(frame)
+ if float(np.sqrt(np.mean(frame**2))) < 1e-4:
+ continue
+ ac = np.correlate(frame, frame, mode="full")[frame_len - 1 :]
+ ac = ac / (float(ac[0]) + 1e-12)
+ region = ac[lag_min : lag_max + 1]
+ peak_rel = int(np.argmax(region))
+ peak_lag = peak_rel + lag_min
+ if peak_lag <= 0:
+ continue
+ f0 = float(sr) / float(peak_lag)
+ if f0_min_hz <= f0 <= f0_max_hz:
+ pitches.append(f0)
+ if len(pitches) < 4:
+ return None
+ return float(np.median(np.asarray(pitches, dtype=np.float64)))
+
+
+def _estimate_voice_gender_from_audio(audio_bytes: bytes) -> str:
+ """
+ Estimate voice gender from audio using a small pre-trained classification model.
+
+ Uses a cached `audio-classification` pipeline to classify the clip.
+ Returns 'male' / 'female' when the model confidence is >= 0.9 and the label
+ maps to one of these; otherwise returns 'unknown'. If the model is unavailable
+ or inference fails, returns 'unknown' to keep tests stable.
+
+ Under concurrent tests, a global lock serializes pipeline calls (the HF pipeline is not
+ thread-safe). A coarse F0 median can correct systematic "male -> female" errors on TTS audio.
+ """
+ data, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True)
+ if data.size == 0:
+ raise ValueError("Empty audio")
+ mono = np.mean(data, axis=1)
+
+ try:
+ target_sr = 16000
+ if int(sr) != target_sr and mono.size > 1:
+ src_len = int(mono.shape[0])
+ dst_len = max(1, int(round(src_len * float(target_sr) / float(sr))))
+ src_idx = np.arange(src_len, dtype=np.float32)
+ dst_idx = np.linspace(0, src_len - 1, dst_len, dtype=np.float32)
+ mono = np.interp(dst_idx, src_idx, mono.astype(np.float32, copy=False)).astype(np.float32)
+ sr = target_sr
+
+ median_f0 = _median_pitch_hz_from_autocorr(mono, sr)
+
+ clf = _load_gender_pipeline()
+ if clf is None:
+ print("gender model not available, returning 'unknown'")
+ return "unknown"
+
+ # transformers pipeline returns a list of {label, score} (highest score first).
+ with _GENDER_PIPELINE_LOCK:
+ outputs = clf(mono, sampling_rate=sr)
+ if not outputs:
+ return "unknown"
+
+ top = outputs[0]
+ label = str(top.get("label", "")).lower()
+ conf = float(top.get("score", 0.0))
+
+ if conf < 0.5:
+ gender = "unknown"
+ # Some models use non-English labels (e.g., Russian). Normalize to 'male'/'female'.
+ elif ("female" in label) or ("жен" in label):
+ gender = "female"
+ elif ("male" in label) or ("муж" in label):
+ gender = "male"
+ else:
+ gender = "unknown"
+
+ # Debias: wav2vec2 gender heads often call TTS / band-limited male speech "female".
+ # Low median F0 (~speech male range) + female label -> trust pitch when score is not overwhelming.
+ if gender == "female" and median_f0 is not None and median_f0 < 165.0 and conf < 0.88:
+ print(f"gender pitch assist: reclassifying female->male (median_f0={median_f0:.1f} Hz, conf={conf:.3f})")
+ gender = "male"
+ elif gender == "male" and median_f0 is not None and median_f0 > 230.0 and conf < 0.88:
+ print(f"gender pitch assist: reclassifying male->female (median_f0={median_f0:.1f} Hz, conf={conf:.3f})")
+ gender = "female"
+
+ print(
+ f"gender classifier: label={label}, conf={conf:.3f}, gender={gender}"
+ + (f", median_f0={median_f0:.1f}Hz" if median_f0 is not None else "")
+ )
+ return gender
+ except Exception as exc: # pragma: no cover - best-effort fallback
+ print(f"Warning: gender classification failed, returning 'unknown': {exc}")
+ return "unknown"
+
+
+_PRESET_VOICE_GENDER_MAP: dict[str, str] = {
+ "serena": "female",
+ "uncle_fu": "male",
+ "chelsie": "female",
+ "clone": "female",
+ "ethan": "male",
}
+
+
+def _assert_preset_voice_gender_from_audio(
+ audio_bytes: bytes | None,
+ voice_name: str | None,
+) -> None:
+ """If ``voice_name`` matches a known preset, assert classifier gender matches (skip when unknown)."""
+ if not voice_name or not audio_bytes:
+ return
+ key = str(voice_name).lower()
+ expected_gender = _PRESET_VOICE_GENDER_MAP.get(key)
+ if expected_gender is None:
+ return
+ estimated_gender = _estimate_voice_gender_from_audio(audio_bytes)
+ print(f"Preset voice gender check: preset={key!r}, estimated={estimated_gender!r}, expected={expected_gender!r}")
+ if estimated_gender != "unknown":
+ assert estimated_gender == expected_gender, (
+ f"{voice_name!r} is expected {expected_gender}, but estimated gender is {estimated_gender!r}"
+ )
+
+
+# Threshold aligned with _compute_pcm_hnr_db docstring (clean clone vs distorted).
+_MIN_PCM_SPEECH_HNR_DB = 1.0
+
+
+def _compute_pcm_hnr_db(pcm_samples: np.ndarray, sr: int = _PCM_SPEECH_SAMPLE_RATE_HZ) -> float:
+ """Compute mean Harmonic-to-Noise Ratio (dB) for speech quality.
+
+ Clean cloned speech has HNR > 1.2 dB; distorted speech (e.g. lost
+ ref_code decoder context) drops below 1.0 dB.
+ """
+ frame_len = int(0.03 * sr) # 30ms frames
+ hop = frame_len // 2
+ hnr_values: list[float] = []
+
+ for start in range(0, len(pcm_samples) - frame_len, hop):
+ frame = pcm_samples[start : start + frame_len].astype(np.float32, copy=False)
+ frame = frame - np.mean(frame)
+ if np.max(np.abs(frame)) < 0.01:
+ continue
+ ac = np.correlate(frame, frame, mode="full")[len(frame) - 1 :]
+ ac = ac / (ac[0] + 1e-10)
+ min_lag = int(sr / 400)
+ max_lag = min(int(sr / 80), len(ac))
+ if min_lag >= max_lag:
+ continue
+ peak = float(np.max(ac[min_lag:max_lag]))
+ if 0 < peak < 1:
+ hnr_values.append(10 * np.log10(peak / (1 - peak + 1e-10)))
+
+ return float(np.mean(hnr_values)) if hnr_values else 0.0
+
+
+def _assert_pcm_int16_speech_hnr(audio_bytes: bytes) -> None:
+ """Validate harmonic-to-noise ratio on raw int16 PCM from /v1/audio/speech."""
+ assert audio_bytes is not None and len(audio_bytes) >= 2, "missing PCM bytes"
+ assert len(audio_bytes) % 2 == 0, "PCM byte length must be aligned to int16"
+ pcm_samples = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
+ hnr = _compute_pcm_hnr_db(pcm_samples)
+ print(f"PCM speech HNR: {hnr:.2f} dB (threshold: {_MIN_PCM_SPEECH_HNR_DB} dB)")
+ assert hnr >= _MIN_PCM_SPEECH_HNR_DB, (
+ f"Audio distortion detected: HNR={hnr:.2f} dB < {_MIN_PCM_SPEECH_HNR_DB} dB. "
+ "Voice clone decoder may be losing ref_code speaker context on later chunks."
+ )
+
+
+def assert_omni_response(response: OmniResponse, request_config: dict[str, Any], run_level):
+ """
+ Validate response results.
+
+ Args:
+ response: OmniResponse object
+
+ Raises:
+ AssertionError: When the response does not meet validation criteria
+ """
+ assert response.success, "The request failed."
+ e2e_latency = response.e2e_latency
+ if e2e_latency is not None:
+ print(f"the e2e latency is: {e2e_latency}")
+
+ modalities = request_config.get("modalities", ["text", "audio"])
+
+ if run_level == "advanced_model":
+ if "audio" in modalities:
+ assert response.audio_content is not None, "No audio output is generated"
+ print(f"audio content is: {response.audio_content}")
+ speaker = request_config.get("speaker")
+ if speaker:
+ _assert_preset_voice_gender_from_audio(
+ response.audio_bytes,
+ speaker,
+ )
+
+ if "text" in modalities:
+ assert response.text_content is not None, "No text output is generated"
+ print(f"text content is: {response.text_content}")
+
+ # Verify image description
+ word_types = ["text", "image", "audio", "video"]
+ keywords_dict = request_config.get("key_words", {})
+ for word_type in word_types:
+ keywords = keywords_dict.get(word_type)
+ if "text" in modalities:
+ if keywords:
+ text_lower = response.text_content.lower()
+ assert any(str(kw).lower() in text_lower for kw in keywords), (
+ "The output does not contain any of the keywords."
+ )
+ else:
+ if keywords:
+ audio_lower = response.audio_content.lower()
+ assert any(str(kw).lower() in audio_lower for kw in keywords), (
+ "The output does not contain any of the keywords."
+ )
+
+ # Verify similarity (Whisper transcript vs streamed/detokenized text)
+ if "text" in modalities and "audio" in modalities:
+ assert response.similarity is not None and response.similarity > 0.9, (
+ "The audio content is not same as the text"
+ )
+ print(f"similarity is: {response.similarity}")
+
+
+def assert_audio_speech_response(
+ response: OmniResponse,
+ request_config: dict[str, Any],
+ run_level: str,
+) -> None:
+ """
+ Validate /v1/audio/speech response: success, optional format check, transcription similarity
+ and gender (non-PCM only for advanced_model), and int16 PCM HNR when response_format is pcm.
+ """
+ assert response.success, "The request failed."
+
+ req_fmt = request_config.get("response_format")
+
+ if req_fmt == "pcm" and response.audio_bytes:
+ _assert_pcm_int16_speech_hnr(response.audio_bytes)
+ if response.audio_format:
+ assert "pcm" in response.audio_format.lower(), (
+ f"Expected audio/pcm content-type, got {response.audio_format!r}"
+ )
+
+ elif req_fmt == "wav" and response.audio_format:
+ assert req_fmt in response.audio_format, (
+ f"The response audio format {response.audio_format} don't match the request audio format {req_fmt}"
+ )
+
+ e2e_latency = response.e2e_latency
+ if e2e_latency is not None:
+ print(f"the avg e2e latency is: {e2e_latency}")
+
+ if run_level == "advanced_model" and req_fmt != "pcm":
+ # Text–audio semantic similarity check (skipped for raw PCM: no Whisper transcript).
+ expected_text = request_config.get("input")
+ if expected_text:
+ transcript = (response.audio_content or "").strip()
+ print(f"audio content is: {transcript}")
+ print(f"input text is: {expected_text}")
+ similarity = cosine_similarity_text(transcript.lower(), expected_text.lower())
+ print(f"Cosine similarity: {similarity:.3f}")
+ assert similarity > 0.9, (
+ f"Transcript doesn't match input: similarity={similarity:.2f}, transcript='{transcript}'"
+ )
+
+ # Voice gender consistency check (preset names in ``_PRESET_VOICE_GENDER_MAP``).
+ # When the estimator returns 'unknown', we treat it as inconclusive and do NOT fail the test.
+ _assert_preset_voice_gender_from_audio(
+ response.audio_bytes,
+ request_config.get("voice"),
+ )
+
+
+def assert_diffusion_response(response: DiffusionResponse, request_config: dict[str, Any], run_level: str = None):
+ """
+ Validate diffusion response results.
+
+ Dispatcher that routes validation to modality-specific assert functions.
+
+ Args:
+ response: DiffusionResponse object.
+ request_config: Request configuration dictionary.
+ run_level: Test run level (e.g. "core_model", "advanced_model")
+
+ Raises:
+ AssertionError: When the response does not meet validation criteria
+ KeyError: When the request_config does not contain necessary parameters for validation
+ """
+ assert response.success, "The request failed."
+
+ e2e_latency = response.e2e_latency
+ if e2e_latency is not None:
+ print(f"the avg e2e is: {e2e_latency}")
+
+ has_any_content = any(content is not None for content in (response.images, response.videos, response.audios))
+ assert has_any_content, "Response contains no images, videos, or audios"
+
+ if response.images is not None:
+ assert_image_diffusion_response(
+ response=response,
+ request_config=request_config,
+ run_level=run_level,
+ )
+
+ if response.videos is not None:
+ assert_video_diffusion_response(
+ response=response,
+ request_config=request_config,
+ run_level=run_level,
+ )
+
+ if response.audios is not None:
+ assert_audio_diffusion_response(
+ response=response,
+ request_config=request_config,
+ run_level=run_level,
+ )
+
+
+class OpenAIClientHandler:
+ """
+ OpenAI client handler class, encapsulating both streaming and non-streaming response processing logic.
+
+ This class integrates OpenAI API request sending, response handling, and validation functionality,
+ supporting both single request and concurrent request modes.
+ """
+
+ def __init__(
+ self, host: str = "127.0.0.1", port: int = get_open_port(), api_key: str = "EMPTY", run_level: str = None
+ ):
+ """
+ Initialize the OpenAI client.
+
+ Args:
+ host: vLLM-Omni server host address
+ port: vLLM-Omni server port
+ api_key: API key (defaults to "EMPTY")
+ """
+ self.base_url = f"http://{host}:{port}"
+ self.client = OpenAI(base_url=f"http://{host}:{port}/v1", api_key=api_key)
+ self.run_level = run_level
+
+ def _process_stream_omni_response(self, chat_completion) -> OmniResponse:
+ """
+ Process streaming responses.
+
+ Args:
+ chat_completion: OpenAI streaming response object
+ request_config: Request configuration dictionary
+
+ Returns:
+ OmniResponse: Processed response object
+ """
+ result = OmniResponse()
+ start_time = time.perf_counter()
+
+ try:
+ text_content = ""
+ audio_data = []
+
+ for chunk in chat_completion:
+ for choice in chunk.choices:
+ # Get content data
+ if hasattr(choice, "delta"):
+ content = getattr(choice.delta, "content", None)
+ else:
+ content = None
+
+ # Get modality type
+ modality = getattr(chunk, "modality", None)
+
+ # Process content based on modality type
+ if modality == "audio" and content:
+ audio_data.append(content)
+ elif modality == "text" and content:
+ text_content += content if content else ""
+
+ # Calculate end-to-end latency
+ result.e2e_latency = time.perf_counter() - start_time
+
+ # Process audio and text content
+ audio_content = None
+ similarity = None
+
+ if audio_data or text_content:
+ if audio_data:
+ merged_seg = _merge_base64_audio_to_segment(audio_data)
+ wav_buf = BytesIO()
+ merged_seg.export(wav_buf, format="wav")
+ result.audio_bytes = wav_buf.getvalue()
+ audio_content = convert_audio_bytes_to_text(result.audio_bytes)
+ if audio_content and text_content:
+ similarity = cosine_similarity_text(audio_content.lower(), text_content.lower())
+
+ # Populate result object
+ result.text_content = text_content
+ result.audio_data = audio_data
+ result.audio_content = audio_content
+ result.similarity = similarity
+ result.success = True
+
+ except Exception as e:
+ result.error_message = f"Stream processing error: {str(e)}"
+ print(f"Error: {result.error_message}")
+
+ return result
+
+ def _process_non_stream_omni_response(self, chat_completion) -> OmniResponse:
+ """
+ Process non-streaming responses.
+
+ Args:
+ chat_completion: OpenAI non-streaming response object
+ request_config: Request configuration dictionary
+
+ Returns:
+ OmniResponse: Processed response object
+ """
+ result = OmniResponse()
+ start_time = time.perf_counter()
+
+ try:
+ audio_data = None
+ text_content = None
+
+ # Iterate through all choices
+ for choice in chat_completion.choices:
+ # Process audio data
+ if hasattr(choice.message, "audio") and choice.message.audio is not None:
+ audio_message = choice.message
+ audio_data = audio_message.audio.data
+
+ # Process text content
+ if hasattr(choice.message, "content") and choice.message.content is not None:
+ text_content = choice.message.content
+
+ # Calculate end-to-end latency
+ result.e2e_latency = time.perf_counter() - start_time
+
+ # Process audio and text content
+ audio_content = None
+ similarity = None
+
+ if audio_data or text_content:
+ if audio_data:
+ result.audio_bytes = base64.b64decode(audio_data)
+ audio_content = convert_audio_bytes_to_text(result.audio_bytes)
+ if audio_content and text_content:
+ similarity = cosine_similarity_text(audio_content.lower(), text_content.lower())
+
+ # Populate result object
+ result.text_content = text_content
+ result.audio_content = audio_content
+ result.similarity = similarity
+ result.success = True
+
+ except Exception as e:
+ result.error_message = f"Non-stream processing error: {str(e)}"
+ print(f"Error: {result.error_message}")
+
+ return result
+
+ def _process_diffusion_response(self, chat_completion) -> DiffusionResponse:
+ """
+ Process diffusion responses (image generation/editing).
+
+ Args:
+ chat_completion: OpenAI response object
+
+ Returns:
+ DiffusionResponse: Processed response object
+ """
+ result = DiffusionResponse()
+ start_time = time.perf_counter()
+
+ try:
+ images = []
+ # [TODO] reading video and audio output from API response for later validation
+
+ for choice in chat_completion.choices:
+ if hasattr(choice.message, "content") and choice.message.content is not None:
+ content = choice.message.content
+ if isinstance(content, list):
+ for item in content:
+ if isinstance(item, dict):
+ image_url = item.get("image_url", {}).get("url")
+ else:
+ image_url_obj = getattr(item, "image_url", None)
+ image_url = hasattr(image_url_obj, "url", None) if image_url_obj else None
+ if image_url and image_url.startswith("data:image"):
+ b64_data = image_url.split(",", 1)[1]
+ img = decode_b64_image(b64_data)
+ images.append(img)
+
+ result.e2e_latency = time.perf_counter() - start_time
+ result.images = images if images else None
+ result.success = True
+
+ except Exception as e:
+ result.error_message = f"Diffusion response processing error: {str(e)}"
+ print(f"Error: {result.error_message}")
+
+ return result
+
+ def _process_stream_audio_speech_response(self, response, *, response_format: str | None = None) -> OmniResponse:
+ """
+ Process streaming /v1/audio/speech responses into an OmniResponse.
+
+ This mirrors _process_stream_omni_response but operates on low-level
+ audio bytes and produces an OmniResponse with audio_content filled
+ from Whisper transcription.
+ """
+ result = OmniResponse()
+ start_time = time.perf_counter()
+
+ try:
+ # Aggregate all audio bytes from the streaming response.
+ data = bytearray()
+
+ # Preferred OpenAI helper.
+ if hasattr(response, "iter_bytes") and callable(getattr(response, "iter_bytes")):
+ for chunk in response.iter_bytes():
+ if chunk:
+ data.extend(chunk)
+ else:
+ # Generic iterable-of-bytes fallback (e.g., generator or list of chunks).
+ try:
+ iterator = iter(response)
+ except TypeError:
+ iterator = None
+
+ if iterator is not None:
+ for chunk in iterator:
+ if not chunk:
+ continue
+ if isinstance(chunk, (bytes, bytearray)):
+ data.extend(chunk)
+ elif hasattr(chunk, "data"):
+ data.extend(chunk.data) # type: ignore[arg-type]
+ elif hasattr(chunk, "content"):
+ data.extend(chunk.content) # type: ignore[arg-type]
+ else:
+ raise TypeError(f"Unsupported stream chunk type: {type(chunk)}")
+ else:
+ raise TypeError(f"Unsupported audio speech streaming response type: {type(response)}")
+
+ raw_bytes = bytes(data)
+ if response_format == "pcm":
+ transcript = None
+ else:
+ transcript = convert_audio_bytes_to_text(raw_bytes)
+
+ # Populate OmniResponse.
+ result.audio_bytes = raw_bytes
+ result.audio_content = transcript
+ result.e2e_latency = time.perf_counter() - start_time
+ result.success = True
+ result.audio_format = getattr(response, "response", None)
+ if result.audio_format is not None:
+ result.audio_format = result.audio_format.headers.get("content-type", "")
+
+ except Exception as e:
+ result.error_message = f"Audio speech stream processing error: {str(e)}"
+ print(f"Error: {result.error_message}")
+
+ return result
+
+ def _process_non_stream_audio_speech_response(
+ self, response, *, response_format: str | None = None
+ ) -> OmniResponse:
+ """
+ Process non-streaming /v1/audio/speech responses into an OmniResponse.
+
+ This mirrors _process_non_stream_omni_response but for the binary
+ audio payload returned by audio.speech.create.
+ """
+ result = OmniResponse()
+ start_time = time.perf_counter()
+
+ try:
+ # OpenAI non-streaming audio.speech.create returns HttpxBinaryResponseContent (.read() or .content)
+ if hasattr(response, "read") and callable(getattr(response, "read")):
+ raw_bytes = response.read()
+ elif hasattr(response, "content"):
+ raw_bytes = response.content # type: ignore[assignment]
+ else:
+ raise TypeError(f"Unsupported audio speech response type: {type(response)}")
+
+ if response_format == "pcm":
+ transcript = None
+ else:
+ transcript = convert_audio_bytes_to_text(raw_bytes)
+
+ result.audio_bytes = raw_bytes
+ result.audio_content = transcript
+ result.e2e_latency = time.perf_counter() - start_time
+ result.success = True
+ result.audio_format = getattr(response, "response", None)
+ if result.audio_format is not None:
+ result.audio_format = result.audio_format.headers.get("content-type", "")
+
+ except Exception as e:
+ result.error_message = f"Audio speech non-stream processing error: {str(e)}"
+ print(f"Error: {result.error_message}")
+
+ return result
+
+ def send_omni_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]:
+ """
+ Send OpenAI requests.
+
+ Args:
+ request_config: Request configuration dictionary containing parameters like model, messages, stream.
+ Optional ``use_audio_in_video`` (bool): when true, sets
+ ``extra_body["mm_processor_kwargs"] = {"use_audio_in_video": True}`` for Qwen-Omni video+audio
+ extraction.
+ Optional top-level ``speaker`` (str): Qwen3-Omni preset TTS speaker name; sent as
+ ``extra_body["speaker"]`` to ``chat.completions.create``.
+ request_num: Number of requests, defaults to 1 (single request)
+
+ Returns:
+ List[OmniResponse]: List of response objects
+ """
+
+ responses = []
+ stream = request_config.get("stream", False)
+ modalities = request_config.get("modalities", ["text", "audio"])
+
+ extra_body: dict[str, Any] = {}
+ if "speaker" in request_config:
+ extra_body["speaker"] = request_config["speaker"]
+ if request_config.get("use_audio_in_video"):
+ mm = dict(extra_body.get("mm_processor_kwargs") or {})
+ mm["use_audio_in_video"] = True
+ extra_body["mm_processor_kwargs"] = mm
+ extra_body_arg: dict[str, Any] | None = extra_body if extra_body else None
+
+ create_kwargs: dict[str, Any] = {
+ "model": request_config.get("model"),
+ "messages": request_config.get("messages"),
+ "stream": stream,
+ "modalities": modalities,
+ }
+ if extra_body_arg is not None:
+ create_kwargs["extra_body"] = extra_body_arg
+
+ if request_num == 1:
+ # Send single request
+ chat_completion = self.client.chat.completions.create(**create_kwargs)
+
+ if stream:
+ response = self._process_stream_omni_response(chat_completion)
+ else:
+ response = self._process_non_stream_omni_response(chat_completion)
+
+ assert_omni_response(response, request_config, run_level=self.run_level)
+ responses.append(response)
+
+ else:
+ # Send concurrent requests: run create + process in worker so e2e_latency includes full round-trip.
+ def _one_omni_request():
+ start = time.perf_counter()
+ worker_kwargs: dict[str, Any] = {
+ "model": request_config.get("model"),
+ "messages": request_config.get("messages"),
+ "modalities": modalities,
+ "stream": stream,
+ }
+ if extra_body_arg is not None:
+ worker_kwargs["extra_body"] = extra_body_arg
+ chat_completion = self.client.chat.completions.create(**worker_kwargs)
+ if stream:
+ response = self._process_stream_omni_response(chat_completion)
+ else:
+ response = self._process_non_stream_omni_response(chat_completion)
+ response.e2e_latency = time.perf_counter() - start
+ return response
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor:
+ futures = [executor.submit(_one_omni_request) for _ in range(request_num)]
+ for future in concurrent.futures.as_completed(futures):
+ response = future.result()
+ assert_omni_response(response, request_config, run_level=self.run_level)
+ responses.append(response)
+
+ return responses
+
+ def send_audio_speech_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]:
+ """
+ Call the /v1/audio/speech endpoint using the same configuration-dict
+ style as send_omni_request, but via the OpenAI Python client's
+ audio.speech APIs.
+
+ Expected keys in request_config:
+ - model: model name/path (required)
+ - input: text to synthesize (required)
+ - response_format: audio format such as "wav" or "pcm" (optional)
+ - task_type, ref_text, ref_audio: TTS-specific extras (optional, passed via extra_body)
+ - timeout: request timeout in seconds (float, optional, default 120.0)
+ - stream: whether to use streaming API (bool, optional, default False)
+ """
+ timeout = float(request_config.get("timeout", 120.0))
+
+ model = request_config["model"]
+ text_input = request_config["input"]
+ stream = bool(request_config.get("stream", False))
+ voice = request_config.get("voice", None)
+
+ # Standard OpenAI param: use omit when not provided to keep default behavior.
+ response_format = request_config.get("response_format", omit)
+
+ # Qwen3-TTS custom fields, forwarded via extra_body.
+ extra_body: dict[str, Any] = {}
+ # Keep this list aligned with vllm_omni.entrypoints.openai.protocol.audio params.
+ for key in ("task_type", "ref_text", "ref_audio", "language", "max_new_tokens"):
+ if key in request_config:
+ extra_body[key] = request_config[key]
+
+ responses: list[OmniResponse] = []
+
+ speech_fmt: str | None = None if response_format is omit else str(response_format).lower()
+
+ if request_num == 1:
+ if stream:
+ # Use streaming response helper.
+ with self.client.audio.speech.with_streaming_response.create(
+ model=model,
+ input=text_input,
+ response_format=response_format,
+ extra_body=extra_body or None,
+ timeout=timeout,
+ voice=voice,
+ ) as resp:
+ omni_resp = self._process_stream_audio_speech_response(resp, response_format=speech_fmt)
+ else:
+ # Non-streaming response.
+ resp = self.client.audio.speech.create(
+ model=model,
+ input=text_input,
+ response_format=response_format,
+ extra_body=extra_body or None,
+ timeout=timeout,
+ voice=voice,
+ )
+ omni_resp = self._process_non_stream_audio_speech_response(resp, response_format=speech_fmt)
+
+ assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level)
+ responses.append(omni_resp)
+ return responses
+ else:
+ # request_num > 1: concurrent requests (use same params as single-request path)
+
+ if stream:
+
+ def _stream_task():
+ with self.client.audio.speech.with_streaming_response.create(
+ model=model,
+ input=text_input,
+ response_format=response_format,
+ extra_body=extra_body or None,
+ timeout=timeout,
+ voice=voice,
+ ) as resp:
+ return self._process_stream_audio_speech_response(resp, response_format=speech_fmt)
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor:
+ futures = [executor.submit(_stream_task) for _ in range(request_num)]
+ for future in concurrent.futures.as_completed(futures):
+ omni_resp = future.result()
+ assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level)
+ responses.append(omni_resp)
+ else:
+ with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor:
+ futures = []
+ for _ in range(request_num):
+ future = executor.submit(
+ self.client.audio.speech.create,
+ model=model,
+ input=text_input,
+ response_format=response_format,
+ extra_body=extra_body or None,
+ timeout=timeout,
+ voice=voice,
+ )
+ futures.append(future)
+
+ for future in concurrent.futures.as_completed(futures):
+ resp = future.result()
+ omni_resp = self._process_non_stream_audio_speech_response(resp, response_format=speech_fmt)
+ assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level)
+ responses.append(omni_resp)
+
+ return responses
+
+ def send_diffusion_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]:
+ """
+ Send OpenAI requests for diffusion models.
+
+ Args:
+ request_config: Request configuration dictionary containing parameters like model, messages
+ request_num: Number of requests to send concurrently, defaults to 1 (single request)
+ Returns:
+ List[OmniResponse]: List of response objects
+ """
+ responses = []
+ stream = request_config.get("stream", False)
+ modalities = request_config.get("modalities", omit) # Most diffusion models don't require modalities param
+ extra_body = request_config.get("extra_body", None)
+
+ if stream:
+ raise NotImplementedError("Streaming is not currently implemented for diffusion model e2e test")
+
+ if request_num == 1:
+ # Send single request
+ chat_completion = self.client.chat.completions.create(
+ model=request_config.get("model"),
+ messages=request_config.get("messages"),
+ extra_body=extra_body,
+ modalities=modalities,
+ )
+
+ response = self._process_diffusion_response(chat_completion)
+ assert_diffusion_response(response, request_config, run_level=self.run_level)
+ responses.append(response)
+
+ else:
+ # Send concurrent requests
+ with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor:
+ futures = []
+
+ # Submit all request tasks
+ for _ in range(request_num):
+ future = executor.submit(
+ self.client.chat.completions.create,
+ model=request_config.get("model"),
+ messages=request_config.get("messages"),
+ modalities=modalities,
+ extra_body=extra_body,
+ )
+ futures.append(future)
+
+ # Process completed tasks
+ for future in concurrent.futures.as_completed(futures):
+ chat_completion = future.result()
+ response = self._process_diffusion_response(chat_completion)
+ assert_diffusion_response(response, request_config, run_level=self.run_level)
+ responses.append(response)
+
+ return responses
+
+ def send_video_diffusion_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]:
+ """
+ Send native /v1/videos requests.
+ """
+ if request_num != 1:
+ raise NotImplementedError("Concurrent video diffusion requests are not currently implemented")
+
+ if request_config.get("stream", False):
+ raise NotImplementedError("Streaming is not currently implemented for video diffusion e2e test")
+
+ form_data = request_config.get("form_data")
+ if not isinstance(form_data, dict):
+ raise ValueError("Video request_config must contain 'form_data'")
+
+ if not form_data.get("prompt"):
+ raise ValueError("Video request_config['form_data'] must contain 'prompt'")
+
+ normalized_form_data = {key: str(value) for key, value in form_data.items() if value is not None}
+
+ files: dict[str, tuple[str, BytesIO, str]] = {}
+ image_reference = request_config.get("image_reference")
+ if image_reference:
+ if image_reference.startswith("data:image"):
+ header, encoded = image_reference.split(",", 1)
+ content_type = header.split(";")[0].removeprefix("data:")
+ extension = content_type.split("/")[-1]
+ file_data = base64.b64decode(encoded)
+
+ files["input_reference"] = (
+ f"reference.{extension}",
+ BytesIO(file_data),
+ content_type,
+ )
+ else:
+ normalized_form_data["image_reference"] = json.dumps({"image_url": image_reference})
+
+ result = DiffusionResponse()
+ start_time = time.perf_counter()
+
+ try:
+ create_url = self._build_url("/v1/videos")
+ response = requests.post(
+ create_url,
+ data=normalized_form_data,
+ files=files,
+ headers={"Accept": "application/json"},
+ timeout=60,
+ )
+ response.raise_for_status()
+
+ job_data = response.json()
+ video_id = job_data["id"]
+
+ self._wait_until_video_completed(video_id)
+
+ video_content = self._download_video_content(video_id)
+
+ result.success = True
+ result.videos = [video_content]
+ result.e2e_latency = time.perf_counter() - start_time
+
+ assert_diffusion_response(result, request_config, run_level=self.run_level)
+
+ except Exception as e:
+ result.success = False
+ result.error_message = f"Diffusion response processing error: {e}"
+ assert False, result.error_message
+
+ return [result]
+
+ def _wait_until_video_completed(
+ self,
+ video_id: str,
+ poll_interval_seconds: int = 2,
+ timeout_seconds: int = 300,
+ ) -> None:
+ status_url = self._build_url(f"/v1/videos/{video_id}")
+ deadline = time.monotonic() + timeout_seconds
+
+ while time.monotonic() < deadline:
+ status_resp = requests.get(
+ status_url,
+ headers={"Accept": "application/json"},
+ timeout=30,
+ )
+ status_resp.raise_for_status()
+
+ status_data = status_resp.json()
+ current_status = status_data["status"]
+
+ if current_status == "completed":
+ return
+
+ if current_status == "failed":
+ error_msg = status_data.get("last_error", "Unknown error")
+ raise RuntimeError(f"Job failed: {error_msg}")
+
+ time.sleep(poll_interval_seconds)
+
+ raise TimeoutError(f"Video job {video_id} did not complete within {timeout_seconds}s")
+
+ def _download_video_content(self, video_id: str) -> bytes:
+ download_url = self._build_url(f"/v1/videos/{video_id}/content")
+ video_resp = requests.get(download_url, stream=True, timeout=60)
+ video_resp.raise_for_status()
+
+ video_bytes = BytesIO()
+ for chunk in video_resp.iter_content(chunk_size=8192):
+ if chunk:
+ video_bytes.write(chunk)
+
+ return video_bytes.getvalue()
+
+ def _build_url(self, path: str) -> str:
+ return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
+
+
+@pytest.fixture
+def openai_client(omni_server: OmniServer, run_level: str):
+ """Create OpenAIClientHandler fixture to facilitate communication with OmniServer
+ with encapsulated request sending, concurrent requests, response handling, and validation."""
+ return OpenAIClientHandler(host=omni_server.host, port=omni_server.port, api_key="EMPTY", run_level=run_level)
+
+
+class OmniRunner:
+ """
+ Offline test runner for Omni models.
+ """
+
+ def __init__(
+ self,
+ model_name: str,
+ seed: int = 42,
+ stage_init_timeout: int = 300,
+ batch_timeout: int = 10,
+ init_timeout: int = 300,
+ shm_threshold_bytes: int = 65536,
+ log_stats: bool = False,
+ stage_configs_path: str | None = None,
+ **kwargs,
+ ) -> None:
+ """
+ Initialize an OmniRunner for testing.
+
+ Args:
+ model_name: The model name or path
+ seed: Random seed for reproducibility
+ stage_init_timeout: Timeout for initializing a single stage in seconds
+ batch_timeout: Timeout for batching in seconds
+ init_timeout: Timeout for initializing stages in seconds
+ shm_threshold_bytes: Threshold for using shared memory
+ log_stats: Enable detailed statistics logging
+ stage_configs_path: Optional path to YAML stage config file
+ **kwargs: Additional arguments passed to Omni
+ """
+ cleanup_dist_env_and_memory()
+ _run_pre_test_cleanup(enable_force=True)
+ _run_post_test_cleanup(enable_force=True)
+ self.model_name = model_name
+ self.seed = seed
+
+ self.omni = Omni(
+ model=model_name,
+ log_stats=log_stats,
+ stage_init_timeout=stage_init_timeout,
+ batch_timeout=batch_timeout,
+ init_timeout=init_timeout,
+ shm_threshold_bytes=shm_threshold_bytes,
+ stage_configs_path=stage_configs_path,
+ **kwargs,
+ )
+
+ def _estimate_prompt_len(
+ self,
+ additional_information: dict[str, Any],
+ model_name: str,
+ _cache: dict[str, Any] = {},
+ ) -> int:
+ """Estimate prompt_token_ids placeholder length for the Talker stage.
+
+ The AR Talker replaces all input embeddings via ``preprocess``, so the
+ placeholder values are irrelevant but the **length** must match the
+ embeddings that ``preprocess`` will produce.
+ """
+ try:
+ from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import Qwen3TTSConfig
+ from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import (
+ Qwen3TTSTalkerForConditionalGeneration,
+ )
+
+ if model_name not in _cache:
+ from transformers import AutoTokenizer
+
+ tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left")
+ cfg = Qwen3TTSConfig.from_pretrained(model_name, trust_remote_code=True)
+ _cache[model_name] = (tok, getattr(cfg, "talker_config", None))
+
+ tok, tcfg = _cache[model_name]
+ task_type = (additional_information.get("task_type") or ["CustomVoice"])[0]
+ return Qwen3TTSTalkerForConditionalGeneration.estimate_prompt_len_from_additional_information(
+ additional_information=additional_information,
+ task_type=task_type,
+ tokenize_prompt=lambda t: tok(t, padding=False)["input_ids"],
+ codec_language_id=getattr(tcfg, "codec_language_id", None),
+ spk_is_dialect=getattr(tcfg, "spk_is_dialect", None),
+ )
+ except Exception as exc:
+ logger.warning("Failed to estimate prompt length, using fallback 2048: %s", exc)
+ return 2048
+
+ def get_default_sampling_params_list(self) -> list[OmniSamplingParams]:
+ """
+ Get a list of default sampling parameters for all stages.
+
+ Returns:
+ List of SamplingParams with default decoding for each stage
+ """
+ if not hasattr(self.omni, "default_sampling_params_list"):
+ raise AttributeError("Omni.default_sampling_params_list is not available")
+ return list(self.omni.default_sampling_params_list)
+
+ def get_omni_inputs(
+ self,
+ prompts: list[str] | str,
+ system_prompt: str | None = None,
+ audios: PromptAudioInput = None,
+ images: PromptImageInput = None,
+ videos: PromptVideoInput = None,
+ mm_processor_kwargs: dict[str, Any] | None = None,
+ modalities: list[str] | None = None,
+ ) -> list[TextPrompt]:
+ """
+ Construct Omni input format from prompts and multimodal data.
+
+ Args:
+ prompts: Text prompt(s) - either a single string or list of strings
+ system_prompt: Optional system prompt (defaults to Qwen system prompt)
+ audios: Audio input(s) - tuple of (audio_array, sample_rate) or list of tuples
+ images: Image input(s) - PIL Image or list of PIL Images
+ videos: Video input(s) - numpy array or list of numpy arrays
+ mm_processor_kwargs: Optional processor kwargs (e.g., use_audio_in_video)
+
+ Returns:
+ List of prompt dictionaries suitable for Omni.generate()
+ """
+ if system_prompt is None:
+ system_prompt = (
+ "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
+ "Group, capable of perceiving auditory and visual inputs, as well as "
+ "generating text and speech."
+ )
+
+ video_padding_token = "<|VIDEO|>"
+ image_padding_token = "<|IMAGE|>"
+ audio_padding_token = "<|AUDIO|>"
+
+ if "Qwen3-Omni-30B-A3B-Instruct" in self.model_name:
+ video_padding_token = "<|video_pad|>"
+ image_padding_token = "<|image_pad|>"
+ audio_padding_token = "<|audio_pad|>"
+
+ if isinstance(prompts, str):
+ prompts = [prompts]
+
+ # Qwen-TTS: follow examples/offline_inference/qwen3_tts/end2end.py style.
+ # Stage 0 expects token placeholders + additional_information (text/speaker/task_type/...),
+ # and Talker replaces embeddings in preprocess based on additional_information only.
+ is_tts_model = "Qwen3-TTS" in self.model_name or "qwen3_tts" in self.model_name.lower()
+ if is_tts_model and modalities == ["audio"]:
+ tts_kw = mm_processor_kwargs or {}
+ task_type = tts_kw.get("task_type", "CustomVoice")
+ speaker = tts_kw.get("speaker", "Vivian")
+ language = tts_kw.get("language", "Auto")
+ max_new_tokens = int(tts_kw.get("max_new_tokens", 2048))
+ ref_audio = tts_kw.get("ref_audio", None)
+ ref_text = tts_kw.get("ref_text", None)
+
+ omni_inputs: list[TextPrompt] = []
+ for prompt_text in prompts:
+ text_str = str(prompt_text).strip() or " "
+ additional_information: dict[str, Any] = {
+ "task_type": [task_type],
+ "text": [text_str],
+ "language": [language],
+ "speaker": [speaker],
+ "max_new_tokens": [max_new_tokens],
+ }
+ if ref_audio is not None:
+ additional_information["ref_audio"] = [ref_audio]
+ if ref_text is not None:
+ additional_information["ref_text"] = [ref_text]
+ # Use official helper to get correct placeholder length
+ plen = self._estimate_prompt_len(additional_information, self.model_name)
+ input_dict: TextPrompt = {
+ "prompt_token_ids": [0] * plen,
+ "additional_information": additional_information,
+ }
+ omni_inputs.append(input_dict)
+ return omni_inputs
+
+ def _normalize_mm_input(mm_input, num_prompts):
+ if mm_input is None:
+ return [None] * num_prompts
+ if isinstance(mm_input, list):
+ if len(mm_input) != num_prompts:
+ raise ValueError(
+ f"Multimodal input list length ({len(mm_input)}) must match prompts length ({num_prompts})"
+ )
+ return mm_input
+ return [mm_input] * num_prompts
+
+ num_prompts = len(prompts)
+ audios_list = _normalize_mm_input(audios, num_prompts)
+ images_list = _normalize_mm_input(images, num_prompts)
+ videos_list = _normalize_mm_input(videos, num_prompts)
+
+ omni_inputs = []
+ for i, prompt_text in enumerate(prompts):
+ user_content = ""
+ multi_modal_data = {}
+
+ audio = audios_list[i]
+ if audio is not None:
+ if isinstance(audio, list):
+ for _ in audio:
+ user_content += f"<|audio_bos|>{audio_padding_token}<|audio_eos|>"
+ multi_modal_data["audio"] = audio
+ else:
+ user_content += f"<|audio_bos|>{audio_padding_token}<|audio_eos|>"
+ multi_modal_data["audio"] = audio
+
+ image = images_list[i]
+ if image is not None:
+ if isinstance(image, list):
+ for _ in image:
+ user_content += f"<|vision_bos|>{image_padding_token}<|vision_eos|>"
+ multi_modal_data["image"] = image
+ else:
+ user_content += f"<|vision_bos|>{image_padding_token}<|vision_eos|>"
+ multi_modal_data["image"] = image
+
+ video = videos_list[i]
+ if video is not None:
+ if isinstance(video, list):
+ for _ in video:
+ user_content += f"<|vision_bos|>{video_padding_token}<|vision_eos|>"
+ multi_modal_data["video"] = video
+ else:
+ user_content += f"<|vision_bos|>{video_padding_token}<|vision_eos|>"
+ multi_modal_data["video"] = video
+
+ user_content += prompt_text
+
+ full_prompt = (
+ f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
+ f"<|im_start|>user\n{user_content}<|im_end|>\n"
+ f"<|im_start|>assistant\n"
+ )
+
+ input_dict: TextPrompt = {"prompt": full_prompt}
+ if multi_modal_data:
+ input_dict["multi_modal_data"] = multi_modal_data
+ if modalities:
+ input_dict["modalities"] = modalities
+ if mm_processor_kwargs:
+ input_dict["mm_processor_kwargs"] = mm_processor_kwargs
+
+ omni_inputs.append(input_dict)
+
+ return omni_inputs
+
+ def generate(
+ self,
+ prompts: list[TextPrompt],
+ sampling_params_list: list[OmniSamplingParams] | None = None,
+ ) -> list[OmniRequestOutput]:
+ """
+ Generate outputs for the given prompts.
+
+ Args:
+ prompts: List of prompt dictionaries with 'prompt' and optionally
+ 'multi_modal_data' keys
+ sampling_params_list: List of sampling parameters for each stage.
+ If None, uses default parameters.
+
+ Returns:
+ List of OmniRequestOutput objects from stages with final_output=True
+ """
+ if sampling_params_list is None:
+ sampling_params_list = self.get_default_sampling_params_list()
+
+ return self.omni.generate(prompts, sampling_params_list)
+
+ def generate_multimodal(
+ self,
+ prompts: list[str] | str,
+ sampling_params_list: list[OmniSamplingParams] | None = None,
+ system_prompt: str | None = None,
+ audios: PromptAudioInput = None,
+ images: PromptImageInput = None,
+ videos: PromptVideoInput = None,
+ mm_processor_kwargs: dict[str, Any] | None = None,
+ modalities: list[str] | None = None,
+ ) -> list[OmniRequestOutput]:
+ """
+ Convenience method to generate with multimodal inputs.
+
+ Args:
+ prompts: Text prompt(s)
+ sampling_params_list: List of sampling parameters for each stage
+ system_prompt: Optional system prompt
+ audios: Audio input(s)
+ images: Image input(s)
+ videos: Video input(s)
+ mm_processor_kwargs: Optional processor kwargs
+
+ Returns:
+ List of OmniRequestOutput objects from stages with final_output=True
+ """
+ omni_inputs = self.get_omni_inputs(
+ prompts=prompts,
+ system_prompt=system_prompt,
+ audios=audios,
+ images=images,
+ videos=videos,
+ mm_processor_kwargs=mm_processor_kwargs,
+ modalities=modalities,
+ )
+ return self.generate(omni_inputs, sampling_params_list)
+
+ def start_profile(
+ self,
+ profile_prefix: str | None = None,
+ stages: list[int] | None = None,
+ ) -> list[Any]:
+ """Start profiling specified stages.
+
+ Args:
+ profile_prefix: Optional prefix for the trace file names.
+ stages: List of stage IDs to profile. If None, profiles all stages.
+
+ Returns:
+ List of results from each stage.
+ """
+ return self.omni.start_profile(profile_prefix=profile_prefix, stages=stages)
+
+ def stop_profile(self, stages: list[int] | None = None) -> list[Any]:
+ """Stop profiling specified stages.
+
+ Args:
+ stages: List of stage IDs to profile. If None, stops all stages.
+
+ Returns:
+ List of results from each stage.
+ """
+ return self.omni.stop_profile(stages=stages)
+
+ def _cleanup_process(self):
+ try:
+ keywords = ["enginecore"]
+ matched = []
+
+ for proc in psutil.process_iter(["pid", "name", "cmdline", "username"]):
+ try:
+ cmdline = " ".join(proc.cmdline()).lower() if proc.cmdline() else ""
+ name = proc.name().lower()
+
+ is_process = any(keyword in cmdline for keyword in keywords) or any(
+ keyword in name for keyword in keywords
+ )
+
+ if is_process:
+ print(f"Found vllm process: PID={proc.pid}, cmd={cmdline[:100]}")
+ matched.append(proc)
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
+ pass
+
+ for proc in matched:
+ try:
+ proc.terminate()
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
+ pass
+
+ _, still_alive = psutil.wait_procs(matched, timeout=5)
+ for proc in still_alive:
+ try:
+ proc.kill()
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
+ pass
+
+ if still_alive:
+ _, stubborn = psutil.wait_procs(still_alive, timeout=3)
+ if stubborn:
+ print(f"Warning: failed to kill residual vllm pids: {[p.pid for p in stubborn]}")
+ else:
+ print(f"Force-killed residual vllm pids: {[p.pid for p in still_alive]}")
+ elif matched:
+ print(f"Terminated vllm pids: {[p.pid for p in matched]}")
+
+ except Exception as e:
+ print(f"Error in psutil vllm cleanup: {e}")
+
+ def __enter__(self):
+ """Context manager entry."""
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """Context manager exit - cleanup resources."""
+ if hasattr(self.omni, "close"):
+ self.omni.close()
+ self._cleanup_process()
+ _run_pre_test_cleanup(enable_force=True)
+ _run_post_test_cleanup(enable_force=True)
+ cleanup_dist_env_and_memory()
+
+
+@pytest.fixture(scope="module")
+def omni_runner(request, model_prefix):
+ with _omni_server_lock:
+ model, stage_config_path = request.param
+ model = model_prefix + model
+ with OmniRunner(model, seed=42, stage_configs_path=stage_config_path, stage_init_timeout=300) as runner:
+ print("OmniRunner started successfully")
+ yield runner
+ print("OmniRunner stopping...")
+
+ print("OmniRunner stopped")
+
+
+class OmniRunnerHandler:
+ def __init__(self, omni_runner):
+ self.runner = omni_runner
+
+ def _process_output(self, outputs: list[Any]) -> OmniResponse:
+ result = OmniResponse()
+ try:
+ text_content = None
+ audio_content = None
+ for stage_output in outputs:
+ if getattr(stage_output, "final_output_type", None) == "text":
+ text_content = stage_output.request_output.outputs[0].text
+ if getattr(stage_output, "final_output_type", None) == "audio":
+ audio_content = stage_output.request_output.outputs[0].multimodal_output["audio"]
+
+ result.audio_content = audio_content
+ result.text_content = text_content
+ result.success = True
+
+ except Exception as e:
+ result.error_message = f"Output processing error: {str(e)}"
+ result.success = False
+ print(f"Error: {result.error_message}")
+
+ return result
+
+ def send_request(self, request_config: dict[str, Any] | None = None) -> OmniResponse:
+ if request_config is None:
+ request_config = {}
+ prompts = request_config.get("prompts")
+ videos = request_config.get("videos")
+ images = request_config.get("images")
+ audios = request_config.get("audios")
+ modalities = request_config.get("modalities", ["text", "audio"])
+ outputs = self.runner.generate_multimodal(
+ prompts=prompts, videos=videos, images=images, audios=audios, modalities=modalities
+ )
+ response = self._process_output(outputs)
+ assert_omni_response(response, request_config, run_level="core_model")
+ return response
+
+ def send_audio_speech_request(
+ self,
+ request_config: dict[str, Any],
+ ) -> OmniResponse:
+ """
+ Offline TTS: text -> audio via generate_multimodal, then validate with assert_audio_speech_response.
+
+ request_config must contain:
+ - 'input' or 'prompts': text to synthesize.
+ Optional keys:
+ - 'voice' -> speaker (CustomVoice)
+ - 'task_type' -> task_type in additional_information (default: "CustomVoice")
+ - 'language' -> language in additional_information (default: "Auto")
+ - 'max_new_tokens' -> max_new_tokens in additional_information (default: 2048)
+ - 'response_format' -> desired audio format (used only for assertion)
+ """
+ input_text = request_config.get("input") or request_config.get("prompts")
+ if input_text is None:
+ raise ValueError("request_config must contain 'input' or 'prompts' for TTS")
+ if isinstance(input_text, list):
+ input_text = input_text[0] if input_text else ""
+
+ # Build TTS-specific kwargs passed through to get_omni_inputs for Qwen3-TTS,
+ # matching examples/offline_inference/qwen3_tts/end2end.py.
+ mm_processor_kwargs: dict[str, Any] = {}
+ if "voice" in request_config:
+ mm_processor_kwargs["speaker"] = request_config["voice"]
+ if "task_type" in request_config:
+ mm_processor_kwargs["task_type"] = request_config["task_type"]
+ if "ref_audio" in request_config:
+ mm_processor_kwargs["ref_audio"] = request_config["ref_audio"]
+ if "ref_text" in request_config:
+ mm_processor_kwargs["ref_text"] = request_config["ref_text"]
+ if "language" in request_config:
+ mm_processor_kwargs["language"] = request_config["language"]
+ if "max_new_tokens" in request_config:
+ mm_processor_kwargs["max_new_tokens"] = request_config["max_new_tokens"]
+
+ outputs = self.runner.generate_multimodal(
+ prompts=input_text,
+ modalities=["audio"],
+ mm_processor_kwargs=mm_processor_kwargs or None,
+ )
+ mm_out: dict[str, Any] | None = None
+ for stage_out in outputs:
+ if getattr(stage_out, "final_output_type", None) == "audio":
+ mm_out = stage_out.request_output.outputs[0].multimodal_output
+ break
+ if mm_out is None:
+ result = OmniResponse(success=False, error_message="No audio output from pipeline")
+ assert result.success, result.error_message
+ return result
+
+ audio_data = mm_out.get("audio")
+ if audio_data is None:
+ result = OmniResponse(success=False, error_message="No audio tensor in multimodal output")
+ assert result.success, result.error_message
+ return result
+
+ sr_raw = mm_out.get("sr")
+ sr_val = sr_raw[-1] if isinstance(sr_raw, list) and sr_raw else sr_raw
+ sr = int(sr_val.item() if hasattr(sr_val, "item") else sr_val)
+ wav_tensor = torch.cat(audio_data, dim=-1) if isinstance(audio_data, list) else audio_data
+ wav_buf = io.BytesIO()
+ sf.write(
+ wav_buf,
+ wav_tensor.float().cpu().numpy().reshape(-1),
+ samplerate=sr,
+ format="WAV",
+ subtype="PCM_16",
+ )
+ result = OmniResponse(success=True, audio_bytes=wav_buf.getvalue(), audio_format="audio/wav")
+ assert_audio_speech_response(result, request_config, run_level="core_model")
+ return result
+
+ def start_profile(
+ self,
+ profile_prefix: str | None = None,
+ stages: list[int] | None = None,
+ ) -> list[Any]:
+ """Start profiling specified stages."""
+ return self.runner.start_profile(profile_prefix=profile_prefix, stages=stages)
+
+ def stop_profile(self, stages: list[int] | None = None) -> list[Any]:
+ """Stop profiling specified stages."""
+ return self.runner.stop_profile(stages=stages)
+
+
+@pytest.fixture
+def omni_runner_handler(omni_runner):
+ return OmniRunnerHandler(omni_runner)
diff --git a/tests/core/sched/test_chunk_scheduling_coordinator.py b/tests/core/sched/test_chunk_scheduling_coordinator.py
deleted file mode 100644
index 5e19465e224..00000000000
--- a/tests/core/sched/test_chunk_scheduling_coordinator.py
+++ /dev/null
@@ -1,690 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for OmniSchedulingCoordinator (formerly ChunkSchedulingCoordinator).
-
-These tests use mock request objects and mock queues. They do not require
-GPU, vLLM runtime, or any connector.
-"""
-
-from __future__ import annotations
-
-import unittest
-from types import SimpleNamespace
-
-import vllm_omni.core.sched.omni_scheduling_coordinator as coord_mod
-from vllm_omni.core.sched.omni_scheduling_coordinator import (
- ChunkSchedulingCoordinator,
- OmniSchedulingCoordinator,
-)
-
-# ------------------------------------------------------------------ #
-# Mock helpers
-# ------------------------------------------------------------------ #
-
-
-class _RequestStatus:
- WAITING = "waiting"
- RUNNING = "running"
- WAITING_FOR_CHUNK = "waiting_for_chunk"
- WAITING_FOR_INPUT = "waiting_for_input"
- FINISHED_STOPPED = "finished_stopped"
-
-
-# Patch RequestStatus for tests that don't import vllm
-try:
- from vllm.v1.request import RequestStatus
-except ImportError:
- RequestStatus = _RequestStatus # type: ignore[misc,assignment]
-
-if not hasattr(RequestStatus, "WAITING_FOR_INPUT"):
- coord_mod.RequestStatus = _RequestStatus # type: ignore[assignment]
- RequestStatus = _RequestStatus # type: ignore[misc,assignment]
-
-
-def _make_request(req_id: str, status: str = "waiting") -> SimpleNamespace:
- return SimpleNamespace(
- request_id=req_id,
- external_req_id=req_id,
- status=status,
- additional_information=None,
- prompt_token_ids=[],
- num_prompt_tokens=0,
- num_computed_tokens=0,
- _all_token_ids=[],
- _output_token_ids=[],
- )
-
-
-class MockQueue:
- """Simplified queue that mimics the Scheduler waiting queue interface."""
-
- def __init__(self, items: list | None = None):
- self._items: list = list(items or [])
-
- def __iter__(self):
- return iter(self._items)
-
- def __len__(self):
- return len(self._items)
-
- def __contains__(self, item):
- return item in self._items
-
- def add_request(self, request):
- self._items.append(request)
-
- def prepend_requests(self, requests):
- self._items = list(requests) + self._items
-
- def remove(self, request):
- self._items.remove(request)
-
- def remove_requests(self, requests):
- remove_set = set(id(r) for r in requests)
- self._items = [r for r in self._items if id(r) not in remove_set]
-
-
-# ------------------------------------------------------------------ #
-# Tests
-# ------------------------------------------------------------------ #
-
-
-class TestChunkCoordinatorStateTransition(unittest.TestCase):
- """Test 5: process_pending_chunks transitions WAITING_FOR_CHUNK → target."""
-
- def test_ready_request_transitions_to_waiting(self):
- coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True)
-
- req = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK)
- waiting = MockQueue([req])
- running: list = []
-
- coord.process_pending_chunks(
- waiting,
- running,
- chunk_ready_req_ids={"r1"},
- chunk_finished_req_ids=set(),
- )
-
- self.assertEqual(req.status, RequestStatus.WAITING)
- self.assertIn("r1", coord.requests_with_ready_chunks)
-
- def test_non_ready_stays_waiting_for_chunk(self):
- coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True)
-
- req = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK)
- waiting = MockQueue([req])
- running: list = []
-
- coord.process_pending_chunks(
- waiting,
- running,
- chunk_ready_req_ids=set(),
- chunk_finished_req_ids=set(),
- )
-
- self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK)
-
- def test_stage_0_is_noop(self):
- coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=0)
- req = _make_request("r1")
- waiting = MockQueue([req])
- running: list = []
-
- coord.process_pending_chunks(
- waiting,
- running,
- chunk_ready_req_ids={"r1"},
- chunk_finished_req_ids=set(),
- )
- self.assertNotEqual(req.status, RequestStatus.WAITING_FOR_CHUNK)
-
-
-class TestChunkCoordinatorRestoreQueues(unittest.TestCase):
- """Test 6: restore_queues returns waiting-for-chunk requests."""
-
- def test_restore(self):
- coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
-
- r1 = _make_request("r1")
- r2 = _make_request("r2")
- coord._waiting_for_chunk_waiting.append(r1)
- coord._waiting_for_chunk_running.append(r2)
-
- waiting = MockQueue()
- running: list = []
-
- coord.restore_queues(waiting, running)
-
- self.assertIn(r1, waiting)
- self.assertIn(r2, running)
- self.assertEqual(len(coord._waiting_for_chunk_waiting), 0)
- self.assertEqual(len(coord._waiting_for_chunk_running), 0)
-
-
-class TestChunkCoordinatorFinishedSignal(unittest.TestCase):
- """Test 8: chunk_finished_req_ids → finished_requests."""
-
- def test_finished_signal(self):
- coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1, async_chunk=True)
-
- req = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK)
- waiting = MockQueue([req])
- running: list = []
-
- coord.process_pending_chunks(
- waiting,
- running,
- chunk_ready_req_ids={"r1"},
- chunk_finished_req_ids={"r1"},
- )
-
- self.assertIn("r1", coord.finished_requests)
-
-
-class TestChunkCoordinatorUpdateRequestMetadata(unittest.TestCase):
- """Test update_request_metadata applies scheduling metadata to requests."""
-
- def test_ar_mode_no_longer_sets_additional_information(self):
- """AR mode only processes scheduling metadata, not full payloads."""
- coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
-
- req = _make_request("r1")
- requests = {"r1": req}
-
- # Only scheduling metadata is passed now (full payload stays in model runner)
- request_metadata = {"r1": {"next_stage_prompt_len": 50}}
-
- coord.update_request_metadata(requests, request_metadata, model_mode="ar")
-
- # next_stage_prompt_len should update prompt_token_ids
- self.assertEqual(len(req.prompt_token_ids), 50)
- self.assertEqual(req.num_prompt_tokens, 50)
- # additional_information should NOT be set
- self.assertIsNone(getattr(req, "additional_information", None))
-
- def test_generation_mode(self):
- coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
-
- req = _make_request("r1")
- req.prompt_token_ids = [0, 0, 0]
- requests = {"r1": req}
-
- request_metadata = {
- "r1": {
- "code_predictor_codes": [10, 20, 30],
- "left_context_size": 25,
- }
- }
-
- coord.update_request_metadata(requests, request_metadata, model_mode="generation")
-
- self.assertEqual(req.prompt_token_ids, [10, 20, 30])
- self.assertEqual(req.num_computed_tokens, 0)
- self.assertIsNone(req.additional_information)
- self.assertEqual(req._omni_initial_model_buffer, {"left_context_size": 25})
-
-
-class TestChunkCoordinatorPostprocess(unittest.TestCase):
- """Test postprocess_scheduler_output clears ready chunks."""
-
- def test_clear_ready(self):
- coord = ChunkSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
- coord.requests_with_ready_chunks = {"r1", "r2"}
-
- new_req = SimpleNamespace(req_id="r1")
- cached_reqs = SimpleNamespace(req_ids=["r2"])
- scheduler_output = SimpleNamespace(
- scheduled_new_reqs=[new_req],
- scheduled_cached_reqs=cached_reqs,
- )
-
- coord.postprocess_scheduler_output(scheduler_output)
-
- self.assertEqual(coord.requests_with_ready_chunks, set())
-
-
-class TestWaitingForInputTransition(unittest.TestCase):
- """Test B8: process_pending_full_payload_inputs transitions WAITING_FOR_INPUT."""
-
- def test_transition_on_recv(self):
- coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
-
- req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT)
- waiting = MockQueue([req])
- running: list = []
-
- coord.process_pending_full_payload_inputs(
- waiting,
- running,
- stage_recv_req_ids={"r1"},
- )
-
- self.assertEqual(req.status, RequestStatus.WAITING)
-
- def test_stays_waiting_for_input_if_not_received(self):
- coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
-
- req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT)
- waiting = MockQueue([req])
- running: list = []
-
- coord.process_pending_full_payload_inputs(
- waiting,
- running,
- stage_recv_req_ids=set(),
- )
-
- self.assertEqual(req.status, RequestStatus.WAITING_FOR_INPUT)
- self.assertEqual(len(coord._waiting_for_input), 1)
-
- def test_stage_0_is_noop(self):
- coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=0)
-
- req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT)
- waiting = MockQueue([req])
- running: list = []
-
- coord.process_pending_full_payload_inputs(
- waiting,
- running,
- stage_recv_req_ids={"r1"},
- )
- self.assertEqual(req.status, RequestStatus.WAITING_FOR_INPUT)
-
- def test_restore_queues_includes_waiting_for_input(self):
- coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
-
- r1 = _make_request("r1")
- coord._waiting_for_input.append(r1)
-
- waiting = MockQueue()
- running: list = []
-
- coord.restore_queues(waiting, running)
-
- self.assertIn(r1, waiting)
- self.assertEqual(len(coord._waiting_for_input), 0)
-
- def test_full_payload_mode_auto_transitions_waiting_to_waiting_for_input(self):
- """In full_payload_mode (async_chunk=False), fresh WAITING requests on
- non-Stage-0 should be transitioned to WAITING_FOR_INPUT."""
- coord = OmniSchedulingCoordinator(
- scheduler_max_num_seqs=10,
- stage_id=1,
- async_chunk=False,
- )
-
- req = _make_request("r1", status=RequestStatus.WAITING)
- waiting = MockQueue([req])
- running: list = []
-
- coord.process_pending_full_payload_inputs(
- waiting,
- running,
- stage_recv_req_ids=set(),
- )
-
- self.assertEqual(req.status, RequestStatus.WAITING_FOR_INPUT)
- self.assertEqual(len(coord._waiting_for_input), 1)
- self.assertEqual(len(coord.pending_input_registrations), 1)
-
- def test_async_chunk_mode_does_not_auto_transition(self):
- """In async_chunk mode, fresh WAITING requests should NOT be
- transitioned to WAITING_FOR_INPUT."""
- coord = OmniSchedulingCoordinator(
- scheduler_max_num_seqs=10,
- stage_id=1,
- async_chunk=True,
- )
-
- req = _make_request("r1", status=RequestStatus.WAITING)
- waiting = MockQueue([req])
- running: list = []
-
- coord.process_pending_full_payload_inputs(
- waiting,
- running,
- stage_recv_req_ids=set(),
- )
-
- self.assertEqual(req.status, RequestStatus.WAITING)
-
- def test_pending_input_registrations(self):
- coord = OmniSchedulingCoordinator(scheduler_max_num_seqs=10, stage_id=1)
-
- req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT)
- waiting = MockQueue([req])
- running: list = []
-
- coord.process_pending_full_payload_inputs(
- waiting,
- running,
- stage_recv_req_ids=set(),
- )
-
- self.assertEqual(len(coord.pending_input_registrations), 1)
- self.assertEqual(coord.pending_input_registrations[0].request_id, "r1")
-
-
-class TestTimeoutDetection(unittest.TestCase):
- """Regression tests for orphaned pending-recv timeout detection.
-
- Covers the full lifecycle:
- 1. Request enters WAITING_FOR_CHUNK from either waiting or running queue
- 2. restore_queues() moves it back to the scheduler queue
- 3. Timeout fires via collect_timed_out_request_ids()
- 4. Scheduler removes from both queues and calls _free_request()
- """
-
- def test_waiting_since_recorded_on_chunk_wait(self):
- """_waiting_since is set when a request enters WAITING_FOR_CHUNK."""
- coord = OmniSchedulingCoordinator(
- scheduler_max_num_seqs=10,
- stage_id=1,
- async_chunk=True,
- )
- req = _make_request("r1", status=RequestStatus.WAITING)
- waiting = MockQueue([req])
-
- coord.process_pending_chunks(
- waiting,
- [],
- chunk_ready_req_ids=set(),
- chunk_finished_req_ids=set(),
- )
-
- self.assertIn("r1", coord._waiting_since)
- self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK)
-
- def test_waiting_since_cleared_on_chunk_arrival(self):
- """_waiting_since is cleared when a chunk arrives."""
- coord = OmniSchedulingCoordinator(
- scheduler_max_num_seqs=10,
- stage_id=1,
- async_chunk=True,
- )
- req = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK)
- waiting = MockQueue([req])
-
- coord.process_pending_chunks(
- waiting,
- [],
- chunk_ready_req_ids={"r1"},
- chunk_finished_req_ids=set(),
- )
-
- self.assertNotIn("r1", coord._waiting_since)
-
- def test_waiting_since_recorded_on_input_wait(self):
- """_waiting_since is set when a request enters WAITING_FOR_INPUT."""
- coord = OmniSchedulingCoordinator(
- scheduler_max_num_seqs=10,
- stage_id=1,
- async_chunk=False,
- )
- req = _make_request("r1", status=RequestStatus.WAITING)
- waiting = MockQueue([req])
-
- coord.process_pending_full_payload_inputs(
- waiting,
- [],
- stage_recv_req_ids=set(),
- )
-
- self.assertIn("r1", coord._waiting_since)
-
- def test_waiting_since_cleared_on_input_arrival(self):
- """_waiting_since is cleared when input data arrives."""
- coord = OmniSchedulingCoordinator(
- scheduler_max_num_seqs=10,
- stage_id=1,
- async_chunk=False,
- )
- req = _make_request("r1", status=RequestStatus.WAITING_FOR_INPUT)
- coord._waiting_for_input.append(req)
- coord._waiting_since["r1"] = 0.0
-
- waiting = MockQueue()
- coord.process_pending_full_payload_inputs(
- waiting,
- [],
- stage_recv_req_ids={"r1"},
- )
-
- self.assertNotIn("r1", coord._waiting_since)
- self.assertEqual(req.status, RequestStatus.WAITING)
-
- def test_collect_timed_out_request_ids_no_timeout(self):
- """No IDs returned when nothing has timed out."""
- coord = OmniSchedulingCoordinator(
- scheduler_max_num_seqs=10,
- stage_id=1,
- )
- import time
-
- coord._waiting_since["r1"] = time.monotonic()
-
- result = coord.collect_timed_out_request_ids(timeout_s=300.0)
- self.assertEqual(result, set())
-
- def test_collect_timed_out_request_ids_expired(self):
- """Timed-out IDs are returned and _waiting_since is cleared."""
- coord = OmniSchedulingCoordinator(
- scheduler_max_num_seqs=10,
- stage_id=1,
- )
- coord._waiting_since["r1"] = 0.0 # epoch → definitely expired
- coord._waiting_since["r2"] = 0.0
-
- import time
-
- coord._waiting_since["r3"] = time.monotonic() + 9999 # far future
-
- result = coord.collect_timed_out_request_ids(timeout_s=1.0)
-
- self.assertEqual(result, {"r1", "r2"})
- self.assertNotIn("r1", coord._waiting_since)
- self.assertNotIn("r2", coord._waiting_since)
- self.assertIn("r3", coord._waiting_since)
-
- def test_collect_removes_from_coordinator_queues(self):
- """Timed-out requests are defensively removed from internal queues."""
- coord = OmniSchedulingCoordinator(
- scheduler_max_num_seqs=10,
- stage_id=1,
- )
- r1 = _make_request("r1")
- r2 = _make_request("r2")
- coord._waiting_for_chunk_waiting.append(r1)
- coord._waiting_for_input.append(r2)
- coord._waiting_since["r1"] = 0.0
- coord._waiting_since["r2"] = 0.0
-
- result = coord.collect_timed_out_request_ids(timeout_s=1.0)
-
- self.assertEqual(result, {"r1", "r2"})
- self.assertEqual(len(coord._waiting_for_chunk_waiting), 0)
- self.assertEqual(len(coord._waiting_for_input), 0)
-
- def test_free_finished_request_clears_waiting_since(self):
- """free_finished_request clears _waiting_since."""
- coord = OmniSchedulingCoordinator(
- scheduler_max_num_seqs=10,
- stage_id=1,
- )
- coord._waiting_since["r1"] = 0.0
- coord.free_finished_request("r1")
- self.assertNotIn("r1", coord._waiting_since)
-
- def test_timeout_from_running_queue_full_lifecycle(self):
- """End-to-end: request from running → WAITING_FOR_CHUNK → restore →
- timeout → removed from running list.
-
- This is the critical regression case: WAITING_FOR_CHUNK requests
- that originated from self.running are placed back into self.running
- by restore_queues(), but their status remains WAITING_FOR_CHUNK.
- The scheduler must remove from BOTH queues unconditionally.
- """
- coord = OmniSchedulingCoordinator(
- scheduler_max_num_seqs=10,
- stage_id=1,
- async_chunk=True,
- )
-
- # 1) Request starts in running queue with WAITING status
- req = _make_request("r1", status=RequestStatus.WAITING)
- running = [req]
- waiting = MockQueue()
-
- # 2) process_pending_chunks: moves to WAITING_FOR_CHUNK
- coord.process_pending_chunks(
- waiting,
- running,
- chunk_ready_req_ids=set(),
- chunk_finished_req_ids=set(),
- )
- self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK)
- self.assertIn("r1", coord._waiting_since)
- self.assertEqual(len(coord._waiting_for_chunk_running), 1)
-
- # 3) restore_queues: back to running (status stays WAITING_FOR_CHUNK)
- coord.restore_queues(waiting, running)
- self.assertIn(req, running)
- self.assertEqual(len(coord._waiting_for_chunk_running), 0)
- self.assertEqual(req.status, RequestStatus.WAITING_FOR_CHUNK)
-
- # 4) Force timeout by setting _waiting_since to epoch
- coord._waiting_since["r1"] = 0.0
-
- timed_out_ids = coord.collect_timed_out_request_ids(timeout_s=1.0)
- self.assertEqual(timed_out_ids, {"r1"})
-
- # 5) Scheduler removes from both queues (simulating the scheduler path)
- timed_out_id_set = {id(req)}
- running = [r for r in running if id(r) not in timed_out_id_set]
- waiting.remove_requests([req])
-
- self.assertNotIn(req, running)
- self.assertEqual(len(waiting), 0)
-
- def test_timeout_from_waiting_queue_full_lifecycle(self):
- """End-to-end: request from waiting → WAITING_FOR_CHUNK → restore →
- timeout → removed from waiting queue."""
- coord = OmniSchedulingCoordinator(
- scheduler_max_num_seqs=10,
- stage_id=1,
- async_chunk=True,
- )
-
- req = _make_request("r1", status=RequestStatus.WAITING)
- waiting = MockQueue([req])
- running: list = []
-
- coord.process_pending_chunks(
- waiting,
- running,
- chunk_ready_req_ids=set(),
- chunk_finished_req_ids=set(),
- )
- self.assertEqual(len(coord._waiting_for_chunk_waiting), 1)
-
- coord.restore_queues(waiting, running)
- self.assertIn(req, waiting)
-
- coord._waiting_since["r1"] = 0.0
- timed_out_ids = coord.collect_timed_out_request_ids(timeout_s=1.0)
- self.assertEqual(timed_out_ids, {"r1"})
-
- waiting.remove_requests([req])
- self.assertEqual(len(waiting), 0)
-
-
-class TestOverflowPreemption(unittest.TestCase):
- """Tests for P1-1: overflow requests must get WAITING status.
-
- Overflow happens when multiple WAITING_FOR_CHUNK requests in
- ``_waiting_for_chunk_running`` receive their chunk in the same cycle.
- ``_process_chunk_queue`` restores them to RUNNING (``continue``
- path) while RUNNING requests without chunks are moved out. If the
- net result exceeds ``scheduler_max_num_seqs``, the tail is pushed
- to ``waiting_queue`` and must have status == WAITING.
- """
-
- def test_overflow_sets_waiting_status(self):
- coord = OmniSchedulingCoordinator(
- scheduler_max_num_seqs=1,
- stage_id=1,
- async_chunk=True,
- )
-
- # r1 is currently RUNNING in the queue.
- # r2, r3 were previously moved to _waiting_for_chunk_running.
- r1 = _make_request("r1", status=RequestStatus.RUNNING)
- r2 = _make_request("r2", status=RequestStatus.WAITING_FOR_CHUNK)
- r3 = _make_request("r3", status=RequestStatus.WAITING_FOR_CHUNK)
-
- running = [r1]
- waiting = MockQueue([])
- coord._waiting_for_chunk_running.extend([r2, r3])
-
- # restore_queues puts r2, r3 back into running
- coord.restore_queues(waiting, running)
- self.assertEqual(len(running), 3)
-
- # Now process_pending_chunks with r2, r3 chunks ready:
- # _process_chunk_queue will:
- # r1 (RUNNING) → no chunk → move to _waiting_for_chunk_running
- # r2 (WAITING_FOR_CHUNK, chunk ready) → set RUNNING, stay in running
- # r3 (WAITING_FOR_CHUNK, chunk ready) → set RUNNING, stay in running
- # running = [r2, r3], len=2 > max=1 → overflow
- coord.process_pending_chunks(
- waiting,
- running,
- chunk_ready_req_ids={"r2", "r3"},
- chunk_finished_req_ids=set(),
- )
-
- self.assertEqual(len(running), 1)
- self.assertEqual(len(waiting), 1)
- overflow_req = list(waiting)[0]
- self.assertEqual(
- overflow_req.status,
- RequestStatus.WAITING,
- f"Overflowed request should have WAITING status, got {overflow_req.status}",
- )
-
- def test_overflow_does_not_strand_request(self):
- """Without the fix, the overflowed request would keep its
- RUNNING status in the waiting queue and never be re-scheduled."""
- coord = OmniSchedulingCoordinator(
- scheduler_max_num_seqs=1,
- stage_id=1,
- async_chunk=True,
- )
-
- r1 = _make_request("r1", status=RequestStatus.WAITING_FOR_CHUNK)
- r2 = _make_request("r2", status=RequestStatus.WAITING_FOR_CHUNK)
- coord._waiting_for_chunk_running.extend([r1, r2])
-
- running: list = []
- waiting = MockQueue([])
-
- coord.restore_queues(waiting, running)
- self.assertEqual(len(running), 2)
-
- coord.process_pending_chunks(
- waiting,
- running,
- chunk_ready_req_ids={"r1", "r2"},
- chunk_finished_req_ids=set(),
- )
-
- self.assertEqual(len(running), 1)
- self.assertEqual(len(waiting), 1)
- for req in waiting:
- self.assertNotEqual(req.status, RequestStatus.RUNNING, "Overflowed request must not keep RUNNING status")
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/core/sched/test_generation_scheduler_restore.py b/tests/core/sched/test_generation_scheduler_restore.py
index 5cc1cab7025..0eae3c4db91 100644
--- a/tests/core/sched/test_generation_scheduler_restore.py
+++ b/tests/core/sched/test_generation_scheduler_restore.py
@@ -6,12 +6,9 @@
those requests are permanently orphaned.
"""
+import unittest
from collections import deque
-import pytest
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
class FakeAdapter:
"""Minimal mock of OmniChunkTransferAdapter tracking restore calls."""
@@ -38,7 +35,7 @@ def postprocess_scheduler_output(self, output):
pass
-class TestRestoreQueuesOnError:
+class TestRestoreQueuesOnError(unittest.TestCase):
"""Verify that restore_queues is called even when rewrapping raises."""
def test_requests_not_lost_on_exception(self):
@@ -51,8 +48,8 @@ def test_requests_not_lost_on_exception(self):
# Step 1: process_pending_chunks moves req-B out
adapter.process_pending_chunks(waiting=[], running=running)
- assert running == ["req-A"]
- assert len(adapter.waiting_for_chunk_running_requests) == 1
+ self.assertEqual(running, ["req-A"])
+ self.assertEqual(len(adapter.waiting_for_chunk_running_requests), 1)
# Step 2: simulate the try/except/finally pattern
try:
@@ -64,9 +61,9 @@ def test_requests_not_lost_on_exception(self):
adapter.restore_queues(waiting=[], running=running)
# Step 3: verify request is restored
- assert adapter.restore_called is True
- assert "req-B" in running
- assert len(adapter.waiting_for_chunk_running_requests) == 0
+ self.assertTrue(adapter.restore_called)
+ self.assertIn("req-B", running)
+ self.assertEqual(len(adapter.waiting_for_chunk_running_requests), 0)
def test_requests_lost_without_fix(self):
"""Demonstrate the bug: without restore in except, request is lost."""
@@ -75,7 +72,7 @@ def test_requests_lost_without_fix(self):
running = ["req-A", "req-B"]
adapter.process_pending_chunks(waiting=[], running=running)
- assert running == ["req-A"]
+ self.assertEqual(running, ["req-A"])
# Simulate the BUGGY code: except without restore
try:
@@ -84,8 +81,8 @@ def test_requests_lost_without_fix(self):
pass # Bug: no restore_queues call
# Request is lost!
- assert "req-B" not in running
- assert len(adapter.waiting_for_chunk_running_requests) == 1
+ self.assertNotIn("req-B", running)
+ self.assertEqual(len(adapter.waiting_for_chunk_running_requests), 1)
def test_happy_path_restores_via_finally(self):
"""When no exception, restore_queues is still called via finally."""
@@ -101,5 +98,9 @@ def test_happy_path_restores_via_finally(self):
finally:
adapter.restore_queues(waiting=[], running=running)
- assert adapter.restore_called is True
- assert "req-B" in running
+ self.assertTrue(adapter.restore_called)
+ self.assertIn("req-B", running)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/sched/test_omni_scheduler_mixin.py b/tests/core/sched/test_omni_scheduler_mixin.py
deleted file mode 100644
index e04a9c39fbc..00000000000
--- a/tests/core/sched/test_omni_scheduler_mixin.py
+++ /dev/null
@@ -1,129 +0,0 @@
-"""Unit tests for OmniSchedulerMixin streaming session replacement.
-
-These tests pin the behavior of `_replace_session_with_streaming_update` against
-current vLLM `Request` / `StreamingUpdate` (and Omni patches). When upgrading
-vLLM, failures here should highlight incompatible changes to request state or
-update payloads early.
-"""
-
-from __future__ import annotations
-
-from dataclasses import replace
-
-import pytest
-
-# Imports must run in this order: vllm_omni applies patches to vllm.v1.request before
-# Request / StreamingUpdate are bound in this module. Ruff isort would reorder them.
-# isort: off
-import vllm_omni # noqa: F401 - import for side effects (patch vLLM)
-from vllm.sampling_params import SamplingParams
-from vllm.v1.engine import EngineCoreEventType
-from vllm.v1.request import Request, RequestStatus, StreamingUpdate
-from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin
-
-# isort: on
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-class _SchedulerStub(OmniSchedulerMixin):
- """Minimal scheduler surface required by OmniSchedulerMixin."""
-
- def __init__(self, *, log_stats: bool = False) -> None:
- self.num_waiting_for_streaming_input = 0
- self.log_stats = log_stats
-
-
-def _make_request(**kwargs) -> Request:
- sp = SamplingParams(max_tokens=8)
- defaults = dict(
- request_id="req-mixin-test",
- prompt_token_ids=[1, 2, 3],
- sampling_params=sp,
- pooling_params=None,
- arrival_time=100.0,
- block_hasher=None,
- )
- defaults.update(kwargs)
- return Request(**defaults)
-
-
-def _make_update(**kwargs) -> StreamingUpdate:
- sp_new = SamplingParams(max_tokens=16)
- defaults = dict(
- mm_features=None,
- prompt_token_ids=[10, 20],
- max_tokens=32,
- arrival_time=200.0,
- sampling_params=sp_new,
- )
- defaults.update(kwargs)
- return StreamingUpdate(**defaults)
-
-
-class TestReplaceSessionWithStreamingUpdate:
- def test_resets_tokens_and_prompt_from_update(self) -> None:
- sched = _SchedulerStub()
- session = _make_request()
- session.append_output_token_ids([7, 8])
- session.num_computed_tokens = 99
- session.status = RequestStatus.WAITING_FOR_STREAMING_REQ
-
- update = _make_update(prompt_token_ids=[40, 41, 42])
- sched.num_waiting_for_streaming_input = 3
- sched._replace_session_with_streaming_update(session, update)
-
- assert session._output_token_ids == []
- assert list(session._all_token_ids) == [40, 41, 42]
- assert session.prompt_token_ids == [40, 41, 42]
- assert session.num_computed_tokens == 0
- assert session.num_prompt_tokens == 3
- assert session.arrival_time == 200.0
- assert session.sampling_params is update.sampling_params
- assert session.status == RequestStatus.WAITING
- assert sched.num_waiting_for_streaming_input == 2
-
- def test_none_prompt_token_ids_becomes_empty(self) -> None:
- sched = _SchedulerStub()
- session = _make_request()
- session.status = RequestStatus.RUNNING
- update = _make_update(prompt_token_ids=None)
- sched._replace_session_with_streaming_update(session, update)
-
- assert session.prompt_token_ids == ()
- assert list(session._all_token_ids) == []
- assert session.num_prompt_tokens == 0
- assert sched.num_waiting_for_streaming_input == 0
-
- def test_additional_information_cleared_when_update_omits_it(self) -> None:
- sched = _SchedulerStub()
- session = _make_request()
- if not hasattr(session, "additional_information"):
- pytest.skip("Request has no additional_information (Omni patch inactive?)")
- session.additional_information = {"keep": True}
- session.status = RequestStatus.RUNNING
-
- base = _make_update()
- if not hasattr(base, "additional_information"):
- pytest.skip("StreamingUpdate has no additional_information (Omni patch inactive?)")
- update = replace(base, additional_information=None)
-
- sched._replace_session_with_streaming_update(session, update)
- assert session.additional_information is None
-
- def test_does_not_decrement_waiting_when_not_streaming_status(self) -> None:
- sched = _SchedulerStub()
- session = _make_request()
- session.status = RequestStatus.RUNNING
- sched.num_waiting_for_streaming_input = 5
- sched._replace_session_with_streaming_update(session, _make_update())
- assert sched.num_waiting_for_streaming_input == 5
-
- def test_records_queued_event_when_log_stats_enabled(self) -> None:
- sched = _SchedulerStub(log_stats=True)
- session = _make_request()
- session.status = RequestStatus.WAITING_FOR_STREAMING_REQ
- sched._replace_session_with_streaming_update(session, _make_update())
-
- assert session.events
- assert session.events[-1].type == EngineCoreEventType.QUEUED
diff --git a/tests/core/test_prefix_cache.py b/tests/core/test_prefix_cache.py
deleted file mode 100644
index b5d0e96d305..00000000000
--- a/tests/core/test_prefix_cache.py
+++ /dev/null
@@ -1,349 +0,0 @@
-import pytest
-import torch
-
-from vllm_omni.core.prefix_cache import OmniTensorPrefixCache
-
-DEFAULT_SEQ_LEN = 15
-NUM_BLOCKS = 10
-BLOCK_SIZE = 4
-HIDDEN_SIZE = 2
-DTYPE = torch.float32
-OTHER_DTYPE = torch.float16
-DEFAULT_SHAPE = torch.Size([NUM_BLOCKS, BLOCK_SIZE, HIDDEN_SIZE])
-
-
-class MockInputBatch:
- def __init__(self, num_computed_tokens_cpu):
- self.req_ids = ["req1", "req2"]
- self.req_id_to_index = {req_id: i for i, req_id in enumerate(self.req_ids)}
- self.num_computed_tokens_cpu = num_computed_tokens_cpu
-
- # Block table is only mocked for validation of length;
- # we don't actually need to add valid values here since
- # we patch the table when testing.
- class _DummyBlockTable:
- pass
-
- self.block_table = _DummyBlockTable()
- self.block_table.block_tables = [None]
-
-
-def get_omni_pcache_with_mm_tensors(feat_dims, seq_len) -> OmniTensorPrefixCache:
- """Build an OmniTensorPrefixCache and init mm tensors."""
- cache = get_omni_pcache()
- mm_outputs = get_multimodal_outputs(feat_dims, seq_len)
- cache.maybe_init_missing_mm_cache_keys(mm_outputs, seq_len)
- return cache
-
-
-def get_omni_pcache() -> OmniTensorPrefixCache:
- """Build an OmniTensorPrefixCache, but don't init mm tensors."""
- cache = OmniTensorPrefixCache(
- num_blocks=NUM_BLOCKS,
- block_size=BLOCK_SIZE,
- hidden_size=HIDDEN_SIZE,
- hs_dtype=DTYPE,
- )
- return cache
-
-
-def get_multimodal_outputs(feat_dims: dict[str, int], seq_len: int) -> dict[str, torch.Tensor]:
- fake_mm_inputs = {}
- for mm_key, feat_dim in feat_dims.items():
- fake_mm_inputs[mm_key] = torch.rand((seq_len, feat_dim), dtype=DTYPE)
- return fake_mm_inputs
-
-
-### Tests for initialization
-def test_initialization_simple():
- """Check default initialization only creates the hidden states."""
- cache = get_omni_pcache()
- assert isinstance(cache.hidden_states_cache, torch.Tensor)
- assert cache.hidden_states_cache.shape == DEFAULT_SHAPE
- assert len(cache.mm_outputs_cache) == 0
- assert len(cache.mm_cache_keys) == 0
-
-
-def test_initialization_with_multimodal():
- """Check initialization + registration of multimodal outputs."""
- cache = get_omni_pcache()
- feat_dims = {"foo": 100, "bar": 50, "baz": 10}
- mm_outputs = get_multimodal_outputs(
- feat_dims,
- seq_len=DEFAULT_SEQ_LEN,
- )
- # Cast one of the keys to a different dtype; the dtype of the tensor
- # that is used to initialize the cache dictates the cache dtype.
- mm_outputs["foo"] = mm_outputs["foo"].to(OTHER_DTYPE)
-
- cache.maybe_init_missing_mm_cache_keys(mm_outputs, DEFAULT_SEQ_LEN)
- assert len(cache.mm_cache_keys) == 3
- assert set(cache.mm_cache_keys) == set(feat_dims.keys())
- for mm_key in cache.mm_cache_keys:
- cache_tensor = cache.mm_outputs_cache[mm_key]
- assert isinstance(cache_tensor, torch.Tensor)
- assert cache_tensor.shape[-1] == feat_dims[mm_key]
- assert mm_outputs[mm_key].dtype == cache_tensor.dtype
-
-
-def test_init_missing_mm_cache_keys_is_idempotent():
- """Ensure that the cache doesn't reinitialize old keys."""
- cache = get_omni_pcache()
- mm_key = "foo"
- feat_dims = {mm_key: 100}
- mm_outputs = get_multimodal_outputs(
- feat_dims,
- seq_len=DEFAULT_SEQ_LEN,
- )
- cache.maybe_init_missing_mm_cache_keys(mm_outputs, DEFAULT_SEQ_LEN)
- assert len(cache.mm_cache_keys) == 1
- assert mm_key in cache.mm_cache_keys
-
- # Cache is initialized to 0 - fill it with 1s
- cache.mm_outputs_cache[mm_key].fill_(1)
-
- # Ensure that running another initialization
- # doesn't zero out our cache values
- cache.maybe_init_missing_mm_cache_keys(mm_outputs, DEFAULT_SEQ_LEN)
- assert len(cache.mm_cache_keys) == 1
- assert mm_key in cache.mm_cache_keys
- assert torch.all(cache.mm_outputs_cache[mm_key] == 1)
-
-
-### Tests for Update
-def test_update_no_multimodal():
- """Test that slot mappings act as row indices hidden states."""
- cache = get_omni_pcache()
-
- num_tokens_unpadded = 8
- slot_offset = 8
- slot_mapping = torch.arange(slot_offset, slot_offset + num_tokens_unpadded)
- new_hidden_states = torch.rand((num_tokens_unpadded, HIDDEN_SIZE), dtype=DTYPE)
-
- cache.update_omni_tensor_prefix_cache(
- hidden_states=new_hidden_states,
- multimodal_outputs=None,
- num_tokens_unpadded=num_tokens_unpadded,
- slot_mapping=slot_mapping,
- )
-
- # Ensure that if we reshape our 3D cache back to 2D, we can use the
- # indices in our slot mappings to access the hidden states as expected
- hs_rows = cache.hidden_states_cache.view(NUM_BLOCKS * BLOCK_SIZE, HIDDEN_SIZE)
- for slot_idx, new_states in zip(slot_mapping, new_hidden_states):
- slot_states = hs_rows[slot_idx]
- assert torch.all(slot_states == new_states)
-
-
-@pytest.mark.parametrize(
- "feat_dims",
- [
- {"foo": 100, "bar": 100},
- {"foo": 100, "bar": 50, "baz": 10},
- ],
-)
-def test_update_with_multimodal_outputs(feat_dims):
- """Test that slot mappings are correct for multimodal tensors."""
- cache = get_omni_pcache_with_mm_tensors(feat_dims, seq_len=DEFAULT_SEQ_LEN)
-
- num_tokens_unpadded = 8
- slot_offset = 8
- slot_mapping = torch.arange(slot_offset, slot_offset + num_tokens_unpadded)
- feature_dims = {key: val.shape[-1] for key, val in cache.mm_outputs_cache.items()}
- mm_outputs = {key: torch.rand((num_tokens_unpadded, feature_dims[key]), dtype=DTYPE) for key in cache.mm_cache_keys}
- cache.update_omni_tensor_prefix_cache(
- hidden_states=None,
- multimodal_outputs=mm_outputs,
- num_tokens_unpadded=num_tokens_unpadded,
- slot_mapping=slot_mapping,
- )
-
- for mm_key in feat_dims.keys():
- assert mm_key in cache.mm_outputs_cache
- key_feat_dim = feature_dims[mm_key]
- mm_state_rows = cache.mm_outputs_cache[mm_key].view(NUM_BLOCKS * BLOCK_SIZE, key_feat_dim)
-
- # Similar to hidden states, but for each key in the dict;
- # Different tensors may have different feature dims
- new_mm_outputs = mm_outputs[mm_key]
- for slot_idx, new_output in zip(slot_mapping, new_mm_outputs):
- slot_states = mm_state_rows[slot_idx]
- assert torch.all(slot_states == new_output)
-
-
-### Tests for Merging
-def fake_get_cached_block_ids(self, req_idx, *args, **kwargs):
- """Fake block table lookup.
-
- Assumption:
- req_idx 0 is a cache hit with slots 8, 9, ..., 15
- req_idx 1 is a cache miss
- """
- assert req_idx < 2
- if req_idx == 0:
- # With the slot offset we provided (8), the corresponding
- # blocks IDs are 2 & 3 because the block size is 4.
- return torch.tensor([2, 3], dtype=torch.long)
- return torch.tensor([], dtype=torch.long)
-
-
-@pytest.mark.parametrize("num_tokens_padded", [None, 16])
-def test_get_merged_hidden_states(num_tokens_padded, mocker):
- """Ensure that hidden states are merged correctly."""
- cache = get_omni_pcache()
-
- orig_num_tokens_unpadded = 8
- slot_offset = 8 # We'll put our states in slots 8, 9, 10, ..., 15
- orig_slot_mapping = torch.arange(slot_offset, slot_offset + orig_num_tokens_unpadded)
- orig_hidden_states = torch.rand((orig_num_tokens_unpadded, HIDDEN_SIZE), dtype=DTYPE)
-
- cache.update_omni_tensor_prefix_cache(
- hidden_states=orig_hidden_states,
- multimodal_outputs=None,
- num_tokens_unpadded=orig_num_tokens_unpadded,
- slot_mapping=orig_slot_mapping,
- num_tokens_padded=num_tokens_padded,
- )
-
- # Say that we have two requests, but only one of them is a cache hit
- num_new_toks_req1 = 3
- num_new_toks_req2 = 2
- cache.add_prefix_cached_new_req_id("req1")
-
- num_scheduled_tokens = {
- "req1": num_new_toks_req1,
- "req2": num_new_toks_req2,
- }
- new_hidden_states = torch.rand(
- (num_new_toks_req1 + num_new_toks_req2, HIDDEN_SIZE),
- dtype=DTYPE,
- )
- req1_new_states = new_hidden_states[:num_new_toks_req1]
- req2_new_states = new_hidden_states[-num_new_toks_req2:]
-
- input_batch = MockInputBatch(num_computed_tokens_cpu=torch.Tensor([orig_num_tokens_unpadded, 0]))
-
- mocker.patch(
- "vllm_omni.core.prefix_cache.OmniTensorPrefixCache._get_cached_block_ids",
- new=fake_get_cached_block_ids,
- )
- merged_states = cache.get_merged_hidden_states(
- query_start_loc=[0, num_new_toks_req1],
- input_batch=input_batch,
- hidden_states=new_hidden_states,
- num_scheduled_tokens=num_scheduled_tokens,
- )
-
- assert "req1" in merged_states and "req2" in merged_states
- req1_merged_states = merged_states["req1"]
- req2_merged_states = merged_states["req2"]
-
- # First, check the cache hit case
- assert req1_merged_states.shape == torch.Size([orig_num_tokens_unpadded + num_new_toks_req1, HIDDEN_SIZE])
- # Ensure that the req1 merged states are the cached states + the new req1 states
- assert torch.all(req1_merged_states[:orig_num_tokens_unpadded] == orig_hidden_states)
- assert torch.all(req1_merged_states[-num_new_toks_req1:] == req1_new_states)
-
- # Next, ensure that the cache miss case only has the new states
- assert req2_merged_states.shape == torch.Size([num_new_toks_req2, HIDDEN_SIZE])
- assert torch.all(req2_merged_states == req2_new_states)
-
-
-@pytest.mark.parametrize("num_tokens_padded", [None, 16])
-@pytest.mark.parametrize(
- "feat_dims",
- [
- {"foo": 100, "bar": 100},
- {"foo": 100, "bar": 50, "baz": 10},
- ],
-)
-def test_get_merged_multimodal_outputs(feat_dims, num_tokens_padded, mocker):
- cache = get_omni_pcache_with_mm_tensors(feat_dims, seq_len=DEFAULT_SEQ_LEN)
-
- orig_num_tokens_unpadded = 8
- slot_offset = 8 # We'll put our states in slots 8, 9, 10, ..., 15
- orig_slot_mapping = torch.arange(slot_offset, slot_offset + orig_num_tokens_unpadded)
- feature_dims = {key: val.shape[-1] for key, val in cache.mm_outputs_cache.items()}
- orig_mm_outputs = {
- key: torch.rand((orig_num_tokens_unpadded, feature_dims[key]), dtype=DTYPE) for key in cache.mm_cache_keys
- }
-
- cache.update_omni_tensor_prefix_cache(
- hidden_states=None,
- multimodal_outputs=orig_mm_outputs,
- num_tokens_unpadded=orig_num_tokens_unpadded,
- slot_mapping=orig_slot_mapping,
- num_tokens_padded=num_tokens_padded,
- )
-
- # Similar to hs test- say that we have two requests, but only one of them is a cache hit
- num_new_toks_req1 = 3
- num_new_toks_req2 = 2
- cache.add_prefix_cached_new_req_id("req1")
-
- num_scheduled_tokens = {
- "req1": num_new_toks_req1,
- "req2": num_new_toks_req2,
- }
-
- new_mm_outputs = {}
- for mm_key in cache.mm_cache_keys:
- new_mm_outputs[mm_key] = torch.rand(
- (num_new_toks_req1 + num_new_toks_req2, feature_dims[mm_key]),
- dtype=DTYPE,
- )
- # We also want to make sure passthrough data (outside of our keys) isn't dropped
- new_mm_outputs["passthrough_data"] = "Something else"
- # Lists are a special case because we can't split them yet if we want to match
- # the nonprefix cache behavior, because this runs before post process.
- new_mm_outputs["passthrough_list"] = ["should", "not", "split"]
-
- input_batch = MockInputBatch(num_computed_tokens_cpu=torch.Tensor([orig_num_tokens_unpadded, 0]))
-
- mocker.patch(
- "vllm_omni.core.prefix_cache.OmniTensorPrefixCache._get_cached_block_ids",
- new=fake_get_cached_block_ids,
- )
- merged_mm_outputs = cache.get_merged_multimodal_states(
- query_start_loc=[0, num_new_toks_req1],
- input_batch=input_batch,
- multimodal_outputs=new_mm_outputs,
- num_scheduled_tokens=num_scheduled_tokens,
- )
-
- # Ensure the passthrough data wasn't dropped
- assert "passthrough_data" in merged_mm_outputs
- assert "passthrough_list" in merged_mm_outputs
-
- for mm_key, mm_output in merged_mm_outputs.items():
- # Ensure passthrough data is just forwarded normally and not duplicated
- assert isinstance(mm_output, dict)
- assert "req1" in mm_output and "req2" in mm_output
- if mm_key == "passthrough_data":
- assert mm_key not in cache.mm_cache_keys
- assert new_mm_outputs[mm_key] == mm_output["req1"]
- assert new_mm_outputs[mm_key] == mm_output["req2"]
- elif mm_key == "passthrough_list":
- assert mm_key not in cache.mm_cache_keys
- assert new_mm_outputs[mm_key] == mm_output["req1"]
- assert new_mm_outputs[mm_key] == mm_output["req2"]
- else:
- assert mm_key in cache.mm_cache_keys
- curr_feat_dim = feature_dims[mm_key]
- # Ensure that req1 (cache hit) merged the mm data
- req1_merged_mm_outputs = mm_output["req1"]
- req1_new_mm_outputs = new_mm_outputs[mm_key][:num_new_toks_req1]
-
- assert req1_merged_mm_outputs.shape == torch.Size(
- [orig_num_tokens_unpadded + num_new_toks_req1, curr_feat_dim]
- )
- # Ensure that the req1 merged mm data are the cached data + the new data
- assert torch.all(req1_merged_mm_outputs[:orig_num_tokens_unpadded] == orig_mm_outputs[mm_key])
- assert torch.all(req1_merged_mm_outputs[-num_new_toks_req1:] == req1_new_mm_outputs)
-
- # Ensure that req2 (cache miss) only has the new mm data
- req2_merged_mm_outputs = mm_output["req2"]
- req2_new_mm_outputs = new_mm_outputs[mm_key][-num_new_toks_req2:]
-
- assert req2_merged_mm_outputs.shape == torch.Size([num_new_toks_req2, curr_feat_dim])
- assert torch.all(req2_merged_mm_outputs == req2_new_mm_outputs)
diff --git a/tests/dfx/conftest.py b/tests/dfx/conftest.py
index 12eb8e6f1b5..e54141b3442 100644
--- a/tests/dfx/conftest.py
+++ b/tests/dfx/conftest.py
@@ -1,13 +1,8 @@
import json
-import os
-import subprocess
-from datetime import datetime
from pathlib import Path
from typing import Any
-import pytest
-
-from tests.helpers.stage_config import modify_stage_config
+from tests.conftest import modify_stage_config
def load_configs(config_path: str) -> list[dict[str, Any]]:
@@ -40,70 +35,25 @@ def modify_stage(default_path, updates, deletes):
return path
-def _build_serve_args(serve_args: Any) -> list[str]:
- """Convert server_params.serve_args to a flat CLI args list."""
- if serve_args is None:
- return []
- if isinstance(serve_args, list):
- return [str(item) for item in serve_args]
- if not isinstance(serve_args, dict):
- raise TypeError(f"serve_args must be dict/list/None, got {type(serve_args).__name__}")
-
- args: list[str] = []
- for key, value in serve_args.items():
- flag = f"--{str(key).replace('_', '-')}"
- if isinstance(value, bool):
- if value:
- args.append(flag)
- continue
- if value is None:
- continue
- if isinstance(value, (dict, list)):
- args.extend([flag, json.dumps(value, ensure_ascii=False, separators=(",", ":"))])
- continue
- args.extend([flag, str(value)])
- return args
-
-
def create_unique_server_params(
configs: list[dict[str, Any]],
stage_configs_dir: Path,
-) -> list[tuple[str, str, str | None, str | None, tuple[str, ...]]]:
- """Return one row per unique server configuration (same 5-tuple shape as upstream).
-
- ``(test_name, model, deploy_yaml_path, stage_overrides_json, extra_cli_args)``.
-
- JSON ``server_params.serve_args`` (dict/list) is expanded via ``_build_serve_args``
- and **prepended** to ``extra_cli_args`` so perf / stability ``omni_server`` fixtures
- stay identical to main while still honoring ``serve_args`` in benchmark JSON.
- """
- unique_params: list[tuple[str, str, str | None, str | None, tuple[str, ...]]] = []
- seen: set[tuple[str, str, str | None, str | None, tuple[str, ...]]] = set()
+) -> list[tuple[str, str, str]]:
+ unique_params = []
+ seen = set()
for config in configs:
test_name = config["test_name"]
- server_params = config["server_params"]
- model = server_params["model"]
- stage_config_name = server_params.get("stage_config_name")
+ model = config["server_params"]["model"]
+ stage_config_name = config["server_params"].get("stage_config_name")
if stage_config_name:
stage_config_path = str(stage_configs_dir / stage_config_name)
- delete = server_params.get("delete", None)
- update = server_params.get("update", None)
+ delete = config["server_params"].get("delete", None)
+ update = config["server_params"].get("update", None)
stage_config_path = modify_stage(stage_config_path, update, delete)
else:
stage_config_path = None
- stage_overrides = server_params.get("stage_overrides")
- stage_overrides_json = json.dumps(stage_overrides) if stage_overrides else None
-
- # ``extra_cli_args`` passes raw CLI flags straight through to
- # ``vllm_omni.entrypoints.cli.main serve`` — used for flags that
- # don't map to stage-level overrides, e.g. ``--async-chunk`` /
- # ``--no-async-chunk`` toggling the deploy-level async_chunk bool.
- serve_flat = _build_serve_args(server_params.get("serve_args"))
- raw_extra = tuple(server_params.get("extra_cli_args") or ())
- extra_cli_args = tuple(serve_flat) + raw_extra
-
- server_param = (test_name, model, stage_config_path, stage_overrides_json, extra_cli_args)
+ server_param = (test_name, model, stage_config_path)
if server_param not in seen:
seen.add(server_param)
unique_params.append(server_param)
@@ -120,11 +70,7 @@ def create_test_parameter_mapping(configs: list[dict[str, Any]]) -> dict[str, di
"test_name": test_name,
"benchmark_params": [],
}
- for entry in config["benchmark_params"]:
- # Skip disabled entries
- if not entry.get("enabled", True):
- continue
- mapping[test_name]["benchmark_params"].append(entry)
+ mapping[test_name]["benchmark_params"].extend(config["benchmark_params"])
return mapping
@@ -149,146 +95,3 @@ def create_benchmark_indices(
indices.append((test_name, idx))
return indices
-
-
-def _safe_filename_token(value: Any | None, *, default: str = "na") -> str:
- """Make a single path segment safe for result filenames on common filesystems."""
- if value is None:
- return default
- s = str(value).strip()
- for bad in ("/", "\\", ":", "*", "?", '"', "<", ">", "|"):
- s = s.replace(bad, "_")
- return s if s else default
-
-
-def _resolve_baseline_value(
- baseline_raw: Any,
- *,
- sweep_index: int | None,
- max_concurrency: Any = None,
- request_rate: Any = None,
-) -> Any:
- """Pick the baseline threshold for this sweep step."""
- if baseline_raw is None:
- return 100000
- if isinstance(baseline_raw, dict):
- if max_concurrency is not None:
- for key in (max_concurrency, str(max_concurrency)):
- if key in baseline_raw:
- return baseline_raw[key]
- if request_rate is not None:
- for key in (request_rate, str(request_rate)):
- if key in baseline_raw:
- return baseline_raw[key]
- raise KeyError(
- f"baseline dict has no key for max_concurrency={max_concurrency!r} "
- f"or request_rate={request_rate!r}; keys={list(baseline_raw.keys())!r}"
- )
- if isinstance(baseline_raw, (list, tuple)):
- if sweep_index is None:
- raise ValueError("list baseline requires sweep_index")
- if not (0 <= sweep_index < len(baseline_raw)):
- raise IndexError(f"baseline list len={len(baseline_raw)} has no index {sweep_index}")
- return baseline_raw[sweep_index]
- return baseline_raw
-
-
-def _baseline_thresholds_for_step(
- baseline_data: dict[str, Any],
- *,
- sweep_index: int | None = None,
- max_concurrency: Any = None,
- request_rate: Any = None,
-) -> dict[str, Any]:
- """Resolve baseline config to one threshold per metric for this iteration."""
- return {
- metric_name: _resolve_baseline_value(
- baseline_raw,
- sweep_index=sweep_index,
- max_concurrency=max_concurrency,
- request_rate=request_rate,
- )
- for metric_name, baseline_raw in baseline_data.items()
- }
-
-
-def run_benchmark(
- args: list[str],
- test_name: str,
- flow: Any,
- dataset_name: str,
- num_prompt: int,
- *,
- baseline_config: dict[str, Any] | None = None,
- sweep_index: int | None = None,
- request_rate: Any | None = None,
- max_concurrency: Any | None = None,
- random_input_len: Any | None = None,
- random_output_len: Any | None = None,
-) -> dict[str, Any]:
- """Run one ``vllm bench serve --omni`` iteration and return parsed metrics."""
- current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
- ri = _safe_filename_token(random_input_len)
- ro = _safe_filename_token(random_output_len)
- result_filename = f"result_{test_name}_{dataset_name}_{flow}_{num_prompt}_in{ri}_out{ro}_{current_dt}.json"
- if "--result-filename" in args:
- print(f"The result file will be overwritten by {result_filename}")
- command = (
- ["vllm", "bench", "serve", "--omni"]
- + args
- + [
- "--num-warmups",
- "2",
- "--save-result",
- "--result-dir",
- os.environ.get("BENCHMARK_DIR", "tests"),
- "--result-filename",
- result_filename,
- ]
- )
- process = subprocess.Popen(
- command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, universal_newlines=True
- )
-
- for line in iter(process.stdout.readline, ""):
- print(line, end=" ")
-
- for line in iter(process.stderr.readline, ""):
- print(line, end=" ")
-
- if "--result-dir" in command:
- index = command.index("--result-dir")
- result_dir = command[index + 1]
- else:
- result_dir = "./"
-
- result_path = os.path.join(result_dir, result_filename)
- with open(result_path, encoding="utf-8") as f:
- result = json.load(f)
-
- if baseline_config:
- result["baseline"] = _baseline_thresholds_for_step(
- baseline_config,
- sweep_index=sweep_index,
- request_rate=request_rate,
- max_concurrency=max_concurrency,
- )
- else:
- result["baseline"] = {}
- if random_input_len is not None:
- result["random_input_len"] = random_input_len
- if random_output_len is not None:
- result["random_output_len"] = random_output_len
- with open(result_path, "w", encoding="utf-8") as f:
- json.dump(result, f, ensure_ascii=False, indent=2)
- return result
-
-
-def pytest_addoption(parser: pytest.Parser) -> None:
- """Register shared CLI options for DFX benchmark suites."""
- parser.addoption(
- "--test-config-file",
- action="store",
- default=None,
- help=("Path to benchmark config JSON. Example: --test-config-file tests/dfx/perf/tests/test_tts.json"),
- )
diff --git a/tests/dfx/perf/scripts/diffusion_result_template.json b/tests/dfx/perf/scripts/diffusion_result_template.json
deleted file mode 100644
index 86bdf1bc7aa..00000000000
--- a/tests/dfx/perf/scripts/diffusion_result_template.json
+++ /dev/null
@@ -1,86 +0,0 @@
-[
- {
- "test_name": null,
- "backend": null,
- "timestamp": null,
- "server_params": {
- "model": null,
- "serve_args": {
- "enable-diffusion-pipeline-profiler": false
- }
- },
- "benchmark_params": {
- "name": null,
- "dataset": null,
- "task": null,
- "width": 0,
- "height": 0,
- "num-inference-steps": 0,
- "num-prompts": 0,
- "max-concurrency": 0,
- "num-input-images": 0,
- "enable-negative-prompt": false,
- "baseline": {
- "throughput_qps": 0,
- "latency_mean": 0,
- "peak_memory_mb_max": 0,
- "peak_memory_mb_mean": 0
- }
- },
- "result": {
- "duration": 0,
- "completed_requests": 0,
- "failed_requests": 0,
- "throughput_qps": 0,
- "latency_mean": 0,
- "latency_median": 0,
- "latency_p99": 0,
- "latency_p95": 0,
- "latency_p50": 0,
- "peak_memory_mb_max": 0,
- "peak_memory_mb_mean": 0,
- "peak_memory_mb_median": 0,
- "stage_durations_mean": {},
- "stage_durations_p50": {},
- "stage_durations_p99": {},
- "backend": null,
- "model": null,
- "dataset": null,
- "task": null
- },
- "log_file": null,
- "Model": null,
- "Framework": null,
- "Hardware": null,
- "Deployment": null,
- "Task": null,
- "Dataset": null,
- "resolution": null,
- "Parallelism": null,
- "max_concurrency": 0,
- "Cache": null,
- "Quantization": null,
- "offload": null,
- "compile": null,
- "Attn_backend": null,
- "num_inference_steps": 0,
- "completed": 0,
- "failed": 0,
- "throughput_qps": 0,
- "latency_mean": 0,
- "latency_median": 0,
- "latency_p99": 0,
- "latency_p95": 0,
- "latency_p50": 0,
- "peak_memory_mb_max": 0,
- "peak_memory_mb_mean": 0,
- "peak_memory_mb_median": 0,
- "stage_durations_mean": {},
- "stage_durations_p50": {},
- "stage_durations_p99": {},
- "commit_sha": null,
- "build_id": null,
- "build_url": null,
- "source_file": null
- }
-]
diff --git a/tests/dfx/perf/scripts/result_omni_template.json b/tests/dfx/perf/scripts/result_omni_template.json
deleted file mode 100644
index 1d61321407e..00000000000
--- a/tests/dfx/perf/scripts/result_omni_template.json
+++ /dev/null
@@ -1,55 +0,0 @@
-{
- "date": null,
- "endpoint_type": null,
- "backend": null,
- "label": null,
- "model_id": null,
- "tokenizer_id": null,
- "num_prompts": 0,
- "request_rate": null,
- "burstiness": 0,
- "max_concurrency": 0,
- "duration": 0,
- "completed": 0,
- "failed": 0,
- "total_input_tokens": 0,
- "total_output_tokens": 0,
- "request_throughput": 0,
- "request_goodput": null,
- "output_throughput": 0,
- "total_token_throughput": 0,
- "total_audio_duration_s": 0,
- "total_audio_frames": 0,
- "audio_throughput": 0,
- "max_output_tokens_per_s": 0,
- "max_concurrent_requests": 0,
- "rtfx": 0,
- "mean_ttft_ms": 0,
- "median_ttft_ms": 0,
- "p99_ttft_ms": 0,
- "mean_tpot_ms": 0,
- "median_tpot_ms": 0,
- "p99_tpot_ms": 0,
- "mean_itl_ms": 0,
- "median_itl_ms": 0,
- "p99_itl_ms": 0,
- "mean_e2el_ms": 0,
- "median_e2el_ms": 0,
- "p99_e2el_ms": 0,
- "mean_audio_rtf": 0,
- "median_audio_rtf": 0,
- "p99_audio_rtf": 0,
- "mean_audio_ttfp_ms": 0,
- "median_audio_ttfp_ms": 0,
- "p99_audio_ttfp_ms": 0,
- "mean_audio_duration_s": 0,
- "median_audio_duration_s": 0,
- "p99_audio_duration_s": 0,
- "baseline": {
- "mean_ttft_ms": 0,
- "mean_audio_ttfp_ms": 0,
- "mean_audio_rtf": 0
- },
- "random_input_len": 0,
- "random_output_len": 0
-}
diff --git a/tests/dfx/perf/scripts/run_benchmark.py b/tests/dfx/perf/scripts/run_benchmark.py
index 9036508cb1c..9e375fa9fec 100644
--- a/tests/dfx/perf/scripts/run_benchmark.py
+++ b/tests/dfx/perf/scripts/run_benchmark.py
@@ -8,6 +8,7 @@
import pytest
+from tests.conftest import OmniServer
from tests.dfx.conftest import (
create_benchmark_indices,
create_test_parameter_mapping,
@@ -15,44 +16,17 @@
get_benchmark_params_for_server,
load_configs,
)
-from tests.helpers.runtime import OmniServer
-
-pytestmark = [pytest.mark.full_model, pytest.mark.omni]
-
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-def _get_config_file_from_argv() -> str | None:
- """Read ``--test-config-file`` from ``sys.argv`` at import time so parametrization can use it."""
- import sys
-
- for i, arg in enumerate(sys.argv):
- if arg == "--test-config-file" and i + 1 < len(sys.argv):
- return sys.argv[i + 1]
- if arg.startswith("--test-config-file="):
- return arg.split("=", 1)[1]
- return None
-
-
-_PERF_TESTS_DIR = Path(__file__).resolve().parent.parent / "tests"
-_DEFAULT_CONFIG_FILE = str(_PERF_TESTS_DIR / "test_qwen_omni.json")
-
-CONFIG_FILE_PATH = _get_config_file_from_argv()
-if CONFIG_FILE_PATH is None:
- print(
- "No --test-config-file in argv, using default: tests/dfx/perf/tests/test_qwen_omni.json "
- "(override with e.g. --test-config-file tests/dfx/perf/tests/test_tts.json)"
- )
- CONFIG_FILE_PATH = _DEFAULT_CONFIG_FILE
-
+CONFIG_FILE_PATH = str(Path(__file__).parent.parent / "tests" / "test.json")
BENCHMARK_CONFIGS = load_configs(CONFIG_FILE_PATH)
-OMNI_RESULT_TEMPLATE_PATH = Path(__file__).parent / "result_omni_template.json"
-DEPLOY_CONFIGS_DIR = Path(__file__).parent.parent / "deploy"
-test_params = create_unique_server_params(BENCHMARK_CONFIGS, DEPLOY_CONFIGS_DIR)
+STAGE_CONFIGS_DIR = Path(__file__).parent.parent / "stage_configs"
+test_params = create_unique_server_params(BENCHMARK_CONFIGS, STAGE_CONFIGS_DIR)
server_to_benchmark_mapping = create_test_parameter_mapping(BENCHMARK_CONFIGS)
_omni_server_lock = threading.Lock()
@@ -65,19 +39,13 @@ def omni_server(request):
Multi-stage initialization can take 10-20+ minutes.
"""
with _omni_server_lock:
- test_name, model, stage_config_path, stage_overrides, extra_cli_args = request.param
+ test_name, model, stage_config_path = request.param
print(f"Starting OmniServer with test: {test_name}, model: {model}")
- server_args = ["--stage-init-timeout", "600", "--init-timeout", "900"]
- # --deploy-config and --stage-overrides compose at the CLI (see vllm_omni/entrypoints/utils.py):
- # deploy-config sets the base; stage-overrides are applied on top. Both can be set.
+ server_args = ["--stage-init-timeout", "120"]
if stage_config_path:
- server_args = ["--deploy-config", stage_config_path] + server_args
- if stage_overrides:
- server_args = ["--stage-overrides", stage_overrides] + server_args
- if extra_cli_args:
- server_args = list(extra_cli_args) + server_args
+ server_args = ["--stage-configs-path", stage_config_path] + server_args
with OmniServer(model, server_args) as server:
server.test_name = test_name
print("OmniServer started successfully")
@@ -87,41 +55,16 @@ def omni_server(request):
print("OmniServer stopped")
-def _safe_filename_token(value: Any | None, *, default: str = "na") -> str:
- """Make a single path segment safe for result filenames on common filesystems."""
- if value is None:
- return default
- s = str(value).strip()
- for bad in ("/", "\\", ":", "*", "?", '"', "<", ">", "|"):
- s = s.replace(bad, "_")
- return s if s else default
-
-
def run_benchmark(
args: list,
test_name: str,
flow,
dataset_name: str,
num_prompt,
- *,
- baseline_config: dict[str, Any] | None = None,
- sweep_index: int | None = None,
- request_rate: Any | None = None,
- max_concurrency: Any | None = None,
- random_input_len: Any | None = None,
- random_output_len: Any | None = None,
) -> Any:
- """Run a single benchmark iteration and return the parsed result JSON.
-
- After ``vllm bench`` writes the JSON, ``result["baseline"]`` holds the same
- per-metric resolved thresholds as ``assert_result`` (via ``_baseline_thresholds_for_step``).
- When ``random_input_len`` / ``random_output_len`` are set, they are also written into the result JSON;
- omitted keys when not configured.
- """
+ """Run a single benchmark iteration and return the parsed result JSON."""
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
- ri = _safe_filename_token(random_input_len)
- ro = _safe_filename_token(random_output_len)
- result_filename = f"result_{test_name}_{dataset_name}_{flow}_{num_prompt}_in{ri}_out{ro}_{current_dt}.json"
+ result_filename = f"result_{test_name}_{dataset_name}_{flow}_{num_prompt}_{current_dt}.json"
if "--result-filename" in args:
print(f"The result file will be overwritten by {result_filename}")
command = (
@@ -151,34 +94,8 @@ def run_benchmark(
else:
result_dir = "./"
- result_path = os.path.join(result_dir, result_filename)
- if not os.path.exists(result_path):
- with open(OMNI_RESULT_TEMPLATE_PATH, encoding="utf-8") as f:
- template_result: dict[str, Any] = json.load(f)
- Path(result_path).parent.mkdir(parents=True, exist_ok=True)
- with open(result_path, "w", encoding="utf-8") as f:
- json.dump(template_result, f, ensure_ascii=False, indent=2)
- print(f"Benchmark result file not generated, fallback to template: {result_path}")
- result = template_result
- else:
- with open(result_path, encoding="utf-8") as f:
- result = json.load(f)
-
- if baseline_config:
- result["baseline"] = _baseline_thresholds_for_step(
- baseline_config,
- sweep_index=sweep_index,
- request_rate=request_rate,
- max_concurrency=max_concurrency,
- )
- else:
- result["baseline"] = {}
- if random_input_len is not None:
- result["random_input_len"] = random_input_len
- if random_output_len is not None:
- result["random_output_len"] = random_output_len
- with open(result_path, "w", encoding="utf-8") as f:
- json.dump(result, f, ensure_ascii=False, indent=2)
+ with open(os.path.join(result_dir, result_filename), encoding="utf-8") as f:
+ result = json.load(f)
return result
@@ -248,25 +165,6 @@ def _resolve_baseline_value(
return baseline_raw
-def _baseline_thresholds_for_step(
- baseline_data: dict[str, Any],
- *,
- sweep_index: int | None = None,
- max_concurrency: Any = None,
- request_rate: Any = None,
-) -> dict[str, Any]:
- """Resolve ``test.json`` ``baseline`` block to one threshold per metric (same as ``assert_result``)."""
- return {
- metric_name: _resolve_baseline_value(
- baseline_raw,
- sweep_index=sweep_index,
- max_concurrency=max_concurrency,
- request_rate=request_rate,
- )
- for metric_name, baseline_raw in baseline_data.items()
- }
-
-
def assert_result(
result,
params,
@@ -296,7 +194,6 @@ def assert_result(
print(f"ERROR: Test results exceeded baseline: {metric_name}: {current_value} < {baseline_value}")
-@pytest.mark.benchmark
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
@pytest.mark.parametrize("benchmark_params", benchmark_indices, indirect=True)
def test_performance_benchmark(omni_server, benchmark_params):
@@ -333,7 +230,7 @@ def to_list(value, default=None):
raise ValueError("The number of prompts does not match the QPS or max_concurrency")
args = ["--host", host, "--port", str(port)]
- exclude_keys = {"request_rate", "baseline", "num_prompts", "max_concurrency", "task", "enabled", "eval_phase"}
+ exclude_keys = {"request_rate", "baseline", "num_prompts", "max_concurrency"}
for key, value in params.items():
if key in exclude_keys or value is None:
@@ -358,12 +255,6 @@ def to_list(value, default=None):
flow=qps,
dataset_name=dataset_name,
num_prompt=num_prompt,
- baseline_config=params.get("baseline"),
- sweep_index=i,
- request_rate=qps,
- max_concurrency=None,
- random_input_len=params.get("random_input_len"),
- random_output_len=params.get("random_output_len"),
)
assert_result(
result,
@@ -382,12 +273,6 @@ def to_list(value, default=None):
flow=concurrency,
dataset_name=dataset_name,
num_prompt=num_prompt,
- baseline_config=params.get("baseline"),
- sweep_index=i,
- request_rate=None,
- max_concurrency=concurrency,
- random_input_len=params.get("random_input_len"),
- random_output_len=params.get("random_output_len"),
)
assert_result(
result,
diff --git a/tests/dfx/perf/scripts/run_diffusion_benchmark.py b/tests/dfx/perf/scripts/run_diffusion_benchmark.py
index 7513c2d3f98..1bd9bf1a143 100644
--- a/tests/dfx/perf/scripts/run_diffusion_benchmark.py
+++ b/tests/dfx/perf/scripts/run_diffusion_benchmark.py
@@ -1,16 +1,15 @@
"""
Performance benchmark CI runner for diffusion models.
-This runner separates two concepts:
+Supports vLLM-Omni server backend:
+ - vllm-omni (default): starts DiffusionServer via vllm_omni.entrypoints.cli.main,
+ benchmarks with diffusion_benchmark_serving.py --backend vllm-omni
-1. ``server_type``: how the serving process is started.
- Currently only ``vllm-omni`` is supported here.
-2. ``benchmark_backend``: which serving API the benchmark client calls.
- Examples: ``vllm-omni`` for ``/v1/chat/completions`` and ``v1/videos``
- for async video jobs.
+A config JSON file is REQUIRED via --config-file:
+ pytest run_diffusion_benchmark.py --config-file tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
-A config JSON file is REQUIRED via --test-config-file:
- pytest run_diffusion_benchmark.py --test-config-file tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
+JSON config entries use a "server_type" field, and this runner executes
+the vllm-omni path.
All benchmark results for a session are consolidated into a single JSON file under
BENCHMARK_RESULT_DIR (override via the DIFFUSION_BENCHMARK_DIR environment variable).
@@ -28,16 +27,13 @@
import time
from datetime import datetime
from pathlib import Path
-from typing import Any, cast
+from typing import Any
import psutil
import pytest
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
-
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-os.environ.setdefault("DIFFUSION_ATTENTION_BACKEND", "FLASH_ATTN")
# ---------------------------------------------------------------------------
# Paths
@@ -54,21 +50,19 @@
# Populated lazily after CONFIG_FILE_PATH is resolved.
_SESSION_TIMESTAMP = datetime.now().strftime("%Y%m%d-%H%M%S")
_RESULT_LOCK = threading.Lock()
-_BRANCHPOINT_COMMIT_SHA: str | None = None
-DIFFUSION_RESULT_TEMPLATE_PATH = Path(__file__).parent / "diffusion_result_template.json"
def _get_config_file_from_argv() -> str | None:
- """Read --test-config-file from sys.argv at import time so pytest parametrize can use it.
+ """Read --config-file from sys.argv at import time so pytest parametrize can use it.
pytest_addoption (below) registers the same flag so pytest does not reject it.
- Supports both ``--test-config-file path`` and ``--test-config-file=path`` forms.
+ Supports both ``--config-file path`` and ``--config-file=path`` forms.
Returns None if the flag is not present; callers must handle the missing case.
"""
for i, arg in enumerate(sys.argv):
- if arg == "--test-config-file" and i + 1 < len(sys.argv):
+ if arg == "--config-file" and i + 1 < len(sys.argv):
return sys.argv[i + 1]
- if arg.startswith("--test-config-file="):
+ if arg.startswith("--config-file="):
return arg.split("=", 1)[1]
return None
@@ -116,7 +110,7 @@ def load_configs(config_path: str) -> list[dict[str, Any]]:
BENCHMARK_CONFIGS = load_configs(CONFIG_FILE_PATH)
_config_stem = Path(CONFIG_FILE_PATH).stem # e.g. "test_qwen_image_vllm_omni"
-AGGREGATED_RESULT_FILE = BENCHMARK_RESULT_DIR / f"diffusion_result_{_config_stem}_{_SESSION_TIMESTAMP}.json"
+AGGREGATED_RESULT_FILE = BENCHMARK_RESULT_DIR / f"benchmark_results_{_config_stem}_{_SESSION_TIMESTAMP}.json"
def _append_to_aggregated_file(record: dict[str, Any]) -> None:
@@ -137,6 +131,19 @@ def _append_to_aggregated_file(record: dict[str, Any]) -> None:
json.dump(records, f, indent=2, ensure_ascii=False)
+# Register --config-file with pytest so it does not reject the argument.
+def pytest_addoption(parser: pytest.Parser) -> None:
+ parser.addoption(
+ "--config-file",
+ action="store",
+ default=None,
+ help=(
+ "Path to the benchmark config JSON file (required). "
+ "Example: --config-file tests/dfx/perf/tests/test_qwen_image_vllm_omni.json"
+ ),
+ )
+
+
_server_lock = threading.Lock()
# ---------------------------------------------------------------------------
@@ -225,13 +232,13 @@ class DiffusionServer:
def __init__(
self,
- server_cfg: dict[str, Any],
+ model: str,
+ serve_args: list[str],
*,
port: int | None = None,
) -> None:
- self.server_cfg: dict[str, Any] = server_cfg
- self.model = server_cfg["model"]
- self.serve_args = server_cfg["serve_args"]
+ self.model = model
+ self.serve_args = serve_args
self.host = "127.0.0.1"
self.port = port if port is not None else _get_open_port()
self.proc: subprocess.Popen | None = None
@@ -292,95 +299,6 @@ def _build_serve_args(serve_args_dict: dict[str, Any]) -> list[str]:
return args
-def _get_branchpoint_commit_sha() -> str:
- """Return the branch-point commit SHA against main.
-
- Uses git command: ``git merge-base HEAD origin/main``.
- """
- global _BRANCHPOINT_COMMIT_SHA
- if _BRANCHPOINT_COMMIT_SHA is not None:
- return _BRANCHPOINT_COMMIT_SHA
-
- repo_root = Path(__file__).parent.parent.parent.parent
- try:
- sha = (
- subprocess.check_output(
- ["git", "merge-base", "HEAD", "origin/main"],
- cwd=str(repo_root),
- stderr=subprocess.STDOUT,
- text=True,
- )
- .strip()
- .splitlines()[0]
- )
- _BRANCHPOINT_COMMIT_SHA = sha
- except Exception as e:
- print(f"Warning: failed to get branch-point commit SHA: {e}")
- _BRANCHPOINT_COMMIT_SHA = ""
- return _BRANCHPOINT_COMMIT_SHA
-
-
-def _to_resolution_string(params: dict[str, Any]) -> str:
- width = params.get("width", "unknown width")
- height = params.get("height", "unknown height")
- return f"{width}x{height}"
-
-
-def _to_parallelism_string(framework: str, serve_args_dict: dict[str, Any]) -> str:
- parts: list[str] = []
- if framework == "vllm-omni":
- keys = [
- "num-gpus",
- "usp",
- "ulysses-degree",
- "ring",
- "ring-degree",
- "cfg-parallel-size",
- "vae-patch-parallel-size",
- "vae-use-tiling",
- "tensor-parallel-size",
- ]
- for key in keys:
- if key in serve_args_dict:
- parts.append(f"{key}={serve_args_dict[key]}")
- return ",".join(parts) if parts else "none"
-
-
-def _to_cache_string(framework: str, serve_args_dict: dict[str, Any]) -> str:
- if framework == "vllm-omni":
- if "cache-backend" in serve_args_dict:
- return str(serve_args_dict["cache-backend"])
- return "disabled"
-
-
-def _to_offload_string(framework: str, serve_args_dict: dict[str, Any]) -> str:
- selected: list[str] = []
- if framework == "vllm-omni":
- offload_keys = [
- "enable-cpu-offload",
- "enable-layerwise-offload",
- ]
- for key in offload_keys:
- if key in serve_args_dict:
- selected.append(key)
- return f"enabled({';'.join(selected)})" if selected else "disabled"
-
-
-def _to_compile_value(framework: str, serve_args_dict: dict[str, Any]) -> str:
- if framework == "vllm-omni":
- if "enforce-eager" in serve_args_dict:
- return "disabled"
- return "enabled"
- return "disabled"
-
-
-def _to_quantization_value(framework: str, serve_args_dict: dict[str, Any]) -> str:
- if framework == "vllm-omni":
- quant = serve_args_dict.get("quantization")
- return str(quant) if quant else "disabled"
- return "disabled"
-
-
def _unique_server_params(configs: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Return one server-config dict per unique test_name."""
seen: set[str] = set()
@@ -390,18 +308,15 @@ def _unique_server_params(configs: list[dict[str, Any]]) -> list[dict[str, Any]]
if test_name in seen:
continue
seen.add(test_name)
- server_type = cfg.get("server_type", "vllm-omni")
- if server_type != "vllm-omni":
- raise ValueError(f"Unsupported server_type in config: {server_type}")
- serve_args_dict = cfg["server_params"].get("serve_args", {})
+ if cfg.get("server_type", "vllm-omni") != "vllm-omni":
+ raise ValueError(f"Unsupported server_type in config: {cfg.get('server_type')}")
result.append(
{
"test_name": test_name,
- "server_type": server_type,
+ "server_type": "vllm-omni",
"model": cfg["server_params"]["model"],
- "serve_args_dict": serve_args_dict,
- "serve_args": _build_serve_args(serve_args_dict),
- "benchmark_backend": cfg.get("benchmark_backend"),
+ "serve_args": _build_serve_args(cfg["server_params"].get("serve_args", {})),
+ "benchmark_backend": "vllm-omni",
"server_params": cfg["server_params"],
}
)
@@ -419,7 +334,9 @@ def _test_param_mapping(configs: list[dict[str, Any]]) -> dict[str, list[dict]]:
def _make_server(server_cfg: dict[str, Any]) -> DiffusionServer:
"""Factory: return a vLLM-Omni diffusion server instance for the config."""
- return DiffusionServer(server_cfg=server_cfg)
+ model = server_cfg["model"]
+ serve_args = server_cfg["serve_args"]
+ return DiffusionServer(model=model, serve_args=serve_args)
# ---------------------------------------------------------------------------
@@ -447,6 +364,7 @@ def diffusion_server(request):
print(f"\nStarting {server_type} server for test: {test_name}")
with _make_server(server_cfg) as server:
server.test_name = test_name
+ server.server_params = server_cfg["server_params"]
print(f"{server_type} server started successfully")
yield server
print(f"{server_type} server stopping…")
@@ -484,25 +402,22 @@ def run_benchmark(
params: dict[str, Any],
test_name: str,
backend: str = "vllm-omni",
- server_cfg: dict[str, Any] | None = None,
- source_file: str = "",
+ server_params: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Run diffusion_benchmark_serving.py as a subprocess and return parsed metrics.
The raw metrics are written to a temporary file by the subprocess. After
the run completes the metrics are merged with full metadata (test_name,
- backend, benchmark_params, timestamp, flat reporting fields) and appended
- to the session-wide aggregated JSON file (AGGREGATED_RESULT_FILE). The
- temporary file is removed afterwards. Subprocess stdout/stderr are tee'd
- to a .log file under BENCHMARK_RESULT_DIR/logs/; its path is stored in
- the record.
+ backend, benchmark_params, timestamp) and appended to the session-wide
+ aggregated JSON file (AGGREGATED_RESULT_FILE). The temporary file is
+ removed afterwards. Subprocess stdout/stderr are tee'd to a .log file
+ under BENCHMARK_RESULT_DIR/logs/; its path is stored in the record.
"""
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = BENCHMARK_RESULT_DIR / "logs"
log_dir.mkdir(parents=True, exist_ok=True)
- backend_label = backend.replace("/", "_")
- log_file = log_dir / f"{test_name}_{backend_label}_{timestamp}.log"
+ log_file = log_dir / f"{test_name}_{backend}_{timestamp}.log"
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", prefix="diffusion_bench_tmp_", delete=False) as tmp:
tmp_result_file = Path(tmp.name)
@@ -569,17 +484,10 @@ def run_benchmark(
if process.returncode != 0:
tmp_result_file.unlink(missing_ok=True)
- print(f"ERROR:Benchmark script exited with code {process.returncode}")
+ raise RuntimeError(f"Benchmark script exited with code {process.returncode}")
if not tmp_result_file.exists():
- with open(DIFFUSION_RESULT_TEMPLATE_PATH, encoding="utf-8") as f:
- template_payload = json.load(f)
- # Template schema is fixed and owned by this repo:
- # ``diffusion_result_template.json`` is a one-item list and metrics live at [0]["result"].
- template_metrics: dict[str, Any] = template_payload[0]["result"]
- with open(tmp_result_file, "w", encoding="utf-8") as f:
- json.dump(template_metrics, f, ensure_ascii=False, indent=2)
- print(f"Benchmark result file not generated, fallback to template: {tmp_result_file}")
+ raise FileNotFoundError(f"Benchmark result file not found: {tmp_result_file}")
try:
with open(tmp_result_file, encoding="utf-8") as f:
@@ -587,57 +495,14 @@ def run_benchmark(
finally:
tmp_result_file.unlink(missing_ok=True)
- server_cfg = server_cfg or {}
- server_type = cast(str, server_cfg.get("server_type", "vllm-omni"))
- serve_args_dict = server_cfg.get("serve_args_dict", {})
- if not isinstance(serve_args_dict, dict):
- serve_args_dict = {}
-
- completed = metrics.get("completed_requests", metrics.get("completed", 0))
- failed = metrics.get("failed_requests", metrics.get("failed", 0))
-
record: dict[str, Any] = {
"test_name": test_name,
"backend": backend,
"timestamp": timestamp,
- "server_params": server_cfg.get("server_params"),
+ "server_params": server_params,
"benchmark_params": params,
"result": metrics,
"log_file": str(log_file),
- "Model": model,
- "Framework": server_type,
- "API Backend": backend,
- "Hardware": "",
- "Deployment": "",
- "Task": params.get("task", "t2i"),
- "Dataset": params.get("dataset", "random"),
- "resolution": _to_resolution_string(params),
- "Parallelism": _to_parallelism_string(server_type, serve_args_dict),
- "max_concurrency": params.get("max-concurrency", ""),
- "Cache": _to_cache_string(server_type, serve_args_dict),
- "Quantization": _to_quantization_value(server_type, serve_args_dict),
- "offload": _to_offload_string(server_type, serve_args_dict),
- "compile": _to_compile_value(server_type, serve_args_dict),
- "Attn_backend": os.environ.get("DIFFUSION_ATTENTION_BACKEND", ""),
- "num_inference_steps": params.get("num-inference-steps", ""),
- "completed": completed,
- "failed": failed,
- "throughput_qps": metrics.get("throughput_qps"),
- "latency_mean": metrics.get("latency_mean"),
- "latency_median": metrics.get("latency_median"),
- "latency_p99": metrics.get("latency_p99"),
- "latency_p95": metrics.get("latency_p95"),
- "latency_p50": metrics.get("latency_p50"),
- "peak_memory_mb_max": metrics.get("peak_memory_mb_max"),
- "peak_memory_mb_mean": metrics.get("peak_memory_mb_mean"),
- "peak_memory_mb_median": metrics.get("peak_memory_mb_median"),
- "stage_durations_mean": metrics.get("stage_durations_mean"),
- "stage_durations_p50": metrics.get("stage_durations_p50"),
- "stage_durations_p99": metrics.get("stage_durations_p99"),
- "commit_sha": _get_branchpoint_commit_sha(),
- "build_id": os.environ.get("BUILDKITE_BUILD_ID", ""),
- "build_url": os.environ.get("BUILDKITE_BUILD_URL", ""),
- "source_file": source_file,
}
_append_to_aggregated_file(record)
print(f"\n Result appended to: {AGGREGATED_RESULT_FILE}")
@@ -666,27 +531,11 @@ def assert_result(result: dict[str, Any], params: dict[str, Any]) -> None:
assert current <= threshold, f"{metric}: {current:.4f} > baseline {threshold}"
-def _default_benchmark_backend_for_task(task: str) -> str:
- """Return the default client-side benchmark backend for a diffusion task."""
- if task in {"t2v", "i2v", "ti2v"}:
- return "v1/videos"
- if task in {"t2i", "i2i", "ti2i"}:
- return "vllm-omni"
- raise ValueError(f"Unsupported task for benchmark backend resolution: {task}")
-
-
-def _resolve_benchmark_backend(server_cfg: dict[str, Any], params: dict[str, Any]) -> str:
- """Resolve which serving API the benchmark client should call."""
- configured = server_cfg.get("benchmark_backend")
- if configured:
- return cast(str, configured)
- return _default_benchmark_backend_for_task(cast(str, params.get("task", "t2i")))
-
-
# ---------------------------------------------------------------------------
# Test entry point
# ---------------------------------------------------------------------------
-@pytest.mark.benchmark
+
+
@pytest.mark.parametrize(
"diffusion_server",
server_params,
@@ -707,8 +556,7 @@ def test_diffusion_performance_benchmark(diffusion_server, benchmark_params):
"""
test_name = benchmark_params["test_name"]
params = benchmark_params["params"]
- server_cfg = getattr(diffusion_server, "server_cfg", {})
- backend = _resolve_benchmark_backend(server_cfg, params)
+ backend = diffusion_server.server_type # "vllm-omni"
result = run_benchmark(
host=diffusion_server.host,
@@ -717,8 +565,7 @@ def test_diffusion_performance_benchmark(diffusion_server, benchmark_params):
params=params,
test_name=test_name,
backend=backend,
- server_cfg=server_cfg,
- source_file=cast(str, CONFIG_FILE_PATH),
+ server_params=diffusion_server.server_params,
)
print(f"\n{'=' * 60}")
diff --git a/tests/dfx/perf/stage_configs/qwen3_omni.yaml b/tests/dfx/perf/stage_configs/qwen3_omni.yaml
new file mode 100644
index 00000000000..2add22b8732
--- /dev/null
+++ b/tests/dfx/perf/stage_configs/qwen3_omni.yaml
@@ -0,0 +1,101 @@
+# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
+# Stage 0: Thinker (multimodal understanding + text generation)
+# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
+# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
+
+# The following config has been verified on 2x H100-80G GPUs.
+async_chunk: false
+stage_args:
+ - stage_id: 0
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 64
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent # Output hidden states for talker
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ hf_config_name: thinker_config
+ tensor_parallel_size: 1
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ max_num_seqs: 64
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.6
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent # Output codec codes for code2wav
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ distributed_executor_backend: "mp"
+ hf_config_name: talker_config
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
+ # final_output: true
+ # final_output_type: text
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ detokenize: False
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+
+ - stage_id: 2
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "1"
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 64
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ async_scheduling: false
+ enable_prefix_caching: false
+ engine_output_type: audio # Final output: audio waveform
+ gpu_memory_utilization: 0.1
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 100000
+ hf_config_name: thinker_config
+ engine_input_source: [1]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
diff --git a/tests/dfx/perf/stage_configs/qwen3_tts.yaml b/tests/dfx/perf/stage_configs/qwen3_tts.yaml
new file mode 100644
index 00000000000..dd69b248d1a
--- /dev/null
+++ b/tests/dfx/perf/stage_configs/qwen3_tts.yaml
@@ -0,0 +1,96 @@
+# Stage config for running Qwen3-TTS with 2-stage architecture
+# Stage 0: Talker (text -> 8-layer RVQ codec codes)
+# Stage 1: Code2Wav (codec codes -> audio waveform)
+#
+# The following config has been verified on 1x H100-80G GPU.
+async_chunk: true
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ is_comprehension: true
+ runtime:
+ devices: "0"
+ engine_args:
+ max_num_seqs: 4
+ model_stage: qwen3_tts
+ model_arch: Qwen3TTSTalkerForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ enforce_eager: false
+ trust_remote_code: true
+ async_scheduling: false
+ enable_prefix_caching: false
+ engine_output_type: latent
+ gpu_memory_utilization: 0.3
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 512
+ max_model_len: 4096
+ custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
+ output_connectors:
+ to_stage_1: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ detokenize: false
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+
+ - stage_id: 1
+ stage_type: llm
+ runtime:
+ devices: "0"
+ engine_args:
+ max_num_seqs: 4
+ model_stage: code2wav
+ model_arch: Qwen3TTSCode2Wav
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ async_scheduling: false
+ enable_prefix_caching: false
+ engine_output_type: audio
+ gpu_memory_utilization: 0.2
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 8192
+ max_model_len: 32768
+ engine_input_source: [0]
+ final_output: true
+ final_output_type: audio
+ input_connectors:
+ from_stage_0: connector_of_shared_memory
+ tts_args:
+ max_instructions_length: 500
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ detokenize: true
+ repetition_penalty: 1.0
+
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 4
+
+ connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536
+ codec_streaming: true
+ connector_get_sleep_s: 0.01
+ connector_get_max_wait_first_chunk: 3000
+ connector_get_max_wait: 300
+ codec_chunk_frames: 25
+ codec_left_context_frames: 25
+
+ edges:
+ - from: 0
+ to: 1
+ window_size: -1
diff --git a/tests/dfx/perf/tests/test.json b/tests/dfx/perf/tests/test.json
new file mode 100644
index 00000000000..fe7e3804698
--- /dev/null
+++ b/tests/dfx/perf/tests/test.json
@@ -0,0 +1,236 @@
+[
+ {
+ "test_name": "test_qwen3_omni",
+ "server_params": {
+ "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
+ "stage_config_name": "qwen3_omni.yaml"
+ },
+ "benchmark_params": [
+ {
+ "dataset_name": "random",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [
+ 10,
+ 40,
+ 100
+ ],
+ "max_concurrency": [
+ 1,
+ 4,
+ 10
+ ],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "ignore_eos": true,
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [1000, 3000, 5000],
+ "mean_audio_ttfp_ms": [8000, 10000, 13000],
+ "mean_audio_rtf": [0.2, 0.25, 0.45]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [
+ 10,
+ 40,
+ 100
+ ],
+ "request_rate": [
+ 0.1,
+ 0.3,
+ 0.5
+ ],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
+ "ignore_eos": true,
+ "random_mm_base_items_per_request": 3,
+ "random_mm_num_mm_items_range_ratio": 0,
+ "random_mm_limit_mm_per_prompt": {
+ "image": 1,
+ "video": 1,
+ "audio": 1
+ },
+ "random_mm_bucket_config": {
+ "(32, 32, 1)": 0.5,
+ "(0, 1, 1)": 0.1,
+ "(32, 32, 2)": 0.4
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [2000, 4000, 6000],
+ "mean_audio_ttfp_ms": [10000, 13000, 15000],
+ "mean_audio_rtf": [0.25, 0.35, 0.45]
+ }
+ },
+ {
+ "dataset_name": "random",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [
+ 4,
+ 16
+ ],
+ "max_concurrency": [
+ 1,
+ 4
+ ],
+ "random_input_len": 2500,
+ "random_output_len": 900,
+ "ignore_eos": true,
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [1000, 3000],
+ "mean_audio_ttfp_ms": [30000, 60000],
+ "mean_audio_rtf": [0.35, 0.45]
+ }
+ }
+ ]
+ },
+ {
+ "test_name": "test_qwen3_omni_chunk",
+ "server_params": {
+ "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
+ "stage_config_name": "qwen3_omni.yaml",
+ "update": {
+ "async_chunk": true,
+ "stage_args": {
+ "0": {
+ "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk"
+ },
+ "1": {
+ "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk"
+ }
+ }
+ },
+ "delete": {
+ "stage_args": {
+ "2": [
+ "custom_process_input_func"
+ ]
+ }
+ }
+ },
+ "benchmark_params": [
+ {
+ "dataset_name": "random",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [
+ 10,
+ 40,
+ 100
+ ],
+ "max_concurrency": [
+ 1,
+ 4,
+ 10
+ ],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "ignore_eos": true,
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [1000, 3000, 5000],
+ "mean_audio_ttfp_ms": [1000, 3000, 5000],
+ "mean_audio_rtf": [0.2, 0.35, 0.6]
+ }
+ },
+ {
+ "dataset_name": "random-mm",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [
+ 10,
+ 40,
+ 100
+ ],
+ "request_rate": [
+ 0.1,
+ 0.3,
+ 0.5
+ ],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "random_range_ratio": 0.0,
+ "ignore_eos": true,
+ "random_mm_base_items_per_request": 3,
+ "random_mm_num_mm_items_range_ratio": 0,
+ "random_mm_limit_mm_per_prompt": {
+ "image": 1,
+ "video": 1,
+ "audio": 1
+ },
+ "random_mm_bucket_config": {
+ "(32, 32, 1)": 0.5,
+ "(0, 1, 1)": 0.1,
+ "(32, 32, 2)": 0.4
+ },
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [2000, 4000, 6000],
+ "mean_audio_ttfp_ms": [2000, 4000, 6000],
+ "mean_audio_rtf": [0.25, 0.4, 0.7]
+ }
+ },
+ {
+ "dataset_name": "random",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "num_prompts": [
+ 4,
+ 16
+ ],
+ "max_concurrency": [
+ 1,
+ 4
+ ],
+ "random_input_len": 2500,
+ "random_output_len": 900,
+ "ignore_eos": true,
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_ttft_ms": [1000, 3000],
+ "mean_audio_ttfp_ms": [1000, 3000],
+ "mean_audio_rtf": [0.35, 0.45]
+ }
+ }
+ ]
+ },
+ {
+ "test_name": "test_qwen3_tts",
+ "server_params": {
+ "model": "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
+ },
+ "benchmark_params": [
+ {
+ "dataset_name": "random",
+ "backend": "openai-audio-speech",
+ "endpoint": "/v1/audio/speech",
+ "num_prompts": [
+ 10,
+ 40
+ ],
+ "max_concurrency": [
+ 1,
+ 4
+ ],
+ "random_input_len": 100,
+ "random_output_len": 100,
+ "extra_body": {
+ "voice": "Vivian",
+ "language": "English"
+ },
+ "percentile-metrics": "ttft,e2el,audio_rtf,audio_ttfp,audio_duration",
+ "baseline": {
+ "mean_audio_ttfp_ms": [6000, 6000],
+ "mean_audio_rtf": [0.3, 0.3]
+ }
+ }
+ ]
+ }
+]
diff --git a/tests/dfx/perf/tests/test_ltx2_vllm_omni.json b/tests/dfx/perf/tests/test_ltx2_vllm_omni.json
deleted file mode 100644
index 4a6f9e3501f..00000000000
--- a/tests/dfx/perf/tests/test_ltx2_vllm_omni.json
+++ /dev/null
@@ -1,217 +0,0 @@
-[
- {
- "test_name": "test_ltx2_baseline_eager",
- "description": "Single-device baseline with enforce-eager (no torch.compile)",
- "server_type": "vllm-omni",
- "server_params": {
- "model": "Lightricks/LTX-2",
- "serve_args": {
- "enforce-eager": true,
- "enable-diffusion-pipeline-profiler": true
- }
- },
- "benchmark_params": [
- {
- "name": "256x256_145f_steps6",
- "dataset": "random",
- "task": "t2v",
- "backend": "v1/videos",
- "width": 256,
- "height": 256,
- "num-frames": 145,
- "fps": 24,
- "num-inference-steps": 6,
- "num-prompts": 3,
- "max-concurrency": 1,
- "enable-negative-prompt": true
- },
- {
- "name": "480x768_41f_steps20",
- "dataset": "random",
- "task": "t2v",
- "backend": "v1/videos",
- "width": 768,
- "height": 480,
- "num-frames": 41,
- "fps": 24,
- "num-inference-steps": 20,
- "num-prompts": 3,
- "max-concurrency": 1,
- "enable-negative-prompt": true
- }
- ]
- },
-
- {
- "test_name": "test_ltx2_torch_compile",
- "description": "Single-device with torch.compile (default, no enforce-eager)",
- "server_type": "vllm-omni",
- "server_params": {
- "model": "Lightricks/LTX-2",
- "serve_args": {
- "enable-diffusion-pipeline-profiler": true
- }
- },
- "benchmark_params": [
- {
- "name": "256x256_145f_steps6",
- "dataset": "random",
- "task": "t2v",
- "backend": "v1/videos",
- "width": 256,
- "height": 256,
- "num-frames": 145,
- "fps": 24,
- "num-inference-steps": 6,
- "num-prompts": 3,
- "max-concurrency": 1,
- "enable-negative-prompt": true
- },
- {
- "name": "480x768_41f_steps20",
- "dataset": "random",
- "task": "t2v",
- "backend": "v1/videos",
- "width": 768,
- "height": 480,
- "num-frames": 41,
- "fps": 24,
- "num-inference-steps": 20,
- "num-prompts": 3,
- "max-concurrency": 1,
- "enable-negative-prompt": true
- }
- ]
- },
-
- {
- "test_name": "test_ltx2_cfg2_eager",
- "description": "CFG-parallel=2 with enforce-eager",
- "server_type": "vllm-omni",
- "server_params": {
- "model": "Lightricks/LTX-2",
- "serve_args": {
- "cfg-parallel-size": 2,
- "enforce-eager": true,
- "enable-diffusion-pipeline-profiler": true
- }
- },
- "benchmark_params": [
- {
- "name": "256x256_145f_steps6",
- "dataset": "random",
- "task": "t2v",
- "backend": "v1/videos",
- "width": 256,
- "height": 256,
- "num-frames": 145,
- "fps": 24,
- "num-inference-steps": 6,
- "num-prompts": 3,
- "max-concurrency": 1,
- "enable-negative-prompt": true
- },
- {
- "name": "480x768_41f_steps20",
- "dataset": "random",
- "task": "t2v",
- "backend": "v1/videos",
- "width": 768,
- "height": 480,
- "num-frames": 41,
- "fps": 24,
- "num-inference-steps": 20,
- "num-prompts": 3,
- "max-concurrency": 1,
- "enable-negative-prompt": true
- }
- ]
- },
-
- {
- "test_name": "test_ltx2_cfg2_compile",
- "description": "CFG-parallel=2 with torch.compile",
- "server_type": "vllm-omni",
- "server_params": {
- "model": "Lightricks/LTX-2",
- "serve_args": {
- "cfg-parallel-size": 2,
- "enable-diffusion-pipeline-profiler": true
- }
- },
- "benchmark_params": [
- {
- "name": "256x256_145f_steps6",
- "dataset": "random",
- "task": "t2v",
- "backend": "v1/videos",
- "width": 256,
- "height": 256,
- "num-frames": 145,
- "fps": 24,
- "num-inference-steps": 6,
- "num-prompts": 3,
- "max-concurrency": 1,
- "enable-negative-prompt": true
- },
- {
- "name": "480x768_41f_steps20",
- "dataset": "random",
- "task": "t2v",
- "backend": "v1/videos",
- "width": 768,
- "height": 480,
- "num-frames": 41,
- "fps": 24,
- "num-inference-steps": 20,
- "num-prompts": 3,
- "max-concurrency": 1,
- "enable-negative-prompt": true
- }
- ]
- },
-
- {
- "test_name": "test_ltx2_cache_dit_eager",
- "description": "CacheDiT with enforce-eager",
- "server_type": "vllm-omni",
- "server_params": {
- "model": "Lightricks/LTX-2",
- "serve_args": {
- "cache-backend": "cache_dit",
- "enforce-eager": true,
- "enable-diffusion-pipeline-profiler": true
- }
- },
- "benchmark_params": [
- {
- "name": "256x256_145f_steps6",
- "dataset": "random",
- "task": "t2v",
- "backend": "v1/videos",
- "width": 256,
- "height": 256,
- "num-frames": 145,
- "fps": 24,
- "num-inference-steps": 6,
- "num-prompts": 3,
- "max-concurrency": 1,
- "enable-negative-prompt": true
- },
- {
- "name": "480x768_41f_steps20",
- "dataset": "random",
- "task": "t2v",
- "backend": "v1/videos",
- "width": 768,
- "height": 480,
- "num-frames": 41,
- "fps": 24,
- "num-inference-steps": 20,
- "num-prompts": 3,
- "max-concurrency": 1,
- "enable-negative-prompt": true
- }
- ]
- }
-]
diff --git a/tests/dfx/perf/tests/test_qwen_image_edit_2509_vllm_omni.json b/tests/dfx/perf/tests/test_qwen_image_edit_2509_vllm_omni.json
deleted file mode 100644
index 7d1fbbfa704..00000000000
--- a/tests/dfx/perf/tests/test_qwen_image_edit_2509_vllm_omni.json
+++ /dev/null
@@ -1,167 +0,0 @@
-[
- {
- "test_name": "test_qwen_image_edit_2509_single_device",
- "description": "Single-device baseline (two input images)",
- "server_type": "vllm-omni",
- "server_params": {
- "model": "Qwen/Qwen-Image-Edit-2509",
- "serve_args": {
- "enable-diffusion-pipeline-profiler": true
- }
- },
- "benchmark_params": [
- {
- "name": "512x512_steps20_i2i_2img",
- "dataset": "random",
- "task": "i2i",
- "width": 512,
- "height": 512,
- "num-inference-steps": 20,
- "num-prompts": 10,
- "max-concurrency": 1,
- "num-input-images": 2,
- "enable-negative-prompt": true,
- "baseline": {
- "throughput_qps": 0.05,
- "latency_mean": 18,
- "peak_memory_mb_max": 78500,
- "peak_memory_mb_mean": 78500
- }
- },
- {
- "name": "1536x1536_steps35_i2i_2img",
- "dataset": "random",
- "task": "i2i",
- "width": 1536,
- "height": 1536,
- "num-inference-steps": 35,
- "num-prompts": 10,
- "max-concurrency": 1,
- "num-input-images": 2,
- "enable-negative-prompt": true,
- "baseline": {
- "throughput_qps": 0.01,
- "latency_mean": 70,
- "peak_memory_mb_max": 81000,
- "peak_memory_mb_mean": 81000
- }
- }
- ]
- },
- {
- "test_name": "test_qwen_image_edit_2509_ulysses2_cfg2_vae_patch4",
- "description": "Ulysses SP=2 + CFG=2 + VAE patch parallel=4",
- "server_type": "vllm-omni",
- "server_params": {
- "model": "Qwen/Qwen-Image-Edit-2509",
- "serve_args": {
- "ulysses-degree": 2,
- "cfg-parallel-size": 2,
- "vae-patch-parallel-size": 4,
- "vae-use-tiling": true,
- "enable-diffusion-pipeline-profiler": true
- }
- },
- "benchmark_params": [
- {
- "name": "512x512_steps20_i2i_2img",
- "dataset": "random",
- "task": "i2i",
- "width": 512,
- "height": 512,
- "num-inference-steps": 20,
- "num-prompts": 10,
- "max-concurrency": 1,
- "num-input-images": 2,
- "enable-negative-prompt": true,
- "baseline": {
- "throughput_qps": 0.1,
- "latency_mean": 12,
- "peak_memory_mb_max": 69000,
- "peak_memory_mb_mean": 69000
- }
- },
- {
- "name": "1536x1536_steps35_i2i_2img",
- "dataset": "random",
- "task": "i2i",
- "width": 1536,
- "height": 1536,
- "num-inference-steps": 35,
- "num-prompts": 10,
- "max-concurrency": 1,
- "num-input-images": 2,
- "enable-negative-prompt": true,
- "baseline": {
- "throughput_qps": 0.03,
- "latency_mean": 28,
- "peak_memory_mb_max": 69000,
- "peak_memory_mb_mean": 69000
- }
- }
- ]
- },
- {
- "test_name": "test_qwen_image_edit_2509_ulysses2_cfg2_cache_dit",
- "description": "Ulysses SP=2 + CFG=2 + CacheDiT",
- "server_type": "vllm-omni",
- "server_params": {
- "model": "Qwen/Qwen-Image-Edit-2509",
- "serve_args": {
- "ulysses-degree": 2,
- "cfg-parallel-size": 2,
- "cache-backend": "cache_dit",
- "cache-config": {
- "Fn_compute_blocks": 1,
- "Bn_compute_blocks": 0,
- "max_warmup_steps": 4,
- "residual_diff_threshold": 0.24,
- "max_continuous_cached_steps": 3,
- "enable_taylorseer": false,
- "taylorseer_order": 1,
- "scm_steps_mask_policy": null,
- "scm_steps_policy": "dynamic"
- },
- "enable-diffusion-pipeline-profiler": true
- }
- },
- "benchmark_params": [
- {
- "name": "512x512_steps20_i2i_2img",
- "dataset": "random",
- "task": "i2i",
- "width": 512,
- "height": 512,
- "num-inference-steps": 20,
- "num-prompts": 10,
- "max-concurrency": 1,
- "num-input-images": 2,
- "enable-negative-prompt": true,
- "baseline": {
- "throughput_qps": 0.10,
- "latency_mean": 12,
- "peak_memory_mb_max": 73000,
- "peak_memory_mb_mean": 73000
- }
- },
- {
- "name": "1536x1536_steps35_i2i_2img",
- "dataset": "random",
- "task": "i2i",
- "width": 1536,
- "height": 1536,
- "num-inference-steps": 35,
- "num-prompts": 10,
- "max-concurrency": 1,
- "num-input-images": 2,
- "enable-negative-prompt": true,
- "baseline": {
- "throughput_qps": 0.05,
- "latency_mean": 20,
- "peak_memory_mb_max": 81000,
- "peak_memory_mb_mean": 81000
- }
- }
- ]
- }
-]
diff --git a/tests/dfx/perf/tests/test_qwen_image_edit_vllm_omni.json b/tests/dfx/perf/tests/test_qwen_image_edit_vllm_omni.json
deleted file mode 100644
index f68201db5f5..00000000000
--- a/tests/dfx/perf/tests/test_qwen_image_edit_vllm_omni.json
+++ /dev/null
@@ -1,161 +0,0 @@
-[
- {
- "test_name": "test_qwen_image_edit_single_device",
- "description": "Single-device baseline",
- "server_type": "vllm-omni",
- "server_params": {
- "model": "Qwen/Qwen-Image-Edit",
- "serve_args": {
- "enable-diffusion-pipeline-profiler": true
- }
- },
- "benchmark_params": [
- {
- "name": "512x512_steps20_i2i",
- "dataset": "random",
- "task": "i2i",
- "width": 512,
- "height": 512,
- "num-inference-steps": 20,
- "num-prompts": 10,
- "max-concurrency": 1,
- "enable-negative-prompt": true,
- "baseline": {
- "throughput_qps": 0.05,
- "latency_mean": 15.0,
- "peak_memory_mb_max": 72500,
- "peak_memory_mb_mean": 72500
- }
- },
- {
- "name": "1536x1536_steps35_i2i",
- "dataset": "random",
- "task": "i2i",
- "width": 1536,
- "height": 1536,
- "num-inference-steps": 35,
- "num-prompts": 10,
- "max-concurrency": 1,
- "enable-negative-prompt": true,
- "baseline": {
- "throughput_qps": 0.01,
- "latency_mean": 65.6,
- "peak_memory_mb_max": 80777,
- "peak_memory_mb_mean": 80777
- }
- }
- ]
- },
- {
- "test_name": "test_qwen_image_edit_ulysses2_cfg2_vae_patch4",
- "description": "Ulysses SP=2 + CFG=2 + VAE patch parallel=4",
- "server_type": "vllm-omni",
- "server_params": {
- "model": "Qwen/Qwen-Image-Edit",
- "serve_args": {
- "ulysses-degree": 2,
- "cfg-parallel-size": 2,
- "vae-patch-parallel-size": 4,
- "vae-use-tiling": true,
- "enable-diffusion-pipeline-profiler": true
- }
- },
- "benchmark_params": [
- {
- "name": "512x512_steps20_i2i",
- "dataset": "random",
- "task": "i2i",
- "width": 512,
- "height": 512,
- "num-inference-steps": 20,
- "num-prompts": 10,
- "max-concurrency": 1,
- "enable-negative-prompt": true,
- "baseline": {
- "throughput_qps": 0.10,
- "latency_mean": 7.2,
- "peak_memory_mb_max": 68100,
- "peak_memory_mb_mean": 68100
- }
- },
- {
- "name": "1536x1536_steps35_i2i",
- "dataset": "random",
- "task": "i2i",
- "width": 1536,
- "height": 1536,
- "num-inference-steps": 35,
- "num-prompts": 10,
- "max-concurrency": 1,
- "enable-negative-prompt": true,
- "baseline": {
- "throughput_qps": 0.03,
- "latency_mean": 24.0,
- "peak_memory_mb_max": 68100,
- "peak_memory_mb_mean": 68100
- }
- }
- ]
- },
- {
- "test_name": "test_qwen_image_edit_ulysses2_cfg2_cache_dit",
- "description": "Ulysses SP=2 + CFG=2 + CacheDiT",
- "server_type": "vllm-omni",
- "server_params": {
- "model": "Qwen/Qwen-Image-Edit",
- "serve_args": {
- "ulysses-degree": 2,
- "cfg-parallel-size": 2,
- "cache-backend": "cache_dit",
- "cache-config": {
- "Fn_compute_blocks": 1,
- "Bn_compute_blocks": 0,
- "max_warmup_steps": 4,
- "residual_diff_threshold": 0.24,
- "max_continuous_cached_steps": 3,
- "enable_taylorseer": false,
- "taylorseer_order": 1,
- "scm_steps_mask_policy": null,
- "scm_steps_policy": "dynamic"
- },
- "enable-diffusion-pipeline-profiler": true
- }
- },
- "benchmark_params": [
- {
- "name": "512x512_steps20_i2i",
- "dataset": "random",
- "task": "i2i",
- "width": 512,
- "height": 512,
- "num-inference-steps": 20,
- "num-prompts": 10,
- "max-concurrency": 1,
- "enable-negative-prompt": true,
- "baseline": {
- "throughput_qps": 0.1,
- "latency_mean": 6.5,
- "peak_memory_mb_max": 72600,
- "peak_memory_mb_mean": 72600
- }
- },
- {
- "name": "1536x1536_steps35_i2i",
- "dataset": "random",
- "task": "i2i",
- "width": 1536,
- "height": 1536,
- "num-inference-steps": 35,
- "num-prompts": 10,
- "max-concurrency": 1,
- "enable-negative-prompt": true,
- "baseline": {
- "throughput_qps": 0.05,
- "latency_mean": 16.0,
- "peak_memory_mb_max": 81000,
- "peak_memory_mb_mean": 81000
- }
- }
- ]
- }
-]
diff --git a/tests/dfx/perf/tests/test_qwen_image_layered_vllm_omni.json b/tests/dfx/perf/tests/test_qwen_image_layered_vllm_omni.json
deleted file mode 100644
index 3cf13509c8d..00000000000
--- a/tests/dfx/perf/tests/test_qwen_image_layered_vllm_omni.json
+++ /dev/null
@@ -1,49 +0,0 @@
-[
- {
- "test_name": "test_qwen_image_layered_single_device",
- "description": "Single-device baseline",
- "server_type": "vllm-omni",
- "server_params": {
- "model": "Qwen/Qwen-Image-Layered",
- "serve_args": {
- "enable-diffusion-pipeline-profiler": true
- }
- },
- "benchmark_params": [
- {
- "name": "640x640_steps20_i2i",
- "dataset": "random",
- "task": "i2i",
- "width": 640,
- "height": 640,
- "num-inference-steps": 20,
- "num-prompts": 10,
- "max-concurrency": 1,
- "enable-negative-prompt": true,
- "baseline": {
- "throughput_qps": 0.02,
- "latency_mean": 40.0,
- "peak_memory_mb_max": 70000,
- "peak_memory_mb_mean": 70000
- }
- },
- {
- "name": "1024x1024_steps35_i2i",
- "dataset": "random",
- "task": "i2i",
- "width": 1024,
- "height": 1024,
- "num-inference-steps": 35,
- "num-prompts": 10,
- "max-concurrency": 1,
- "enable-negative-prompt": true,
- "baseline": {
- "throughput_qps": 0.005,
- "latency_mean": 80.0,
- "peak_memory_mb_max": 70000,
- "peak_memory_mb_mean": 70000
- }
- }
- ]
- }
-]
diff --git a/tests/dfx/perf/tests/test_qwen_image_vllm_omni.json b/tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
index cdd0cac2c03..387e874ad5f 100644
--- a/tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
+++ b/tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
@@ -44,15 +44,19 @@
}
]
},
+
{
- "test_name": "test_qwen_image_single_device_step_execution",
- "description": "Single-device baseline (no parallelism) with step execution",
+ "test_name": "test_qwen_image_ulysses2_cfg2_vae_patch4",
+ "description": "Ulysses SP=2 + CFG-parallel=2 + VAE Patch Parallel=4",
"server_type": "vllm-omni",
"server_params": {
"model": "Qwen/Qwen-Image",
"serve_args": {
- "enable-diffusion-pipeline-profiler": true,
- "step-execution": true
+ "ulysses-degree": 2,
+ "cfg-parallel-size": 2,
+ "vae-patch-parallel-size": 4,
+ "vae-use-tiling": true,
+ "enable-diffusion-pipeline-profiler": true
}
},
"benchmark_params": [
@@ -67,44 +71,11 @@
"max-concurrency": 1,
"enable-negative-prompt": true,
"baseline": {
- "throughput_qps": 0.30,
- "latency_mean": 3.50,
- "peak_memory_mb_mean": 67000
+ "throughput_qps": 0.1,
+ "latency_mean": 2.34,
+ "peak_memory_mb_mean": 61000
}
},
- {
- "name": "1536x1536_steps35",
- "dataset": "random",
- "task": "t2i",
- "width": 1536,
- "height": 1536,
- "num-inference-steps": 35,
- "num-prompts": 10,
- "max-concurrency": 1,
- "enable-negative-prompt": true,
- "baseline": {
- "throughput_qps": 0.037,
- "latency_mean": 27.0,
- "peak_memory_mb_mean": 74000
- }
- }
- ]
- },
- {
- "test_name": "test_qwen_image_ulysses2_cfg2_vae_patch4",
- "description": "Ulysses SP=2 + CFG-parallel=2 + VAE Patch Parallel=4",
- "server_type": "vllm-omni",
- "server_params": {
- "model": "Qwen/Qwen-Image",
- "serve_args": {
- "ulysses-degree": 2,
- "cfg-parallel-size": 2,
- "vae-patch-parallel-size": 4,
- "vae-use-tiling": true,
- "enable-diffusion-pipeline-profiler": true
- }
- },
- "benchmark_params": [
{
"name": "1536x1536_steps35",
"dataset": "random",
@@ -123,6 +94,7 @@
}
]
},
+
{
"test_name": "test_qwen_image_ulysses2_cfg2_cache_dit",
"description": "Ulysses SP=2 + CFG-parallel=2 + CacheDiT acceleration",
diff --git a/tests/dfx/perf/tests/test_qwen_omni.json b/tests/dfx/perf/tests/test_qwen_omni.json
deleted file mode 100644
index eda9720c417..00000000000
--- a/tests/dfx/perf/tests/test_qwen_omni.json
+++ /dev/null
@@ -1,315 +0,0 @@
-[
- {
- "test_name": "test_qwen3_omni",
- "server_params": {
- "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
- "extra_cli_args": ["--no-async-chunk"]
- },
- "benchmark_params": [
- {
- "dataset_name": "random",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [4, 16, 32, 64, 128],
- "max_concurrency": [1, 4, 8, 16, 32],
- "random_input_len": 2500,
- "random_output_len": 900,
- "ignore_eos": true,
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_ttft_ms": [1000, 3000, 5000, 7000, 9000],
- "mean_audio_ttfp_ms": [30000, 60000, 90000, 120000, 150000],
- "mean_audio_rtf": [0.35, 0.45, 0.55, 0.65, 0.75]
- }
- },
- {
- "dataset_name": "random-mm",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [10],
- "request_rate": [0.1],
- "random_input_len": 100,
- "random_output_len": 100,
- "random_range_ratio": 0.0,
- "ignore_eos": true,
- "random_mm_base_items_per_request": 1,
- "random_mm_num_mm_items_range_ratio": 0.5,
- "random_mm_limit_mm_per_prompt": {
- "audio": 1
- },
- "random_mm_bucket_config": {
- "(0, 60, 3)": 1.0
- },
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_ttft_ms": [2000],
- "mean_audio_ttfp_ms": [10000],
- "mean_audio_rtf": [0.25]
- }
- },
- {
- "dataset_name": "random-mm",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [40],
- "request_rate": [0.5],
- "random_input_len": 100,
- "random_output_len": 100,
- "random_range_ratio": 0.0,
- "ignore_eos": true,
- "random_mm_base_items_per_request": 2,
- "random_mm_num_mm_items_range_ratio": 0.5,
- "random_mm_limit_mm_per_prompt": {
- "image": 1,
- "video": 1
- },
- "random_mm_bucket_config": {
- "(256, 256, 1)": 0.5,
- "(720, 1280, 2)": 0.5
- },
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_ttft_ms": [6000],
- "mean_audio_ttfp_ms": [15000],
- "mean_audio_rtf": [0.45]
- }
- },
- {
- "dataset_name": "random-mm",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [100],
- "request_rate": [1.0],
- "random_input_len": 100,
- "random_output_len": 100,
- "random_range_ratio": 0.0,
- "ignore_eos": true,
- "random_mm_base_items_per_request": 3,
- "random_mm_num_mm_items_range_ratio": 0.5,
- "random_mm_limit_mm_per_prompt": {
- "image": 1,
- "video": 1,
- "audio": 1
- },
- "random_mm_bucket_config": {
- "(256, 256, 1)": 0.34,
- "(720, 1280, 2)": 0.33,
- "(0, 60, 3)": 0.33
- },
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_ttft_ms": [12000],
- "mean_audio_ttfp_ms": [18000],
- "mean_audio_rtf": [0.9]
- }
- }
- ]
- },
- {
- "test_name": "test_qwen3_omni_chunk",
- "server_params": {
- "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
- "extra_cli_args": ["--async-chunk"]
- },
- "benchmark_params": [
- {
- "dataset_name": "random",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [4, 16, 32, 64],
- "max_concurrency": [1, 4, 8, 16],
- "random_input_len": 2500,
- "random_output_len": 900,
- "ignore_eos": true,
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_ttft_ms": [1000, 3000, 5000, 7000],
- "mean_audio_ttfp_ms": [1000, 3000, 5000, 7000],
- "mean_audio_rtf": [0.2, 0.35, 0.6, 0.85]
- }
- },
- {
- "dataset_name": "random-mm",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [10],
- "request_rate": [0.1],
- "random_input_len": 100,
- "random_output_len": 100,
- "random_range_ratio": 0.0,
- "ignore_eos": true,
- "random_mm_base_items_per_request": 1,
- "random_mm_num_mm_items_range_ratio": 0.5,
- "random_mm_limit_mm_per_prompt": {
- "audio": 1
- },
- "random_mm_bucket_config": {
- "(0, 60, 3)": 1.0
- },
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_ttft_ms": [2000],
- "mean_audio_ttfp_ms": [2000],
- "mean_audio_rtf": [0.25]
- }
- },
- {
- "dataset_name": "random-mm",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [40],
- "request_rate": [0.5],
- "random_input_len": 100,
- "random_output_len": 100,
- "random_range_ratio": 0.0,
- "ignore_eos": true,
- "random_mm_base_items_per_request": 2,
- "random_mm_num_mm_items_range_ratio": 0.5,
- "random_mm_limit_mm_per_prompt": {
- "image": 1,
- "video": 1
- },
- "random_mm_bucket_config": {
- "(256, 256, 1)": 0.5,
- "(720, 1280, 2)": 0.5
- },
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_ttft_ms": [6000],
- "mean_audio_ttfp_ms": [6000],
- "mean_audio_rtf": [0.7]
- }
- },
- {
- "dataset_name": "random-mm",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [100],
- "request_rate": [1.0],
- "random_input_len": 100,
- "random_output_len": 100,
- "random_range_ratio": 0.0,
- "ignore_eos": true,
- "random_mm_base_items_per_request": 3,
- "random_mm_num_mm_items_range_ratio": 0.5,
- "random_mm_limit_mm_per_prompt": {
- "image": 1,
- "video": 1,
- "audio": 1
- },
- "random_mm_bucket_config": {
- "(256, 256, 1)": 0.34,
- "(720, 1280, 2)": 0.33,
- "(0, 60, 3)": 0.33
- },
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_ttft_ms": [12000],
- "mean_audio_ttfp_ms": [12000],
- "mean_audio_rtf": [1.0]
- }
- },
- {
- "dataset_name": "random",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [4, 16, 32, 64, 128],
- "max_concurrency": [1, 4, 8, 16, 32],
- "random_input_len": 2500,
- "random_output_len": 900,
- "ignore_eos": true,
- "extra_body": {
- "modalities": ["text"]
- },
- "percentile-metrics": "ttft,tpot,itl,e2el",
- "baseline": {
- "mean_ttft_ms": [1000, 3000, 5000, 7000, 9000]
- }
- },
- {
- "dataset_name": "random-mm",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [10],
- "request_rate": [0.1],
- "random_input_len": 100,
- "random_output_len": 100,
- "random_range_ratio": 0.0,
- "ignore_eos": true,
- "extra_body": {
- "modalities": ["text"]
- },
- "random_mm_base_items_per_request": 1,
- "random_mm_num_mm_items_range_ratio": 0.5,
- "random_mm_limit_mm_per_prompt": {
- "audio": 1
- },
- "random_mm_bucket_config": {
- "(0, 60, 3)": 1.0
- },
- "percentile-metrics": "ttft,tpot,itl,e2el",
- "baseline": {
- "mean_ttft_ms": [2000]
- }
- },
- {
- "dataset_name": "random-mm",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [40],
- "request_rate": [0.5],
- "random_input_len": 100,
- "random_output_len": 100,
- "random_range_ratio": 0.0,
- "ignore_eos": true,
- "extra_body": {
- "modalities": ["text"]
- },
- "random_mm_base_items_per_request": 2,
- "random_mm_num_mm_items_range_ratio": 0.5,
- "random_mm_limit_mm_per_prompt": {
- "image": 1,
- "video": 1
- },
- "random_mm_bucket_config": {
- "(256, 256, 1)": 0.5,
- "(720, 1280, 2)": 0.5
- },
- "percentile-metrics": "ttft,tpot,itl,e2el",
- "baseline": {
- "mean_ttft_ms": [6000]
- }
- },
- {
- "dataset_name": "random-mm",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "num_prompts": [100],
- "request_rate": [1.0],
- "random_input_len": 100,
- "random_output_len": 100,
- "random_range_ratio": 0.0,
- "ignore_eos": true,
- "extra_body": {
- "modalities": ["text"]
- },
- "random_mm_base_items_per_request": 3,
- "random_mm_num_mm_items_range_ratio": 0.5,
- "random_mm_limit_mm_per_prompt": {
- "image": 1,
- "video": 1,
- "audio": 1
- },
- "random_mm_bucket_config": {
- "(256, 256, 1)": 0.34,
- "(720, 1280, 2)": 0.33,
- "(0, 60, 3)": 0.33
- },
- "percentile-metrics": "ttft,tpot,itl,e2el",
- "baseline": {
- "mean_ttft_ms": [6000]
- }
- }
- ]
- }
-]
diff --git a/tests/dfx/perf/tests/test_runner_metadata.py b/tests/dfx/perf/tests/test_runner_metadata.py
deleted file mode 100644
index 1276a847069..00000000000
--- a/tests/dfx/perf/tests/test_runner_metadata.py
+++ /dev/null
@@ -1,79 +0,0 @@
-"""Tests for DFX runner metadata field exclusion."""
-
-import json
-
-import pytest
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def test_task_excluded_from_cli_args():
- """'task' field must not become --task CLI arg."""
- params = {
- "task": "voice_clone",
- "dataset_name": "seed-tts",
- "backend": "openai-audio-speech",
- "endpoint": "/v1/audio/speech",
- "percentile-metrics": "audio_rtf,audio_ttfp",
- "baseline": {"mean_audio_rtf": [0.5]},
- }
- exclude_keys = {"request_rate", "baseline", "num_prompts", "max_concurrency", "task", "enabled", "eval_phase"}
- args = []
- for key, value in params.items():
- if key in exclude_keys or value is None:
- continue
- arg_name = f"--{key.replace('_', '-')}"
- if isinstance(value, bool) and value:
- args.append(arg_name)
- elif isinstance(value, dict):
- args.extend([arg_name, json.dumps(value)])
- elif not isinstance(value, bool):
- args.extend([arg_name, str(value)])
- assert "--task" not in args
- assert "--enabled" not in args
- assert "--dataset-name" in args
-
-
-def test_enabled_false_entry_is_skipped():
- """benchmark_params entry with enabled=false should be skipped."""
- import sys
- from pathlib import Path
-
- sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
- from tests.dfx.conftest import create_test_parameter_mapping
-
- configs = [
- {
- "test_name": "test_model",
- "server_params": {"model": "some/model"},
- "benchmark_params": [
- {
- "task": "voice_clone",
- "enabled": True,
- "dataset_name": "seed-tts",
- "backend": "openai-audio-speech",
- "endpoint": "/v1/audio/speech",
- "num_prompts": [10],
- "max_concurrency": [1],
- "percentile-metrics": "audio_rtf",
- "baseline": {},
- },
- {
- "task": "voice_design",
- "enabled": False,
- "dataset_name": "seed-tts-design",
- "backend": "openai-audio-speech",
- "endpoint": "/v1/audio/speech",
- "num_prompts": [5],
- "max_concurrency": [1],
- "percentile-metrics": "audio_rtf",
- "baseline": {},
- },
- ],
- }
- ]
- mapping = create_test_parameter_mapping(configs)
- params = mapping["test_model"]["benchmark_params"]
- # Only the enabled=True entry should appear
- assert len(params) == 1
- assert params[0].get("task") == "voice_clone"
diff --git a/tests/dfx/perf/tests/test_tts.json b/tests/dfx/perf/tests/test_tts.json
deleted file mode 100644
index 06c9c4d2384..00000000000
--- a/tests/dfx/perf/tests/test_tts.json
+++ /dev/null
@@ -1,155 +0,0 @@
-[
- {
- "test_name": "test_qwen3_tts_base",
- "server_params": {
- "model": "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
- },
- "benchmark_params": [
- {
- "task": "voice_clone",
- "eval_phase": "latency",
- "enabled": false,
- "dataset_name": "seed-tts",
- "backend": "openai-audio-speech",
- "endpoint": "/v1/audio/speech",
- "num_prompts": [20],
- "max_concurrency": [1],
- "seed_tts_locale": "en",
- "percentile-metrics": "ttft,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "median_audio_ttfp_ms": [350],
- "median_audio_rtf": [0.25]
- }
- },
- {
- "task": "voice_clone",
- "eval_phase": "throughput",
- "enabled": false,
- "dataset_name": "seed-tts",
- "backend": "openai-audio-speech",
- "endpoint": "/v1/audio/speech",
- "num_prompts": [80],
- "max_concurrency": [8],
- "seed_tts_locale": "en",
- "percentile-metrics": "ttft,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "median_audio_ttfp_ms": [3500],
- "median_audio_rtf": [0.75],
- "audio_throughput": [10.0]
- }
- },
- {
- "task": "voice_clone",
- "eval_phase": "quality",
- "enabled": false,
- "dataset_name": "seed-tts",
- "backend": "openai-audio-speech",
- "endpoint": "/v1/audio/speech",
- "num_prompts": [200],
- "max_concurrency": [4],
- "seed_tts_locale": "en",
- "seed_tts_wer_eval": true,
- "percentile-metrics": "ttft,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_audio_rtf": [0.45]
- }
- }
- ]
- },
- {
- "test_name": "test_qwen3_tts_customvoice",
- "server_params": {
- "model": "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
- },
- "benchmark_params": [
- {
- "task": "default_voice",
- "eval_phase": "latency",
- "dataset_name": "seed-tts-text",
- "backend": "openai-audio-speech",
- "endpoint": "/v1/audio/speech",
- "dataset_path": "benchmarks/build_dataset/seed_tts_smoke",
- "num_prompts": [20],
- "max_concurrency": [1],
- "seed_tts_locale": "en",
- "extra_body": {"voice": "Vivian", "language": "English", "task_type": "CustomVoice"},
- "percentile-metrics": "ttft,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "median_audio_ttfp_ms": [150],
- "median_audio_rtf": [0.15]
- }
- },
- {
- "task": "default_voice",
- "eval_phase": "throughput",
- "dataset_name": "seed-tts-text",
- "backend": "openai-audio-speech",
- "endpoint": "/v1/audio/speech",
- "dataset_path": "benchmarks/build_dataset/seed_tts_smoke",
- "num_prompts": [80],
- "max_concurrency": [8],
- "seed_tts_locale": "en",
- "extra_body": {"voice": "Vivian", "language": "English", "task_type": "CustomVoice"},
- "percentile-metrics": "ttft,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "median_audio_ttfp_ms": [1500],
- "median_audio_rtf": [0.30],
- "audio_throughput": [30.0]
- }
- },
- {
- "task": "default_voice",
- "eval_phase": "quality",
- "enabled": false,
- "dataset_name": "seed-tts-text",
- "backend": "openai-audio-speech",
- "endpoint": "/v1/audio/speech",
- "dataset_path": "benchmarks/build_dataset/seed_tts_smoke",
- "num_prompts": [200],
- "max_concurrency": [4],
- "seed_tts_locale": "en",
- "extra_body": {"voice": "Vivian", "language": "English", "task_type": "CustomVoice"},
- "seed_tts_wer_eval": true,
- "percentile-metrics": "ttft,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "mean_audio_rtf": [0.35]
- }
- },
- {
- "task": "voice_design",
- "eval_phase": "latency",
- "dataset_name": "seed-tts-design",
- "backend": "openai-audio-speech",
- "endpoint": "/v1/audio/speech",
- "dataset_path": "benchmarks/build_dataset/seed_tts_design",
- "num_prompts": [20],
- "max_concurrency": [1],
- "seed_tts_locale": "en",
- "extra_body": {"task_type": "VoiceDesign", "language": "English"},
- "percentile-metrics": "ttft,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "median_audio_ttfp_ms": [150],
- "median_audio_rtf": [0.15]
- }
- },
- {
- "task": "voice_design",
- "eval_phase": "throughput",
- "dataset_name": "seed-tts-design",
- "backend": "openai-audio-speech",
- "endpoint": "/v1/audio/speech",
- "dataset_path": "benchmarks/build_dataset/seed_tts_design",
- "num_prompts": [80],
- "max_concurrency": [8],
- "seed_tts_locale": "en",
- "extra_body": {"task_type": "VoiceDesign", "language": "English"},
- "percentile-metrics": "ttft,e2el,audio_rtf,audio_ttfp,audio_duration",
- "baseline": {
- "median_audio_ttfp_ms": [1500],
- "median_audio_rtf": [0.35],
- "audio_throughput": [25.0]
- }
- }
- ]
- }
-]
diff --git a/tests/dfx/perf/tests/test_wan22_i2v_vllm_omni.json b/tests/dfx/perf/tests/test_wan22_i2v_vllm_omni.json
deleted file mode 100644
index 58a17c980bd..00000000000
--- a/tests/dfx/perf/tests/test_wan22_i2v_vllm_omni.json
+++ /dev/null
@@ -1,107 +0,0 @@
-[
- {
- "test_name": "test_wan22_i2v_single_device",
- "description": "Single-device baseline",
- "server_type": "vllm-omni",
- "server_params": {
- "model": "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
- "serve_args": {
- "enable-diffusion-pipeline-profiler": true
- }
- },
- "benchmark_params": [
- {
- "name": "832x480_frames81_steps4",
- "dataset": "random",
- "task": "i2v",
- "num-prompts": 10,
- "max-concurrency": 1,
- "num-input-images": 1,
- "seed": 42,
- "enable-negative-prompt": true,
- "random-request-config": [
- {
- "width": 832,
- "height": 480,
- "num_inference_steps": 4,
- "num_frames": 81,
- "fps": 16,
- "weight": 1
- }
- ],
- "baseline": {
- "throughput_qps": 0.034,
- "latency_mean": 26.0,
- "peak_memory_mb_mean": 80000
- }
- }
- ]
- },
- {
- "test_name": "test_wan22_i2v_usp2_vae_patch2_hsdp_slicing",
- "description": "USP=2 + VAE patch parallel=2 + HSDP + VAE slicing",
- "server_type": "vllm-omni",
- "server_params": {
- "model": "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
- "serve_args": {
- "usp": 2,
- "vae-patch-parallel-size": 2,
- "use-hsdp": true,
- "vae-use-slicing": true,
- "enable-diffusion-pipeline-profiler": true
- }
- },
- "benchmark_params": [
- {
- "name": "832x480_frames81_steps4",
- "dataset": "random",
- "task": "i2v",
- "num-prompts": 10,
- "max-concurrency": 1,
- "num-input-images": 1,
- "seed": 42,
- "enable-negative-prompt": true,
- "random-request-config": [
- {
- "width": 832,
- "height": 480,
- "num_inference_steps": 4,
- "num_frames": 81,
- "fps": 16,
- "weight": 1
- }
- ],
- "baseline": {
- "throughput_qps": 0.042,
- "latency_mean": 21.6,
- "peak_memory_mb_mean": 55300
- }
- },
- {
- "name": "1280x720_frames121_steps4",
- "dataset": "random",
- "task": "i2v",
- "num-prompts": 10,
- "max-concurrency": 1,
- "num-input-images": 1,
- "seed": 42,
- "enable-negative-prompt": true,
- "random-request-config": [
- {
- "width": 1280,
- "height": 720,
- "num_inference_steps": 4,
- "num_frames": 121,
- "fps": 16,
- "weight": 1
- }
- ],
- "baseline": {
- "throughput_qps": 0.0085,
- "latency_mean": 101.6,
- "peak_memory_mb_mean": 65200
- }
- }
- ]
- }
-]
diff --git a/tests/dfx/stability/conftest.py b/tests/dfx/stability/conftest.py
index 30718d4bf5a..3a0aee7608f 100644
--- a/tests/dfx/stability/conftest.py
+++ b/tests/dfx/stability/conftest.py
@@ -3,79 +3,123 @@
resource monitoring is started before each test and finalized after each test,
so each stability test case gets its own HTML report (one report per case).
No need to wrap pytest with `bash resource_monitor.sh run -- pytest ...`.
-
-Duration-based benchmark helper functions are hosted in ``helpers.py``,
-while this file focuses on pytest fixtures and setup/teardown.
"""
-from __future__ import annotations
-
+import os
import subprocess
import sys
import threading
+import time
+from pathlib import Path
import pytest
-from tests.dfx.conftest import get_benchmark_params_for_server
-from tests.dfx.stability.helpers import (
- finalize_resource_monitor,
- report_latest_gpu_samples,
- start_resource_monitor,
- wait_for_run_dir,
-)
-from tests.helpers.runtime import OmniServer
-
-DEFAULT_STABILITY_SERVER_TIMEOUT_ARGS = ["--stage-init-timeout", "600", "--init-timeout", "900"]
-
-_omni_server_lock = threading.Lock()
-
-
-@pytest.fixture(scope="module")
-def omni_server(request: pytest.FixtureRequest):
- """Start OmniServer for stability tests, with per-module timeout override."""
- timeout_args = getattr(request.module, "STABILITY_SERVER_TIMEOUT_ARGS", DEFAULT_STABILITY_SERVER_TIMEOUT_ARGS)
- with _omni_server_lock:
- # Same 5-tuple and CLI composition as ``tests/dfx/perf/scripts/run_benchmark.py`` on main;
- # ``serve_args`` from JSON are folded into ``extra_cli_args`` inside
- # ``create_unique_server_params``.
- test_name, model, deploy_path, stage_overrides, extra_cli_args = request.param
-
- print(f"Starting OmniServer with test: {test_name}, model: {model}")
- server_args = list(timeout_args)
- if deploy_path:
- server_args = ["--deploy-config", deploy_path] + server_args
- if stage_overrides:
- server_args = ["--stage-overrides", stage_overrides] + server_args
- if extra_cli_args:
- server_args = list(extra_cli_args) + server_args
- with OmniServer(model, server_args) as server:
- server.test_name = test_name
- print("OmniServer started successfully")
- yield server
- print("OmniServer stopping...")
- print("OmniServer stopped")
-
-
-@pytest.fixture
-def stability_benchmark_params(request: pytest.FixtureRequest, omni_server):
- test_name, param_index = request.param
- if test_name != omni_server.test_name:
- pytest.skip(f"Skipping parameter for {test_name} - current server is {omni_server.test_name}")
-
- server_to_benchmark_mapping = getattr(request.module, "server_to_benchmark_mapping", None)
- if server_to_benchmark_mapping is None:
- raise ValueError("server_to_benchmark_mapping must be defined in the test module")
-
- all_params = get_benchmark_params_for_server(test_name, server_to_benchmark_mapping)
- if not all_params:
- raise ValueError(f"No benchmark parameters found for test: {test_name}")
- if param_index >= len(all_params):
- raise ValueError(f"No benchmark parameters found for index {param_index} in test: {test_name}")
-
- current = param_index + 1
- total = len(all_params)
- print(f"\n Running benchmark {current}/{total} for {test_name}")
- return {"test_name": test_name, "params": all_params[param_index]}
+STABILITY_DIR = Path(__file__).resolve().parent
+RESOURCE_MONITOR_SCRIPT = STABILITY_DIR / "scripts" / "resource_monitor.sh"
+REPO_ROOT = STABILITY_DIR.parent.parent.parent
+
+
+def _start_resource_monitor():
+ """Start `resource_monitor.sh start` in the background and return `Popen` or `None`."""
+ if not RESOURCE_MONITOR_SCRIPT.is_file():
+ return None
+ try:
+ proc = subprocess.Popen(
+ ["bash", str(RESOURCE_MONITOR_SCRIPT), "start", "--backend", "gpu"],
+ cwd=str(REPO_ROOT),
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.PIPE,
+ start_new_session=True,
+ )
+ try:
+ proc.wait(timeout=2)
+ if proc.returncode != 0:
+ stderr = proc.stderr.read().decode("utf-8", errors="ignore") if proc.stderr else ""
+ if stderr.strip():
+ sys.stderr.write(f"[Stability] Resource monitor failed to start: {stderr.strip()}\n")
+ return None
+ except subprocess.TimeoutExpired:
+ pass
+ return proc
+ except (FileNotFoundError, OSError):
+ return None
+
+
+def _get_monitor_data_root() -> Path:
+ data_root = os.environ.get("RESOURCE_MONITOR_DATA_ROOT") or os.environ.get("GPU_MONITOR_DATA_ROOT")
+ if data_root:
+ return Path(data_root)
+ return STABILITY_DIR / "gpu_monitor_data"
+
+
+def _wait_for_run_dir(timeout_sec: int = 10) -> Path | None:
+ data_root = _get_monitor_data_root()
+ run_id_file = data_root / "current_run_id"
+ deadline = time.time() + timeout_sec
+ while time.time() < deadline:
+ if run_id_file.is_file():
+ run_id = run_id_file.read_text(encoding="utf-8").strip()
+ if run_id:
+ run_dir = data_root / run_id
+ if run_dir.is_dir():
+ return run_dir
+ time.sleep(0.5)
+ return None
+
+
+def _report_latest_gpu_samples(stop_event: threading.Event) -> None:
+ """Periodically print the latest sampled GPU line."""
+ log_interval = int(
+ os.environ.get("RESOURCE_MONITOR_LOG_INTERVAL") or os.environ.get("GPU_MONITOR_LOG_INTERVAL") or "15"
+ )
+ log_interval = max(log_interval, 1)
+ last_line = ""
+
+ time.sleep(min(log_interval, 5))
+ while not stop_event.wait(log_interval):
+ run_dir = _wait_for_run_dir(timeout_sec=1)
+ if run_dir is None:
+ continue
+ csv_file = run_dir / "gpu_metrics.csv"
+ if not csv_file.is_file():
+ continue
+ try:
+ lines = csv_file.read_text(encoding="utf-8").splitlines()
+ except OSError:
+ continue
+ if len(lines) <= 1:
+ continue
+ latest = lines[-1].strip()
+ if latest and latest != last_line:
+ last_line = latest
+ sys.stderr.write(f"[GPU] {latest}\n")
+
+
+def _finalize_resource_monitor() -> str | None:
+ """
+ Run `resource_monitor.sh finalize` for the current run and generate the report.
+ Returns the bundle dir path (for this test case's report) if successful, else None.
+ """
+ if not RESOURCE_MONITOR_SCRIPT.is_file():
+ return None
+ try:
+ result = subprocess.run(
+ ["bash", str(RESOURCE_MONITOR_SCRIPT), "finalize", "--backend", "gpu"],
+ cwd=str(REPO_ROOT),
+ capture_output=True,
+ text=True,
+ timeout=60,
+ check=False,
+ )
+ if result.returncode != 0:
+ return None
+ for line in (result.stdout or "").splitlines():
+ if line.startswith("GPU_MONITOR_BUNDLE_DIR=") or line.startswith("RESOURCE_MONITOR_BUNDLE_DIR="):
+ _, _, value = line.partition("=")
+ return value.strip() if value else None
+ return None
+ except (FileNotFoundError, OSError, subprocess.TimeoutExpired):
+ return None
@pytest.fixture(autouse=True)
@@ -84,19 +128,19 @@ def stability_resource_monitor_per_test(request: pytest.FixtureRequest):
For each test under this directory: start GPU monitor before the test,
then finalize after the test so this case gets its own report.html.
"""
- proc = start_resource_monitor()
+ proc = _start_resource_monitor()
stop_event = threading.Event()
reporter: threading.Thread | None = None
if proc is not None:
reporter = threading.Thread(
- target=report_latest_gpu_samples,
+ target=_report_latest_gpu_samples,
args=(stop_event,),
name="stability-resource-monitor-reporter",
daemon=True,
)
reporter.start()
- run_dir = wait_for_run_dir(timeout_sec=5)
+ run_dir = _wait_for_run_dir(timeout_sec=5)
node_name = request.node.name
if run_dir is not None:
sys.stderr.write(f"[Stability] Resource monitor started for test: {node_name} | run dir: {run_dir}\n")
@@ -117,7 +161,7 @@ def stability_resource_monitor_per_test(request: pytest.FixtureRequest):
except subprocess.TimeoutExpired:
proc.kill()
proc.wait()
- bundle_dir = finalize_resource_monitor()
+ bundle_dir = _finalize_resource_monitor()
node_name = request.node.name
if bundle_dir:
sys.stderr.write(f"[Stability] Report for test «{node_name}»: {bundle_dir}/report.html\n")
diff --git a/tests/dfx/stability/helpers.py b/tests/dfx/stability/helpers.py
deleted file mode 100644
index 3a873f69ca4..00000000000
--- a/tests/dfx/stability/helpers.py
+++ /dev/null
@@ -1,504 +0,0 @@
-"""Stability helpers for resource monitoring and benchmark execution."""
-
-from __future__ import annotations
-
-import json
-import os
-import random
-import re
-import shlex
-import subprocess
-import sys
-import tempfile
-import threading
-import time
-from collections.abc import Callable
-from pathlib import Path
-from typing import Any
-
-from tests.dfx.conftest import run_benchmark
-
-STABILITY_DIR = Path(__file__).resolve().parent
-RESOURCE_MONITOR_SCRIPT = STABILITY_DIR / "scripts" / "resource_monitor.sh"
-REPO_ROOT = STABILITY_DIR.parent.parent.parent
-_BUCKET_KEY_PATTERN = re.compile(r"^\(\s*([^,]+)\s*,\s*([^,]+)\s*,\s*([^,]+)\s*\)$")
-
-RunOneBatchFn = Callable[
- [str, int, str, dict[str, Any], int, float | None, int | None, str, int],
- dict[str, Any],
-]
-
-
-def start_resource_monitor():
- """Start `resource_monitor.sh start` in the background and return `Popen` or `None`."""
- if not RESOURCE_MONITOR_SCRIPT.is_file():
- return None
- try:
- proc = subprocess.Popen(
- ["bash", str(RESOURCE_MONITOR_SCRIPT), "start", "--backend", "gpu"],
- cwd=str(REPO_ROOT),
- stdout=subprocess.DEVNULL,
- stderr=subprocess.PIPE,
- start_new_session=True,
- )
- try:
- proc.wait(timeout=2)
- if proc.returncode != 0:
- stderr = proc.stderr.read().decode("utf-8", errors="ignore") if proc.stderr else ""
- if stderr.strip():
- sys.stderr.write(f"[Stability] Resource monitor failed to start: {stderr.strip()}\n")
- return None
- except subprocess.TimeoutExpired:
- pass
- return proc
- except (FileNotFoundError, OSError):
- return None
-
-
-def get_monitor_data_root() -> Path:
- data_root = os.environ.get("RESOURCE_MONITOR_DATA_ROOT") or os.environ.get("GPU_MONITOR_DATA_ROOT")
- if data_root:
- return Path(data_root)
- return STABILITY_DIR / "gpu_monitor_data"
-
-
-def wait_for_run_dir(timeout_sec: int = 10) -> Path | None:
- data_root = get_monitor_data_root()
- run_id_file = data_root / "current_run_id"
- deadline = time.time() + timeout_sec
- while time.time() < deadline:
- if run_id_file.is_file():
- run_id = run_id_file.read_text(encoding="utf-8").strip()
- if run_id:
- run_dir = data_root / run_id
- if run_dir.is_dir():
- return run_dir
- time.sleep(0.5)
- return None
-
-
-def report_latest_gpu_samples(stop_event: threading.Event) -> None:
- """Periodically print the latest sampled GPU line."""
- log_interval = int(
- os.environ.get("RESOURCE_MONITOR_LOG_INTERVAL") or os.environ.get("GPU_MONITOR_LOG_INTERVAL") or "15"
- )
- log_interval = max(log_interval, 1)
- last_line = ""
-
- time.sleep(min(log_interval, 5))
- while not stop_event.wait(log_interval):
- run_dir = wait_for_run_dir(timeout_sec=1)
- if run_dir is None:
- continue
- csv_file = run_dir / "gpu_metrics.csv"
- if not csv_file.is_file():
- continue
- try:
- lines = csv_file.read_text(encoding="utf-8").splitlines()
- except OSError:
- continue
- if len(lines) <= 1:
- continue
- latest = lines[-1].strip()
- if latest and latest != last_line:
- last_line = latest
- sys.stderr.write(f"[GPU] {latest}\n")
-
-
-def finalize_resource_monitor() -> str | None:
- """
- Run `resource_monitor.sh finalize` for the current run and generate the report.
- Returns the bundle dir path (for this test case's report) if successful, else None.
- """
- if not RESOURCE_MONITOR_SCRIPT.is_file():
- return None
- try:
- result = subprocess.run(
- ["bash", str(RESOURCE_MONITOR_SCRIPT), "finalize", "--backend", "gpu"],
- cwd=str(REPO_ROOT),
- capture_output=True,
- text=True,
- timeout=60,
- check=False,
- )
- if result.returncode != 0:
- return None
- for line in (result.stdout or "").splitlines():
- if line.startswith("GPU_MONITOR_BUNDLE_DIR=") or line.startswith("RESOURCE_MONITOR_BUNDLE_DIR="):
- _, _, value = line.partition("=")
- return value.strip() if value else None
- return None
- except (FileNotFoundError, OSError, subprocess.TimeoutExpired):
- return None
-
-
-def _normalize_bench_metrics(raw: dict[str, Any]) -> dict[str, Any]:
- completed = int(raw.get("completed", raw.get("completed_requests", 0) or 0))
- failed = int(raw.get("failed", raw.get("failed_requests", 0) or 0))
- duration = float(raw.get("duration", 0.0) or 0.0)
- errors = list(raw.get("errors") or [])
- if failed and not errors:
- errors = [f"{failed} benchmark request(s) failed"]
- return {"completed": completed, "failed": failed, "duration": duration, "errors": errors}
-
-
-def _build_base_args(params: dict[str, Any], host: str, port: int) -> list[str]:
- exclude = {
- "request_rate",
- "max_concurrency",
- "num_prompts",
- "baseline",
- "duration_sec",
- "num_prompts_per_batch",
- }
- args = ["--host", host, "--port", str(port)]
- for key, value in params.items():
- if key in exclude or value is None:
- continue
- arg_name = f"--{key.replace('_', '-')}"
- if isinstance(value, bool) and value:
- args.append(arg_name)
- elif isinstance(value, dict):
- args.extend([arg_name, json.dumps(value, ensure_ascii=False, separators=(",", ":"))])
- elif not isinstance(value, bool):
- args.extend([arg_name, str(value)])
- return args
-
-
-def _build_diffusion_cmd(
- host: str,
- port: int,
- model: str,
- params: dict[str, Any],
- num_prompts: int,
- request_rate: float | None,
- max_concurrency: int | None,
- output_path: Path,
- diffusion_benchmark_script: Path,
-) -> list[str]:
- skip_keys = {
- "request_rate",
- "max_concurrency",
- "num_prompts",
- "baseline",
- "duration_sec",
- "num_prompts_per_batch",
- }
- cmd: list[str] = [
- sys.executable,
- "-u",
- str(diffusion_benchmark_script),
- "--host",
- host,
- "--port",
- str(port),
- "--model",
- model,
- "--output-file",
- str(output_path),
- ]
- for key, value in params.items():
- if key in skip_keys or value is None:
- continue
- flag = f"--{str(key).replace('_', '-')}"
- if isinstance(value, bool) and value:
- cmd.append(flag)
- elif isinstance(value, bool):
- continue
- elif isinstance(value, (dict, list)):
- cmd.extend([flag, json.dumps(value, ensure_ascii=False, separators=(",", ":"))])
- else:
- cmd.extend([flag, str(value)])
-
- cmd.extend(["--num-prompts", str(num_prompts)])
- if request_rate is not None:
- cmd.extend(["--request-rate", str(request_rate)])
- else:
- cmd.extend(["--max-concurrency", str(max_concurrency), "--request-rate", "inf"])
- return cmd
-
-
-def _sample_int_from_range_spec(value: Any, rng: random.Random) -> Any:
- """Resolve one value that may be scalar or range spec into an int."""
- if isinstance(value, int):
- return value
-
- if isinstance(value, (list, tuple)) and len(value) == 2 and all(isinstance(v, int) for v in value):
- low, high = int(value[0]), int(value[1])
- if low > high:
- low, high = high, low
- return rng.randint(low, high)
-
- if isinstance(value, dict) and {"min", "max"} <= set(value):
- low, high = int(value["min"]), int(value["max"])
- if low > high:
- low, high = high, low
- return rng.randint(low, high)
-
- if isinstance(value, str):
- raw = value.strip()
- if raw.isdigit():
- return int(raw)
- if "-" in raw:
- parts = [p.strip() for p in raw.split("-", 1)]
- if len(parts) == 2 and parts[0].isdigit() and parts[1].isdigit():
- low, high = int(parts[0]), int(parts[1])
- if low > high:
- low, high = high, low
- return rng.randint(low, high)
-
- return value
-
-
-def _sample_bucket_key(raw_key: str, rng: random.Random) -> str:
- """Sample bucket tuple keys that use range syntax, e.g. ``(128-512, 128-512, 1)``."""
- match = _BUCKET_KEY_PATTERN.match(raw_key.strip())
- if not match:
- return raw_key
-
- sampled_parts: list[int] = []
- for token in match.groups():
- sampled = _sample_int_from_range_spec(token.strip(), rng)
- if not isinstance(sampled, int):
- return raw_key
- sampled_parts.append(sampled)
-
- # For video buckets (height>0 and num_frames>1), enforce even H/W to avoid
- # ffmpeg yuv420p encoding/decoding failures ("Could not open video stream").
- if sampled_parts[0] > 0 and sampled_parts[2] > 1:
- sampled_parts[0] = max(2, sampled_parts[0] - (sampled_parts[0] % 2))
- sampled_parts[1] = max(2, sampled_parts[1] - (sampled_parts[1] % 2))
-
- return f"({sampled_parts[0]}, {sampled_parts[1]}, {sampled_parts[2]})"
-
-
-def _sample_stability_batch_params(params: dict[str, Any], batch_index: int) -> dict[str, Any]:
- """Materialize per-batch random values for configured range fields."""
- sampled = dict(params)
- rng = random.Random(time.time_ns() + batch_index)
-
- for field_name in (
- "random_input_len",
- "random_output_len",
- "random_mm_base_items_per_request",
- "width",
- "height",
- ):
- if field_name in sampled:
- sampled[field_name] = _sample_int_from_range_spec(sampled[field_name], rng)
-
- bucket_config = sampled.get("random_mm_bucket_config")
- if isinstance(bucket_config, dict):
- sampled_bucket_config: dict[str, float] = {}
- for raw_key, probability in bucket_config.items():
- sampled_key = _sample_bucket_key(str(raw_key), rng)
- sampled_bucket_config[sampled_key] = sampled_bucket_config.get(sampled_key, 0.0) + float(probability)
- sampled["random_mm_bucket_config"] = sampled_bucket_config
-
- return sampled
-
-
-def _run_one_vllm_bench_batch(
- host: str,
- port: int,
- _model: str,
- params: dict[str, Any],
- num_prompts: int,
- request_rate: float | None,
- max_concurrency: int | None,
- result_dir: str,
- batch_index: int,
-) -> dict[str, Any]:
- base = _build_base_args(params, host, port)
- if request_rate is not None:
- args = base + ["--request-rate", str(request_rate), "--num-prompts", str(num_prompts)]
- flow = request_rate
- else:
- args = base + [
- "--max-concurrency",
- str(max_concurrency),
- "--num-prompts",
- str(num_prompts),
- "--request-rate",
- "inf",
- ]
- flow = max_concurrency
-
- # Print the exact per-batch benchmark CLI (randomized params are already materialized).
- preview_cmd = ["vllm", "bench", "serve", "--omni", *args]
- print(f"\n[Stability][Batch {batch_index}] Benchmark command:")
- print(shlex.join(preview_cmd))
-
- dataset_name = params.get("dataset_name", "random")
- old_benchmark_dir = os.environ.get("BENCHMARK_DIR")
- try:
- os.environ["BENCHMARK_DIR"] = result_dir
- result = run_benchmark(
- args=args,
- test_name="stability",
- flow=flow,
- dataset_name=dataset_name,
- num_prompt=num_prompts,
- random_input_len=params.get("random_input_len"),
- random_output_len=params.get("random_output_len"),
- )
- return _normalize_bench_metrics(result)
- except (FileNotFoundError, OSError) as exc:
- return {
- "completed": 0,
- "failed": 1,
- "duration": 0.0,
- "errors": [f"Benchmark batch failed: {type(exc).__name__}: {exc}"],
- }
- finally:
- if old_benchmark_dir is not None:
- os.environ["BENCHMARK_DIR"] = old_benchmark_dir
- elif "BENCHMARK_DIR" in os.environ:
- os.environ.pop("BENCHMARK_DIR")
-
-
-def _run_one_diffusion_batch(
- host: str,
- port: int,
- model: str,
- params: dict[str, Any],
- num_prompts: int,
- request_rate: float | None,
- max_concurrency: int | None,
- _result_dir: str,
- _batch_index: int,
-) -> dict[str, Any]:
- diffusion_benchmark_script = Path(REPO_ROOT / "benchmarks" / "diffusion" / "diffusion_benchmark_serving.py")
- with tempfile.NamedTemporaryFile(mode="w", suffix=".json", prefix="stability_diffusion_", delete=False) as tmp:
- out_path = Path(tmp.name)
- try:
- cmd = _build_diffusion_cmd(
- host,
- port,
- model,
- params,
- num_prompts,
- request_rate,
- max_concurrency,
- out_path,
- diffusion_benchmark_script,
- )
- proc = subprocess.run(
- cmd,
- cwd=str(REPO_ROOT),
- capture_output=True,
- text=True,
- )
- if proc.stdout:
- print(proc.stdout, end="" if proc.stdout.endswith("\n") else "\n")
- if proc.stderr:
- print(proc.stderr, end="" if proc.stderr.endswith("\n") else "\n")
- if proc.returncode != 0:
- return {
- "completed": 0,
- "failed": 1,
- "duration": 0.0,
- "errors": [f"diffusion_benchmark_serving.py exited {proc.returncode}"],
- }
- if not out_path.is_file():
- return {
- "completed": 0,
- "failed": 1,
- "duration": 0.0,
- "errors": [f"Missing benchmark output: {out_path}"],
- }
- with open(out_path, encoding="utf-8") as file:
- metrics = json.load(file)
- return _normalize_bench_metrics(metrics)
- except (FileNotFoundError, OSError, json.JSONDecodeError) as exc:
- return {
- "completed": 0,
- "failed": 1,
- "duration": 0.0,
- "errors": [f"Diffusion batch failed: {type(exc).__name__}: {exc}"],
- }
- finally:
- out_path.unlink(missing_ok=True)
-
-
-def merge_batch_results(batch_results: list[dict[str, Any]], total_duration_sec: float) -> dict[str, Any]:
- if not batch_results:
- return {"completed": 0, "failed": 0, "duration": total_duration_sec, "errors": []}
-
- completed = sum(result.get("completed", 0) for result in batch_results)
- failed = sum(result.get("failed", 0) for result in batch_results)
- merged: dict[str, Any] = {
- "completed": completed,
- "failed": failed,
- "duration": total_duration_sec,
- "errors": [],
- }
- for result in batch_results:
- merged["errors"].extend(result.get("errors") or [])
- return merged
-
-
-def print_merged_report(result: dict[str, Any]) -> None:
- fmt = "{:<40} {:<10}"
- fmt_float = "{:<40} {:<10.2f}"
- completed = result.get("completed", 0)
- failed = result.get("failed", 0)
- duration = float(result.get("duration", 0.0) or 0.0)
- print("\n============ Stability Benchmark Summary ============")
- print(fmt.format("Successful requests:", completed))
- print(fmt.format("Failed requests:", failed))
- print(fmt_float.format("Total duration (s):", duration))
- print("==================================================\n")
-
-
-def run_stability_benchmark_loop(
- host: str,
- port: int,
- model: str,
- duration_sec: int | float,
- params: dict[str, Any],
- *,
- request_rate: float | None,
- max_concurrency: int | None,
- result_dir: str,
- num_prompts_per_batch: int,
- run_one_batch: RunOneBatchFn,
- result_filename: str | None = None,
-) -> dict[str, Any]:
- if (request_rate is None) == (max_concurrency is None):
- raise ValueError("Exactly one of request_rate or max_concurrency must be specified")
-
- start_time = time.perf_counter()
- batch_results: list[dict[str, Any]] = []
- batch_index = 0
-
- while True:
- if (time.perf_counter() - start_time) >= duration_sec:
- break
- sampled_params = _sample_stability_batch_params(params, batch_index)
- result = run_one_batch(
- host,
- port,
- model,
- sampled_params,
- num_prompts_per_batch,
- request_rate,
- max_concurrency,
- result_dir,
- batch_index,
- )
- batch_results.append(result)
- batch_index += 1
- if (time.perf_counter() - start_time) >= duration_sec:
- break
-
- total_duration = time.perf_counter() - start_time
- merged = merge_batch_results(batch_results, total_duration)
- print_merged_report(merged)
-
- if result_filename and result_dir:
- result_path = Path(result_dir) / result_filename
- with open(result_path, "w", encoding="utf-8") as file:
- json.dump(merged, file, indent=2, ensure_ascii=False)
-
- return merged
diff --git a/tests/dfx/stability/scripts/test_benchmark_stability.py b/tests/dfx/stability/scripts/test_benchmark_stability.py
new file mode 100644
index 00000000000..e8568652d18
--- /dev/null
+++ b/tests/dfx/stability/scripts/test_benchmark_stability.py
@@ -0,0 +1,286 @@
+"""
+Stability test cases: start OmniServer first, then run benchmark traffic with either
+`request-rate` or `max-concurrency` for a fixed duration. No new requests are sent
+after the duration is reached, and the test asserts that there are no failed requests.
+
+The overall flow matches the perf logic: `load_configs`, `modify_stage`,
+`create_unique_server_params`, `create_test_parameter_mapping`,
+`get_benchmark_params_for_server`, `create_benchmark_indices`, and the
+`omni_server` fixture are aligned with perf. Only the benchmark execution
+(`run_stability_benchmark`, which is duration-based here) and the test cases differ.
+
+All test-specific parameters, such as `duration_sec`, `request_rate` /
+`max_concurrency`, and `num_prompts_per_batch`, are configured in
+`tests/dfx/stability/tests/test.json` and are no longer overridden
+through environment variables.
+"""
+
+import json
+import os
+import threading
+import time
+from pathlib import Path
+from typing import Any
+
+import pytest
+
+from tests.conftest import OmniServer
+from tests.dfx.conftest import (
+ create_benchmark_indices,
+ create_test_parameter_mapping,
+ create_unique_server_params,
+ get_benchmark_params_for_server,
+ load_configs,
+)
+from tests.dfx.perf.scripts.run_benchmark import run_benchmark
+
+STABILITY_DIR = Path(__file__).resolve().parent.parent
+STAGE_CONFIGS_DIR = STABILITY_DIR / "stage_configs"
+CONFIG_FILE_PATH = str(STABILITY_DIR / "tests" / "test.json")
+DEFAULT_NUM_PROMPTS_PER_BATCH = 20
+
+
+try:
+ BENCHMARK_CONFIGS = load_configs(CONFIG_FILE_PATH)
+except FileNotFoundError:
+ BENCHMARK_CONFIGS = []
+
+test_params = create_unique_server_params(BENCHMARK_CONFIGS, STAGE_CONFIGS_DIR) if BENCHMARK_CONFIGS else []
+server_to_benchmark_mapping = create_test_parameter_mapping(BENCHMARK_CONFIGS) if BENCHMARK_CONFIGS else {}
+
+_omni_server_lock = threading.Lock()
+
+
+benchmark_indices = create_benchmark_indices(BENCHMARK_CONFIGS, server_to_benchmark_mapping)
+
+
+def _build_base_args(params: dict[str, Any], host: str, port: int) -> list[str]:
+ exclude = {
+ "request_rate",
+ "max_concurrency",
+ "num_prompts",
+ "baseline",
+ "duration_sec",
+ "num_prompts_per_batch",
+ }
+ args = ["--host", host, "--port", str(port)]
+ for key, value in params.items():
+ if key in exclude or value is None:
+ continue
+ arg_name = f"--{key.replace('_', '-')}"
+ if isinstance(value, bool) and value:
+ args.append(arg_name)
+ elif isinstance(value, dict):
+ args.extend([arg_name, json.dumps(value, ensure_ascii=False, separators=(",", ":"))])
+ elif not isinstance(value, bool):
+ args.extend([arg_name, str(value)])
+ return args
+
+
+def _run_one_benchmark_batch(
+ host: str,
+ port: int,
+ params: dict[str, Any],
+ num_prompts: int,
+ request_rate: float | None,
+ max_concurrency: int | None,
+ result_dir: str,
+ batch_index: int,
+) -> dict[str, Any]:
+ base = _build_base_args(params, host, port)
+ if request_rate is not None:
+ args = base + ["--request-rate", str(request_rate), "--num-prompts", str(num_prompts)]
+ flow = request_rate
+ else:
+ args = base + [
+ "--max-concurrency",
+ str(max_concurrency),
+ "--num-prompts",
+ str(num_prompts),
+ "--request-rate",
+ "inf",
+ ]
+ flow = max_concurrency
+
+ dataset_name = params.get("dataset_name", "random")
+ old_benchmark_dir = os.environ.get("BENCHMARK_DIR")
+ try:
+ os.environ["BENCHMARK_DIR"] = result_dir
+ result = run_benchmark(
+ args=args,
+ test_name="stability",
+ flow=flow,
+ dataset_name=dataset_name,
+ num_prompt=num_prompts,
+ )
+ return result
+ except (FileNotFoundError, OSError) as e:
+ # Surface batch failure so the stability test does not false-pass when
+ # run_benchmark fails before writing JSON (e.g. command not found).
+ return {
+ "completed": 0,
+ "failed": 1,
+ "duration": 0.0,
+ "errors": [f"Benchmark batch failed: {type(e).__name__}: {e}"],
+ }
+ finally:
+ if old_benchmark_dir is not None:
+ os.environ["BENCHMARK_DIR"] = old_benchmark_dir
+ elif "BENCHMARK_DIR" in os.environ:
+ os.environ.pop("BENCHMARK_DIR")
+
+
+def _merge_batch_results(batch_results: list[dict[str, Any]], total_duration_sec: float) -> dict[str, Any]:
+ if not batch_results:
+ return {"completed": 0, "failed": 0, "duration": total_duration_sec, "errors": []}
+
+ completed = sum(r.get("completed", 0) for r in batch_results)
+ failed = sum(r.get("failed", 0) for r in batch_results)
+ merged: dict[str, Any] = {
+ "completed": completed,
+ "failed": failed,
+ "duration": total_duration_sec,
+ "errors": [],
+ }
+ for r in batch_results:
+ merged["errors"].extend(r.get("errors") or [])
+ return merged
+
+
+def _print_merged_report(result: dict[str, Any]) -> None:
+ """Print the final summary: successful requests, failed requests, and total duration only."""
+ fmt = "{:<40} {:<10}"
+ fmt_float = "{:<40} {:<10.2f}"
+ completed = result.get("completed", 0)
+ failed = result.get("failed", 0)
+ duration = float(result.get("duration", 0.0) or 0.0)
+ print("\n============ Stability Benchmark Summary ============")
+ print(fmt.format("Successful requests:", completed))
+ print(fmt.format("Failed requests:", failed))
+ print(fmt_float.format("Total duration (s):", duration))
+ print("==================================================\n")
+
+
+def run_stability_benchmark(
+ host: str,
+ port: int,
+ duration_sec: int | float,
+ params: dict[str, Any],
+ *,
+ request_rate: float | None = None,
+ max_concurrency: int | None = None,
+ result_filename: str | None = None,
+ result_dir: str = "./",
+ num_prompts_per_batch: int = DEFAULT_NUM_PROMPTS_PER_BATCH,
+) -> dict[str, Any]:
+ if (request_rate is None) == (max_concurrency is None):
+ raise ValueError("Exactly one of request_rate or max_concurrency must be specified")
+
+ start_time = time.perf_counter()
+ batch_results: list[dict[str, Any]] = []
+ batch_index = 0
+
+ while True:
+ if (time.perf_counter() - start_time) >= duration_sec:
+ break
+ result = _run_one_benchmark_batch(
+ host=host,
+ port=port,
+ params=params,
+ num_prompts=num_prompts_per_batch,
+ request_rate=request_rate,
+ max_concurrency=max_concurrency,
+ result_dir=result_dir,
+ batch_index=batch_index,
+ )
+ batch_results.append(result)
+ batch_index += 1
+ if (time.perf_counter() - start_time) >= duration_sec:
+ break
+
+ total_duration = time.perf_counter() - start_time
+ merged = _merge_batch_results(batch_results, total_duration)
+ _print_merged_report(merged)
+
+ if result_filename and result_dir:
+ result_path = Path(result_dir) / result_filename
+ with open(result_path, "w", encoding="utf-8") as f:
+ json.dump(merged, f, indent=2, ensure_ascii=False)
+
+ return merged
+
+
+@pytest.fixture(scope="module")
+def omni_server(request):
+ """Start vLLM-Omni server as a subprocess with actual model weights.
+ Uses session scope so the server starts only once for the entire test session.
+ Multi-stage initialization can take 10-20+ minutes.
+ """
+ with _omni_server_lock:
+ test_name, model, stage_config_path = request.param
+
+ print(f"Starting OmniServer with test: {test_name}, model: {model}")
+
+ with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "120"]) as server:
+ server.test_name = test_name
+ print("OmniServer started successfully")
+ yield server
+ print("OmniServer stopping...")
+
+ print("OmniServer stopped")
+
+
+@pytest.fixture(params=benchmark_indices)
+def stability_benchmark_params(request, omni_server):
+ """Benchmark parameters fixture with proper parametrization (same as perf)."""
+ test_name, param_index = request.param
+
+ if test_name != omni_server.test_name:
+ pytest.skip(f"Skipping parameter for {test_name} - current server is {omni_server.test_name}")
+
+ all_params = get_benchmark_params_for_server(test_name, server_to_benchmark_mapping)
+
+ if not all_params:
+ raise ValueError(f"No benchmark parameters found for test: {test_name}")
+
+ if param_index >= len(all_params):
+ raise ValueError(f"No benchmark parameters found for index {param_index} in test: {test_name}")
+
+ current = param_index + 1
+ total = len(all_params)
+ print(f"\n Running benchmark {current}/{total} for {test_name}")
+
+ return {"test_name": test_name, "params": all_params[param_index]}
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+@pytest.mark.parametrize("stability_benchmark_params", benchmark_indices, indirect=True)
+def test_benchmark_stability(omni_server, stability_benchmark_params):
+ """Run the benchmark for a fixed duration using request-rate or max-concurrency and assert zero failed requests."""
+ test_name = stability_benchmark_params["test_name"]
+ params = stability_benchmark_params["params"]
+ duration_sec = params.get("duration_sec", 300)
+ num_prompts_per_batch = params.get("num_prompts_per_batch", DEFAULT_NUM_PROMPTS_PER_BATCH)
+ request_rate = params.get("request_rate")
+ max_concurrency = params.get("max_concurrency")
+
+ bench_params = {
+ k: v
+ for k, v in params.items()
+ if k not in ("duration_sec", "request_rate", "max_concurrency", "num_prompts_per_batch")
+ }
+
+ result = run_stability_benchmark(
+ host=omni_server.host,
+ port=omni_server.port,
+ duration_sec=duration_sec,
+ params=bench_params,
+ request_rate=request_rate,
+ max_concurrency=max_concurrency,
+ result_dir=str(STABILITY_DIR),
+ num_prompts_per_batch=num_prompts_per_batch,
+ )
+
+ assert result.get("failed", 0) == 0, f"[{test_name}] Failed requests detected: {result.get('errors', [])}"
+ assert result.get("completed", 0) > 0, f"[{test_name}] No requests completed"
diff --git a/tests/dfx/stability/scripts/test_stability_qwen3_omni.py b/tests/dfx/stability/scripts/test_stability_qwen3_omni.py
deleted file mode 100644
index d1c2af8cf08..00000000000
--- a/tests/dfx/stability/scripts/test_stability_qwen3_omni.py
+++ /dev/null
@@ -1,68 +0,0 @@
-"""
-Qwen3-Omni stability: OmniServer + ``vllm bench serve --omni`` for a fixed duration.
-
-Configuration: ``tests/dfx/stability/tests/test_qwen3_omni.json``.
-"""
-
-from __future__ import annotations
-
-from pathlib import Path
-
-import pytest
-
-from tests.dfx.conftest import (
- create_benchmark_indices,
- create_test_parameter_mapping,
- create_unique_server_params,
- load_configs,
-)
-from tests.dfx.stability.helpers import _run_one_vllm_bench_batch, run_stability_benchmark_loop
-
-STABILITY_DIR = Path(__file__).resolve().parent.parent
-DEPLOY_CONFIGS_DIR = STABILITY_DIR / "deploy"
-CONFIG_FILE_PATH = str(STABILITY_DIR / "tests" / "test_qwen3_omni.json")
-DEFAULT_NUM_PROMPTS_PER_BATCH = 20
-STABILITY_SERVER_TIMEOUT_ARGS = ["--stage-init-timeout", "600"]
-
-try:
- BENCHMARK_CONFIGS = load_configs(CONFIG_FILE_PATH)
-except FileNotFoundError:
- BENCHMARK_CONFIGS = []
-
-test_params = create_unique_server_params(BENCHMARK_CONFIGS, DEPLOY_CONFIGS_DIR) if BENCHMARK_CONFIGS else []
-server_to_benchmark_mapping = create_test_parameter_mapping(BENCHMARK_CONFIGS) if BENCHMARK_CONFIGS else {}
-benchmark_indices = create_benchmark_indices(BENCHMARK_CONFIGS, server_to_benchmark_mapping)
-
-
-@pytest.mark.slow
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-@pytest.mark.parametrize("stability_benchmark_params", benchmark_indices, indirect=True)
-def test_stability_qwen3_omni(omni_server, stability_benchmark_params):
- test_name = stability_benchmark_params["test_name"]
- params = stability_benchmark_params["params"]
- duration_sec = params.get("duration_sec", 300)
- num_prompts_per_batch = params.get("num_prompts_per_batch", DEFAULT_NUM_PROMPTS_PER_BATCH)
- request_rate = params.get("request_rate")
- max_concurrency = params.get("max_concurrency")
-
- bench_params = {
- k: v
- for k, v in params.items()
- if k not in ("duration_sec", "request_rate", "max_concurrency", "num_prompts_per_batch")
- }
-
- result = run_stability_benchmark_loop(
- host=omni_server.host,
- port=omni_server.port,
- model=omni_server.model,
- duration_sec=duration_sec,
- params=bench_params,
- request_rate=request_rate,
- max_concurrency=max_concurrency,
- result_dir=str(STABILITY_DIR),
- num_prompts_per_batch=num_prompts_per_batch,
- run_one_batch=_run_one_vllm_bench_batch,
- )
-
- assert result.get("failed", 0) == 0, f"[{test_name}] Failed requests detected: {result.get('errors', [])}"
- assert result.get("completed", 0) > 0, f"[{test_name}] No requests completed"
diff --git a/tests/dfx/stability/scripts/test_stability_qwen3_tts.py b/tests/dfx/stability/scripts/test_stability_qwen3_tts.py
deleted file mode 100644
index beccd67d964..00000000000
--- a/tests/dfx/stability/scripts/test_stability_qwen3_tts.py
+++ /dev/null
@@ -1,68 +0,0 @@
-"""
-Qwen3-TTS stability: OmniServer + ``vllm bench serve --omni`` for a fixed duration.
-
-Configuration: ``tests/dfx/stability/tests/test_qwen3_tts.json``.
-"""
-
-from __future__ import annotations
-
-from pathlib import Path
-
-import pytest
-
-from tests.dfx.conftest import (
- create_benchmark_indices,
- create_test_parameter_mapping,
- create_unique_server_params,
- load_configs,
-)
-from tests.dfx.stability.helpers import _run_one_vllm_bench_batch, run_stability_benchmark_loop
-
-STABILITY_DIR = Path(__file__).resolve().parent.parent
-DEPLOY_CONFIGS_DIR = STABILITY_DIR / "deploy"
-CONFIG_FILE_PATH = str(STABILITY_DIR / "tests" / "test_qwen3_tts.json")
-DEFAULT_NUM_PROMPTS_PER_BATCH = 20
-STABILITY_SERVER_TIMEOUT_ARGS = ["--stage-init-timeout", "600"]
-
-try:
- BENCHMARK_CONFIGS = load_configs(CONFIG_FILE_PATH)
-except FileNotFoundError:
- BENCHMARK_CONFIGS = []
-
-test_params = create_unique_server_params(BENCHMARK_CONFIGS, DEPLOY_CONFIGS_DIR) if BENCHMARK_CONFIGS else []
-server_to_benchmark_mapping = create_test_parameter_mapping(BENCHMARK_CONFIGS) if BENCHMARK_CONFIGS else {}
-benchmark_indices = create_benchmark_indices(BENCHMARK_CONFIGS, server_to_benchmark_mapping)
-
-
-@pytest.mark.slow
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-@pytest.mark.parametrize("stability_benchmark_params", benchmark_indices, indirect=True)
-def test_stability_qwen3_tts(omni_server, stability_benchmark_params):
- test_name = stability_benchmark_params["test_name"]
- params = stability_benchmark_params["params"]
- duration_sec = params.get("duration_sec", 300)
- num_prompts_per_batch = params.get("num_prompts_per_batch", DEFAULT_NUM_PROMPTS_PER_BATCH)
- request_rate = params.get("request_rate")
- max_concurrency = params.get("max_concurrency")
-
- bench_params = {
- k: v
- for k, v in params.items()
- if k not in ("duration_sec", "request_rate", "max_concurrency", "num_prompts_per_batch")
- }
-
- result = run_stability_benchmark_loop(
- host=omni_server.host,
- port=omni_server.port,
- model=omni_server.model,
- duration_sec=duration_sec,
- params=bench_params,
- request_rate=request_rate,
- max_concurrency=max_concurrency,
- result_dir=str(STABILITY_DIR),
- num_prompts_per_batch=num_prompts_per_batch,
- run_one_batch=_run_one_vllm_bench_batch,
- )
-
- assert result.get("failed", 0) == 0, f"[{test_name}] Failed requests detected: {result.get('errors', [])}"
- assert result.get("completed", 0) > 0, f"[{test_name}] No requests completed"
diff --git a/tests/dfx/stability/scripts/test_stability_qwen_image.py b/tests/dfx/stability/scripts/test_stability_qwen_image.py
deleted file mode 100644
index a90e2092f5e..00000000000
--- a/tests/dfx/stability/scripts/test_stability_qwen_image.py
+++ /dev/null
@@ -1,68 +0,0 @@
-"""
-Qwen-Image stability: OmniServer (diffusion) + ``diffusion_benchmark_serving.py``.
-
-Configuration: ``tests/dfx/stability/tests/test_qwen_image.json``.
-"""
-
-from __future__ import annotations
-
-from pathlib import Path
-
-import pytest
-
-from tests.dfx.conftest import (
- create_benchmark_indices,
- create_test_parameter_mapping,
- create_unique_server_params,
- load_configs,
-)
-from tests.dfx.stability.helpers import _run_one_diffusion_batch, run_stability_benchmark_loop
-
-STABILITY_DIR = Path(__file__).resolve().parent.parent
-DEPLOY_CONFIGS_DIR = STABILITY_DIR / "deploy"
-CONFIG_FILE_PATH = str(STABILITY_DIR / "tests" / "test_qwen_image.json")
-DEFAULT_NUM_PROMPTS_PER_BATCH = 20
-STABILITY_SERVER_TIMEOUT_ARGS = ["--stage-init-timeout", "600", "--init-timeout", "900"]
-
-try:
- BENCHMARK_CONFIGS = load_configs(CONFIG_FILE_PATH)
-except FileNotFoundError:
- BENCHMARK_CONFIGS = []
-
-test_params = create_unique_server_params(BENCHMARK_CONFIGS, DEPLOY_CONFIGS_DIR) if BENCHMARK_CONFIGS else []
-server_to_benchmark_mapping = create_test_parameter_mapping(BENCHMARK_CONFIGS) if BENCHMARK_CONFIGS else {}
-benchmark_indices = create_benchmark_indices(BENCHMARK_CONFIGS, server_to_benchmark_mapping)
-
-
-@pytest.mark.slow
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-@pytest.mark.parametrize("stability_benchmark_params", benchmark_indices, indirect=True)
-def test_stability_qwen_image(omni_server, stability_benchmark_params):
- test_name = stability_benchmark_params["test_name"]
- params = stability_benchmark_params["params"]
- duration_sec = params.get("duration_sec", 300)
- num_prompts_per_batch = params.get("num_prompts_per_batch", DEFAULT_NUM_PROMPTS_PER_BATCH)
- request_rate = params.get("request_rate")
- max_concurrency = params.get("max_concurrency")
-
- bench_params = {
- k: v
- for k, v in params.items()
- if k not in ("duration_sec", "request_rate", "max_concurrency", "num_prompts_per_batch")
- }
-
- result = run_stability_benchmark_loop(
- host=omni_server.host,
- port=omni_server.port,
- model=omni_server.model,
- duration_sec=duration_sec,
- params=bench_params,
- request_rate=request_rate,
- max_concurrency=max_concurrency,
- result_dir=str(STABILITY_DIR),
- num_prompts_per_batch=num_prompts_per_batch,
- run_one_batch=_run_one_diffusion_batch,
- )
-
- assert result.get("failed", 0) == 0, f"[{test_name}] Failed requests detected: {result.get('errors', [])}"
- assert result.get("completed", 0) > 0, f"[{test_name}] No requests completed"
diff --git a/tests/dfx/stability/scripts/test_stability_wan22.py b/tests/dfx/stability/scripts/test_stability_wan22.py
deleted file mode 100644
index afe9c4d0ca7..00000000000
--- a/tests/dfx/stability/scripts/test_stability_wan22.py
+++ /dev/null
@@ -1,68 +0,0 @@
-"""
-Wan2.2 T2V stability: OmniServer (diffusion) + ``diffusion_benchmark_serving.py`` / ``v1/videos``.
-
-Configuration: ``tests/dfx/stability/tests/test_wan22.json``.
-"""
-
-from __future__ import annotations
-
-from pathlib import Path
-
-import pytest
-
-from tests.dfx.conftest import (
- create_benchmark_indices,
- create_test_parameter_mapping,
- create_unique_server_params,
- load_configs,
-)
-from tests.dfx.stability.helpers import _run_one_diffusion_batch, run_stability_benchmark_loop
-
-STABILITY_DIR = Path(__file__).resolve().parent.parent
-DEPLOY_CONFIGS_DIR = STABILITY_DIR / "deploy"
-CONFIG_FILE_PATH = str(STABILITY_DIR / "tests" / "test_wan22.json")
-DEFAULT_NUM_PROMPTS_PER_BATCH = 20
-STABILITY_SERVER_TIMEOUT_ARGS = ["--stage-init-timeout", "600", "--init-timeout", "900"]
-
-try:
- BENCHMARK_CONFIGS = load_configs(CONFIG_FILE_PATH)
-except FileNotFoundError:
- BENCHMARK_CONFIGS = []
-
-test_params = create_unique_server_params(BENCHMARK_CONFIGS, DEPLOY_CONFIGS_DIR) if BENCHMARK_CONFIGS else []
-server_to_benchmark_mapping = create_test_parameter_mapping(BENCHMARK_CONFIGS) if BENCHMARK_CONFIGS else {}
-benchmark_indices = create_benchmark_indices(BENCHMARK_CONFIGS, server_to_benchmark_mapping)
-
-
-@pytest.mark.slow
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-@pytest.mark.parametrize("stability_benchmark_params", benchmark_indices, indirect=True)
-def test_stability_wan22(omni_server, stability_benchmark_params):
- test_name = stability_benchmark_params["test_name"]
- params = stability_benchmark_params["params"]
- duration_sec = params.get("duration_sec", 300)
- num_prompts_per_batch = params.get("num_prompts_per_batch", DEFAULT_NUM_PROMPTS_PER_BATCH)
- request_rate = params.get("request_rate")
- max_concurrency = params.get("max_concurrency")
-
- bench_params = {
- k: v
- for k, v in params.items()
- if k not in ("duration_sec", "request_rate", "max_concurrency", "num_prompts_per_batch")
- }
-
- result = run_stability_benchmark_loop(
- host=omni_server.host,
- port=omni_server.port,
- model=omni_server.model,
- duration_sec=duration_sec,
- params=bench_params,
- request_rate=request_rate,
- max_concurrency=max_concurrency,
- result_dir=str(STABILITY_DIR),
- num_prompts_per_batch=num_prompts_per_batch,
- run_one_batch=_run_one_diffusion_batch,
- )
-
- assert result.get("failed", 0) == 0, f"[{test_name}] Failed requests detected: {result.get('errors', [])}"
- assert result.get("completed", 0) > 0, f"[{test_name}] No requests completed"
diff --git a/tests/dfx/stability/stage_configs/qwen3_omni.yaml b/tests/dfx/stability/stage_configs/qwen3_omni.yaml
new file mode 100644
index 00000000000..802f8dd2494
--- /dev/null
+++ b/tests/dfx/stability/stage_configs/qwen3_omni.yaml
@@ -0,0 +1,101 @@
+# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
+# Stage 0: Thinker (multimodal understanding + text generation)
+# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
+# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
+
+# The following config has been verified on 2x H100-80G GPUs.
+async_chunk: false
+stage_args:
+ - stage_id: 0
+ stage_type: llm # Use llm stage type to launch OmniLLM
+ runtime:
+ devices: "0"
+ max_batch_size: 64
+ engine_args:
+ model_stage: thinker
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent # Output hidden states for talker
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ hf_config_name: thinker_config
+ tensor_parallel_size: 1
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ stage_type: llm # Use llm stage type to launch OmniLLM
+ runtime:
+ devices: "1"
+ max_batch_size: 64
+ engine_args:
+ model_stage: talker
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.6
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent # Output codec codes for code2wav
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ distributed_executor_backend: "mp"
+ hf_config_name: talker_config
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
+ # final_output: true
+ # final_output_type: text
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ detokenize: False
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+
+ - stage_id: 2
+ stage_type: llm # Use llm stage type to launch OmniLLM
+ runtime:
+ devices: "1"
+ max_batch_size: 64
+ engine_args:
+ model_stage: code2wav
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ async_scheduling: false
+ enable_prefix_caching: false
+ engine_output_type: audio # Final output: audio waveform
+ gpu_memory_utilization: 0.1
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 1000000
+ hf_config_name: thinker_config
+ engine_input_source: [1]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
diff --git a/tests/dfx/stability/tests/test.json b/tests/dfx/stability/tests/test.json
new file mode 100644
index 00000000000..95993c9c556
--- /dev/null
+++ b/tests/dfx/stability/tests/test.json
@@ -0,0 +1,86 @@
+[
+ {
+ "test_name": "test_qwen3_omni_stability",
+ "server_params": {
+ "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
+ "stage_config_name": "qwen3_omni.yaml"
+ },
+ "benchmark_params": [
+ {
+ "dataset_name": "random",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "duration_sec": 300,
+ "request_rate": 0.2,
+ "num_prompts_per_batch": 10,
+ "random_input_len": 2500,
+ "random_output_len": 900,
+ "ignore_eos": true,
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration"
+ },
+ {
+ "dataset_name": "random",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "duration_sec": 300,
+ "max_concurrency": 2,
+ "num_prompts_per_batch": 10,
+ "random_input_len": 2500,
+ "random_output_len": 900,
+ "ignore_eos": true,
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration"
+ }
+ ]
+ },
+ {
+ "test_name": "test_qwen3_omni_stability_async_chunk",
+ "server_params": {
+ "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
+ "stage_config_name": "qwen3_omni.yaml",
+ "update": {
+ "async_chunk": true,
+ "stage_args": {
+ "0": {
+ "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk"
+ },
+ "1": {
+ "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk"
+ }
+ }
+ },
+ "delete": {
+ "stage_args": {
+ "2": [
+ "custom_process_input_func"
+ ]
+ }
+ }
+ },
+ "benchmark_params": [
+ {
+ "dataset_name": "random",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "duration_sec": 300,
+ "request_rate": 0.2,
+ "num_prompts_per_batch": 10,
+ "random_input_len": 2500,
+ "random_output_len": 900,
+ "ignore_eos": true,
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration"
+ },
+ {
+ "dataset_name": "random",
+ "backend": "openai-chat-omni",
+ "endpoint": "/v1/chat/completions",
+ "duration_sec": 300,
+ "max_concurrency": 2,
+ "num_prompts_per_batch": 10,
+ "random_input_len": 2500,
+ "random_output_len": 900,
+ "ignore_eos": true,
+ "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration"
+ }
+ ]
+ }
+]
diff --git a/tests/dfx/stability/tests/test_qwen3_omni.json b/tests/dfx/stability/tests/test_qwen3_omni.json
deleted file mode 100644
index a16ab805cc6..00000000000
--- a/tests/dfx/stability/tests/test_qwen3_omni.json
+++ /dev/null
@@ -1,97 +0,0 @@
-[
- {
- "test_name": "test_qwen3_omni_stability",
- "server_params": {
- "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
- "stage_overrides": {
- "2": {
- "max_num_batched_tokens": 1000000
- }
- }
- },
- "benchmark_params": [
- {
- "dataset_name": "random-mm",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "duration_sec": 86400,
- "request_rate": 0.3,
- "num_prompts_per_batch": 10,
- "random_input_len": {
- "min": 0,
- "max": 8000
- },
- "random_output_len": {
- "min": 0,
- "max": 1000
- },
- "random_range_ratio": 0.0,
- "random_mm_base_items_per_request": {
- "min": 0,
- "max": 6
- },
- "random_mm_num_mm_items_range_ratio": 0.0,
- "random_mm_limit_mm_per_prompt": {
- "image": 2,
- "video": 2,
- "audio": 2
- },
- "random_mm_bucket_config": {
- "(128-1024, 128-1024, 1)": 0.34,
- "(256-1080, 256-1920, 2-16)": 0.33,
- "(0, 1-60, 1-3)": 0.33
- },
- "ignore_eos": true,
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration"
- }
- ]
- },
- {
- "test_name": "test_qwen3_omni_stability_async_chunk",
- "server_params": {
- "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
- "stage_overrides": {
- "2": {
- "max_num_batched_tokens": 1000000
- }
- },
- "extra_cli_args": ["--async-chunk"]
- },
- "benchmark_params": [
- {
- "dataset_name": "random-mm",
- "backend": "openai-chat-omni",
- "endpoint": "/v1/chat/completions",
- "duration_sec": 86400,
- "max_concurrency": 2,
- "num_prompts_per_batch": 10,
- "random_input_len": {
- "min": 0,
- "max": 8000
- },
- "random_output_len": {
- "min": 0,
- "max": 1000
- },
- "random_range_ratio": 0.0,
- "random_mm_base_items_per_request": {
- "min": 0,
- "max": 6
- },
- "random_mm_num_mm_items_range_ratio": 0.0,
- "random_mm_limit_mm_per_prompt": {
- "image": 2,
- "video": 2,
- "audio": 2
- },
- "random_mm_bucket_config": {
- "(128-1024, 128-1024, 1)": 0.34,
- "(256-1080, 256-1920, 2-16)": 0.33,
- "(0, 1-60, 1-3)": 0.33
- },
- "ignore_eos": true,
- "percentile-metrics": "ttft,tpot,itl,e2el,audio_rtf,audio_ttfp,audio_duration"
- }
- ]
- }
-]
diff --git a/tests/dfx/stability/tests/test_qwen3_tts.json b/tests/dfx/stability/tests/test_qwen3_tts.json
deleted file mode 100644
index fbf30d88ab2..00000000000
--- a/tests/dfx/stability/tests/test_qwen3_tts.json
+++ /dev/null
@@ -1,56 +0,0 @@
-[
- {
- "test_name": "test_qwen3_tts_stability",
- "server_params": {
- "model": "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
- },
- "benchmark_params": [
- {
- "dataset_name": "random",
- "backend": "openai-audio-speech",
- "endpoint": "/v1/audio/speech",
- "duration_sec": 86400,
- "request_rate": 0.3,
- "num_prompts_per_batch": 10,
- "random_input_len": {
- "min": 0,
- "max": 1000
- },
- "random_output_len": {
- "min": 0,
- "max": 1000
- },
- "random_range_ratio": 0.0,
- "extra_body": {
- "voice": "Vivian",
- "language": "English"
- },
- "ignore_eos": true,
- "percentile-metrics": "ttft,e2el,audio_rtf,audio_ttfp,audio_duration"
- },
- {
- "dataset_name": "random",
- "backend": "openai-audio-speech",
- "endpoint": "/v1/audio/speech",
- "duration_sec": 86400,
- "max_concurrency": 2,
- "num_prompts_per_batch": 10,
- "random_input_len": {
- "min": 0,
- "max": 1000
- },
- "random_output_len": {
- "min": 0,
- "max": 1000
- },
- "random_range_ratio": 0.0,
- "extra_body": {
- "voice": "Vivian",
- "language": "English"
- },
- "ignore_eos": true,
- "percentile-metrics": "ttft,e2el,audio_rtf,audio_ttfp,audio_duration"
- }
- ]
- }
-]
diff --git a/tests/dfx/stability/tests/test_qwen_image.json b/tests/dfx/stability/tests/test_qwen_image.json
deleted file mode 100644
index f3dd93f6f25..00000000000
--- a/tests/dfx/stability/tests/test_qwen_image.json
+++ /dev/null
@@ -1,28 +0,0 @@
-[
- {
- "test_name": "test_qwen_image_stability",
- "server_params": {
- "model": "Qwen/Qwen-Image"
- },
- "benchmark_params": [
- {
- "dataset": "random",
- "task": "t2i",
- "backend": "vllm-omni",
- "duration_sec": 86400,
- "max_concurrency": 1,
- "num_prompts_per_batch": 10,
- "width": {
- "min": 512,
- "max": 2048
- },
- "height": {
- "min": 512,
- "max": 2048
- },
- "num_inference_steps": 50,
- "enable_negative_prompt": true
- }
- ]
- }
-]
diff --git a/tests/dfx/stability/tests/test_wan22.json b/tests/dfx/stability/tests/test_wan22.json
deleted file mode 100644
index c787ce96a07..00000000000
--- a/tests/dfx/stability/tests/test_wan22.json
+++ /dev/null
@@ -1,31 +0,0 @@
-[
- {
- "test_name": "test_wan22_i2v_stability_v1_videos",
- "server_params": {
- "model": "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
- "serve_args": {
- "ulysses-degree": 2,
- "vae-patch-parallel-size": 2,
- "tensor-parallel-size": 1,
- "use-hsdp": true,
- "vae-use-slicing": true,
- "vae-use-tiling": true
- }
- },
- "benchmark_params": [
- {
- "dataset": "random",
- "task": "i2v",
- "backend": "v1/videos",
- "duration_sec": 86400,
- "max_concurrency": 1,
- "num_prompts_per_batch": 20,
- "enable_negative_prompt": true,
- "random_request_config": [
- {"width": 832, "height": 480, "num_inference_steps": 4, "num_frames": 81, "fps": 16, "weight": 0.5},
- {"width": 1280, "height": 720, "num_inference_steps": 4, "num_frames": 121, "fps": 16, "weight": 0.5}
- ]
- }
- ]
- }
-]
diff --git a/tests/diffusion/cache/test_cache_dit.py b/tests/diffusion/cache/test_cache_dit.py
deleted file mode 100644
index 0b7ef723585..00000000000
--- a/tests/diffusion/cache/test_cache_dit.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""
-Model specific tests for CacheDiT enablement.
-"""
-
-from unittest.mock import Mock, patch
-
-import pytest
-
-import vllm_omni.diffusion.cache.cache_dit_backend as cd_backend
-from vllm_omni.diffusion.data import DiffusionCacheConfig
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-SEPARATE_CFG_ENABLERS = [
- cd_backend.enable_cache_for_ltx2,
- cd_backend.enable_cache_for_wan22,
- cd_backend.enable_cache_for_longcat_image,
-]
-
-SAMPLE_CACHE_CONFIG = DiffusionCacheConfig()
-
-
-@pytest.mark.parametrize("enabler", SEPARATE_CFG_ENABLERS)
-@patch("vllm_omni.diffusion.cache.cache_dit_backend.BlockAdapter")
-@patch("vllm_omni.diffusion.cache.cache_dit_backend.cache_dit")
-def test_separate_cfg(mock_cache_dit, mock_block_adapter, enabler):
- """Ensure that custom enablers for models with separate CFG pass
- the param through to cache_dit correctly.
-
- Regression test for: https://github.com/vllm-project/vllm-omni/pull/2860
- """
- mock_pipeline = Mock()
- enabler(mock_pipeline, SAMPLE_CACHE_CONFIG)
-
- mock_cache_dit.enable_cache.assert_called_once()
- adapter_kwargs = mock_block_adapter.call_args.kwargs
- assert adapter_kwargs["has_separate_cfg"] is True
diff --git a/tests/diffusion/cache/test_teacache_extractors.py b/tests/diffusion/cache/test_teacache_extractors.py
index 4bb958a36c1..5ba52ddfe2d 100644
--- a/tests/diffusion/cache/test_teacache_extractors.py
+++ b/tests/diffusion/cache/test_teacache_extractors.py
@@ -21,13 +21,12 @@
import pytest
import torch
-from tests.helpers.mark import hardware_test
-from vllm_omni.diffusion.cache.teacache.extractors import extract_flux2_context, extract_flux2_klein_context
+from vllm_omni.diffusion.cache.teacache.extractors import extract_flux2_klein_context
from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import (
Flux2Transformer2DModel,
)
-pytestmark = [pytest.mark.core_model]
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
@pytest.fixture(scope="function", autouse=True)
@@ -114,7 +113,6 @@ def sample_inputs(self):
def get_sample_inputs(self, sample_inputs):
return sample_inputs
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_modulated_input_shape(self, flux2_klein_module, sample_inputs):
"""Test that modulated_input has correct shape matching the model's inner_dim.
@@ -128,19 +126,16 @@ def test_modulated_input_shape(self, flux2_klein_module, sample_inputs):
inner_dim = flux2_klein_module.inner_dim
assert context.modulated_input.shape == (batch_size, img_seq_len, inner_dim)
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_run_transformer_blocks_callable(self, flux2_klein_module, sample_inputs):
"""Test that run_transformer_blocks is callable."""
context = extract_flux2_klein_context(flux2_klein_module, **sample_inputs)
assert callable(context.run_transformer_blocks)
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_postprocess_callable(self, flux2_klein_module, sample_inputs):
"""Test that postprocess is callable."""
context = extract_flux2_klein_context(flux2_klein_module, **sample_inputs)
assert callable(context.postprocess)
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_extra_states_contains_full_transformer(self, flux2_klein_module, sample_inputs):
"""Test that extra_states contains run_flux2_full_transformer_with_single."""
context = extract_flux2_klein_context(flux2_klein_module, **sample_inputs)
@@ -159,7 +154,6 @@ def test_without_guidance(self, flux2_klein_module, sample_inputs):
assert context is not None
assert context.temb is not None
- @pytest.mark.cpu
def test_invalid_module_raises_error(self):
"""Test that invalid module without transformer_blocks raises ValueError."""
invalid_module = Mock()
@@ -174,106 +168,3 @@ def test_invalid_module_raises_error(self):
img_ids=torch.randint(0, 64, (1, 1024, 4)),
txt_ids=torch.randint(0, 64, (1, 512, 4)),
)
-
-
-class TestFlux2Extractor(BaseExtractorTest):
- """Test extract_flux2_context function."""
-
- def get_extractor(self):
- return extract_flux2_context
-
- @pytest.fixture
- def flux2_module(self):
- """Create a minimal Flux2Transformer2DModel for testing."""
- from vllm_omni.diffusion.models.flux2.flux2_transformer import Flux2Transformer2DModel
-
- model = Flux2Transformer2DModel(
- num_layers=2,
- num_single_layers=2,
- num_attention_heads=48,
- attention_head_dim=128,
- joint_attention_dim=15360,
- )
- return model
-
- def get_module(self, flux2_module):
- return flux2_module
-
- @pytest.fixture
- def sample_inputs(self):
- """Create sample input tensors for Flux2.
-
- Note: hidden_states uses in_channels=128 (default for Flux2),
- not inner_dim=6144. The x_embedder projects from 128 -> 6144.
- encoder_hidden_states uses joint_attention_dim=15360 (model default),
- which then gets projected to inner_dim=6144 by context_embedder.
- """
- batch_size = 1
- img_seq_len = 1024
- txt_seq_len = 512
- in_channels = 128 # Model default in_channels
- txt_dim = 15360 # Model default joint_attention_dim
-
- return {
- "hidden_states": torch.randn(batch_size, img_seq_len, in_channels),
- "encoder_hidden_states": torch.randn(batch_size, txt_seq_len, txt_dim),
- "timestep": torch.tensor([500]),
- "img_ids": torch.randint(0, 64, (batch_size, img_seq_len, 4)),
- "txt_ids": torch.randint(0, 64, (batch_size, txt_seq_len, 4)),
- "guidance": torch.tensor([3.5]),
- }
-
- def get_sample_inputs(self, sample_inputs):
- return sample_inputs
-
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
- def test_modulated_input_shape(self, flux2_module, sample_inputs):
- """Test that modulated_input has correct shape matching the model's inner_dim.
-
- Note: After x_embedder projection, hidden_states are projected from
- in_channels (128) to inner_dim (6144), so modulated_input should match
- the projected shape, not the input shape.
- """
- context = extract_flux2_klein_context(flux2_module, **sample_inputs)
-
- batch_size, img_seq_len, _ = sample_inputs["hidden_states"].shape
- inner_dim = flux2_module.inner_dim
- assert context.modulated_input.shape == (batch_size, img_seq_len, inner_dim)
-
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
- def test_run_transformer_blocks_callable(self, flux2_module, sample_inputs):
- """Test that run_transformer_blocks is callable."""
- context = extract_flux2_context(flux2_module, **sample_inputs)
- assert callable(context.run_transformer_blocks)
-
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
- def test_postprocess_callable(self, flux2_module, sample_inputs):
- """Test that postprocess is callable."""
- context = extract_flux2_context(flux2_module, **sample_inputs)
- assert callable(context.postprocess)
-
- def test_without_guidance(self, flux2_module, sample_inputs):
- """Test context extraction works without guidance (no CFG)."""
- inputs = sample_inputs.copy()
- inputs["guidance"] = None
-
- context = extract_flux2_context(flux2_module, **inputs)
-
- assert context is not None
- assert context.temb is not None
-
- @pytest.mark.cpu
- def test_invalid_module_raises_error(self):
- """Test that invalid module without transformer_blocks raises ValueError."""
- invalid_module = Mock()
- invalid_module.transformer_blocks = []
-
- with pytest.raises(ValueError, match="Module must have transformer_blocks"):
- extract_flux2_context(
- invalid_module,
- hidden_states=torch.randn(1, 1024, 6144),
- encoder_hidden_states=torch.randn(1, 512, 15360),
- timestep=torch.tensor([500]),
- img_ids=torch.randint(0, 64, (1, 1024, 4)),
- txt_ids=torch.randint(0, 64, (1, 512, 4)),
- )
diff --git a/tests/diffusion/distributed/test_autoencoder_kl_wan.py b/tests/diffusion/distributed/test_autoencoder_kl_wan.py
deleted file mode 100644
index 2ea1c1214b8..00000000000
--- a/tests/diffusion/distributed/test_autoencoder_kl_wan.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import pytest
-import torch
-
-from vllm_omni.diffusion.distributed.autoencoders import autoencoder_kl_wan as wan_vae_module
-from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_wan import OmniAutoencoderKLWan
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-class _DummyOmniAutoencoderKLWan(OmniAutoencoderKLWan):
- def __init__(self, *, dtype: torch.dtype):
- torch.nn.Module.__init__(self)
- self.register_parameter("dummy_weight", torch.nn.Parameter(torch.ones(1, dtype=dtype)))
-
-
-def test_wan_vae_execution_context_handles_fp32():
- model = _DummyOmniAutoencoderKLWan(dtype=torch.float32)
- with model._execution_context():
- output = model.dummy_weight + 1
- assert output.dtype == torch.float32
-
-
-def test_wan_vae_execution_context_handles_bf16():
- model = _DummyOmniAutoencoderKLWan(dtype=torch.bfloat16)
- with model._execution_context():
- output = model.dummy_weight + 1
- assert output.dtype == torch.bfloat16
-
-
-def test_wan_vae_execution_context_uses_platform_autocast(mocker):
- sentinel = object()
- platform = mocker.Mock()
- platform.create_autocast_context.return_value = sentinel
- mocker.patch.object(wan_vae_module, "current_omni_platform", platform)
-
- model = _DummyOmniAutoencoderKLWan(dtype=torch.bfloat16)
-
- assert model._execution_context() is sentinel
- platform.create_autocast_context.assert_called_once_with(
- device_type=model.dummy_weight.device.type,
- dtype=torch.bfloat16,
- enabled=True,
- )
diff --git a/tests/diffusion/distributed/test_autoencoder_kl_wan_encode.py b/tests/diffusion/distributed/test_autoencoder_kl_wan_encode.py
deleted file mode 100644
index 7a18fa66da3..00000000000
--- a/tests/diffusion/distributed/test_autoencoder_kl_wan_encode.py
+++ /dev/null
@@ -1,273 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""Unit tests for DistributedAutoencoderKLWan encode parallel (CPU-only)."""
-
-import pytest
-import torch
-
-pytestmark = [pytest.mark.cpu, pytest.mark.core_model]
-
-
-class _DummyConfig:
- def __init__(self, patch_size=None, scale_factor_temporal=4):
- self.patch_size = patch_size
- self.scale_factor_temporal = scale_factor_temporal
-
-
-class _DummyWanVae:
- """Minimal mock of DistributedAutoencoderKLWan for testing encode_tile_split."""
-
- def __init__(
- self,
- config=None,
- spatial_compression_ratio=8,
- tile_sample_min_height=256,
- tile_sample_min_width=256,
- tile_sample_stride_height=192,
- tile_sample_stride_width=192,
- ):
- self.config = config or _DummyConfig()
- self.spatial_compression_ratio = spatial_compression_ratio
- self.tile_sample_min_height = tile_sample_min_height
- self.tile_sample_min_width = tile_sample_min_width
- self.tile_sample_stride_height = tile_sample_stride_height
- self.tile_sample_stride_width = tile_sample_stride_width
- self.dtype = torch.float32
-
- # Mock caches
- self._enc_feat_map = None
- self._enc_conv_idx = [0]
-
- def clear_cache(self):
- self._enc_feat_map = None
- self._enc_conv_idx = [0]
-
- def encoder(self, x, feat_cache=None, feat_idx=None): # noqa: ARG002
- # Simple mock: just return the input
- return x
-
- def quant_conv(self, x):
- return x
-
- def blend_v(self, _a, b, _blend_extent):
- return b
-
- def blend_h(self, _a, b, _blend_extent):
- return b
-
-
-def _import_encode_tile_split():
- """Import the encode_tile_split method from the module."""
- from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_wan import (
- DistributedAutoencoderKLWan,
- )
-
- return DistributedAutoencoderKLWan.encode_tile_split
-
-
-def _import_encode_tile_exec():
- from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_wan import (
- DistributedAutoencoderKLWan,
- )
-
- return DistributedAutoencoderKLWan.encode_tile_exec
-
-
-def _import_encode_tile_merge():
- from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_wan import (
- DistributedAutoencoderKLWan,
- )
-
- return DistributedAutoencoderKLWan.encode_tile_merge
-
-
-class TestEncodeTileSplit:
- """Tests for encode_tile_split method."""
-
- def test_basic_split_without_patch_size(self):
- """Test basic tile splitting without patch_size."""
- encode_tile_split = _import_encode_tile_split()
-
- vae = _DummyWanVae(
- config=_DummyConfig(patch_size=None, scale_factor_temporal=4),
- spatial_compression_ratio=8,
- tile_sample_min_height=256,
- tile_sample_min_width=256,
- tile_sample_stride_height=192,
- tile_sample_stride_width=192,
- )
-
- # Input: (B, C, T, H, W) = (1, 3, 5, 256, 256)
- x = torch.randn(1, 3, 5, 256, 256)
-
- tiletask_list, grid_spec = encode_tile_split(vae, x)
-
- # With stride 192 and input size 256, we should get:
- # Height: ceil(256/192) = 2 positions (0, 192) but 192+256 > 256, so only 1
- # Actually for i in range(0, 256, 192): i = 0, 192 but 192 is out of bounds
- # So we get 1x1 grid
- assert len(tiletask_list) >= 1
- assert grid_spec.grid_shape[0] >= 1
- assert grid_spec.grid_shape[1] >= 1
-
- # Check temporal chunking: 5 frames -> 1 + (5-1)//4 = 2 chunks
- first_task = tiletask_list[0]
- assert len(first_task.tensor) == 2 # 2 temporal chunks
-
- def test_split_with_patch_size_scales_coordinates(self):
- """Test that patch_size properly scales tile coordinates."""
- encode_tile_split = _import_encode_tile_split()
-
- # Without patch_size
- vae_no_patch = _DummyWanVae(
- config=_DummyConfig(patch_size=None, scale_factor_temporal=4),
- spatial_compression_ratio=8,
- tile_sample_min_height=256,
- tile_sample_min_width=256,
- tile_sample_stride_height=128,
- tile_sample_stride_width=128,
- )
-
- # With patch_size=2 (simulating patchified input)
- vae_with_patch = _DummyWanVae(
- config=_DummyConfig(patch_size=2, scale_factor_temporal=4),
- spatial_compression_ratio=8,
- tile_sample_min_height=256,
- tile_sample_min_width=256,
- tile_sample_stride_height=128,
- tile_sample_stride_width=128,
- )
-
- # Same patchified input size
- x = torch.randn(1, 3, 5, 256, 256)
-
- tasks_no_patch, _ = encode_tile_split(vae_no_patch, x)
- tasks_with_patch, _ = encode_tile_split(vae_with_patch, x)
-
- # With patch_size=2, stride becomes 128//2=64, so more tiles
- assert len(tasks_with_patch) >= len(tasks_no_patch)
-
- def test_temporal_compression_from_config(self):
- """Test that temporal compression ratio is read from config."""
- encode_tile_split = _import_encode_tile_split()
-
- # temporal_compression=4 (default)
- vae_4x = _DummyWanVae(
- config=_DummyConfig(scale_factor_temporal=4),
- tile_sample_min_height=512,
- tile_sample_min_width=512,
- tile_sample_stride_height=512,
- tile_sample_stride_width=512,
- )
-
- # temporal_compression=2
- vae_2x = _DummyWanVae(
- config=_DummyConfig(scale_factor_temporal=2),
- tile_sample_min_height=512,
- tile_sample_min_width=512,
- tile_sample_stride_height=512,
- tile_sample_stride_width=512,
- )
-
- # 9 frames input
- x = torch.randn(1, 3, 9, 512, 512)
-
- tasks_4x, _ = encode_tile_split(vae_4x, x)
- tasks_2x, _ = encode_tile_split(vae_2x, x)
-
- # With 4x compression: 1 + (9-1)//4 = 3 chunks
- assert len(tasks_4x[0].tensor) == 3
-
- # With 2x compression: 1 + (9-1)//2 = 5 chunks
- assert len(tasks_2x[0].tensor) == 5
-
- def test_grid_spec_latent_dimensions(self):
- """Test that grid_spec contains correct latent dimensions."""
- encode_tile_split = _import_encode_tile_split()
-
- vae = _DummyWanVae(
- config=_DummyConfig(patch_size=None),
- spatial_compression_ratio=8,
- tile_sample_min_height=512,
- tile_sample_min_width=512,
- tile_sample_stride_height=512,
- tile_sample_stride_width=512,
- )
-
- # Input: 512x512 with compression 8 -> 64x64 latent
- x = torch.randn(1, 3, 5, 512, 512)
-
- _, grid_spec = encode_tile_split(vae, x)
-
- assert grid_spec.tile_spec["latent_height"] == 64
- assert grid_spec.tile_spec["latent_width"] == 64
-
-
-class TestEncodeTileExec:
- """Tests for encode_tile_exec method."""
-
- def test_basic_exec(self):
- """Test basic tile execution."""
- encode_tile_exec = _import_encode_tile_exec()
-
- vae = _DummyWanVae()
-
- from vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor import (
- TileTask,
- )
-
- # Create a simple task with 2 temporal chunks
- tile1 = torch.randn(1, 3, 1, 32, 32)
- tile2 = torch.randn(1, 3, 4, 32, 32)
- task = TileTask(tile_id=0, grid_coord=(0, 0), tensor=[tile1, tile2])
-
- result = encode_tile_exec(vae, task)
-
- # Result should concatenate temporal dimension
- assert result.shape[2] == 5 # 1 + 4 frames
-
-
-class TestEncodeTileMerge:
- """Tests for encode_tile_merge method."""
-
- def test_basic_merge(self):
- """Test basic tile merging."""
- encode_tile_merge = _import_encode_tile_merge()
-
- vae = _DummyWanVae()
-
- from vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor import (
- GridSpec,
- )
-
- # Create 2x2 grid of tiles
- tile_00 = torch.ones(1, 16, 2, 32, 32) * 0
- tile_01 = torch.ones(1, 16, 2, 32, 32) * 1
- tile_10 = torch.ones(1, 16, 2, 32, 32) * 2
- tile_11 = torch.ones(1, 16, 2, 32, 32) * 3
-
- coord_tensor_map = {
- (0, 0): tile_00,
- (0, 1): tile_01,
- (1, 0): tile_10,
- (1, 1): tile_11,
- }
-
- grid_spec = GridSpec(
- split_dims=(3, 4),
- grid_shape=(2, 2),
- tile_spec={
- "latent_height": 48,
- "latent_width": 48,
- "blend_height": 8,
- "blend_width": 8,
- "tile_latent_stride_height": 24,
- "tile_latent_stride_width": 24,
- },
- )
-
- result = encode_tile_merge(vae, coord_tensor_map, grid_spec)
-
- # Output should be (1, 16, 2, 48, 48)
- assert result.shape == (1, 16, 2, 48, 48)
diff --git a/tests/diffusion/distributed/test_cfg_parallel.py b/tests/diffusion/distributed/test_cfg_parallel.py
index bf709618de2..79dbe9e6dd6 100644
--- a/tests/diffusion/distributed/test_cfg_parallel.py
+++ b/tests/diffusion/distributed/test_cfg_parallel.py
@@ -2,9 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for CFG (Classifier-Free Guidance) parallel functionality.
-This test verifies that predict_noise_maybe_with_cfg and
-predict_noise_with_multi_branch_cfg produce numerically equivalent results
-with and without CFG parallel using fixed random inputs.
+This test verifies that predict_noise_maybe_with_cfg produces numerically
+equivalent results with and without CFG parallel using fixed random inputs.
"""
import os
@@ -430,340 +429,3 @@ def test_predict_noise_without_cfg(dtype: torch.dtype):
assert noise_pred.shape == (1, 4, 16, 16)
print(f"✓ Test passed: predict_noise without CFG (dtype={dtype})")
-
-
-class MultiBranchTestPipeline(CFGParallelMixin):
- """Test pipeline with custom 3-branch combine logic (like OmniGen2)."""
-
- def __init__(self, in_channels: int = 4, hidden_dim: int = 128, seed: int = 42):
- torch.manual_seed(seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(seed)
-
- self.transformer = SimpleTransformer(in_channels, hidden_dim)
-
- torch.manual_seed(seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(seed)
- for param in self.transformer.parameters():
- torch.nn.init.normal_(param, mean=0.0, std=0.02)
-
- def combine_multi_branch_cfg_noise(self, predictions, true_cfg_scale, cfg_normalize=False):
- """N-branch combine with weighted sum for testing.
-
- - 2-branch: standard CFG formula (true_cfg_scale is float)
- - 3-branch: OmniGen2-style dual guidance scale (true_cfg_scale is dict)
- - 4-branch: DreamID-style weighted sum (true_cfg_scale is dict)
- """
- if len(predictions) == 4:
- text_scale = true_cfg_scale["text"]
- image_scale = true_cfg_scale["image"]
- vid_ref_scale = true_cfg_scale["vid_ref"]
- pos, neg, vid_neg, audio_neg = predictions
- combined = (
- audio_neg
- + vid_ref_scale * (vid_neg - audio_neg)
- + image_scale * (neg - vid_neg)
- + text_scale * (pos - neg)
- )
- elif len(predictions) == 3:
- text_scale = true_cfg_scale["text"]
- image_scale = true_cfg_scale["image"]
- pos, ref, uncond = predictions
- combined = uncond + image_scale * (ref - uncond) + text_scale * (pos - ref)
- else:
- pos, neg = predictions[0], predictions[1]
- combined = neg + true_cfg_scale * (pos - neg)
-
- if cfg_normalize:
- combined = self.cfg_normalize_function(pos, combined)
- return combined
-
-
-def _test_multi_branch_parallel_worker(
- local_rank: int,
- world_size: int,
- cfg_parallel_size: int,
- dtype: torch.dtype,
- test_config: dict,
- result_queue: torch.multiprocessing.Queue,
-):
- """Worker function for multi-branch CFG parallel test."""
- device = torch.device(f"{current_omni_platform.device_type}:{local_rank}")
- current_omni_platform.set_device(device)
-
- update_environment_variables(
- {
- "RANK": str(local_rank),
- "LOCAL_RANK": str(local_rank),
- "WORLD_SIZE": str(world_size),
- "MASTER_ADDR": "localhost",
- "MASTER_PORT": "29504",
- }
- )
-
- init_distributed_environment()
- initialize_model_parallel(cfg_parallel_size=cfg_parallel_size)
-
- cfg_rank = get_classifier_free_guidance_rank()
- cfg_world_size = get_classifier_free_guidance_world_size()
- assert cfg_world_size == cfg_parallel_size
-
- pipeline = MultiBranchTestPipeline(
- in_channels=test_config["channels"],
- hidden_dim=test_config["hidden_dim"],
- seed=test_config["model_seed"],
- )
- pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype)
- pipeline.transformer.eval()
-
- n_branches = test_config["n_branches"]
- batch_size = test_config["batch_size"]
- channels = test_config["channels"]
- height = test_config["height"]
- width = test_config["width"]
-
- # Create N branch inputs with distinct seeds
- branches_kwargs = []
- for b in range(n_branches):
- torch.manual_seed(test_config["input_seed"] + b)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(test_config["input_seed"] + b)
- x = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device)
- branches_kwargs.append({"x": x})
-
- with torch.no_grad():
- noise_pred = pipeline.predict_noise_with_multi_branch_cfg(
- do_true_cfg=True,
- true_cfg_scale=test_config["cfg_scale"],
- branches_kwargs=branches_kwargs,
- cfg_normalize=test_config["cfg_normalize"],
- )
-
- assert noise_pred is not None
- result_queue.put((cfg_rank, noise_pred.cpu()))
-
- destroy_distributed_env()
-
-
-def _test_multi_branch_sequential_worker(
- local_rank: int,
- world_size: int,
- dtype: torch.dtype,
- test_config: dict,
- result_queue: torch.multiprocessing.Queue,
-):
- """Worker function for sequential multi-branch CFG test (baseline)."""
- device = torch.device(f"{current_omni_platform.device_type}:{local_rank}")
- current_omni_platform.set_device(device)
-
- update_environment_variables(
- {
- "RANK": str(local_rank),
- "LOCAL_RANK": str(local_rank),
- "WORLD_SIZE": str(world_size),
- "MASTER_ADDR": "localhost",
- "MASTER_PORT": "29505",
- }
- )
-
- init_distributed_environment()
- initialize_model_parallel(cfg_parallel_size=1)
-
- cfg_world_size = get_classifier_free_guidance_world_size()
- assert cfg_world_size == 1
-
- pipeline = MultiBranchTestPipeline(
- in_channels=test_config["channels"],
- hidden_dim=test_config["hidden_dim"],
- seed=test_config["model_seed"],
- )
- pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype)
- pipeline.transformer.eval()
-
- n_branches = test_config["n_branches"]
- batch_size = test_config["batch_size"]
- channels = test_config["channels"]
- height = test_config["height"]
- width = test_config["width"]
-
- branches_kwargs = []
- for b in range(n_branches):
- torch.manual_seed(test_config["input_seed"] + b)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(test_config["input_seed"] + b)
- x = torch.randn(batch_size, channels, height, width, dtype=dtype, device=device)
- branches_kwargs.append({"x": x})
-
- with torch.no_grad():
- noise_pred = pipeline.predict_noise_with_multi_branch_cfg(
- do_true_cfg=True,
- true_cfg_scale=test_config["cfg_scale"],
- branches_kwargs=branches_kwargs,
- cfg_normalize=test_config["cfg_normalize"],
- )
-
- assert noise_pred is not None
- result_queue.put(noise_pred.cpu())
-
- destroy_distributed_env()
-
-
-@pytest.mark.parametrize(
- "cfg_parallel_size,n_branches",
- [
- (2, 2), # 2 branches on 2 GPUs: [[0],[1]]
- (2, 3), # 3 branches on 2 GPUs: [[0,2],[1]]
- (3, 3), # 3 branches on 3 GPUs: [[0],[1],[2]]
- (2, 4), # 4 branches on 2 GPUs: [[0,2],[1,3]]
- ],
-)
-@pytest.mark.parametrize("dtype", [torch.bfloat16])
-@pytest.mark.parametrize("batch_size", [2])
-@pytest.mark.parametrize("cfg_normalize", [False, True])
-def test_predict_noise_with_multi_branch_cfg(
- cfg_parallel_size: int,
- n_branches: int,
- dtype: torch.dtype,
- batch_size: int,
- cfg_normalize: bool,
-):
- """
- Test that predict_noise_with_multi_branch_cfg produces identical results
- with and without CFG parallel for N-branch models.
-
- Args:
- cfg_parallel_size: Number of GPUs for CFG parallel
- n_branches: Number of CFG branches
- dtype: Data type for computation
- batch_size: Batch size for testing
- cfg_normalize: Whether to normalize CFG output
- """
- available_gpus = current_omni_platform.get_device_count()
- if available_gpus < cfg_parallel_size:
- pytest.skip(f"Test requires {cfg_parallel_size} GPUs but only {available_gpus} available")
-
- if n_branches == 2:
- cfg_scale = 5.0
- elif n_branches == 3:
- cfg_scale = {"text": 5.0, "image": 2.0}
- else:
- cfg_scale = {"text": 5.0, "image": 2.0, "vid_ref": 1.5}
-
- test_config = {
- "batch_size": batch_size,
- "channels": 4,
- "height": 16,
- "width": 16,
- "hidden_dim": 128,
- "cfg_scale": cfg_scale,
- "cfg_normalize": cfg_normalize,
- "model_seed": 42,
- "input_seed": 123,
- "n_branches": n_branches,
- }
-
- mp_context = torch.multiprocessing.get_context("spawn")
- manager = mp_context.Manager()
- baseline_queue = manager.Queue()
- cfg_parallel_queue = manager.Queue()
-
- # Run baseline (sequential, cfgp=1)
- torch.multiprocessing.spawn(
- _test_multi_branch_sequential_worker,
- args=(1, dtype, test_config, baseline_queue),
- nprocs=1,
- )
-
- # Run CFG parallel
- torch.multiprocessing.spawn(
- _test_multi_branch_parallel_worker,
- args=(cfg_parallel_size, cfg_parallel_size, dtype, test_config, cfg_parallel_queue),
- nprocs=cfg_parallel_size,
- )
-
- baseline_output = baseline_queue.get()
- cfg_parallel_outputs = [cfg_parallel_queue.get() for _ in range(cfg_parallel_size)]
- cfg_parallel_outputs.sort(key=lambda item: item[0])
- cfg_parallel_output = cfg_parallel_outputs[0][1]
-
- # All ranks should produce identical output
- for cfg_rank, rank_output in cfg_parallel_outputs[1:]:
- torch.testing.assert_close(
- rank_output,
- cfg_parallel_output,
- rtol=0,
- atol=0,
- msg=f"Multi-branch CFG parallel ranks differ (rank 0 vs rank {cfg_rank})",
- )
-
- assert baseline_output.shape == cfg_parallel_output.shape, (
- f"Shape mismatch: baseline {baseline_output.shape} vs CFG parallel {cfg_parallel_output.shape}"
- )
-
- if dtype == torch.float32:
- rtol, atol = 1e-5, 1e-5
- elif dtype == torch.bfloat16:
- rtol, atol = 1e-2, 1e-2
- else:
- rtol, atol = 1e-3, 1e-3
-
- torch.testing.assert_close(
- cfg_parallel_output,
- baseline_output,
- rtol=rtol,
- atol=atol,
- msg=(
- f"Multi-branch CFG parallel output differs from sequential\n"
- f" n_branches={n_branches}, cfg_parallel_size={cfg_parallel_size}\n"
- f" dtype={dtype}, cfg_normalize={cfg_normalize}\n"
- f" Max diff: {(cfg_parallel_output - baseline_output).abs().max().item():.6e}"
- ),
- )
-
- print(
- f"✓ Test passed: multi_branch n_branches={n_branches}, "
- f"cfg_size={cfg_parallel_size}, dtype={dtype}, cfg_normalize={cfg_normalize}"
- )
-
-
-@pytest.mark.parametrize("dtype", [torch.bfloat16])
-def test_multi_branch_without_cfg(dtype: torch.dtype):
- """
- Test predict_noise_with_multi_branch_cfg when do_true_cfg=False.
-
- When CFG is disabled, only the first branch (positive) should be computed.
- This test runs on a single GPU without distributed environment.
- """
- available_gpus = current_omni_platform.get_device_count()
- if available_gpus < 1:
- pytest.skip("Test requires at least 1 GPU")
-
- device = torch.device(f"{current_omni_platform.device_type}:0")
- current_omni_platform.set_device(device)
-
- pipeline = MultiBranchTestPipeline(in_channels=4, hidden_dim=128, seed=42)
- pipeline.transformer = pipeline.transformer.to(device=device, dtype=dtype)
- pipeline.transformer.eval()
-
- # Create 3 branch inputs (only first should be used)
- branches_kwargs = []
- for b in range(3):
- torch.manual_seed(123 + b)
- if torch.cuda.is_available():
- torch.cuda.manual_seed_all(123 + b)
- x = torch.randn(1, 4, 16, 16, dtype=dtype, device=device)
- branches_kwargs.append({"x": x})
-
- with torch.no_grad():
- noise_pred = pipeline.predict_noise_with_multi_branch_cfg(
- do_true_cfg=False, # No CFG
- true_cfg_scale=5.0,
- branches_kwargs=branches_kwargs,
- cfg_normalize=False,
- )
-
- assert noise_pred is not None
- assert noise_pred.shape == (1, 4, 16, 16)
-
- print(f"✓ Test passed: multi_branch predict_noise without CFG (dtype={dtype})")
diff --git a/tests/diffusion/distributed/test_distributed_vae_executor.py b/tests/diffusion/distributed/test_distributed_vae_executor.py
index b2ee7c10d33..42e9f3300bc 100644
--- a/tests/diffusion/distributed/test_distributed_vae_executor.py
+++ b/tests/diffusion/distributed/test_distributed_vae_executor.py
@@ -1,4 +1,4 @@
-from types import SimpleNamespace
+from unittest.mock import MagicMock, patch
import pytest
import torch
@@ -11,8 +11,6 @@
TileTask,
)
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
class E2EOperator:
"""tiles with (2, 3) -- (H,W)"""
@@ -61,31 +59,40 @@ def merge(self, coord_tensor_map, grid_spec):
class DummyMixin(DistributedVaeMixin):
def __init__(self):
self.use_tiling = True
- self.distributed_executor = SimpleNamespace(parallel_size=2, group=None)
+ self.distributed_decoder = MagicMock()
+ self.distributed_decoder.parallel_size = 2
+ self.distributed_decoder.group = None
@pytest.fixture(autouse=True)
-def mock_dist(monkeypatch: pytest.MonkeyPatch):
- monkeypatch.setattr(dist, "get_world_size", lambda *args, **kwargs: 2)
- monkeypatch.setattr(dist, "get_rank", lambda *args, **kwargs: 0)
- monkeypatch.setattr(dist, "is_initialized", lambda: True)
- monkeypatch.setattr(dist, "all_reduce", lambda *args, **kwargs: None)
- monkeypatch.setattr(dist, "gather", lambda *args, **kwargs: None)
- monkeypatch.setattr(dist, "broadcast", lambda *args, **kwargs: None)
+def mock_dist():
+ with (
+ patch.object(dist, "get_world_size", return_value=2),
+ patch.object(dist, "get_rank", return_value=0),
+ patch.object(dist, "is_initialized", return_value=True),
+ patch.object(dist, "all_reduce", return_value=None),
+ patch.object(dist, "gather", return_value=None),
+ patch.object(dist, "broadcast", return_value=None),
+ ):
+ yield
@pytest.fixture(autouse=True)
-def mock_dit_group(monkeypatch: pytest.MonkeyPatch):
- monkeypatch.setattr(
+def mock_dit_group():
+ with patch(
"vllm_omni.diffusion.distributed.autoencoders.distributed_vae_executor.get_dit_group",
- lambda: None,
- )
+ new=MagicMock(return_value=None),
+ ):
+ yield
@pytest.fixture(autouse=True)
-def mock_dist_vae_executor(monkeypatch: pytest.MonkeyPatch):
- monkeypatch.setattr(DistributedVaeExecutor, "gather_tensors", lambda self, x: [x])
- monkeypatch.setattr(DistributedVaeExecutor, "broadcast_tensor", lambda self, x: x)
+def mock_dist_vae_executor():
+ with (
+ patch.object(DistributedVaeExecutor, "gather_tensors", side_effect=lambda x: [x]),
+ patch.object(DistributedVaeExecutor, "broadcast_tensor", side_effect=lambda x: x),
+ ):
+ yield
# ============================
diff --git a/tests/diffusion/distributed/test_ulysses_uaa_perf.py b/tests/diffusion/distributed/test_ulysses_uaa_perf.py
index 2a16a9ae578..c8b07ba152a 100644
--- a/tests/diffusion/distributed/test_ulysses_uaa_perf.py
+++ b/tests/diffusion/distributed/test_ulysses_uaa_perf.py
@@ -17,7 +17,6 @@
import torch
import torch.distributed as dist
-from tests.helpers.mark import hardware_test
from vllm_omni.diffusion.attention.parallel.ulysses import (
_all_gather_int,
_ulysses_all_to_all_any_o,
@@ -70,8 +69,6 @@ def world_size(self) -> int:
@pytest.mark.parametrize("case", PERF_CASES)
-@pytest.mark.core_model
-@hardware_test(res={"cuda": "L4"}, num_cards=4)
def test_ulysses_advanced_uaa_comm_overhead(case: _PerfCase) -> None:
available_gpus = current_omni_platform.get_device_count()
if available_gpus < case.world_size:
diff --git a/tests/diffusion/hooks/test_hook_registry.py b/tests/diffusion/hooks/test_hook_registry.py
deleted file mode 100644
index 6c8535cfec4..00000000000
--- a/tests/diffusion/hooks/test_hook_registry.py
+++ /dev/null
@@ -1,164 +0,0 @@
-"""
-Tests for hook registry.
-
-NOTE: The hook registry is also tested indirectly through a lot of
-other tests, e.g., tests/diffusion/distributed/test_sp_plan_hooks.py
-"""
-
-from typing import Any
-
-import pytest
-from torch import nn
-
-from vllm_omni.diffusion.hooks.base import HookRegistry, ModelHook
-
-DEFAULT_OUT = "ECHO"
-OVERRIDE_OUT = "OVERRIDE"
-INPUT_KWARG = "inp"
-
-
-class EchoModule(nn.Module):
- """Just echo the input."""
-
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- super().__init__(*args, **kwargs)
-
- def forward(self, *args, **kwargs):
- input_val = kwargs[INPUT_KWARG]
- return input_val + DEFAULT_OUT
-
-
-class AppendHook(ModelHook):
- """Append an echo value to the input string on pre / post forward."""
-
- def __init__(self, echo_val: str):
- self.echo_val = echo_val
-
- def pre_forward(self, module: nn.Module, *args, **kwargs):
- input_val = kwargs[INPUT_KWARG]
- return (), {INPUT_KWARG: input_val + self.echo_val}
-
- def post_forward(self, module: nn.Module, output):
- return output + self.echo_val
-
-
-class OverrideAppendHook(AppendHook):
- """Same as AppendHook, but replace the forward call with a different string."""
-
- def new_forward(self, module: nn.Module, *args, **kwargs):
- return kwargs[INPUT_KWARG] + OVERRIDE_OUT
-
-
-def test_register_no_fwd_override_hooks():
- """Ensure registration is correct with no forward hooks."""
- mod = EchoModule()
- registry = HookRegistry.get_or_create(mod)
- first_hook = AppendHook("1")
- second_hook = AppendHook("2")
- sorted_no_fwd_hooks = [first_hook, second_hook]
-
- # Will add and sort the hook by key
- registry.register_hook(name="b", hook=second_hook)
- registry.register_hook(name="a", hook=first_hook)
-
- assert len(registry._hooks) == 2
- assert len(registry._sorted_hooks) == 2
- assert registry._new_fwd_impl_hook is None
- # Ensure registering a new hook sorting alphabetically
- for actual_hook, expected_hook in zip(registry._sorted_hooks, sorted_no_fwd_hooks):
- assert actual_hook is expected_hook
-
-
-def test_register_with_forward_hooks():
- """Ensure registration is correct with a forward hooks."""
- mod = EchoModule()
- registry = HookRegistry.get_or_create(mod)
- first_hook = AppendHook("1")
- second_hook = AppendHook("2")
- exec_hook = OverrideAppendHook("3")
- sorted_no_fwd_hooks = [first_hook, second_hook]
-
- # Will add and sort the hook by key
- registry.register_hook(name="b", hook=second_hook)
- registry.register_hook(name="a", hook=first_hook)
- registry.register_hook(name="c", hook=exec_hook)
-
- assert len(registry._hooks) == 3
- assert len(registry._sorted_hooks) == 3
- assert registry._new_fwd_impl_hook is exec_hook
- # Ensure registering a new hook sorting alphabetically
- for actual_hook, expected_hook in zip(registry._sorted_hooks, sorted_no_fwd_hooks):
- assert actual_hook is expected_hook
-
-
-def test_register_fails_with_multiple_forward_hooks():
- """Ensure registration only allows one hook overriding new_forward"""
- mod = EchoModule()
- registry = HookRegistry.get_or_create(mod)
-
- registry.register_hook(name="foo", hook=OverrideAppendHook("1"))
- with pytest.raises(RuntimeError):
- registry.register_hook(name="bar", hook=OverrideAppendHook("2"))
-
-
-def test_remove_hooks():
- """Ensure removal sorts hooks."""
- mod = EchoModule()
- registry = HookRegistry.get_or_create(mod)
-
- first_hook = AppendHook("1")
- second_hook = AppendHook("2")
- exec_hook = OverrideAppendHook("3")
-
- registry.register_hook(name="b", hook=second_hook)
- registry.register_hook(name="a", hook=first_hook)
- registry.register_hook(name="c", hook=exec_hook)
- # Explicitly reorder our hooks to be in the wrong order, since register
- # forces them to be sorted too. Ensure that remove the hook will also
- # enforce the sorted order.
- registry._sorted_hooks = [second_hook, first_hook]
-
- assert registry._new_fwd_impl_hook is exec_hook
- registry.remove_hook("c")
- assert registry._new_fwd_impl_hook is None
-
- sorted_no_fwd_hooks = [first_hook, second_hook]
- for actual_hook, expected_hook in zip(registry._sorted_hooks, sorted_no_fwd_hooks):
- assert actual_hook is expected_hook
-
-
-def test_dispatch_no_fwd_override_hooks():
- """Ensure dispatch runs hooks in deterministic sorted order."""
- mod = EchoModule()
- registry = HookRegistry.get_or_create(mod)
-
- first_hook = AppendHook("1")
- second_hook = AppendHook("2")
-
- # Register will sort the hooks, so hook 1 will run first
- # on preprocess and last in post process
- registry.register_hook(name="2", hook=second_hook)
- registry.register_hook(name="1", hook=first_hook)
- res = registry.dispatch(inp="")
- assert isinstance(res, str)
- assert res == f"12{DEFAULT_OUT}21"
-
-
-def test_dispatch_with_fwd_hooks():
- """Ensure dispatch runs hooks in deterministic sorted order."""
- mod = EchoModule()
- registry = HookRegistry.get_or_create(mod)
-
- first_hook = AppendHook("1")
- second_hook = AppendHook("2")
- exec_hook = OverrideAppendHook("3")
-
- # Register will sort the hooks, so hook 1 will run first on preprocess and last in
- # post process. Since the override hook mutates forward, it will run last even
- # though the name of the exec_hook is alphabetically before the second hook.
- registry.register_hook(name="c", hook=second_hook)
- registry.register_hook(name="a", hook=first_hook)
- registry.register_hook(name="b", hook=exec_hook)
- res = registry.dispatch(inp="")
- assert isinstance(res, str)
- assert res == f"123{OVERRIDE_OUT}321"
diff --git a/tests/diffusion/layers/test_norm.py b/tests/diffusion/layers/test_norm.py
deleted file mode 100644
index e420415285d..00000000000
--- a/tests/diffusion/layers/test_norm.py
+++ /dev/null
@@ -1,453 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for LayerNorm and RMSNorm custom ops in diffusion layers."""
-
-import pytest
-import torch
-
-pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]
-
-
-# ── Import tests ──
-
-
-def test_layernorm_import():
- """Verify LayerNorm can be imported from the norm module."""
- from vllm_omni.diffusion.layers.norm import LayerNorm # noqa: F401
-
-
-def test_rmsnorm_import():
- """Verify RMSNorm can be imported from the norm module."""
- from vllm_omni.diffusion.layers.norm import RMSNorm # noqa: F401
-
-
-# ── LayerNorm tests ──
-
-
-def test_layernorm_forward_shape():
- """LayerNorm produces correct output shapes."""
- from vllm_omni.diffusion.layers.norm import LayerNorm
-
- dim = 64
- batch = 2
- seq_len = 4
- norm = LayerNorm(dim)
-
- x = torch.randn(batch, seq_len, dim)
- out = norm(x)
-
- assert out.shape == (batch, seq_len, dim)
-
-
-def test_layernorm_forward_shape_2d():
- """LayerNorm works with 2D input tensors."""
- from vllm_omni.diffusion.layers.norm import LayerNorm
-
- dim = 64
- batch = 2
- norm = LayerNorm(dim)
-
- x = torch.randn(batch, dim)
- out = norm(x)
-
- assert out.shape == (batch, dim)
-
-
-def test_layernorm_preserves_dtype_fp32():
- """LayerNorm preserves float32 dtype."""
- from vllm_omni.diffusion.layers.norm import LayerNorm
-
- dim = 64
- norm = LayerNorm(dim)
-
- x = torch.randn(2, 4, dim, dtype=torch.float32)
- out = norm(x)
-
- assert out.dtype == torch.float32
-
-
-def test_layernorm_preserves_dtype_fp16():
- """LayerNorm preserves float16 dtype."""
- from vllm_omni.diffusion.layers.norm import LayerNorm
-
- dim = 64
- norm = LayerNorm(dim)
-
- x = torch.randn(2, 4, dim, dtype=torch.float16)
- out = norm(x)
-
- assert out.dtype == torch.float16
-
-
-def test_layernorm_preserves_dtype_bf16():
- """LayerNorm preserves bfloat16 dtype."""
- from vllm_omni.diffusion.layers.norm import LayerNorm
-
- dim = 64
- norm = LayerNorm(dim)
-
- x = torch.randn(2, 4, dim, dtype=torch.bfloat16)
- out = norm(x)
-
- assert out.dtype == torch.bfloat16
-
-
-def test_layernorm_without_elementwise_affine():
- """LayerNorm works without elementwise_affine (no learned parameters)."""
- from vllm_omni.diffusion.layers.norm import LayerNorm
-
- dim = 64
- norm = LayerNorm(dim, elementwise_affine=False)
-
- assert norm.weight is None
- assert norm.bias is None
-
- x = torch.randn(2, 4, dim)
- out = norm(x)
-
- assert out.shape == (2, 4, dim)
-
-
-def test_layernorm_custom_eps():
- """LayerNorm accepts custom epsilon value."""
- from vllm_omni.diffusion.layers.norm import LayerNorm
-
- dim = 64
- eps = 1e-5
- norm = LayerNorm(dim, eps=eps)
-
- assert norm.eps == eps
-
-
-def test_layernorm_has_learnable_parameters():
- """LayerNorm has learnable weight and bias by default."""
- from vllm_omni.diffusion.layers.norm import LayerNorm
-
- dim = 64
- norm = LayerNorm(dim)
-
- assert norm.weight is not None
- assert norm.bias is not None
- assert norm.weight.shape == (dim,)
- assert norm.bias.shape == (dim,)
-
-
-def test_layernorm_matches_fp32_reference():
- """Verify LayerNorm produces identical output to FP32 nn.LayerNorm."""
- from vllm_omni.diffusion.layers.norm import LayerNorm
-
- dim = 64
- eps = 1e-6
- torch.manual_seed(42)
-
- ours = LayerNorm(dim, eps=eps)
- ref = torch.nn.LayerNorm(dim, eps=eps)
-
- # Copy weights
- ref.weight.data.copy_(ours.weight.data)
- ref.bias.data.copy_(ours.bias.data)
-
- x = torch.randn(2, 4, dim)
-
- out_ours = ours(x)
- out_ref = ref(x.float()).to(x.dtype)
-
- torch.testing.assert_close(out_ours, out_ref, atol=1e-5, rtol=1e-5)
-
-
-def test_layernorm_matches_diffusers_fp32layernorm():
- """Verify LayerNorm produces identical output to diffusers FP32LayerNorm."""
- from diffusers.models.normalization import FP32LayerNorm
-
- from vllm_omni.diffusion.layers.norm import LayerNorm
-
- dim = 64
- eps = 1e-6
- torch.manual_seed(42)
-
- ours = LayerNorm(dim, eps=eps)
- ref = FP32LayerNorm(dim, eps=eps)
-
- # Copy weights
- ref.weight.data.copy_(ours.weight.data)
- ref.bias.data.copy_(ours.bias.data)
-
- # Test with fp16 input to verify FP32 computation
- x = torch.randn(2, 4, dim, dtype=torch.float16)
-
- out_ours = ours(x)
- out_ref = ref(x)
-
- torch.testing.assert_close(out_ours, out_ref, atol=1e-3, rtol=1e-3)
-
-
-# ── RMSNorm tests ──
-
-
-def test_rmsnorm_forward_shape():
- """RMSNorm produces correct output shapes."""
- from vllm_omni.diffusion.layers.norm import RMSNorm
-
- hidden_size = 64
- batch = 2
- seq_len = 4
- norm = RMSNorm(hidden_size)
-
- x = torch.randn(batch, seq_len, hidden_size)
- out = norm(x)
-
- assert out.shape == (batch, seq_len, hidden_size)
-
-
-def test_rmsnorm_forward_shape_2d():
- """RMSNorm works with 2D input tensors."""
- from vllm_omni.diffusion.layers.norm import RMSNorm
-
- hidden_size = 64
- batch = 2
- norm = RMSNorm(hidden_size)
-
- x = torch.randn(batch, hidden_size)
- out = norm(x)
-
- assert out.shape == (batch, hidden_size)
-
-
-def test_rmsnorm_preserves_dtype_fp32():
- """RMSNorm preserves float32 dtype."""
- from vllm_omni.diffusion.layers.norm import RMSNorm
-
- hidden_size = 64
- norm = RMSNorm(hidden_size)
-
- x = torch.randn(2, 4, hidden_size, dtype=torch.float32)
- out = norm(x)
-
- assert out.dtype == torch.float32
-
-
-def test_rmsnorm_preserves_dtype_fp16():
- """RMSNorm preserves float16 dtype."""
- from vllm_omni.diffusion.layers.norm import RMSNorm
-
- hidden_size = 64
- norm = RMSNorm(hidden_size)
-
- x = torch.randn(2, 4, hidden_size, dtype=torch.float16)
- out = norm(x)
-
- assert out.dtype == torch.float16
-
-
-def test_rmsnorm_preserves_dtype_bf16():
- """RMSNorm preserves bfloat16 dtype."""
- from vllm_omni.diffusion.layers.norm import RMSNorm
-
- hidden_size = 64
- norm = RMSNorm(hidden_size)
-
- x = torch.randn(2, 4, hidden_size, dtype=torch.bfloat16)
- out = norm(x)
-
- assert out.dtype == torch.bfloat16
-
-
-def test_rmsnorm_custom_eps():
- """RMSNorm accepts custom epsilon value."""
- from vllm_omni.diffusion.layers.norm import RMSNorm
-
- hidden_size = 64
- eps = 1e-5
- norm = RMSNorm(hidden_size, eps=eps)
-
- assert norm.variance_epsilon == eps
-
-
-def test_rmsnorm_has_weight_parameter():
- """RMSNorm has learnable weight parameter initialized to ones."""
- from vllm_omni.diffusion.layers.norm import RMSNorm
-
- hidden_size = 64
- norm = RMSNorm(hidden_size)
-
- assert norm.weight is not None
- assert norm.weight.shape == (hidden_size,)
- torch.testing.assert_close(norm.weight, torch.ones(hidden_size))
-
-
-def test_rmsnorm_numerical_correctness():
- """Verify RMSNorm produces numerically correct output."""
- from vllm_omni.diffusion.layers.norm import RMSNorm
-
- hidden_size = 64
- eps = 1e-6
- torch.manual_seed(42)
-
- norm = RMSNorm(hidden_size, eps=eps)
- x = torch.randn(2, 4, hidden_size)
-
- # Compute expected output manually
- x_fp32 = x.to(torch.float32)
- variance = x_fp32.pow(2).mean(-1, keepdim=True)
- expected = x_fp32 * torch.rsqrt(variance + eps)
- expected = norm.weight.to(torch.float32) * expected
- expected = expected.to(x.dtype)
-
- out = norm(x)
-
- torch.testing.assert_close(out, expected, atol=1e-5, rtol=1e-5)
-
-
-def test_rmsnorm_matches_reference_implementation():
- """Verify RMSNorm matches a reference implementation."""
- from vllm_omni.diffusion.layers.norm import RMSNorm
-
- def reference_rmsnorm(x, weight, eps):
- """Reference RMSNorm implementation."""
- input_dtype = x.dtype
- x = x.to(torch.float32)
- variance = x.pow(2).mean(-1, keepdim=True)
- out = x * torch.rsqrt(variance + eps)
- out = weight.to(torch.float32) * out
- return out.to(input_dtype)
-
- hidden_size = 128
- eps = 1e-6
- torch.manual_seed(123)
-
- norm = RMSNorm(hidden_size, eps=eps)
-
- # Test with various dtypes
- for dtype in [torch.float32, torch.float16, torch.bfloat16]:
- x = torch.randn(4, 8, hidden_size, dtype=dtype)
- expected = reference_rmsnorm(x, norm.weight, eps)
- out = norm(x)
- torch.testing.assert_close(out, expected, atol=1e-3, rtol=1e-3)
-
-
-# ── CustomOp dispatch tests ──
-
-
-def test_layernorm_inherits_from_customop():
- """LayerNorm inherits from CustomOp for platform dispatch."""
- from vllm_omni.diffusion.layers.custom_op import CustomOp
- from vllm_omni.diffusion.layers.norm import LayerNorm
-
- norm = LayerNorm(64)
- assert isinstance(norm, CustomOp)
-
-
-def test_rmsnorm_inherits_from_customop():
- """RMSNorm inherits from CustomOp for platform dispatch."""
- from vllm_omni.diffusion.layers.custom_op import CustomOp
- from vllm_omni.diffusion.layers.norm import RMSNorm
-
- norm = RMSNorm(64)
- assert isinstance(norm, CustomOp)
-
-
-def test_layernorm_has_platform_methods():
- """LayerNorm has forward methods for each platform."""
- from vllm_omni.diffusion.layers.norm import LayerNorm
-
- norm = LayerNorm(64)
-
- assert hasattr(norm, "forward_cuda")
- assert hasattr(norm, "forward_hip")
- assert hasattr(norm, "forward_xpu")
- assert hasattr(norm, "forward_npu")
- assert hasattr(norm, "forward_native")
-
-
-def test_rmsnorm_has_platform_methods():
- """RMSNorm has forward methods for each platform."""
- from vllm_omni.diffusion.layers.norm import RMSNorm
-
- norm = RMSNorm(64)
-
- assert hasattr(norm, "forward_cuda")
- assert hasattr(norm, "forward_hip")
- assert hasattr(norm, "forward_xpu")
- assert hasattr(norm, "forward_npu")
- assert hasattr(norm, "forward_native")
-
-
-def test_layernorm_forward_native_directly():
- """LayerNorm.forward_native can be called directly."""
- from vllm_omni.diffusion.layers.norm import LayerNorm
-
- dim = 64
- norm = LayerNorm(dim)
- x = torch.randn(2, 4, dim)
-
- out = norm.forward_native(x)
-
- assert out.shape == (2, 4, dim)
-
-
-def test_rmsnorm_forward_native_directly():
- """RMSNorm.forward_native can be called directly."""
- from vllm_omni.diffusion.layers.norm import RMSNorm
-
- hidden_size = 64
- norm = RMSNorm(hidden_size)
- x = torch.randn(2, 4, hidden_size)
-
- out = norm.forward_native(x)
-
- assert out.shape == (2, 4, hidden_size)
-
-
-# ── Edge case tests ──
-
-
-def test_layernorm_with_large_dim():
- """LayerNorm works with large hidden dimensions."""
- from vllm_omni.diffusion.layers.norm import LayerNorm
-
- dim = 4096
- norm = LayerNorm(dim)
- x = torch.randn(1, 16, dim)
-
- out = norm(x)
-
- assert out.shape == (1, 16, dim)
-
-
-def test_rmsnorm_with_large_dim():
- """RMSNorm works with large hidden dimensions."""
- from vllm_omni.diffusion.layers.norm import RMSNorm
-
- hidden_size = 4096
- norm = RMSNorm(hidden_size)
- x = torch.randn(1, 16, hidden_size)
-
- out = norm(x)
-
- assert out.shape == (1, 16, hidden_size)
-
-
-def test_layernorm_with_single_element_batch():
- """LayerNorm works with batch size of 1."""
- from vllm_omni.diffusion.layers.norm import LayerNorm
-
- dim = 64
- norm = LayerNorm(dim)
- x = torch.randn(1, 1, dim)
-
- out = norm(x)
-
- assert out.shape == (1, 1, dim)
-
-
-def test_rmsnorm_with_single_element_batch():
- """RMSNorm works with batch size of 1."""
- from vllm_omni.diffusion.layers.norm import RMSNorm
-
- hidden_size = 64
- norm = RMSNorm(hidden_size)
- x = torch.randn(1, 1, hidden_size)
-
- out = norm(x)
-
- assert out.shape == (1, 1, hidden_size)
diff --git a/tests/diffusion/layers/test_rotary_emb_equivalence.py b/tests/diffusion/layers/test_rotary_emb_equivalence.py
deleted file mode 100644
index 2fbb7a31f5a..00000000000
--- a/tests/diffusion/layers/test_rotary_emb_equivalence.py
+++ /dev/null
@@ -1,112 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""
-Numerical equivalence tests for rotary embedding implementations (#2436).
-
-Verifies that the optimized stack+flatten RoPE produces bit-identical results
-to the original strided-slice implementation across various tensor shapes and
-dtypes, ensuring the refactor is safe.
-"""
-
-from __future__ import annotations
-
-import pytest
-import torch
-
-
-def _apply_rotary_emb_helios_original(
- hidden_states: torch.Tensor,
- freqs_cis: torch.Tensor,
-) -> torch.Tensor:
- """Original Helios RoPE using strided slice assignment (pre-#2436)."""
- x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
- cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
- out = torch.empty_like(hidden_states)
- out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2]
- out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2]
- return out.type_as(hidden_states)
-
-
-def _apply_rotary_emb_helios_optimized(
- hidden_states: torch.Tensor,
- freqs_cis: torch.Tensor,
-) -> torch.Tensor:
- """Optimized Helios RoPE using stack+flatten (post-#2436)."""
- x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
- cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
- rotated = torch.stack(
- (
- x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2],
- x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2],
- ),
- dim=-1,
- )
- return rotated.flatten(-2, -1).type_as(hidden_states)
-
-
-def _make_inputs(
- batch: int,
- seq_len: int,
- num_heads: int,
- head_dim: int,
- dtype: torch.dtype = torch.float32,
-) -> tuple[torch.Tensor, torch.Tensor]:
- """Generate random hidden_states and freqs_cis for testing."""
- torch.manual_seed(42)
- hidden_states = torch.randn(batch, seq_len, num_heads, head_dim, dtype=dtype)
- # freqs_cis: [B, seq, head_dim*2] — cos and sin concatenated along last dim
- freqs_cis = torch.randn(batch, seq_len, head_dim * 2, dtype=dtype)
- return hidden_states, freqs_cis
-
-
-class TestHeliosRoPEEquivalence:
- """Verify optimized Helios RoPE is numerically identical to original."""
-
- @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
- def test_equivalence_across_dtypes(self, dtype: torch.dtype) -> None:
- """Optimized output must be bit-identical to original across dtypes."""
- hidden, freqs = _make_inputs(2, 16, 8, 64, dtype=dtype)
- original = _apply_rotary_emb_helios_original(hidden, freqs)
- optimized = _apply_rotary_emb_helios_optimized(hidden, freqs)
- torch.testing.assert_close(optimized, original, atol=0, rtol=0)
-
- @pytest.mark.parametrize(
- "batch,seq_len,num_heads,head_dim",
- [
- (1, 8, 1, 32), # minimal: single batch, single head
- (2, 16, 8, 64), # typical transformer config
- (1, 8192, 4, 64), # video-scale patch tokens (720p DiT)
- (4, 32, 16, 128), # large head_dim
- ],
- )
- def test_equivalence_across_shapes(self, batch: int, seq_len: int, num_heads: int, head_dim: int) -> None:
- """Equivalence must hold across different tensor shapes."""
- hidden, freqs = _make_inputs(batch, seq_len, num_heads, head_dim)
- original = _apply_rotary_emb_helios_original(hidden, freqs)
- optimized = _apply_rotary_emb_helios_optimized(hidden, freqs)
- torch.testing.assert_close(optimized, original, atol=0, rtol=0)
-
- def test_output_contiguous(self) -> None:
- """Optimized output should be contiguous in memory."""
- hidden, freqs = _make_inputs(2, 16, 8, 64)
- optimized = _apply_rotary_emb_helios_optimized(hidden, freqs)
- assert optimized.is_contiguous()
-
- def test_output_shape_preserved(self) -> None:
- """Output shape must match input shape."""
- hidden, freqs = _make_inputs(2, 16, 8, 64)
- optimized = _apply_rotary_emb_helios_optimized(hidden, freqs)
- assert optimized.shape == hidden.shape
-
- def test_output_dtype_preserved(self) -> None:
- """Output dtype must match input dtype."""
- hidden, freqs = _make_inputs(2, 16, 8, 64, dtype=torch.float16)
- optimized = _apply_rotary_emb_helios_optimized(hidden, freqs)
- assert optimized.dtype == hidden.dtype
-
- def test_odd_head_dim_raises(self) -> None:
- """Odd head_dim should fail at unflatten (not a valid RoPE config)."""
- hidden = torch.randn(1, 4, 2, 63)
- freqs = torch.randn(1, 4, 126)
- with pytest.raises(RuntimeError):
- _apply_rotary_emb_helios_optimized(hidden, freqs)
diff --git a/tests/diffusion/lora/helpers.py b/tests/diffusion/lora/helpers.py
deleted file mode 100644
index 8b9b1ef4d20..00000000000
--- a/tests/diffusion/lora/helpers.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Shared test helpers for diffusion LoRA tests."""
-
-from __future__ import annotations
-
-import torch
-from vllm.model_executor.layers.linear import LinearBase
-
-
-class FakeLinearBase(LinearBase):
- """Minimal LinearBase stub for LoRA layer discovery."""
-
- def __init__(self):
- torch.nn.Module.__init__(self)
-
-
-class DummyBaseLayerWithLoRA(torch.nn.Module):
- """Fake LoRA wrapper that records set/reset/create calls."""
-
- def __init__(self, base_layer: torch.nn.Module):
- super().__init__()
- self.base_layer = base_layer
-
- self.set_calls: list[
- tuple[list[torch.Tensor | None] | torch.Tensor, list[torch.Tensor | None] | torch.Tensor]
- ] = []
- self.reset_calls: int = 0
- self.create_calls: int = 0
-
- def set_lora(self, index: int, lora_a, lora_b):
- assert index == 0
- self.set_calls.append((lora_a, lora_b))
-
- def reset_lora(self, index: int):
- assert index == 0
- self.reset_calls += 1
-
- def create_lora_weights(self, max_loras, lora_config, model_config):
- self.create_calls += 1
-
-
-def fake_replace_submodule(
- root: torch.nn.Module,
- module_name: str,
- submodule: torch.nn.Module,
- replace_calls: list[str] | None = None,
-) -> None:
- """Replace a submodule by traversing dotted paths correctly."""
- if replace_calls is not None:
- replace_calls.append(module_name)
- parts = module_name.split(".")
- parent = root
- for attr in parts[:-1]:
- parent = getattr(parent, attr)
- setattr(parent, parts[-1], submodule)
diff --git a/tests/diffusion/lora/test_lora_manager.py b/tests/diffusion/lora/test_lora_manager.py
index 785f5d84217..8d4a1487fd0 100644
--- a/tests/diffusion/lora/test_lora_manager.py
+++ b/tests/diffusion/lora/test_lora_manager.py
@@ -7,12 +7,8 @@
import torch
from vllm.lora.lora_weights import LoRALayerWeights
from vllm.lora.utils import get_supported_lora_modules
+from vllm.model_executor.layers.linear import LinearBase
-from tests.diffusion.lora.helpers import (
- DummyBaseLayerWithLoRA,
- FakeLinearBase,
- fake_replace_submodule,
-)
from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager
from vllm_omni.lora.request import LoRARequest
@@ -37,9 +33,35 @@ def reset_lora(self, index: int):
self.reset_calls += 1
-# Aliases for backward compatibility within this file
-_FakeLinearBase = FakeLinearBase
-_DummyBaseLayerWithLoRA = DummyBaseLayerWithLoRA
+class _FakeLinearBase(LinearBase):
+ def __init__(self):
+ torch.nn.Module.__init__(self)
+
+
+class _DummyBaseLayerWithLoRA(torch.nn.Module):
+ def __init__(self, base_layer: torch.nn.Module):
+ super().__init__()
+ self.base_layer = base_layer
+
+ self.set_calls: list[
+ tuple[list[torch.Tensor | None] | torch.Tensor, list[torch.Tensor | None] | torch.Tensor]
+ ] = []
+ self.reset_calls: int = 0
+ self.create_calls: int = 0
+
+ def set_lora(self, index: int, lora_a, lora_b):
+ assert index == 0
+ self.set_calls.append((lora_a, lora_b))
+
+ def reset_lora(self, index: int):
+ assert index == 0
+ self.reset_calls += 1
+
+ def create_lora_weights(self, max_loras, lora_config, model_config):
+ # Needs to be callable for scale test when rank changes, but not
+ # actually used since we mock everything and check everything based
+ # on set calls.
+ self.create_calls += 1
class _DummyPipeline(torch.nn.Module):
@@ -533,45 +555,3 @@ def _fake_load(_req: LoRARequest):
req1 = _dummy_lora_request(1)
with pytest.raises(ValueError):
manager.add_adapter(req1)
-
-
-def test_lora_manager_discovers_bagel_component(monkeypatch):
- """Verify that _replace_layers_with_lora finds layers under 'bagel'."""
- import vllm_omni.diffusion.lora.manager as manager_mod
-
- monkeypatch.setattr(manager_mod, "BaseLayerWithLoRA", _DummyBaseLayerWithLoRA)
-
- def _fake_from_layer_diffusion(*, layer: torch.nn.Module, **_kwargs):
- if isinstance(layer, _FakeLinearBase):
- return _DummyBaseLayerWithLoRA(layer)
- return layer
-
- replace_calls: list[str] = []
-
- monkeypatch.setattr(manager_mod, "from_layer_diffusion", _fake_from_layer_diffusion)
- monkeypatch.setattr(
- manager_mod,
- "replace_submodule",
- lambda root, name, sub: fake_replace_submodule(root, name, sub, replace_calls),
- )
-
- # Pipeline with a 'bagel' component (no 'transformer')
- pipeline = torch.nn.Module()
- pipeline.bagel = torch.nn.Module()
- pipeline.bagel.language_model = torch.nn.Module()
- pipeline.bagel.language_model.qkv_proj = _FakeLinearBase()
-
- manager = DiffusionLoRAManager(
- pipeline=pipeline,
- device=torch.device("cpu"),
- dtype=torch.bfloat16,
- max_cached_adapters=1,
- )
-
- peft_helper = type("_PH", (), {"r": 1})()
- manager._replace_layers_with_lora(peft_helper)
-
- assert "language_model.qkv_proj" in replace_calls
- assert "bagel.language_model.qkv_proj" in manager._lora_modules
- # Verify the module was actually replaced in the tree (not just recorded)
- assert isinstance(pipeline.bagel.language_model.qkv_proj, _DummyBaseLayerWithLoRA)
diff --git a/tests/diffusion/models/bagel/test_bagel_lora.py b/tests/diffusion/models/bagel/test_bagel_lora.py
deleted file mode 100644
index c285758fe86..00000000000
--- a/tests/diffusion/models/bagel/test_bagel_lora.py
+++ /dev/null
@@ -1,248 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for BAGEL LoRA support across Stage 0 (Thinker) and Stage 1 (DiT)."""
-
-from __future__ import annotations
-
-import json
-from pathlib import Path
-
-import pytest
-import torch
-from safetensors.torch import save_file
-
-from tests.diffusion.lora.helpers import (
- DummyBaseLayerWithLoRA,
- FakeLinearBase,
- fake_replace_submodule,
-)
-from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager
-from vllm_omni.lora.request import LoRARequest
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-_FakeLinearBase = FakeLinearBase
-
-
-# ---------------------------------------------------------------------------
-# Stage 0 (Thinker / AR) -- packed_modules_mapping on the AR model class
-# ---------------------------------------------------------------------------
-
-
-class TestStage0ThinkerLoRA:
- """Validate that OmniBagelForConditionalGeneration declares correct LoRA metadata."""
-
- def test_omni_bagel_supports_lora(self):
- from vllm_omni.model_executor.models.bagel.bagel import (
- OmniBagelForConditionalGeneration,
- )
-
- assert getattr(OmniBagelForConditionalGeneration, "supports_lora", False) is True
-
- def test_omni_bagel_packed_modules_mapping_complete(self):
- from vllm_omni.model_executor.models.bagel.bagel import (
- OmniBagelForConditionalGeneration,
- )
-
- mapping = OmniBagelForConditionalGeneration.packed_modules_mapping
- # Standard Qwen2 projections
- assert mapping["qkv_proj"] == ["q_proj", "k_proj", "v_proj"]
- assert mapping["gate_up_proj"] == ["gate_proj", "up_proj"]
- # MoE generation-mode projections
- assert mapping["qkv_proj_moe_gen"] == [
- "q_proj_moe_gen",
- "k_proj_moe_gen",
- "v_proj_moe_gen",
- ]
- assert mapping["mlp_moe_gen.gate_up_proj"] == [
- "mlp_moe_gen.gate_proj",
- "mlp_moe_gen.up_proj",
- ]
-
-
-# ---------------------------------------------------------------------------
-# Stage 1 (DiT / Diffusion) -- DiffusionLoRAManager with bagel component
-# ---------------------------------------------------------------------------
-
-
-class TestStage1DiTLoRA:
- """Validate DiffusionLoRAManager discovers BAGEL's packed modules."""
-
- def test_diffusion_lora_manager_discovers_bagel_packed_modules(self):
- """Manager should derive packed→sublayer mapping from stacked_params_mapping."""
- pipeline = torch.nn.Module()
- pipeline.bagel = torch.nn.Module()
-
- # Simulate a submodule that exposes stacked_params_mapping
- # (as Bagel does after load_weights())
- language_model = torch.nn.Module()
- language_model.stacked_params_mapping = [
- (".qkv_proj_moe_gen", ".q_proj_moe_gen", "q"),
- (".qkv_proj_moe_gen", ".k_proj_moe_gen", "k"),
- (".qkv_proj_moe_gen", ".v_proj_moe_gen", "v"),
- (".qkv_proj", ".q_proj", "q"),
- (".qkv_proj", ".k_proj", "k"),
- (".qkv_proj", ".v_proj", "v"),
- (".gate_up_proj", ".gate_proj", 0),
- (".gate_up_proj", ".up_proj", 1),
- ]
- pipeline.bagel.language_model = language_model
-
- manager = DiffusionLoRAManager(
- pipeline=pipeline,
- device=torch.device("cpu"),
- dtype=torch.bfloat16,
- max_cached_adapters=1,
- )
-
- mapping = manager._packed_modules_mapping
- assert mapping["qkv_proj"] == ["q_proj", "k_proj", "v_proj"]
- assert mapping["qkv_proj_moe_gen"] == [
- "q_proj_moe_gen",
- "k_proj_moe_gen",
- "v_proj_moe_gen",
- ]
- assert mapping["gate_up_proj"] == ["gate_proj", "up_proj"]
-
- def test_diffusion_lora_manager_replaces_bagel_packed_layer_via_sublayer_target(self, monkeypatch):
- """Targeting sublayer 'q_proj' should replace the fused 'qkv_proj' under bagel."""
- import vllm_omni.diffusion.lora.manager as manager_mod
-
- monkeypatch.setattr(manager_mod, "BaseLayerWithLoRA", DummyBaseLayerWithLoRA)
-
- def _fake_from_layer_diffusion(*, layer, **_kwargs):
- return DummyBaseLayerWithLoRA(layer)
-
- replace_calls: list[str] = []
-
- monkeypatch.setattr(manager_mod, "from_layer_diffusion", _fake_from_layer_diffusion)
- monkeypatch.setattr(
- manager_mod,
- "replace_submodule",
- lambda root, name, sub: fake_replace_submodule(root, name, sub, replace_calls),
- )
-
- # Build pipeline with bagel component
- pipeline = torch.nn.Module()
- pipeline.bagel = torch.nn.Module()
- lm = torch.nn.Module()
- lm.stacked_params_mapping = [
- (".qkv_proj", ".q_proj", "q"),
- (".qkv_proj", ".k_proj", "k"),
- (".qkv_proj", ".v_proj", "v"),
- ]
- lm.attn = torch.nn.Module()
- lm.attn.qkv_proj = _FakeLinearBase()
- pipeline.bagel.language_model = lm
-
- manager = DiffusionLoRAManager(
- pipeline=pipeline,
- device=torch.device("cpu"),
- dtype=torch.bfloat16,
- max_cached_adapters=1,
- )
-
- # Treat qkv_proj as 3-slice packed layer
- monkeypatch.setattr(manager, "_get_packed_modules_list", lambda _module: ["q", "k", "v"])
-
- # Target sublayer "q_proj" -- manager should replace the packed "qkv_proj"
- peft_helper = type("_PH", (), {"r": 1, "target_modules": ["q_proj"]})()
- manager._replace_layers_with_lora(peft_helper)
-
- assert "language_model.attn.qkv_proj" in replace_calls
- assert "bagel.language_model.attn.qkv_proj" in manager._lora_modules
- # Verify the module was actually replaced in the tree (not just recorded)
- assert isinstance(pipeline.bagel.language_model.attn.qkv_proj, DummyBaseLayerWithLoRA)
-
-
-# ---------------------------------------------------------------------------
-# Round-trip: synthetic checkpoint → set_active_adapter → verify weights
-# ---------------------------------------------------------------------------
-
-
-def _write_synthetic_lora(
- adapter_dir: Path,
- module_name: str,
- rank: int,
- in_dim: int,
- out_dim: int,
-) -> str:
- """Write a minimal LoRA adapter (safetensors + config) to *adapter_dir*."""
- adapter_dir.mkdir(parents=True, exist_ok=True)
- lora_a = torch.ones((rank, in_dim), dtype=torch.float32)
- lora_b = torch.ones((out_dim, rank), dtype=torch.float32) * 2.0
- save_file(
- {
- f"base_model.model.{module_name}.lora_A.weight": lora_a,
- f"base_model.model.{module_name}.lora_B.weight": lora_b,
- },
- str(adapter_dir / "adapter_model.safetensors"),
- )
- (adapter_dir / "adapter_config.json").write_text(
- json.dumps({"r": rank, "lora_alpha": rank, "target_modules": [module_name]}),
- encoding="utf-8",
- )
- return str(adapter_dir)
-
-
-class TestBagelLoRARoundTrip:
- """End-to-end: synthetic checkpoint → load → activate → verify weights in fused layer."""
-
- def test_set_active_adapter_loads_and_activates_bagel_lora(self, tmp_path, monkeypatch):
- """Full round-trip through set_active_adapter for a bagel component module."""
- import vllm_omni.diffusion.lora.manager as manager_mod
-
- monkeypatch.setattr(manager_mod, "BaseLayerWithLoRA", DummyBaseLayerWithLoRA)
-
- # Build pipeline with bagel.language_model.foo (simple non-packed layer)
- pipeline = torch.nn.Module()
- pipeline.bagel = torch.nn.Module()
- lm = torch.nn.Module()
- lm.foo = _FakeLinearBase()
- pipeline.bagel.language_model = lm
-
- def _fake_from_layer(*, layer, **_kwargs):
- if isinstance(layer, FakeLinearBase):
- return DummyBaseLayerWithLoRA(layer)
- return layer
-
- monkeypatch.setattr(manager_mod, "from_layer_diffusion", _fake_from_layer)
- monkeypatch.setattr(
- manager_mod,
- "replace_submodule",
- lambda root, name, sub: fake_replace_submodule(root, name, sub),
- )
-
- manager = DiffusionLoRAManager(
- pipeline=pipeline,
- device=torch.device("cpu"),
- dtype=torch.bfloat16,
- max_cached_adapters=1,
- )
-
- # Write synthetic adapter targeting bagel.language_model.foo
- module_name = "bagel.language_model.foo"
- rank = 2
- in_dim = 4
- out_dim = 4
- lora_dir = _write_synthetic_lora(tmp_path / "lora", module_name, rank, in_dim, out_dim)
-
- lora_request = LoRARequest(
- lora_name="test_bagel",
- lora_int_id=42,
- lora_path=lora_dir,
- )
-
- # Full round-trip: load from disk → replace layer → activate weights
- manager.set_active_adapter(lora_request, lora_scale=0.5)
-
- # Verify the layer was replaced and weights were set
- replaced_layer = pipeline.bagel.language_model.foo
- assert isinstance(replaced_layer, DummyBaseLayerWithLoRA), "Layer should be wrapped with LoRA"
- assert len(replaced_layer.set_calls) == 1, "set_lora should have been called once"
-
- lora_a, lora_b = replaced_layer.set_calls[0]
- # A weights should be ones (as written)
- assert torch.all(lora_a == 1.0), f"lora_a should be all ones, got {lora_a}"
- # B weights should be 2.0 * scale(0.5) = 1.0
- assert torch.allclose(lora_b, torch.ones_like(lora_b)), f"lora_b should be 2.0 * 0.5 = 1.0, got {lora_b}"
diff --git a/tests/diffusion/models/bagel/test_trajectory_recording.py b/tests/diffusion/models/bagel/test_trajectory_recording.py
deleted file mode 100644
index 345eac10784..00000000000
--- a/tests/diffusion/models/bagel/test_trajectory_recording.py
+++ /dev/null
@@ -1,244 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for BAGEL trajectory recording in the denoising loop."""
-
-import types
-from dataclasses import dataclass
-
-import pytest
-import torch
-from pytest_mock import MockerFixture
-
-from vllm_omni.diffusion.models.bagel.bagel_transformer import (
- Bagel,
- NaiveCache,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-NUM_TOKENS = 8
-HIDDEN_DIM = 16
-NUM_TIMESTEPS = 5
-# generate_image uses timesteps[:-1], so actual steps = NUM_TIMESTEPS - 1
-EXPECTED_STEPS = NUM_TIMESTEPS - 1
-
-
-def _make_mock_bagel(mocker: MockerFixture):
- """Create a mock Bagel with forward returning constant velocity."""
- mock = mocker.MagicMock(spec=Bagel)
- mock._sp_size = 1
-
- # forward returns a small constant velocity so x_t changes each step
- def fake_forward(self, x_t, **kwargs):
- return torch.ones_like(x_t) * 0.1
-
- mock.forward = types.MethodType(fake_forward, mock)
- # _merge_naive_caches is called in the batched CFG path
- mock._merge_naive_caches = types.MethodType(lambda self, caches: NaiveCache(1), mock)
-
- # Bind the real generate_image to our mock
- mock.generate_image = types.MethodType(Bagel.generate_image, mock)
- return mock
-
-
-def _make_generate_args(num_tokens=NUM_TOKENS, hidden_dim=HIDDEN_DIM, cfg=False):
- """Tensor arguments for generate_image.
-
- Args:
- cfg: If True, enable batched CFG path (cfg_text_scale > 1.0).
- """
- seq_len = num_tokens + 2 # packed_seqlens includes 2 extra tokens
- base = dict(
- packed_text_ids=torch.zeros(2, dtype=torch.long),
- packed_text_indexes=torch.tensor([0, 1], dtype=torch.long),
- packed_init_noises=torch.randn(num_tokens, hidden_dim),
- packed_vae_position_ids=torch.arange(num_tokens, dtype=torch.long),
- packed_vae_token_indexes=torch.arange(2, seq_len, dtype=torch.long),
- packed_seqlens=torch.tensor([seq_len], dtype=torch.int),
- packed_position_ids=torch.arange(seq_len, dtype=torch.long),
- packed_indexes=torch.arange(seq_len, dtype=torch.long),
- past_key_values=NaiveCache(1),
- key_values_lens=torch.tensor([0], dtype=torch.int),
- packed_key_value_indexes=torch.zeros(0, dtype=torch.long),
- num_timesteps=NUM_TIMESTEPS,
- timestep_shift=1.0,
- cfg_text_scale=1.0,
- cfg_img_scale=1.0,
- )
- if cfg:
- base |= dict(
- cfg_text_scale=4.0,
- cfg_text_packed_query_indexes=torch.arange(seq_len, dtype=torch.long),
- cfg_text_packed_position_ids=torch.arange(seq_len, dtype=torch.long),
- cfg_text_past_key_values=NaiveCache(1),
- cfg_text_key_values_lens=torch.tensor([0], dtype=torch.int),
- cfg_text_packed_key_value_indexes=torch.zeros(0, dtype=torch.long),
- )
- return base
-
-
-@pytest.fixture(params=[False, True], ids=["no_cfg", "batched_cfg"])
-def bagel_and_args(
- request,
- monkeypatch: pytest.MonkeyPatch,
- mocker: MockerFixture,
-):
- """Mock Bagel instance and generate_image arguments.
-
- Parametrized over CFG mode so every test runs on both the no-CFG
- and batched-CFG code paths.
- """
- cfg = request.param
- monkeypatch.setattr(
- "vllm_omni.diffusion.models.bagel.bagel_transformer.get_classifier_free_guidance_world_size",
- lambda: 1,
- )
- yield _make_mock_bagel(mocker), _make_generate_args(cfg=cfg)
-
-
-class TestTrajectoryRecording:
- """Tests for trajectory latent/timestep recording in generate_image."""
-
- def test_trajectory_disabled_returns_none(self, bagel_and_args):
- bagel, args = bagel_and_args
-
- unpacked, trajectory_latents, trajectory_timesteps, trajectory_log_probs = bagel.generate_image(
- **args, return_trajectory_latents=False
- )
-
- assert isinstance(unpacked, (list, tuple))
- assert len(unpacked) == 1 # one sequence
- assert trajectory_latents is None
- assert trajectory_timesteps is None
- assert trajectory_log_probs is None
-
- def test_trajectory_enabled_returns_correct_count(self, bagel_and_args):
- bagel, args = bagel_and_args
-
- _, trajectory_latents, trajectory_timesteps, trajectory_log_probs = bagel.generate_image(
- **args, return_trajectory_latents=True
- )
-
- assert trajectory_latents is not None
- assert trajectory_timesteps is not None
- assert len(trajectory_latents) == EXPECTED_STEPS
- assert len(trajectory_timesteps) == EXPECTED_STEPS
- # log_probs is None without a scheduler (default ODE path)
- assert trajectory_log_probs is None
-
- def test_trajectory_latents_shape_matches_input(self, bagel_and_args):
- bagel, args = bagel_and_args
- expected_shape = args["packed_init_noises"].shape
-
- _, trajectory_latents, *_ = bagel.generate_image(**args, return_trajectory_latents=True)
-
- for i, lat in enumerate(trajectory_latents):
- assert lat.shape == expected_shape, f"Step {i}: expected {expected_shape}, got {lat.shape}"
-
- def test_trajectory_latents_are_distinct(self, bagel_and_args):
- bagel, args = bagel_and_args
-
- _, trajectory_latents, *_ = bagel.generate_image(**args, return_trajectory_latents=True)
-
- for i in range(1, len(trajectory_latents)):
- assert not torch.equal(trajectory_latents[i], trajectory_latents[i - 1]), (
- f"Steps {i - 1} and {i} should differ"
- )
-
- def test_trajectory_timesteps_are_decreasing(self, bagel_and_args):
- bagel, args = bagel_and_args
-
- _, _, trajectory_timesteps, _ = bagel.generate_image(**args, return_trajectory_latents=True)
-
- for i in range(1, len(trajectory_timesteps)):
- assert trajectory_timesteps[i] < trajectory_timesteps[i - 1], (
- f"Timestep {i} ({trajectory_timesteps[i]:.4f}) should be less than "
- f"timestep {i - 1} ({trajectory_timesteps[i - 1]:.4f})"
- )
-
- def test_trajectory_final_latent_matches_output(self, bagel_and_args):
- bagel, args = bagel_and_args
-
- unpacked, trajectory_latents, *_ = bagel.generate_image(**args, return_trajectory_latents=True)
-
- # Reconstruct the full final latent from unpacked pieces
- final_latent = torch.cat(unpacked, dim=0)
- assert torch.allclose(trajectory_latents[-1], final_latent, atol=1e-6), (
- "Last trajectory latent should match the final output"
- )
-
-
-# ---------------------------------------------------------------------------
-# Mock scheduler for log-prob tests
-# ---------------------------------------------------------------------------
-
-
-@dataclass
-class _MockStepOutput:
- prev_sample: torch.Tensor
- log_prob: torch.Tensor
-
-
-class _MockScheduler:
- """Minimal scheduler: Euler step + constant log-prob per step."""
-
- def step(self, model_output, sigma, sample, dt, **kwargs):
- prev_sample = sample - model_output * dt
- log_prob = torch.tensor(-1.0)
- return _MockStepOutput(prev_sample=prev_sample, log_prob=log_prob)
-
-
-class TestTrajectoryLogProbs:
- """Tests for log-prob recording when a scheduler is provided."""
-
- @pytest.fixture()
- def bagel_scheduler_args(
- self,
- monkeypatch: pytest.MonkeyPatch,
- mocker: MockerFixture,
- ):
- monkeypatch.setattr(
- "vllm_omni.diffusion.models.bagel.bagel_transformer.get_classifier_free_guidance_world_size",
- lambda: 1,
- )
- yield _make_mock_bagel(mocker), _make_generate_args(), _MockScheduler()
-
- def test_log_probs_recorded_with_scheduler(self, bagel_scheduler_args):
- bagel, args, scheduler = bagel_scheduler_args
-
- _, _, _, trajectory_log_probs = bagel.generate_image(
- **args, return_trajectory_latents=True, scheduler=scheduler
- )
-
- assert trajectory_log_probs is not None
- assert len(trajectory_log_probs) == EXPECTED_STEPS
-
- def test_log_probs_are_finite(self, bagel_scheduler_args):
- bagel, args, scheduler = bagel_scheduler_args
-
- _, _, _, trajectory_log_probs = bagel.generate_image(
- **args, return_trajectory_latents=True, scheduler=scheduler
- )
-
- for i, lp in enumerate(trajectory_log_probs):
- assert torch.isfinite(lp).all(), f"Step {i}: log_prob is not finite"
-
- def test_log_probs_none_without_scheduler(self, bagel_scheduler_args):
- bagel, args, _ = bagel_scheduler_args
-
- _, _, _, trajectory_log_probs = bagel.generate_image(**args, return_trajectory_latents=True, scheduler=None)
-
- assert trajectory_log_probs is None
-
- def test_scheduler_updates_latents(self, bagel_scheduler_args):
- """Verify the scheduler's prev_sample is used (not the raw Euler step)."""
- bagel, args, scheduler = bagel_scheduler_args
-
- _, traj_with_sched, *_ = bagel.generate_image(**args, return_trajectory_latents=True, scheduler=scheduler)
- _, traj_without, *_ = bagel.generate_image(**args, return_trajectory_latents=True, scheduler=None)
-
- # Mock scheduler does the same Euler step, so latents should match
- for i in range(len(traj_with_sched)):
- assert torch.allclose(traj_with_sched[i], traj_without[i], atol=1e-5), (
- f"Step {i}: scheduler and ODE paths should produce same latents"
- )
diff --git a/tests/diffusion/models/dmd2/__init__.py b/tests/diffusion/models/dmd2/__init__.py
deleted file mode 100644
index e69de29bb2d..00000000000
diff --git a/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py b/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py
deleted file mode 100644
index e270390bd99..00000000000
--- a/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py
+++ /dev/null
@@ -1,180 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from unittest.mock import MagicMock, patch
-
-import pytest
-import torch
-
-from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import LTX2Pipeline, LTX2T2VDMD2Pipeline
-from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import LTX2I2VDMD2Pipeline, LTX2ImageToVideoPipeline
-from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline, WanT2VDMD2Pipeline
-from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import Wan22I2VPipeline, WanI2VDMD2Pipeline
-from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-# DMD2 subclass → immediate base pipeline whose __init__ loads model weights (mocked in tests).
-_DMD2_BASE = {
- WanT2VDMD2Pipeline: Wan22Pipeline,
- WanI2VDMD2Pipeline: Wan22I2VPipeline,
- LTX2T2VDMD2Pipeline: LTX2Pipeline,
- LTX2I2VDMD2Pipeline: LTX2ImageToVideoPipeline,
-}
-
-
-def _make_pipeline(cls):
- """Run the DMD2 __init__ with the base pipeline mocked out (no model weights loaded)."""
-
- base = _DMD2_BASE[cls]
- od_config = MagicMock()
- od_config.model = "/nonexistent"
-
- def _mock_base_init(self, *a, **kw):
- self.od_config = od_config
-
- with patch.object(base, "__init__", _mock_base_init):
- pipeline = object.__new__(cls)
- torch.nn.Module.__init__(pipeline)
- cls.__init__(pipeline, od_config=od_config)
- return pipeline
-
-
-def _make_request(prompts=None, **sp_kwargs) -> OmniDiffusionRequest:
- sp = OmniDiffusionSamplingParams(**sp_kwargs)
- return OmniDiffusionRequest(
- prompts=prompts or [{"prompt": "a cat dancing"}],
- sampling_params=sp,
- )
-
-
-@pytest.fixture(
- params=list(_DMD2_BASE.keys()),
- ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v"],
-)
-def pipeline(request):
- return _make_pipeline(request.param)
-
-
-# ---------------------------------------------------------------------------
-# num_inference_steps
-# ---------------------------------------------------------------------------
-
-
-def test_num_inference_steps_forced_to_dmd2_value(pipeline):
- req = _make_request(num_inference_steps=40)
- pipeline._sanitize_dmd2_request(req)
- assert req.sampling_params.num_inference_steps == pipeline.num_inference_steps
-
-
-def test_num_inference_steps_already_correct(pipeline):
- req = _make_request(num_inference_steps=pipeline.num_inference_steps)
- pipeline._sanitize_dmd2_request(req)
- assert req.sampling_params.num_inference_steps == pipeline.num_inference_steps
-
-
-# ---------------------------------------------------------------------------
-# guidance_scale
-# ---------------------------------------------------------------------------
-
-
-def test_guidance_scale_forced_to_one(pipeline):
- req = _make_request(guidance_scale=5.0, guidance_scale_provided=True)
- pipeline._sanitize_dmd2_request(req)
- assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale
- assert req.sampling_params.guidance_scale_provided is False
-
-
-def test_guidance_scale_already_correct(pipeline):
- req = _make_request(guidance_scale=pipeline.dmd2_guidance_scale, guidance_scale_provided=False)
- pipeline._sanitize_dmd2_request(req)
- assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale
-
-
-def test_guidance_scale_provided_flag_cleared(pipeline):
- """guidance_scale_provided=True must be cleared even if scale is already dmd2_guidance_scale."""
- req = _make_request(guidance_scale=pipeline.dmd2_guidance_scale, guidance_scale_provided=True)
- pipeline._sanitize_dmd2_request(req)
- assert req.sampling_params.guidance_scale_provided is False
-
-
-def test_guidance_scale_2_cleared(pipeline):
- req = _make_request(guidance_scale_2=3.0)
- pipeline._sanitize_dmd2_request(req)
- assert req.sampling_params.guidance_scale_2 is None
-
-
-def test_guidance_scale_2_unset_unchanged(pipeline):
- req = _make_request()
- pipeline._sanitize_dmd2_request(req)
- assert req.sampling_params.guidance_scale_2 is None
-
-
-def test_true_cfg_scale_cleared(pipeline):
- req = _make_request(true_cfg_scale=2.0)
- pipeline._sanitize_dmd2_request(req)
- assert req.sampling_params.true_cfg_scale is None
-
-
-def test_do_classifier_free_guidance_forced_false(pipeline):
- req = _make_request(do_classifier_free_guidance=True)
- pipeline._sanitize_dmd2_request(req)
- assert req.sampling_params.do_classifier_free_guidance is False
-
-
-def test_is_cfg_negative_forced_false(pipeline):
- req = _make_request(is_cfg_negative=True)
- pipeline._sanitize_dmd2_request(req)
- assert req.sampling_params.is_cfg_negative is False
-
-
-def test_negative_prompt_stripped_from_prompt_dict(pipeline):
- req = _make_request(prompts=[{"prompt": "a cat", "negative_prompt": "blurry"}])
- pipeline._sanitize_dmd2_request(req)
- assert "negative_prompt" not in req.prompts[0]
- assert req.prompts[0]["prompt"] == "a cat"
-
-
-def test_no_negative_prompt_unchanged(pipeline):
- req = _make_request(prompts=[{"prompt": "a cat"}])
- pipeline._sanitize_dmd2_request(req)
- assert req.prompts[0] == {"prompt": "a cat"}
-
-
-def test_string_prompt_not_mutated(pipeline):
- """String prompts (not dicts) must pass through unchanged."""
- req = _make_request(prompts=["a cat dancing"])
- pipeline._sanitize_dmd2_request(req)
- assert req.prompts == ["a cat dancing"]
-
-
-def test_multiple_prompts_all_sanitized(pipeline):
- req = _make_request(
- prompts=[
- {"prompt": "a cat", "negative_prompt": "blurry"},
- {"prompt": "a dog", "negative_prompt": "ugly"},
- ]
- )
- pipeline._sanitize_dmd2_request(req)
- for p in req.prompts:
- assert "negative_prompt" not in p
-
-
-# ---------------------------------------------------------------------------
-# Clean request — nothing changes
-# ---------------------------------------------------------------------------
-
-
-def test_clean_request_no_changes(pipeline):
- req = _make_request(
- guidance_scale=pipeline.dmd2_guidance_scale,
- guidance_scale_provided=False,
- do_classifier_free_guidance=False,
- is_cfg_negative=False,
- )
- pipeline._sanitize_dmd2_request(req)
- assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale
- assert req.sampling_params.guidance_scale_provided is False
- assert req.sampling_params.guidance_scale_2 is None
- assert req.sampling_params.true_cfg_scale is None
- assert req.sampling_params.do_classifier_free_guidance is False
- assert req.sampling_params.is_cfg_negative is False
diff --git a/tests/diffusion/models/dmd2/test_dmd2_scheduler.py b/tests/diffusion/models/dmd2/test_dmd2_scheduler.py
deleted file mode 100644
index 32d00dbf18e..00000000000
--- a/tests/diffusion/models/dmd2/test_dmd2_scheduler.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from unittest.mock import MagicMock, patch
-
-import pytest
-import torch
-
-from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import LTX2Pipeline, LTX2T2VDMD2Pipeline
-from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import LTX2I2VDMD2Pipeline, LTX2ImageToVideoPipeline
-from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline, WanT2VDMD2Pipeline
-from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import Wan22I2VPipeline, WanI2VDMD2Pipeline
-from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-_DMD2_TIMESTEPS = [999, 937, 833, 624]
-
-# DMD2 subclass → immediate base pipeline whose __init__ loads model weights (mocked in tests).
-_DMD2_BASE = {
- WanT2VDMD2Pipeline: Wan22Pipeline,
- WanI2VDMD2Pipeline: Wan22I2VPipeline,
- LTX2T2VDMD2Pipeline: LTX2Pipeline,
- LTX2I2VDMD2Pipeline: LTX2ImageToVideoPipeline,
-}
-
-
-def _make_pipeline(cls):
- """Run the DMD2 __init__ (including __init_dmd2__) with the base pipeline mocked."""
-
- base = _DMD2_BASE[cls]
- od_config = MagicMock()
- od_config.model = "/nonexistent"
-
- def _mock_base_init(self, *a, **kw):
- self.od_config = od_config # __init_dmd2__ needs this
-
- with patch.object(base, "__init__", _mock_base_init):
- pipeline = object.__new__(cls)
- torch.nn.Module.__init__(pipeline)
- cls.__init__(pipeline, od_config=od_config)
- return pipeline
-
-
-def _make_request(**sp_kwargs) -> OmniDiffusionRequest:
- sp = OmniDiffusionSamplingParams(**sp_kwargs)
- return OmniDiffusionRequest(prompts=[{"prompt": "a cat"}], sampling_params=sp)
-
-
-@pytest.fixture(
- params=list(_DMD2_BASE.keys()),
- ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v"],
-)
-def pipeline(request):
- return _make_pipeline(request.param)
-
-
-# ---------------------------------------------------------------------------
-# forward() timestep injection
-# ---------------------------------------------------------------------------
-
-
-def _fake_parent_forward(self, req, *args, num_inference_steps=40, **kwargs):
- """Stub that calls set_timesteps as the real parent does."""
- self.scheduler.set_timesteps(num_inference_steps, device="cpu")
- return MagicMock()
-
-
-def test_forward_timesteps_match_dmd2_schedule(pipeline):
- """After forward() runs, scheduler.timesteps must equal the DMD2 training schedule."""
- parent = _DMD2_BASE[type(pipeline)]
-
- # Baseline: calling set_timesteps(40) without the DMD2 override gives a different schedule
- pipeline.scheduler.set_timesteps(40, device="cpu")
- default_timesteps = pipeline.scheduler.timesteps.long().tolist()
- assert default_timesteps == _DMD2_TIMESTEPS, (
- "DMD2EulerScheduler should always return DMD2 timesteps regardless of num_steps"
- )
-
- with patch.object(parent, "forward", _fake_parent_forward):
- pipeline.forward(_make_request())
-
- assert pipeline.scheduler.timesteps.long().tolist() == _DMD2_TIMESTEPS
-
-
-def test_forward_timesteps_idempotent_across_calls(pipeline):
- """Successive forward() calls must not cause scheduler state to drift."""
- parent = _DMD2_BASE[type(pipeline)]
-
- with patch.object(parent, "forward", _fake_parent_forward):
- pipeline.forward(_make_request())
- pipeline.forward(_make_request())
-
- assert pipeline.scheduler.timesteps.long().tolist() == _DMD2_TIMESTEPS
diff --git a/tests/diffusion/models/flux2/test_flux2_transformer_tp.py b/tests/diffusion/models/flux2/test_flux2_transformer_tp.py
index c613bb0b4c8..a2d1fe6abd3 100644
--- a/tests/diffusion/models/flux2/test_flux2_transformer_tp.py
+++ b/tests/diffusion/models/flux2/test_flux2_transformer_tp.py
@@ -1,8 +1,8 @@
+from unittest.mock import MagicMock, patch
+
import pytest
import torch
-from pytest_mock import MockerFixture
-from tests.helpers.mark import hardware_test
from vllm_omni.diffusion.models.flux2.flux2_transformer import (
Flux2PosEmbed,
Flux2Transformer2DModel,
@@ -11,24 +11,19 @@
# Initialize TP group before tests
@pytest.fixture(scope="function", autouse=True)
-def setup_tp_group(mocker: MockerFixture):
+def setup_tp_group():
"""Set up TP group for each test function"""
- mocker.patch(
- "vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size",
- return_value=2,
- )
- mock_get_tp_group = mocker.patch("vllm.distributed.parallel_state.get_tp_group")
- mock_tp_group = mocker.MagicMock()
- mock_tp_group.world_size = 2
- mock_get_tp_group.return_value = mock_tp_group
- yield
+ with patch("vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size", return_value=2):
+ with patch("vllm.distributed.parallel_state.get_tp_group") as mock_get_tp_group:
+ mock_tp_group = MagicMock()
+ mock_tp_group.world_size = 2
+ mock_get_tp_group.return_value = mock_tp_group
+ yield
class TestFlux2TransformerWeightLoading:
"""Test Flux2Transformer weight loading functionality"""
- @pytest.mark.core_model
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_weight_loading_tp2(self, setup_tp_group):
"""Verify weights load correctly with TP=2"""
# Prepare test data
@@ -83,8 +78,6 @@ def test_weight_loading_tp2(self, setup_tp_group):
class TestFlux2RopePositionEmbedding:
"""Test Flux2 RoPE position embedding functionality"""
- @pytest.mark.core_model
- @pytest.mark.cpu
def test_rope_position_embedding(self):
"""Verify RoPE produces correct embeddings for 4D coordinates"""
# Prepare test data - use model default configuration
@@ -139,8 +132,6 @@ def test_rope_position_embedding(self):
class TestFlux2PackedModuleMapping:
"""Test Flux2 packed module mapping functionality"""
- @pytest.mark.core_model
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_packed_module_mapping(self, setup_tp_group):
"""Verify to_qkv packing matches HF checkpoint"""
model = Flux2Transformer2DModel(
@@ -217,8 +208,6 @@ def test_packed_module_mapping(self, setup_tp_group):
f"add_kv_proj weight dimension should be {expected_add_kv_shape}, got {attn_block.add_kv_proj.weight.shape}"
)
- @pytest.mark.core_model
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_packed_mapping_edge_cases(self, setup_tp_group):
"""Test edge cases for packed mapping"""
model = Flux2Transformer2DModel(
diff --git a/tests/diffusion/models/glm_image/test_glm_image_sp.py b/tests/diffusion/models/glm_image/test_glm_image_sp.py
deleted file mode 100644
index 06a1a116dff..00000000000
--- a/tests/diffusion/models/glm_image/test_glm_image_sp.py
+++ /dev/null
@@ -1,132 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Tests for GLM-Image Sequence Parallelism support."""
-
-import pytest
-
-from vllm_omni.diffusion.data import DiffusionParallelConfig
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-@pytest.fixture(scope="function", autouse=True)
-def setup_sp_groups(mocker):
- """Set up SP and TP groups for each test function."""
- mock_get_sp_group = mocker.patch("vllm_omni.diffusion.distributed.parallel_state.get_sp_group")
- mocker.patch("vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size", return_value=1)
- mock_get_tp_group = mocker.patch("vllm.distributed.parallel_state.get_tp_group")
-
- mock_sp_group = mocker.MagicMock()
- mock_sp_group.world_size = 4
- mock_get_sp_group.return_value = mock_sp_group
-
- mock_tp_group = mocker.MagicMock()
- mock_tp_group.world_size = 1
- mock_get_tp_group.return_value = mock_tp_group
- yield
-
-
-def test_glm_image_sp_plan_defined():
- """Test that _sp_plan is properly defined on GlmImageTransformer2DModel."""
- from vllm_omni.diffusion.models.glm_image.glm_image_transformer import (
- GlmImageTransformer2DModel,
- )
-
- assert hasattr(GlmImageTransformer2DModel, "_sp_plan")
- plan = GlmImageTransformer2DModel._sp_plan
- assert plan is not None
-
- # Verify plan structure
- assert "prepare" in plan
- assert "proj_out" in plan
-
-
-def test_glm_image_sp_plan_valid():
- """Validate _sp_plan structure."""
- from vllm_omni.diffusion.distributed.sp_plan import validate_sp_plan
- from vllm_omni.diffusion.models.glm_image.glm_image_transformer import (
- GlmImageTransformer2DModel,
- )
-
- plan = GlmImageTransformer2DModel._sp_plan
- validate_sp_plan(plan)
-
-
-def test_glm_image_prepare_module_exists():
- """Test that GlmImagePrepare module exists."""
- from vllm_omni.diffusion.models.glm_image.glm_image_transformer import (
- GlmImagePrepare,
- )
-
- assert GlmImagePrepare is not None
-
-
-def test_glm_image_attention_accepts_parallel_config():
- """Test that GlmImageAttention accepts parallel_config parameter."""
- from vllm_omni.diffusion.models.glm_image.glm_image_transformer import (
- GlmImageAttention,
- )
-
- parallel_config = DiffusionParallelConfig(
- ulysses_degree=2,
- ring_degree=2,
- tensor_parallel_size=1,
- sequence_parallel_size=4,
- )
-
- attn = GlmImageAttention(
- dim=2560,
- num_heads=64,
- head_dim=40,
- parallel_config=parallel_config,
- )
-
- assert attn.parallel_config is not None
- assert attn.parallel_config.sequence_parallel_size == 4
-
-
-def test_glm_image_transformer_block_accepts_parallel_config():
- """Test that GlmImageTransformerBlock accepts parallel_config parameter."""
- from vllm_omni.diffusion.models.glm_image.glm_image_transformer import (
- GlmImageTransformerBlock,
- )
-
- parallel_config = DiffusionParallelConfig(
- ulysses_degree=2,
- ring_degree=2,
- tensor_parallel_size=1,
- sequence_parallel_size=4,
- )
-
- block = GlmImageTransformerBlock(
- dim=2560,
- num_attention_heads=64,
- attention_head_dim=40,
- time_embed_dim=512,
- parallel_config=parallel_config,
- )
-
- assert block.attn1.parallel_config is not None
- assert block.attn1.parallel_config.sequence_parallel_size == 4
-
-
-def test_glm_image_has_sp_support():
- """Test that GLM-Image has SP support implemented."""
- from vllm_omni.diffusion.models.glm_image.glm_image_transformer import (
- GlmImageTransformer2DModel,
- )
-
- # Check that the model has parallel_config support
- assert hasattr(GlmImageTransformer2DModel, "__init__")
-
- # Verify the model can be instantiated with SP config
-
- # This test just verifies the structure exists
- # Actual SP testing requires multi-GPU setup
-
-
-@pytest.mark.cuda
-@pytest.mark.sp
-def test_glm_image_sp_inference():
- """Test SP inference (requires multi-GPU setup)."""
- pytest.skip("Requires multi-GPU SP setup")
diff --git a/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py b/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py
deleted file mode 100644
index 51f6a85f580..00000000000
--- a/tests/diffusion/models/hunyuan_image3/test_hunyuan_image3_sampler.py
+++ /dev/null
@@ -1,190 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for HunyuanImage3 AR sampler logic (stage transitions,
-ratio restriction, comprehension blocking)."""
-
-import pytest
-import torch
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-# Fake token IDs for testing (avoid importing the real model).
-END_OF_THINK = 100
-RECAPTION = 101
-END_OF_RECAPTION = 102
-ANSWER = 103
-BOI = 104
-SIZE_TOKEN = 105
-EOS = 106
-RATIO_START = 200
-RATIO_END = 210
-RATIO_OTHER_START = 220
-RATIO_OTHER_END = 223
-
-
-class FakeSamplerModel:
- """Minimal stub that replicates the sampler-relevant attributes of
- HunyuanImage3ForConditionalGeneration without loading real weights."""
-
- def __init__(self, *, is_comprehension: bool = False):
- self._is_comprehension = is_comprehension
- self._eos_token_id = EOS
- self._end_of_think_id = END_OF_THINK
- self._recaption_id = RECAPTION
- self._end_of_recaption_id = END_OF_RECAPTION
- self._answer_id = ANSWER
- self._mrope_boi_token_id = BOI
- self._size_token_id = SIZE_TOKEN
- self._start_ratio_id = RATIO_START
- self._end_ratio_id = RATIO_END
- self._ratio_other_slices = [(RATIO_OTHER_START, RATIO_OTHER_END + 1)]
- self._all_ratio_ids = set(range(RATIO_START, RATIO_END + 1))
- self._all_ratio_ids.update(range(RATIO_OTHER_START, RATIO_OTHER_END + 1))
-
- self._stage_transitions: dict[int, list[int]] = {}
- if not is_comprehension:
- self._stage_transitions[END_OF_THINK] = [RECAPTION]
- self._stage_transitions[END_OF_RECAPTION] = [ANSWER, BOI, SIZE_TOKEN]
-
- self._blocked_token_ids: set[int] = set()
- if is_comprehension:
- self._blocked_token_ids.update([BOI, SIZE_TOKEN])
- self._blocked_token_ids.update(self._all_ratio_ids)
-
- # Bind the real methods from the model class.
- from vllm_omni.model_executor.models.hunyuan_image3.hunyuan_image3 import (
- HunyuanImage3ForConditionalGeneration as _Real,
- )
-
- _get_forced_token = _Real._get_forced_token
- _apply_ratio_restriction = _Real._apply_ratio_restriction
-
-
-class TestGetForcedToken:
- """Tests for the stateless _get_forced_token method."""
-
- def setup_method(self):
- self.model = FakeSamplerModel(is_comprehension=False)
-
- def test_no_trigger_returns_none(self):
- assert self.model._get_forced_token([1, 2, 3]) is None
-
- def test_empty_history_returns_none(self):
- assert self.model._get_forced_token([]) is None
-
- def test_end_of_think_forces_recaption(self):
- assert self.model._get_forced_token([END_OF_THINK]) == RECAPTION
-
- def test_end_of_think_completed(self):
- assert self.model._get_forced_token([END_OF_THINK, RECAPTION]) is None
-
- def test_end_of_recaption_forces_answer(self):
- tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION]
- assert self.model._get_forced_token(tokens) == ANSWER
-
- def test_end_of_recaption_forces_boi_after_answer(self):
- tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION, ANSWER]
- assert self.model._get_forced_token(tokens) == BOI
-
- def test_end_of_recaption_forces_size_after_boi(self):
- tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION, ANSWER, BOI]
- assert self.model._get_forced_token(tokens) == SIZE_TOKEN
-
- def test_full_sequence_complete(self):
- tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION, ANSWER, BOI, SIZE_TOKEN]
- assert self.model._get_forced_token(tokens) is None
-
- def test_diverged_history_returns_none(self):
- tokens = [END_OF_RECAPTION, 999] # 999 != ANSWER
- assert self.model._get_forced_token(tokens) is None
-
- def test_later_trigger_takes_precedence(self):
- tokens = [END_OF_THINK, RECAPTION, END_OF_RECAPTION]
- assert self.model._get_forced_token(tokens) == ANSWER
-
- def test_trigger_with_extra_tokens_before(self):
- tokens = [1, 2, 3, END_OF_THINK]
- assert self.model._get_forced_token(tokens) == RECAPTION
-
-
-class TestComprehensionBlocking:
- """Tests for comprehension mode token blocking."""
-
- def test_blocked_tokens_masked(self):
- model = FakeSamplerModel(is_comprehension=True)
- vocab_size = 300
- logits = torch.zeros(1, vocab_size)
- logits[0, BOI] = 5.0
- logits[0, SIZE_TOKEN] = 3.0
- logits[0, RATIO_START] = 2.0
- min_score = torch.finfo(logits.dtype).min
-
- for tid in model._blocked_token_ids:
- if tid < vocab_size:
- logits[0, tid] = min_score
-
- assert logits[0, BOI].item() == min_score
- assert logits[0, SIZE_TOKEN].item() == min_score
- assert logits[0, RATIO_START].item() == min_score
-
- def test_non_blocked_tokens_preserved(self):
- model = FakeSamplerModel(is_comprehension=True)
- vocab_size = 300
- logits = torch.zeros(1, vocab_size)
- logits[0, 50] = 7.0
- min_score = torch.finfo(logits.dtype).min
-
- for tid in model._blocked_token_ids:
- if tid < vocab_size:
- logits[0, tid] = min_score
-
- assert logits[0, 50].item() == 7.0
-
-
-class TestRatioRestriction:
- """Tests for _apply_ratio_restriction (greedy: only argmax ratio survives)."""
-
- def test_greedy_selects_single_ratio_token(self):
- model = FakeSamplerModel(is_comprehension=False)
- vocab_size = 300
- logits = torch.zeros(1, vocab_size)
- logits[0, RATIO_START + 3] = 10.0
- logits[0, RATIO_START + 1] = 5.0
- logits[0, 50] = 20.0 # non-ratio, should be masked
- min_score = torch.finfo(logits.dtype).min
-
- model._apply_ratio_restriction(logits, 0, min_score)
-
- assert logits[0, RATIO_START + 3].item() == 0
- assert logits[0, RATIO_START + 1].item() == min_score
- assert logits[0, 50].item() == min_score
-
- def test_extra_ratio_slices_considered(self):
- model = FakeSamplerModel(is_comprehension=False)
- vocab_size = 300
- logits = torch.zeros(1, vocab_size)
- logits[0, RATIO_OTHER_START] = 15.0
- logits[0, RATIO_START] = 5.0
- min_score = torch.finfo(logits.dtype).min
-
- model._apply_ratio_restriction(logits, 0, min_score)
-
- assert logits[0, RATIO_OTHER_START].item() == 0
- assert logits[0, RATIO_START].item() == min_score
-
-
-class TestForceEosAfterRatio:
- """Tests that a ratio token as last_token forces EOS."""
-
- def test_ratio_token_forces_eos(self):
- model = FakeSamplerModel(is_comprehension=False)
- vocab_size = 300
- logits = torch.randn(1, vocab_size)
- min_score = torch.finfo(logits.dtype).min
-
- logits[0].fill_(min_score)
- logits[0, model._eos_token_id] = 0
-
- assert logits[0, EOS].item() == 0
- non_eos_max = logits[0, :EOS].max().item()
- assert non_eos_max == min_score
diff --git a/tests/diffusion/models/hunyuan_image3/test_hunyuan_fused_moe.py b/tests/diffusion/models/hunyuan_image_3/test_hunyuan_fused_moe.py
similarity index 85%
rename from tests/diffusion/models/hunyuan_image3/test_hunyuan_fused_moe.py
rename to tests/diffusion/models/hunyuan_image_3/test_hunyuan_fused_moe.py
index 626f78eed9c..2cda9116c7d 100644
--- a/tests/diffusion/models/hunyuan_image3/test_hunyuan_fused_moe.py
+++ b/tests/diffusion/models/hunyuan_image_3/test_hunyuan_fused_moe.py
@@ -12,7 +12,7 @@ class TestSetForwardContextNumTokens:
def test_sets_num_tokens_when_context_available(self, mocker):
"""num_tokens should be set on ForwardContext when available."""
- import vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe as hunyuan_moe
+ import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe
mock_ctx = mocker.MagicMock()
del mock_ctx.in_profile_run # simulate missing attr
@@ -26,7 +26,7 @@ def test_sets_num_tokens_when_context_available(self, mocker):
def test_sets_in_profile_run_only_if_missing(self, mocker):
"""in_profile_run should not be overwritten if already set."""
- import vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe as hunyuan_moe
+ import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe
mock_ctx = mocker.MagicMock()
mock_ctx.in_profile_run = True # already set
@@ -40,7 +40,7 @@ def test_sets_in_profile_run_only_if_missing(self, mocker):
def test_noop_when_context_unavailable(self, mocker):
"""Should do nothing when ForwardContext is not available."""
- import vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe as hunyuan_moe
+ import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe
mocker.patch.object(hunyuan_moe._vllm_fc, "is_forward_context_available", return_value=False)
mock_get = mocker.patch.object(hunyuan_moe._vllm_fc, "get_forward_context")
@@ -55,11 +55,11 @@ class TestHunyuanFusedMoEPlatformDispatch:
def test_default_platform_uses_default_impl_qualname(self, mocker):
"""HunyuanFusedMoE should resolve the impl class from the platform hook."""
- import vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe as hunyuan_moe
+ import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe
mock_platform = mocker.MagicMock()
mock_platform.get_diffusion_model_impl_qualname.return_value = (
- "vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
+ "vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
)
mocker.patch.object(
@@ -71,7 +71,7 @@ def test_default_platform_uses_default_impl_qualname(self, mocker):
mock_impl = mocker.MagicMock()
mock_resolve.return_value = mock_impl
- from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe import (
+ from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import (
HunyuanFusedMoE,
)
@@ -80,7 +80,7 @@ def test_default_platform_uses_default_impl_qualname(self, mocker):
mock_platform.prepare_diffusion_op_runtime.assert_called_once_with("hunyuan_fused_moe")
mock_platform.get_diffusion_model_impl_qualname.assert_called_once_with("hunyuan_fused_moe")
mock_resolve.assert_called_once_with(
- "vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
+ "vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
)
mock_impl.assert_called_once_with(prefix="")
@@ -90,7 +90,7 @@ class TestHunyuanFusedMoEFactory:
def test_new_delegates_to_impl_class(self, mocker):
"""HunyuanFusedMoE(prefix=..., **kwargs) should instantiate and return impl instance."""
- import vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe as hunyuan_moe
+ import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe
class MockImpl:
def __init__(self, *, prefix: str = "", **kwargs):
@@ -104,7 +104,7 @@ def __init__(self, *, prefix: str = "", **kwargs):
mock_impl_class = mocker.MagicMock(return_value=MockImpl(prefix="test", a=1))
mocker.patch.object(hunyuan_moe, "resolve_obj_by_qualname", return_value=mock_impl_class)
- from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe import (
+ from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import (
HunyuanFusedMoE,
)
@@ -119,7 +119,7 @@ def __init__(self, *, prefix: str = "", **kwargs):
def test_make_expert_params_mapping_delegates_to_impl(self, mocker):
"""make_expert_params_mapping should delegate to impl class method."""
- import vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe as hunyuan_moe
+ import vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe as hunyuan_moe
expected_mapping = [("a", "b", 0, "c")]
mock_platform = mocker.MagicMock()
@@ -130,7 +130,7 @@ def test_make_expert_params_mapping_delegates_to_impl(self, mocker):
mock_impl_class.make_expert_params_mapping = mocker.MagicMock(return_value=expected_mapping)
mocker.patch.object(hunyuan_moe, "resolve_obj_by_qualname", return_value=mock_impl_class)
- from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe import (
+ from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import (
HunyuanFusedMoE,
)
diff --git a/tests/diffusion/models/ltx2/test_ltx2_3_pipeline.py b/tests/diffusion/models/ltx2/test_ltx2_3_pipeline.py
deleted file mode 100644
index 665126df737..00000000000
--- a/tests/diffusion/models/ltx2/test_ltx2_3_pipeline.py
+++ /dev/null
@@ -1,230 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""Unit tests for LTX-2.3 pipeline integration.
-
-These tests verify:
-- Pipeline is properly registered in the diffusion registry
-- Post-process function is registered
-- Cache-DiT enablers are registered
-- Pipeline does NOT inherit from LTX2Pipeline
-- Vocoder sample rate detection logic
-- Re-export module works correctly
-"""
-
-import json
-import os
-import tempfile
-
-import pytest
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-class TestPipelineIndependence:
- """Verify LTX23Pipeline is fully independent from LTX2Pipeline."""
-
- def test_ltx23_pipeline_does_not_inherit_from_ltx2(self):
- """LTX23Pipeline must NOT inherit from LTX2Pipeline."""
- from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import LTX2Pipeline
- from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_3 import LTX23Pipeline
-
- assert not issubclass(LTX23Pipeline, LTX2Pipeline), (
- "LTX23Pipeline should be fully independent and not inherit from LTX2Pipeline"
- )
-
- def test_ltx23_pipeline_is_nn_module(self):
- """LTX23Pipeline must be an nn.Module."""
- import torch.nn as nn
-
- from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_3 import LTX23Pipeline
-
- assert issubclass(LTX23Pipeline, nn.Module)
-
- def test_ltx23_pipeline_has_progress_bar(self):
- """LTX23Pipeline must mix in ProgressBarMixin."""
- from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_3 import LTX23Pipeline
- from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
-
- assert issubclass(LTX23Pipeline, ProgressBarMixin)
-
-
-class TestRegistryIntegration:
- """Verify all LTX-2.3 pipeline variants are registered."""
-
- def test_pipeline_models_registered(self):
- """LTX-2.3 pipeline variants must be in _DIFFUSION_MODELS."""
- from vllm_omni.diffusion.registry import _DIFFUSION_MODELS
-
- expected = [
- "LTX23Pipeline",
- "LTX23ImageToVideoPipeline",
- ]
- for name in expected:
- assert name in _DIFFUSION_MODELS, f"{name} not found in _DIFFUSION_MODELS"
-
- def test_pipeline_module_paths(self):
- """Registry entries must point to the correct modules."""
- from vllm_omni.diffusion.registry import _DIFFUSION_MODELS
-
- # T2V -> pipeline_ltx2_3
- assert _DIFFUSION_MODELS["LTX23Pipeline"] == ("ltx2", "pipeline_ltx2_3", "LTX23Pipeline")
-
- # I2V -> pipeline_ltx2_3_image2video
- assert _DIFFUSION_MODELS["LTX23ImageToVideoPipeline"] == (
- "ltx2",
- "pipeline_ltx2_3_image2video",
- "LTX23ImageToVideoPipeline",
- )
-
- def test_post_process_funcs_registered(self):
- """Pipeline variants must map to get_ltx2_post_process_func."""
- from vllm_omni.diffusion.registry import _DIFFUSION_POST_PROCESS_FUNCS
-
- expected = [
- "LTX23Pipeline",
- "LTX23ImageToVideoPipeline",
- ]
- for name in expected:
- assert name in _DIFFUSION_POST_PROCESS_FUNCS, f"{name} not in _DIFFUSION_POST_PROCESS_FUNCS"
- assert _DIFFUSION_POST_PROCESS_FUNCS[name] == "get_ltx2_post_process_func"
-
- def test_cache_dit_enablers_registered(self):
- """Pipeline variants must be registered in CUSTOM_DIT_ENABLERS."""
- from vllm_omni.diffusion.cache.cache_dit_backend import CUSTOM_DIT_ENABLERS
-
- expected = [
- "LTX23Pipeline",
- "LTX23ImageToVideoPipeline",
- ]
- for name in expected:
- assert name in CUSTOM_DIT_ENABLERS, f"{name} not in CUSTOM_DIT_ENABLERS"
-
-
-class TestVocoderSampleRateDetection:
- """Test _detect_vocoder_output_sample_rate logic."""
-
- def test_detects_48khz_from_config(self):
- """Should detect output_sampling_rate=48000 from vocoder/config.json."""
- from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_3 import _detect_vocoder_output_sample_rate
-
- with tempfile.TemporaryDirectory() as tmpdir:
- vocoder_dir = os.path.join(tmpdir, "vocoder")
- os.makedirs(vocoder_dir)
- with open(os.path.join(vocoder_dir, "config.json"), "w") as f:
- json.dump({"output_sampling_rate": 48000, "input_sampling_rate": 16000}, f)
-
- result = _detect_vocoder_output_sample_rate(tmpdir)
- assert result == 48000
-
- def test_returns_none_for_no_output_sr(self):
- """Should return None if vocoder config has no output_sampling_rate."""
- from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_3 import _detect_vocoder_output_sample_rate
-
- with tempfile.TemporaryDirectory() as tmpdir:
- vocoder_dir = os.path.join(tmpdir, "vocoder")
- os.makedirs(vocoder_dir)
- with open(os.path.join(vocoder_dir, "config.json"), "w") as f:
- json.dump({"sampling_rate": 16000}, f)
-
- result = _detect_vocoder_output_sample_rate(tmpdir)
- assert result is None
-
- def test_returns_none_for_missing_directory(self):
- """Should return None if vocoder directory doesn't exist."""
- from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_3 import _detect_vocoder_output_sample_rate
-
- result = _detect_vocoder_output_sample_rate("/nonexistent/path")
- assert result is None
-
-
-class TestPostProcessFunction:
- """Test the post-process function factory."""
-
- def test_post_process_includes_audio_sample_rate(self):
- """Post-process func should include audio_sample_rate when detected."""
- import torch
-
- from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_3 import get_ltx2_post_process_func
-
- with tempfile.TemporaryDirectory() as tmpdir:
- vocoder_dir = os.path.join(tmpdir, "vocoder")
- os.makedirs(vocoder_dir)
- with open(os.path.join(vocoder_dir, "config.json"), "w") as f:
- json.dump({"output_sampling_rate": 48000}, f)
-
- # Create a minimal od_config mock
- class MockConfig:
- model = tmpdir
-
- func = get_ltx2_post_process_func(MockConfig())
-
- video = torch.zeros(1, 3, 4, 64, 64)
- audio = torch.zeros(1, 1, 48000)
- result = func((video, audio))
-
- assert isinstance(result, dict)
- assert "video" in result
- assert "audio" in result
- assert result["audio_sample_rate"] == 48000
-
- def test_post_process_without_vocoder_config(self):
- """Post-process func should work without vocoder config (no audio_sample_rate key)."""
- import torch
-
- from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_3 import get_ltx2_post_process_func
-
- class MockConfig:
- model = "/nonexistent/path"
-
- func = get_ltx2_post_process_func(MockConfig())
-
- video = torch.zeros(1, 3, 4, 64, 64)
- audio = torch.zeros(1, 1, 16000)
- result = func((video, audio))
-
- assert isinstance(result, dict)
- assert "video" in result
- assert "audio" in result
- assert "audio_sample_rate" not in result
-
-
-class TestReExportModule:
- """Test that pipeline_ltx2_3_image2video.py correctly re-exports."""
-
- def test_i2v_classes_importable(self):
- """I2V classes must be importable from the re-export module."""
- from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_3_image2video import LTX23ImageToVideoPipeline
-
- assert LTX23ImageToVideoPipeline is not None
-
- def test_post_process_func_importable(self):
- """get_ltx2_post_process_func must be importable from re-export module."""
- from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_3_image2video import get_ltx2_post_process_func
-
- assert callable(get_ltx2_post_process_func)
-
- def test_i2v_classes_are_same_as_direct_import(self):
- """Re-exported classes must be the same objects as direct imports."""
- from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_3 import LTX23ImageToVideoPipeline as Direct
- from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_3_image2video import (
- LTX23ImageToVideoPipeline as ReExported,
- )
-
- assert Direct is ReExported
-
-
-class TestInitExports:
- """Test that __init__.py exports all LTX-2.3 classes."""
-
- def test_all_ltx23_classes_exported(self):
- """All LTX-2.3 pipeline classes must be in the ltx2 package __all__."""
- from vllm_omni.diffusion.models import ltx2
-
- expected_classes = [
- "LTX23Pipeline",
- "LTX23ImageToVideoPipeline",
- ]
- for name in expected_classes:
- assert hasattr(ltx2, name), f"{name} not exported from ltx2 package"
- assert name in ltx2.__all__, f"{name} not in ltx2.__all__"
diff --git a/tests/diffusion/models/ltx2/test_ltx2_cfg_parallel_adaptation.py b/tests/diffusion/models/ltx2/test_ltx2_cfg_parallel_adaptation.py
deleted file mode 100644
index bbfe63dfa58..00000000000
--- a/tests/diffusion/models/ltx2/test_ltx2_cfg_parallel_adaptation.py
+++ /dev/null
@@ -1,58 +0,0 @@
-from types import SimpleNamespace
-
-import pytest
-import torch
-
-from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import LTX2Pipeline
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def _make_pipeline(sequence_parallel_size: int = 1) -> LTX2Pipeline:
- pipeline = object.__new__(LTX2Pipeline)
- torch.nn.Module.__init__(pipeline)
- pipeline.audio_vae_temporal_compression_ratio = 4
- pipeline.audio_vae_mel_compression_ratio = 4
- pipeline.od_config = SimpleNamespace(parallel_config=SimpleNamespace(sequence_parallel_size=sequence_parallel_size))
- # Mock audio_vae with identity normalization (mean=0, std=1) so
- # _normalize_audio_latents is a no-op and test values are preserved.
- pipeline.audio_vae = SimpleNamespace(
- latents_mean=torch.tensor(0.0),
- latents_std=torch.tensor(1.0),
- )
- return pipeline
-
-
-def test_prepare_audio_latents_pads_packed_sequence_dim_for_provided_latents():
- pipeline = _make_pipeline(sequence_parallel_size=4)
- latents = torch.arange(40, dtype=torch.float32).view(1, 10, 4)
-
- padded, original_num_frames, padded_num_frames = pipeline.prepare_audio_latents(
- batch_size=1,
- num_channels_latents=2,
- num_mel_bins=8,
- audio_latent_length=10,
- dtype=torch.float32,
- device=torch.device("cpu"),
- latents=latents,
- )
-
- assert original_num_frames == 10
- assert padded_num_frames == 12
- assert padded.shape == (1, 12, 4)
- torch.testing.assert_close(padded[:, :10], latents)
- torch.testing.assert_close(padded[:, 10:], torch.zeros(1, 2, 4))
-
-
-def test_unpad_audio_latents_restores_original_frames_before_unpack():
- pipeline = _make_pipeline()
- original = torch.arange(40, dtype=torch.float32).view(1, 10, 4)
- padded = torch.cat([original, torch.full((1, 2, 4), 999.0)], dim=1)
-
- unpadded = pipeline._unpad_audio_latents(padded, 10)
- unpacked = pipeline._unpack_audio_latents(unpadded, latent_length=10, num_mel_bins=2)
- expected = pipeline._unpack_audio_latents(original, latent_length=10, num_mel_bins=2)
-
- assert unpacked.shape == (1, 2, 10, 2)
- assert not (unpacked == 999.0).any()
- torch.testing.assert_close(unpacked, expected)
diff --git a/tests/diffusion/models/ltx2/test_ltx2_hsdp.py b/tests/diffusion/models/ltx2/test_ltx2_hsdp.py
deleted file mode 100644
index 4dd07e1bf82..00000000000
--- a/tests/diffusion/models/ltx2/test_ltx2_hsdp.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import pytest
-import torch.nn as nn
-
-from vllm_omni.diffusion.models.ltx2.ltx2_transformer import LTX2VideoTransformer3DModel
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def test_ltx2_exposes_hsdp_shard_conditions_for_transformer_blocks():
- model = object.__new__(LTX2VideoTransformer3DModel)
- nn.Module.__init__(model)
- model.transformer_blocks = nn.ModuleList([nn.Linear(4, 4) for _ in range(2)])
- model.norm_out = nn.LayerNorm(4)
-
- conditions = getattr(model, "_hsdp_shard_conditions", None)
-
- assert conditions is not None
- assert len(conditions) == 1
-
- matched = []
- for name, module in model.named_modules():
- if any(cond(name, module) for cond in conditions):
- matched.append(name)
-
- assert matched == ["transformer_blocks.0", "transformer_blocks.1"]
diff --git a/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py b/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py
deleted file mode 100644
index 873b52bf7a6..00000000000
--- a/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py
+++ /dev/null
@@ -1,38 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-
-import json
-from pathlib import Path
-from types import SimpleNamespace
-
-import numpy as np
-import pytest
-from PIL import Image
-
-from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus import (
- get_qwen_image_edit_plus_pre_process_func,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]
-
-
-def test_qwen_image_edit_plus_rejects_too_many_input_images(tmp_path: Path):
- vae_dir = tmp_path / "vae"
- vae_dir.mkdir()
- # Keep the mock config intentionally minimal: this test only needs the
- # fields touched during pre-process initialization.
- (vae_dir / "config.json").write_text(json.dumps({"z_dim": 16}))
-
- pre_process = get_qwen_image_edit_plus_pre_process_func(SimpleNamespace(model=str(tmp_path)))
- image = Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8))
- request = SimpleNamespace(
- prompts=[
- {
- "prompt": "combine",
- "multi_modal_data": {"image": [image, image, image, image, image]},
- }
- ],
- sampling_params=SimpleNamespace(height=None, width=None),
- )
-
- with pytest.raises(ValueError, match=r"At most 4 images are supported by this model"):
- pre_process(request)
diff --git a/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py b/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py
deleted file mode 100644
index f5676a0056f..00000000000
--- a/tests/diffusion/models/qwen_image/test_qwen_image_max_sequence_length.py
+++ /dev/null
@@ -1,260 +0,0 @@
-import inspect
-from types import SimpleNamespace
-
-import pytest
-import torch
-from torch import nn
-
-from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import (
- QwenImagePipeline,
-)
-from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit import (
- QwenImageEditPipeline,
-)
-from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus import (
- QwenImageEditPlusPipeline,
-)
-from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_layered import (
- QwenImageLayeredPipeline,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-class _RejectingTextEncoder:
- dtype = torch.float32
-
- def __call__(self, *args, **kwargs):
- raise AssertionError("text encoder should not run for prompts that exceed max_sequence_length")
-
-
-class _FakeModelInputs:
- def __init__(self, total_sequence_length: int):
- attention_mask = torch.ones((1, total_sequence_length), dtype=torch.long)
- self.input_ids = attention_mask.clone()
- self.attention_mask = attention_mask
- self.pixel_values = None
- self.image_grid_thw = None
-
- def to(self, device):
- return self
-
-
-class _FakeTokenizer:
- def __init__(self, total_sequence_length: int | list[int]):
- if isinstance(total_sequence_length, list):
- self.total_sequence_lengths = list(total_sequence_length)
- else:
- self.total_sequence_lengths = [total_sequence_length]
-
- def __call__(self, *args, **kwargs):
- if len(self.total_sequence_lengths) > 1:
- total_sequence_length = self.total_sequence_lengths.pop(0)
- else:
- total_sequence_length = self.total_sequence_lengths[0]
- return _FakeModelInputs(total_sequence_length)
-
-
-class _FakeProcessor(_FakeTokenizer):
- pass
-
-
-class _FakeScheduler:
- def __init__(self):
- self.begin_index = None
-
- def set_begin_index(self, begin_index: int):
- self.begin_index = begin_index
-
-
-PIPELINE_CASES = [
- pytest.param(QwenImagePipeline, 34, "tokenizer", id="qwen-image"),
- pytest.param(QwenImageLayeredPipeline, 34, "tokenizer", id="qwen-image-layered"),
- pytest.param(QwenImageEditPipeline, 64, "processor", id="qwen-image-edit"),
- pytest.param(QwenImageEditPlusPipeline, 64, "processor", id="qwen-image-edit-plus"),
-]
-
-
-def _make_pipeline(
- pipeline_class: type,
- *,
- total_sequence_length: int,
- drop_idx: int,
- input_kind: str,
-):
- pipeline = object.__new__(pipeline_class)
- nn.Module.__init__(pipeline)
- pipeline.device = torch.device("cpu")
- pipeline.text_encoder = _RejectingTextEncoder()
- pipeline.tokenizer_max_length = 1024
- pipeline.prompt_template_encode = "{}"
- pipeline.prompt_template_encode_start_idx = drop_idx
- pipeline.tokenizer = _FakeTokenizer([total_sequence_length, 0])
- if input_kind == "processor":
- pipeline.processor = _FakeProcessor(total_sequence_length)
- return pipeline
-
-
-@pytest.mark.parametrize(("pipeline_class", "drop_idx", "input_kind"), PIPELINE_CASES)
-def test_encode_prompt_rejects_prompt_longer_than_default_max_sequence_length(
- pipeline_class: type,
- drop_idx: int,
- input_kind: str,
-):
- pipeline = _make_pipeline(
- pipeline_class,
- total_sequence_length=1025,
- drop_idx=drop_idx,
- input_kind=input_kind,
- )
-
- with pytest.raises(ValueError, match=r"got 1025 tokens, but `max_sequence_length` is 1024"):
- pipeline.encode_prompt(prompt="prompt")
-
-
-@pytest.mark.parametrize(("pipeline_class", "drop_idx", "input_kind"), PIPELINE_CASES)
-def test_encode_prompt_rejects_prompt_longer_than_explicit_max_sequence_length(
- pipeline_class: type,
- drop_idx: int,
- input_kind: str,
-):
- pipeline = _make_pipeline(
- pipeline_class,
- total_sequence_length=17,
- drop_idx=drop_idx,
- input_kind=input_kind,
- )
-
- with pytest.raises(ValueError, match=r"got 17 tokens, but `max_sequence_length` is 16"):
- pipeline.encode_prompt(prompt="prompt", max_sequence_length=16)
-
-
-def test_prepare_encode_defaults_to_tokenizer_max_length():
- pipeline = object.__new__(QwenImagePipeline)
- nn.Module.__init__(pipeline)
- pipeline.tokenizer_max_length = 1024
- pipeline.vae_scale_factor = 8
- pipeline.default_sample_size = 128
- pipeline.scheduler = _FakeScheduler()
- pipeline._extract_prompts = lambda prompts: (["prompt"], None)
-
- captured = {}
-
- def _fake_prepare_generation_context(**kwargs):
- captured["max_sequence_length"] = kwargs["max_sequence_length"]
- embeds = torch.ones((1, 1, 1))
- mask = torch.ones((1, 1), dtype=torch.long)
- return {
- "prompt_embeds": embeds,
- "prompt_embeds_mask": mask,
- "negative_prompt_embeds": None,
- "negative_prompt_embeds_mask": None,
- "latents": embeds,
- "timesteps": torch.tensor([1]),
- "do_true_cfg": False,
- "guidance": None,
- "img_shapes": [[(1, 1, 1)]],
- "txt_seq_lens": [1],
- "negative_txt_seq_lens": None,
- }
-
- pipeline._prepare_generation_context = _fake_prepare_generation_context
- state = SimpleNamespace(
- prompts=["prompt"],
- sampling=SimpleNamespace(
- height=None,
- width=None,
- num_inference_steps=None,
- sigmas=None,
- guidance_scale_provided=False,
- num_outputs_per_prompt=0,
- generator=None,
- true_cfg_scale=None,
- max_sequence_length=None,
- ),
- )
-
- pipeline.prepare_encode(state)
-
- assert captured["max_sequence_length"] == 1024
-
-
-@pytest.mark.parametrize(
- ("pipeline_class", "drop_idx"),
- [
- pytest.param(QwenImageEditPipeline, 64, id="qwen-image-edit"),
- pytest.param(QwenImageEditPlusPipeline, 64, id="qwen-image-edit-plus"),
- ],
-)
-def test_edit_pipelines_validate_text_prompt_length_before_image_token_expansion(
- pipeline_class: type,
- drop_idx: int,
-):
- pipeline = object.__new__(pipeline_class)
- nn.Module.__init__(pipeline)
- pipeline.device = torch.device("cpu")
- pipeline.text_encoder = _RejectingTextEncoder()
- pipeline.tokenizer_max_length = 1024
- pipeline.prompt_template_encode = "{}"
- pipeline.prompt_template_encode_start_idx = drop_idx
- pipeline.tokenizer = _FakeTokenizer([8, 0])
- pipeline.processor = _FakeProcessor(drop_idx + 1500)
-
- with pytest.raises(AssertionError, match="text encoder should not run"):
- pipeline.encode_prompt(prompt="short prompt")
-
-
-@pytest.mark.parametrize(
- "pipeline_class",
- [
- pytest.param(QwenImagePipeline, id="qwen-image"),
- pytest.param(QwenImageLayeredPipeline, id="qwen-image-layered"),
- ],
-)
-def test_qwen_generation_validator_excludes_template_suffix_from_budget(pipeline_class: type):
- pipeline = object.__new__(pipeline_class)
- nn.Module.__init__(pipeline)
- pipeline.device = torch.device("cpu")
- pipeline.text_encoder = _RejectingTextEncoder()
- pipeline.tokenizer_max_length = 1024
- pipeline.prompt_template_encode = "{}"
- pipeline.prompt_template_encode_start_idx = 34
- pipeline.tokenizer = _FakeTokenizer([1029, 5])
-
- with pytest.raises(AssertionError, match="text encoder should not run"):
- pipeline.encode_prompt(prompt="boundary prompt")
-
-
-@pytest.mark.parametrize(
- "pipeline_class",
- [
- pytest.param(QwenImageEditPipeline, id="qwen-image-edit"),
- pytest.param(QwenImageEditPlusPipeline, id="qwen-image-edit-plus"),
- ],
-)
-def test_qwen_edit_validator_excludes_image_placeholders_from_budget(pipeline_class: type):
- pipeline = object.__new__(pipeline_class)
- nn.Module.__init__(pipeline)
- pipeline.device = torch.device("cpu")
- pipeline.text_encoder = _RejectingTextEncoder()
- pipeline.tokenizer_max_length = 1024
- pipeline.prompt_template_encode = "{}"
- pipeline.prompt_template_encode_start_idx = 64
- pipeline.tokenizer = _FakeTokenizer([30, 20])
- pipeline.processor = _FakeProcessor(1500)
-
- with pytest.raises(AssertionError, match="text encoder should not run"):
- pipeline.encode_prompt(prompt="short prompt")
-
-
-@pytest.mark.parametrize(
- "pipeline_class",
- [
- QwenImagePipeline,
- QwenImageLayeredPipeline,
- QwenImageEditPipeline,
- QwenImageEditPlusPipeline,
- ],
-)
-def test_forward_max_sequence_length_default_is_1024(pipeline_class: type):
- assert inspect.signature(pipeline_class.forward).parameters["max_sequence_length"].default == 1024
diff --git a/tests/diffusion/models/qwen_image/test_qwen_image_size_utils.py b/tests/diffusion/models/qwen_image/test_qwen_image_size_utils.py
deleted file mode 100644
index 7ba8f108a13..00000000000
--- a/tests/diffusion/models/qwen_image/test_qwen_image_size_utils.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import pytest
-
-from vllm_omni.diffusion.utils.size_utils import (
- normalize_min_aligned_size,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-@pytest.mark.parametrize(
- ("height", "width", "expected"),
- [
- (1, 1, (16, 16)),
- (15, 15, (16, 16)),
- (17, 17, (16, 16)),
- (31, 33, (16, 32)),
- (64, 80, (64, 80)),
- ],
-)
-def test_normalize_min_aligned_size_clamps_to_minimum_aligned_shape(height, width, expected):
- assert normalize_min_aligned_size(height, width, alignment=16) == expected
-
-
-def test_normalize_min_aligned_size_rejects_invalid_alignment():
- with pytest.raises(ValueError, match="positive alignment"):
- normalize_min_aligned_size(16, 16, alignment=0)
diff --git a/tests/diffusion/models/stable_audio/test_stable_audio_hsdp.py b/tests/diffusion/models/stable_audio/test_stable_audio_hsdp.py
deleted file mode 100644
index 923b9a86315..00000000000
--- a/tests/diffusion/models/stable_audio/test_stable_audio_hsdp.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import pytest
-import torch.nn as nn
-
-from vllm_omni.diffusion.models.stable_audio.stable_audio_transformer import StableAudioDiTModel
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def test_stable_audio_exposes_hsdp_shard_conditions_for_transformer_blocks():
- model = object.__new__(StableAudioDiTModel)
- nn.Module.__init__(model)
- model.transformer_blocks = nn.ModuleList([nn.Linear(4, 4) for _ in range(2)])
- model.proj_out = nn.Linear(4, 4)
-
- conditions = getattr(model, "_hsdp_shard_conditions", None)
-
- assert conditions is not None
- assert len(conditions) == 1
-
- matched = []
- for name, module in model.named_modules():
- if any(cond(name, module) for cond in conditions):
- matched.append(name)
-
- assert matched == ["transformer_blocks.0", "transformer_blocks.1"]
diff --git a/tests/diffusion/models/t5_encoder/test_t5_encoder_prefix.py b/tests/diffusion/models/t5_encoder/test_t5_encoder_prefix.py
deleted file mode 100644
index 039150f096c..00000000000
--- a/tests/diffusion/models/t5_encoder/test_t5_encoder_prefix.py
+++ /dev/null
@@ -1,164 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Tests for T5EncoderModel prefix handling and weight loading fix."""
-
-import pytest
-import torch
-from transformers import T5Config
-from vllm.config import DeviceConfig, VllmConfig, set_current_vllm_config
-
-from vllm_omni.diffusion.models.t5_encoder.t5_encoder import T5EncoderModel
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-_SMALL_T5_CONFIG = dict(
- d_model=64,
- d_kv=8,
- d_ff=128,
- num_heads=8,
- num_layers=2,
- vocab_size=256,
- relative_attention_num_buckets=32,
- relative_attention_max_distance=128,
- is_gated_act=True,
- dense_act_fn="gelu_new",
- layer_norm_epsilon=1e-6,
- feed_forward_proj="gated-gelu",
-)
-
-_T5_MODULE = "vllm_omni.diffusion.models.t5_encoder.t5_encoder"
-
-
-@pytest.fixture
-def t5_config() -> T5Config:
- return T5Config(**_SMALL_T5_CONFIG)
-
-
-@pytest.fixture(scope="function", autouse=True)
-def setup_vllm_config(monkeypatch, mocker):
- """Set up VllmConfig and TP=2 mocks for tests."""
- device_config = DeviceConfig(device="cpu")
-
- monkeypatch.setattr("vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size", lambda: 2)
- monkeypatch.setattr(f"{_T5_MODULE}.get_tensor_model_parallel_world_size", lambda: 2)
- monkeypatch.setattr(
- "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
- lambda: 2,
- )
-
- monkeypatch.setattr(f"{_T5_MODULE}.get_tensor_model_parallel_rank", lambda: 0)
- monkeypatch.setattr(
- "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
- lambda: 0,
- )
-
- mock_tp_group = mocker.MagicMock()
- mock_tp_group.world_size = 2
- mocker.patch("vllm.distributed.parallel_state.get_tp_group", return_value=mock_tp_group)
-
- monkeypatch.setattr(f"{_T5_MODULE}.get_act_fn", lambda _: torch.nn.GELU())
-
- with set_current_vllm_config(VllmConfig(device_config=device_config)):
- yield
-
-
-class TestT5EncoderModelPrefixHandling:
- """Test that T5EncoderModel correctly handles prefix attribute."""
-
- def test_prefix_stored_in_model(self, t5_config):
- """Test that prefix is stored in the model when provided."""
- prefix = "text_encoder"
- model = T5EncoderModel(t5_config, prefix=prefix)
- assert hasattr(model, "prefix")
- assert model.prefix == prefix
-
- def test_prefix_empty_by_default(self, t5_config):
- """Test that prefix defaults to empty string when not provided."""
- model = T5EncoderModel(t5_config)
- assert hasattr(model, "prefix")
- assert model.prefix == ""
-
-
-class TestT5EncoderModelWeightLoadingWithPrefix:
- """Test weight loading with prefix handling."""
-
- def test_load_weights_with_prefix(self, t5_config):
- """Test that weights without prefix are loaded when model has prefix."""
- config = T5Config(**{**_SMALL_T5_CONFIG, "num_layers": 1})
- model = T5EncoderModel(config, prefix="text_encoder")
-
- inner_dim = config.num_heads * config.d_kv
-
- weights = [
- ("encoder.block.0.layer.0.SelfAttention.q.weight", torch.randn(inner_dim, config.d_model)),
- ("encoder.block.0.layer.0.SelfAttention.k.weight", torch.randn(inner_dim, config.d_model)),
- ("encoder.block.0.layer.0.SelfAttention.v.weight", torch.randn(inner_dim, config.d_model)),
- ]
-
- loaded = model.load_weights(weights)
- assert len(loaded) > 0
-
- def test_load_weights_embed_tokens_shared_sync(self, t5_config):
- """Test that embed_tokens and shared weights are synced."""
- model = T5EncoderModel(t5_config, prefix="text_encoder")
-
- d_model = t5_config.d_model
- vocab_size = t5_config.vocab_size
-
- embed_weight = torch.randn(vocab_size, d_model)
- weights = [
- ("encoder.embed_tokens.weight", embed_weight.clone()),
- ]
-
- model.load_weights(weights)
-
- shared_param = model.shared.weight
- embed_param = model.encoder.embed_tokens.weight
-
- assert torch.allclose(shared_param, embed_param), (
- "shared and embed_tokens should have the same weights after loading"
- )
-
- def test_load_weights_shared_without_prefix(self, t5_config):
- """Test shared.weight is recognized without relying on dot context."""
- model = T5EncoderModel(t5_config, prefix="text_encoder")
-
- shared_weight = torch.randn(t5_config.vocab_size, t5_config.d_model)
- loaded = model.load_weights([("shared.weight", shared_weight)])
-
- assert "shared.weight" in loaded
- assert torch.allclose(model.shared.weight, model.encoder.embed_tokens.weight)
-
- def test_unmatched_weights_are_not_reported_loaded(self, t5_config):
- """Test that skipped checkpoint weights are not added to loaded_params."""
- model = T5EncoderModel(t5_config, prefix="text_encoder")
-
- loaded = model.load_weights(
- [
- (
- "text_encoder.encoder.block.0.layer.0.SelfAttention.missing.weight",
- torch.randn(t5_config.d_model, t5_config.d_model),
- ),
- ]
- )
-
- assert loaded == set()
-
-
-class TestT5EncoderModelWeightLoadingWithoutPrefix:
- """Test weight loading without prefix."""
-
- def test_load_weights_without_prefix(self, t5_config):
- """Test that weights without prefix are loaded correctly."""
- config = T5Config(**{**_SMALL_T5_CONFIG, "num_layers": 1})
- model = T5EncoderModel(config)
-
- inner_dim = config.num_heads * config.d_kv
-
- weights = [
- ("encoder.block.0.layer.0.SelfAttention.q.weight", torch.randn(inner_dim, config.d_model)),
- ]
-
- loaded = model.load_weights(weights)
- assert len(loaded) > 0
diff --git a/tests/diffusion/models/wan2_2/__init__.py b/tests/diffusion/models/wan2_2/__init__.py
deleted file mode 100644
index e69de29bb2d..00000000000
diff --git a/tests/diffusion/models/wan2_2/conftest.py b/tests/diffusion/models/wan2_2/conftest.py
deleted file mode 100644
index f836fa545fd..00000000000
--- a/tests/diffusion/models/wan2_2/conftest.py
+++ /dev/null
@@ -1,80 +0,0 @@
-from __future__ import annotations
-
-from contextlib import contextmanager
-from types import SimpleNamespace
-
-import torch
-from torch import nn
-
-
-class StubTransformer(nn.Module):
- def __init__(self, *, name: str = "transformer", in_channels: int = 4, out_channels: int = 4) -> None:
- super().__init__()
- self.name = name
- self.config = SimpleNamespace(
- patch_size=(1, 2, 2),
- in_channels=in_channels,
- out_channels=out_channels,
- image_dim=None,
- )
-
- @property
- def dtype(self) -> torch.dtype:
- return torch.float32
-
- def forward(self, **kwargs):
- hidden_states = kwargs["hidden_states"]
- return (torch.zeros_like(hidden_states[:, : self.config.out_channels]),)
-
-
-class StubScheduler:
- def __init__(self, timesteps: list[int]) -> None:
- self.timesteps = torch.tensor(timesteps, dtype=torch.int64)
- self.config = SimpleNamespace(num_train_timesteps=1000)
- self.set_timesteps_calls: list[tuple[int, torch.device]] = []
-
- def set_timesteps(self, num_steps: int, device: torch.device) -> None:
- self.set_timesteps_calls.append((num_steps, device))
-
-
-class StubVAE:
- dtype = torch.float32
-
- def __init__(self, z_dim: int = 4) -> None:
- self.config = SimpleNamespace(
- z_dim=z_dim,
- scale_factor_temporal=4,
- scale_factor_spatial=8,
- latents_mean=[0.0] * z_dim,
- latents_std=[1.0] * z_dim,
- )
-
- def encode(self, video: torch.Tensor):
- latent_frames = (video.shape[2] + self.config.scale_factor_temporal - 1) // self.config.scale_factor_temporal
- latent_height = video.shape[-2] // self.config.scale_factor_spatial
- latent_width = video.shape[-1] // self.config.scale_factor_spatial
- latents = torch.ones(
- video.shape[0],
- self.config.z_dim,
- latent_frames,
- latent_height,
- latent_width,
- dtype=video.dtype,
- device=video.device,
- )
- return SimpleNamespace(latents=latents)
-
- def decode(self, latents: torch.Tensor, return_dict: bool = False):
- del return_dict
- return (latents,)
-
-
-@contextmanager
-def noop_progress_bar(*args, **kwargs):
- del args, kwargs
-
- class Bar:
- def update(self) -> None:
- return None
-
- yield Bar()
diff --git a/tests/diffusion/models/wan2_2/test_wan22_i2v_pipeline.py b/tests/diffusion/models/wan2_2/test_wan22_i2v_pipeline.py
deleted file mode 100644
index 576678e2cf0..00000000000
--- a/tests/diffusion/models/wan2_2/test_wan22_i2v_pipeline.py
+++ /dev/null
@@ -1,125 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from types import SimpleNamespace
-
-import pytest
-import torch
-from PIL import Image
-from torch import nn
-
-from tests.diffusion.models.wan2_2.conftest import StubTransformer, StubVAE, noop_progress_bar
-from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import (
- Wan22I2VPipeline,
- get_wan22_i2v_pre_process_func,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion]
-
-
-def _make_i2v_pipeline(*, expand_timesteps: bool) -> Wan22I2VPipeline:
- pipeline = object.__new__(Wan22I2VPipeline)
- nn.Module.__init__(pipeline)
- pipeline.device = torch.device("cpu")
- pipeline.transformer = StubTransformer(name="high", in_channels=8, out_channels=4)
- pipeline.transformer_2 = StubTransformer(name="low", in_channels=8, out_channels=4)
- pipeline.vae = StubVAE(z_dim=4)
- pipeline.vae_scale_factor_temporal = 4
- pipeline.vae_scale_factor_spatial = 8
- pipeline.expand_timesteps = expand_timesteps
- pipeline.progress_bar = noop_progress_bar
- return pipeline
-
-
-def test_i2v_preprocess_requires_image_and_resizes_to_480p_aspect() -> None:
- preprocess = get_wan22_i2v_pre_process_func(SimpleNamespace())
- request = SimpleNamespace(
- prompts=[{"prompt": "p", "multi_modal_data": {"image": Image.new("RGB", (320, 160), "red")}}],
- sampling_params=SimpleNamespace(height=None, width=None),
- )
-
- result = preprocess(request)
- prompt = result.prompts[0]
-
- assert result.sampling_params.height == 432
- assert result.sampling_params.width == 880
- assert prompt["multi_modal_data"]["image"].size == (880, 432)
-
- missing_image = SimpleNamespace(
- prompts=[{"prompt": "p", "multi_modal_data": {}}],
- sampling_params=SimpleNamespace(height=None, width=None),
- )
- with pytest.raises(ValueError, match="No image is provided"):
- preprocess(missing_image)
-
-
-def test_i2v_diffuse_selects_stage_guidance_and_expands_timesteps() -> None:
- pipeline = _make_i2v_pipeline(expand_timesteps=True)
- latents = torch.zeros(1, 4, 2, 4, 4)
- condition = torch.ones_like(latents)
- first_frame_mask = torch.ones(1, 1, 2, 4, 4)
- first_frame_mask[:, :, 0] = 0
- timesteps = torch.tensor([900, 100])
-
- calls = []
-
- def fake_predict_noise_maybe_with_cfg(**kwargs):
- positive = kwargs["positive_kwargs"]
- calls.append(
- {
- "model": positive["current_model"].name,
- "scale": kwargs["true_cfg_scale"],
- "timestep_shape": tuple(positive["timestep"].shape),
- "timestep_values": positive["timestep"].clone(),
- "hidden_states": positive["hidden_states"].clone(),
- }
- )
- return torch.ones_like(latents)
-
- pipeline.predict_noise_maybe_with_cfg = fake_predict_noise_maybe_with_cfg # type: ignore[method-assign]
- pipeline.scheduler_step_maybe_with_cfg = lambda noise, t, current, cfg: current + noise # type: ignore[method-assign]
-
- result = pipeline.diffuse(
- latents=latents,
- timesteps=timesteps,
- prompt_embeds=torch.zeros(1, 2, 3),
- negative_prompt_embeds=None,
- image_embeds=None,
- guidance_low=1.0,
- guidance_high=2.0,
- boundary_timestep=500.0,
- dtype=torch.float32,
- attention_kwargs={},
- condition=condition,
- first_frame_mask=first_frame_mask,
- )
-
- assert [call["model"] for call in calls] == ["high", "low"]
- assert [call["scale"] for call in calls] == [1.0, 2.0]
- assert calls[0]["timestep_shape"] == (1, 8)
- timestep_dtype = calls[0]["timestep_values"].dtype
- torch.testing.assert_close(calls[0]["timestep_values"][0, :4], torch.zeros(4, dtype=timestep_dtype))
- torch.testing.assert_close(calls[0]["timestep_values"][0, 4:], torch.full((4,), 900, dtype=timestep_dtype))
- torch.testing.assert_close(calls[0]["hidden_states"][:, :, 0], torch.ones(1, 4, 4, 4))
- torch.testing.assert_close(result, torch.full_like(latents, 2.0))
-
-
-def test_i2v_prepare_latents_builds_expand_condition_and_first_frame_mask() -> None:
- pipeline = _make_i2v_pipeline(expand_timesteps=True)
- latents, condition, first_frame_mask = pipeline.prepare_latents(
- image=torch.zeros(1, 3, 16, 16),
- batch_size=1,
- num_channels_latents=4,
- height=16,
- width=16,
- num_frames=5,
- dtype=torch.float32,
- device=torch.device("cpu"),
- generator=torch.Generator(device="cpu").manual_seed(0),
- )
-
- assert latents.shape == (1, 4, 2, 2, 2)
- assert condition.shape == (1, 4, 1, 2, 2)
- assert first_frame_mask.shape == (1, 1, 2, 2, 2)
- assert first_frame_mask[:, :, 0].sum() == 0
- assert first_frame_mask[:, :, 1].sum() == 4
diff --git a/tests/diffusion/models/wan2_2/test_wan22_pipeline_diffuse.py b/tests/diffusion/models/wan2_2/test_wan22_pipeline_diffuse.py
deleted file mode 100644
index 54bb672ef81..00000000000
--- a/tests/diffusion/models/wan2_2/test_wan22_pipeline_diffuse.py
+++ /dev/null
@@ -1,155 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from contextlib import contextmanager
-from types import SimpleNamespace
-
-import pytest
-import torch
-from torch import nn
-
-from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion]
-
-
-class _StubTransformer(nn.Module):
- @property
- def dtype(self) -> torch.dtype:
- return torch.float32
-
-
-class _StubScheduler:
- def __init__(self, timesteps: list[int]) -> None:
- self.timesteps = torch.tensor(timesteps, dtype=torch.int64)
- self.config = SimpleNamespace(num_train_timesteps=1000)
- self.set_timesteps_calls: list[tuple[int, torch.device]] = []
-
- def set_timesteps(self, num_steps: int, device: torch.device) -> None:
- self.set_timesteps_calls.append((num_steps, device))
-
-
-@contextmanager
-def _noop_progress_bar(*args, **kwargs):
- del args, kwargs
-
- class _Bar:
- def update(self) -> None:
- return None
-
- yield _Bar()
-
-
-def _make_pipeline() -> Wan22Pipeline:
- pipeline = object.__new__(Wan22Pipeline)
- nn.Module.__init__(pipeline)
- pipeline.device = torch.device("cpu")
- pipeline.transformer = _StubTransformer()
- pipeline.transformer_2 = None
- pipeline.transformer_config = SimpleNamespace(patch_size=(1, 2, 2), in_channels=4, out_channels=4)
- pipeline.scheduler = _StubScheduler([9, 5])
- pipeline.od_config = SimpleNamespace(flow_shift=5.0)
- pipeline._sample_solver = "unipc"
- pipeline._flow_shift = 5.0
- pipeline.vae_scale_factor_temporal = 4
- pipeline.vae_scale_factor_spatial = 8
- pipeline.boundary_ratio = 0.875
- pipeline.expand_timesteps = False
- pipeline._guidance_scale = None
- pipeline._guidance_scale_2 = None
- pipeline._num_timesteps = None
- pipeline._current_timestep = None
- pipeline.check_inputs = lambda **kwargs: None
- pipeline.prepare_latents = lambda **kwargs: torch.zeros((1, 4, 1, 8, 8), dtype=torch.float32)
- pipeline.progress_bar = _noop_progress_bar
- return pipeline
-
-
-def test_forward_delegates_denoising_to_diffuse(monkeypatch) -> None:
- pipeline = _make_pipeline()
-
- prompt_embeds = torch.randn(1, 8)
- captured: dict[str, object] = {}
-
- def _fake_diffuse(**kwargs):
- captured.update(kwargs)
- return kwargs["latents"] + 1
-
- pipeline.diffuse = _fake_diffuse # type: ignore[method-assign]
-
- req = SimpleNamespace(
- prompts=["prompt"],
- sampling_params=SimpleNamespace(
- height=None,
- width=None,
- num_frames=1,
- num_inference_steps=2,
- guidance_scale_provided=False,
- guidance_scale=None,
- guidance_scale_2=None,
- boundary_ratio=None,
- generator=None,
- seed=None,
- num_outputs_per_prompt=1,
- max_sequence_length=32,
- latents=None,
- extra_args={},
- ),
- )
-
- output = pipeline.forward(req, prompt_embeds=prompt_embeds, output_type="latent", guidance_scale=1.0)
-
- assert torch.equal(output.output, torch.ones((1, 4, 1, 8, 8)))
- assert torch.equal(captured["timesteps"], pipeline.scheduler.timesteps)
- assert captured["guidance_low"] == 1.0
- assert captured["guidance_high"] == 1.0
- assert captured["boundary_timestep"] == pytest.approx(875.0)
- assert captured["latent_condition"] is None
- assert captured["first_frame_mask"] is None
- assert pipeline.scheduler.set_timesteps_calls == [(2, torch.device("cpu"))]
-
-
-def test_diffuse_runs_prediction_and_scheduler_for_each_timestep() -> None:
- pipeline = _make_pipeline()
- latents = torch.zeros((1, 1, 1, 2, 2), dtype=torch.float32)
- timesteps = torch.tensor([7, 3], dtype=torch.int64)
- prompt_embeds = torch.randn(1, 8)
-
- predict_calls: list[dict[str, object]] = []
- scheduler_calls: list[tuple[float, int, float, bool]] = []
-
- def _fake_predict_noise_maybe_with_cfg(**kwargs):
- predict_calls.append(kwargs)
- timestep = kwargs["positive_kwargs"]["timestep"]
- assert isinstance(timestep, torch.Tensor)
- return torch.full_like(latents, float(timestep[0].item()))
-
- def _fake_scheduler_step_maybe_with_cfg(noise_pred, t, current_latents, do_true_cfg):
- scheduler_calls.append(
- (float(noise_pred[0, 0, 0, 0, 0]), int(t.item()), float(current_latents.sum()), do_true_cfg)
- )
- return current_latents + noise_pred
-
- pipeline.predict_noise_maybe_with_cfg = _fake_predict_noise_maybe_with_cfg # type: ignore[method-assign]
- pipeline.scheduler_step_maybe_with_cfg = _fake_scheduler_step_maybe_with_cfg # type: ignore[method-assign]
-
- result = pipeline.diffuse(
- latents=latents,
- timesteps=timesteps,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=None,
- guidance_low=1.0,
- guidance_high=2.0,
- boundary_timestep=5.0,
- dtype=torch.float32,
- attention_kwargs={},
- )
-
- assert len(predict_calls) == 2
- assert predict_calls[0]["true_cfg_scale"] == 1.0
- assert predict_calls[1]["true_cfg_scale"] == 2.0
- assert scheduler_calls == [
- (7.0, 7, 0.0, False),
- (3.0, 3, 28.0, False),
- ]
- assert torch.equal(result, torch.full_like(latents, 10.0))
diff --git a/tests/diffusion/models/wan2_2/test_wan22_pipeline_helpers.py b/tests/diffusion/models/wan2_2/test_wan22_pipeline_helpers.py
deleted file mode 100644
index 31471786976..00000000000
--- a/tests/diffusion/models/wan2_2/test_wan22_pipeline_helpers.py
+++ /dev/null
@@ -1,81 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import json
-from types import SimpleNamespace
-
-import pytest
-import torch
-
-import vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 as wan22_module
-from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
- create_transformer_from_config,
- load_transformer_config,
- retrieve_latents,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion]
-
-
-class _LatentDist:
- def sample(self, generator):
- assert isinstance(generator, torch.Generator)
- return torch.tensor([1.0])
-
- def mode(self):
- return torch.tensor([2.0])
-
-
-def test_retrieve_latents_supports_sample_mode_argmax_and_direct_latents() -> None:
- generator = torch.Generator(device="cpu")
-
- assert retrieve_latents(SimpleNamespace(latent_dist=_LatentDist()), generator).item() == 1.0
- assert retrieve_latents(SimpleNamespace(latent_dist=_LatentDist()), sample_mode="argmax").item() == 2.0
- torch.testing.assert_close(retrieve_latents(SimpleNamespace(latents=torch.tensor([3.0]))), torch.tensor([3.0]))
-
-
-def test_retrieve_latents_rejects_unknown_encoder_output() -> None:
- with pytest.raises(AttributeError, match="Could not access latents"):
- retrieve_latents(SimpleNamespace())
-
-
-def test_load_transformer_config_reads_local_subfolder_config(tmp_path) -> None:
- config_dir = tmp_path / "transformer_2"
- config_dir.mkdir(parents=True)
- (config_dir / "config.json").write_text(json.dumps({"patch_size": [1, 2, 2], "num_layers": 2}))
-
- assert load_transformer_config(str(tmp_path), "transformer_2") == {"patch_size": [1, 2, 2], "num_layers": 2}
- assert load_transformer_config(str(tmp_path), "missing") == {}
-
-
-def test_create_transformer_from_config_maps_supported_keys(monkeypatch) -> None:
- captured = {}
-
- class FakeTransformer:
- def __init__(self, **kwargs) -> None:
- captured.update(kwargs)
-
- monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer)
-
- transformer = create_transformer_from_config(
- {
- "patch_size": [1, 2, 2],
- "num_attention_heads": 8,
- "attention_head_dim": 128,
- "in_channels": 16,
- "out_channels": 16,
- "text_dim": 4096,
- "vace_layers": [0],
- "ignored": "value",
- }
- )
-
- assert isinstance(transformer, FakeTransformer)
- assert captured == {
- "patch_size": (1, 2, 2),
- "num_attention_heads": 8,
- "attention_head_dim": 128,
- "in_channels": 16,
- "out_channels": 16,
- "text_dim": 4096,
- }
diff --git a/tests/diffusion/models/wan2_2/test_wan22_ti2v_pipeline.py b/tests/diffusion/models/wan2_2/test_wan22_ti2v_pipeline.py
deleted file mode 100644
index e611c37b6ad..00000000000
--- a/tests/diffusion/models/wan2_2/test_wan22_ti2v_pipeline.py
+++ /dev/null
@@ -1,97 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from types import SimpleNamespace
-
-import pytest
-import torch
-from PIL import Image
-from torch import nn
-
-from tests.diffusion.models.wan2_2.conftest import StubTransformer, StubVAE, noop_progress_bar
-from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_ti2v import (
- Wan22TI2VPipeline,
- get_wan22_ti2v_pre_process_func,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion]
-
-
-def _make_ti2v_pipeline() -> Wan22TI2VPipeline:
- pipeline = object.__new__(Wan22TI2VPipeline)
- nn.Module.__init__(pipeline)
- pipeline.device = torch.device("cpu")
- pipeline.transformer = StubTransformer(in_channels=4, out_channels=4)
- pipeline.vae = StubVAE(z_dim=4)
- pipeline.vae_scale_factor_temporal = 4
- pipeline.vae_scale_factor_spatial = 8
- pipeline.progress_bar = noop_progress_bar
- return pipeline
-
-
-def test_ti2v_preprocess_uses_720p_area_for_image_condition() -> None:
- preprocess = get_wan22_ti2v_pre_process_func(SimpleNamespace())
- request = SimpleNamespace(
- prompts=[{"prompt": "p", "multi_modal_data": {"image": Image.new("RGB", (320, 160), "blue")}}],
- sampling_params=SimpleNamespace(height=None, width=None),
- )
-
- result = preprocess(request)
-
- assert result.sampling_params.height == 672
- assert result.sampling_params.width == 1344
- assert result.prompts[0]["multi_modal_data"]["image"].size == (1344, 672)
-
-
-def test_ti2v_diffuse_without_image_condition_expands_patch_timesteps() -> None:
- pipeline = _make_ti2v_pipeline()
- latents = torch.zeros(1, 4, 2, 4, 4)
- calls = []
-
- def fake_predict_noise_maybe_with_cfg(**kwargs):
- calls.append(kwargs)
- return torch.ones_like(latents)
-
- pipeline.predict_noise_maybe_with_cfg = fake_predict_noise_maybe_with_cfg # type: ignore[method-assign]
- pipeline.scheduler_step_maybe_with_cfg = lambda noise, t, current, cfg: current + noise # type: ignore[method-assign]
-
- result = pipeline.diffuse(
- latents=latents,
- timesteps=torch.tensor([7]),
- prompt_embeds=torch.zeros(1, 2, 3),
- negative_prompt_embeds=torch.zeros(1, 2, 3),
- guidance_scale=3.0,
- dtype=torch.float32,
- attention_kwargs={"a": "b"},
- num_latent_frames=2,
- latent_height=4,
- latent_width=4,
- )
-
- positive = calls[0]["positive_kwargs"]
- assert calls[0]["do_true_cfg"] is True
- assert positive["timestep"].shape == (1, 8)
- torch.testing.assert_close(positive["timestep"], torch.full((1, 8), 7, dtype=positive["timestep"].dtype))
- torch.testing.assert_close(positive["hidden_states"], latents)
- torch.testing.assert_close(result, torch.ones_like(latents))
-
-
-def test_ti2v_prepare_i2v_latents_encodes_condition_and_masks_first_frame() -> None:
- pipeline = _make_ti2v_pipeline()
- latents, latent_condition, first_frame_mask = pipeline.prepare_i2v_latents(
- image=torch.zeros(1, 3, 16, 16),
- batch_size=1,
- num_channels_latents=4,
- height=16,
- width=16,
- num_frames=5,
- dtype=torch.float32,
- device=torch.device("cpu"),
- generator=None,
- latents=torch.zeros(1, 4, 2, 2, 2),
- )
-
- torch.testing.assert_close(latents, torch.zeros(1, 4, 2, 2, 2))
- assert latent_condition.shape == (1, 4, 1, 2, 2)
- assert first_frame_mask[:, :, 0].sum() == 0
- assert first_frame_mask[:, :, 1].sum() == 4
diff --git a/tests/diffusion/models/wan2_2/test_wan22_vace_pipeline.py b/tests/diffusion/models/wan2_2/test_wan22_vace_pipeline.py
deleted file mode 100644
index 9fa9b67c499..00000000000
--- a/tests/diffusion/models/wan2_2/test_wan22_vace_pipeline.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from types import SimpleNamespace
-
-import pytest
-import torch
-from PIL import Image
-from torch import nn
-
-from tests.diffusion.models.wan2_2.conftest import StubTransformer, StubVAE, noop_progress_bar
-from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_vace import (
- Wan22VACEPipeline,
- create_vace_transformer_from_config,
- get_wan22_vace_pre_process_func,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion]
-
-
-def _make_vace_pipeline() -> Wan22VACEPipeline:
- pipeline = object.__new__(Wan22VACEPipeline)
- nn.Module.__init__(pipeline)
- pipeline.device = torch.device("cpu")
- pipeline.transformer = StubTransformer(in_channels=4, out_channels=4)
- pipeline.transformer_config = pipeline.transformer.config
- pipeline.vae = StubVAE(z_dim=4)
- pipeline.vae_scale_factor_temporal = 4
- pipeline.vae_scale_factor_spatial = 8
- pipeline.progress_bar = noop_progress_bar
- return pipeline
-
-
-def test_vace_preprocess_collects_reference_video_and_mask_inputs() -> None:
- preprocess = get_wan22_vace_pre_process_func(SimpleNamespace())
- ref = Image.new("RGB", (320, 160), "green")
- frame = Image.new("RGB", (64, 64), "black")
- mask = Image.new("L", (64, 64), 255)
- request = SimpleNamespace(
- prompts=[
- {
- "prompt": "p",
- "multi_modal_data": {
- "image": ref,
- "video": [frame],
- "mask": mask,
- },
- }
- ],
- sampling_params=SimpleNamespace(height=None, width=None),
- )
-
- result = preprocess(request)
- additional_info = result.prompts[0]["additional_information"]
-
- assert result.sampling_params.height == 432
- assert result.sampling_params.width == 880
- assert additional_info["reference_images"] == [ref]
- assert additional_info["source_video"] == [frame]
- assert additional_info["mask"] == [mask]
-
-
-def test_create_vace_transformer_from_config_maps_vace_specific_keys(monkeypatch) -> None:
- captured = {}
-
- class FakeVACETransformer:
- def __init__(self, **kwargs) -> None:
- captured.update(kwargs)
-
- monkeypatch.setattr(
- "vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_vace.WanVACETransformer3DModel",
- FakeVACETransformer,
- )
-
- transformer = create_vace_transformer_from_config(
- {
- "patch_size": [1, 2, 2],
- "in_channels": 96,
- "out_channels": 16,
- "vace_layers": [0, 1, 2],
- "vace_in_channels": 132,
- "unknown": "ignored",
- }
- )
-
- assert isinstance(transformer, FakeVACETransformer)
- assert captured == {
- "patch_size": (1, 2, 2),
- "in_channels": 96,
- "out_channels": 16,
- "vace_layers": [0, 1, 2],
- "vace_in_channels": 132,
- }
-
-
-def test_vace_prepare_masks_encodes_spatial_stride_and_reference_padding() -> None:
- pipeline = _make_vace_pipeline()
- mask = torch.ones(1, 3, 5, 16, 16)
- reference_images = [[torch.zeros(3, 16, 16), torch.zeros(3, 16, 16)]]
-
- encoded = pipeline.prepare_masks(mask, reference_images)
-
- assert encoded.shape == (1, 64, 4, 2, 2)
- torch.testing.assert_close(encoded[:, :, :2], torch.zeros(1, 64, 2, 2, 2))
- torch.testing.assert_close(encoded[:, :, 2:], torch.ones(1, 64, 2, 2, 2))
-
-
-def test_vace_diffuse_passes_context_and_scale_to_cfg_branches() -> None:
- pipeline = _make_vace_pipeline()
- latents = torch.zeros(1, 4, 1, 2, 2)
- vace_context = torch.ones(1, 12, 1, 2, 2)
- calls = []
-
- def fake_predict_noise_maybe_with_cfg(**kwargs):
- calls.append(kwargs)
- return torch.ones_like(latents)
-
- pipeline.predict_noise_maybe_with_cfg = fake_predict_noise_maybe_with_cfg # type: ignore[method-assign]
- pipeline.scheduler_step_maybe_with_cfg = lambda noise, t, current, cfg: current + noise # type: ignore[method-assign]
-
- result = pipeline.diffuse(
- latents=latents,
- timesteps=torch.tensor([5]),
- prompt_embeds=torch.zeros(1, 2, 3),
- negative_prompt_embeds=torch.zeros(1, 2, 3),
- guidance_scale=4.0,
- dtype=torch.float32,
- attention_kwargs={},
- vace_context=vace_context,
- vace_context_scale=0.75,
- )
-
- assert calls[0]["do_true_cfg"] is True
- assert calls[0]["true_cfg_scale"] == 4.0
- assert calls[0]["positive_kwargs"]["vace_context"] is vace_context
- assert calls[0]["negative_kwargs"]["vace_context_scale"] == 0.75
- torch.testing.assert_close(result, torch.ones_like(latents))
diff --git a/tests/diffusion/offloader/test_layerwise_backend.py b/tests/diffusion/offloader/test_layerwise_backend.py
index 5fd80e75c22..7df3c1bb1a1 100644
--- a/tests/diffusion/offloader/test_layerwise_backend.py
+++ b/tests/diffusion/offloader/test_layerwise_backend.py
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for LayerwiseOffloadHook and LayerWiseOffloadBackend utilities."""
+"""Unit tests for LayerwiseOffloadHook."""
import gc
import os
@@ -15,7 +15,7 @@
from torch.distributed.tensor import DeviceMesh, DTensor, Replicate
import vllm_omni.diffusion.offloader.layerwise_backend as layerwise_backend_module
-from vllm_omni.diffusion.offloader.layerwise_backend import LayerWiseOffloadBackend, LayerwiseOffloadHook
+from vllm_omni.diffusion.offloader.layerwise_backend import LayerwiseOffloadHook
from vllm_omni.platforms import current_omni_platform
pytestmark = [pytest.mark.diffusion, pytest.mark.cpu, pytest.mark.core_model]
@@ -127,116 +127,3 @@ def test_dtensor_wrapper_is_preserved_across_prefetch_and_offload(self, dist_gro
assert current_block.weight.to_local().is_meta
assert current_block.weight.to_local().shape == torch.Size([4])
assert not hook.is_materialized
-
-
-class _DummyBlock(nn.Module):
- def __init__(self):
- super().__init__()
- self.weight = nn.Parameter(torch.randn(10, 10))
-
-
-class _SingleBlockModel(nn.Module):
- _layerwise_offload_blocks_attrs = ["blocks"]
-
- def __init__(self, num_blocks: int = 3):
- super().__init__()
- self.blocks = nn.ModuleList([_DummyBlock() for _ in range(num_blocks)])
-
-
-class _MultiBlockModel(nn.Module):
- _layerwise_offload_blocks_attrs = ["transformer_blocks", "single_transformer_blocks"]
-
- def __init__(self, num_transformer: int = 2, num_single: int = 2):
- super().__init__()
- self.transformer_blocks = nn.ModuleList([_DummyBlock() for _ in range(num_transformer)])
- self.single_transformer_blocks = nn.ModuleList([_DummyBlock() for _ in range(num_single)])
-
-
-class _EmptyBlocksModel(nn.Module):
- _layerwise_offload_blocks_attrs = ["blocks"]
-
- def __init__(self):
- super().__init__()
- self.blocks = nn.ModuleList([])
-
-
-class _InvalidAttrModel(nn.Module):
- _layerwise_offload_blocks_attrs = ["nonexistent_blocks", "blocks"]
-
- def __init__(self, num_blocks: int = 2):
- super().__init__()
- self.blocks = nn.ModuleList([_DummyBlock() for _ in range(num_blocks)])
-
-
-class _DeprecatedSingleAttrModel(nn.Module):
- _layerwise_offload_blocks_attr = "blocks"
-
- def __init__(self, num_blocks: int = 2):
- super().__init__()
- self.blocks = nn.ModuleList([_DummyBlock() for _ in range(num_blocks)])
-
-
-class _NoAttrsModel(nn.Module):
- def __init__(self, num_blocks: int = 2):
- super().__init__()
- self.blocks = nn.ModuleList([_DummyBlock() for _ in range(num_blocks)])
-
-
-class TestGetBlocksFromDit:
- def test_get_blocks_from_dit_single_block_attr(self):
- model = _SingleBlockModel(num_blocks=3)
- attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model)
- assert attr_names == ["blocks"]
- assert len(blocks) == 3
- assert all(isinstance(b, _DummyBlock) for b in blocks)
-
- def test_get_blocks_from_dit_multi_block_attrs(self):
- model = _MultiBlockModel(num_transformer=2, num_single=3)
- attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model)
- assert set(attr_names) == {"transformer_blocks", "single_transformer_blocks"}
- assert len(blocks) == 5
- assert all(isinstance(b, _DummyBlock) for b in blocks)
-
- def test_get_blocks_from_dit_empty_blocks(self):
- model = _EmptyBlocksModel()
- attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model)
- assert attr_names == []
- assert blocks == []
-
- def test_get_blocks_from_dit_invalid_attr_name(self):
- model = _InvalidAttrModel(num_blocks=2)
- with pytest.raises(
- AttributeError,
- match="Attribute 'nonexistent_blocks' declared in _layerwise_offload_blocks_attrs does not exist",
- ):
- LayerWiseOffloadBackend.get_blocks_from_dit(model)
-
- def test_get_blocks_from_dit_no_attrs_defined(self):
- model = _NoAttrsModel(num_blocks=3)
- attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model)
- assert attr_names == []
- assert blocks == []
-
- def test_get_blocks_from_dit_deprecated_single_attr(self):
- model = _DeprecatedSingleAttrModel(num_blocks=2)
- attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(model)
- assert attr_names == ["blocks"]
- assert len(blocks) == 2
-
-
-class TestGetBlocksAttrNames:
- def test_get_blocks_attr_names_new_format(self):
- model = _MultiBlockModel()
- attrs = LayerWiseOffloadBackend.get_blocks_attr_names(model)
- assert attrs == ["transformer_blocks", "single_transformer_blocks"]
-
- def test_get_blocks_attr_names_no_attrs(self):
- model = _NoAttrsModel()
- attrs = LayerWiseOffloadBackend.get_blocks_attr_names(model)
- assert attrs == []
-
- def test_set_blocks_attr_names(self):
- model = _NoAttrsModel()
- LayerWiseOffloadBackend.set_blocks_attr_names(model, ["new_blocks"])
- assert hasattr(model.__class__, "_layerwise_offload_blocks_attrs")
- assert model.__class__._layerwise_offload_blocks_attrs == ["new_blocks"]
diff --git a/tests/diffusion/offloader/test_module_collector.py b/tests/diffusion/offloader/test_module_collector.py
deleted file mode 100644
index ab15ad8df60..00000000000
--- a/tests/diffusion/offloader/test_module_collector.py
+++ /dev/null
@@ -1,240 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""Unit tests for ModuleDiscovery and SupportsModuleOffload."""
-
-from typing import ClassVar
-
-import pytest
-from torch import nn
-
-from vllm_omni.diffusion.models.interface import SupportsModuleOffload
-from vllm_omni.diffusion.offloader.module_collector import ModuleDiscovery
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.cpu, pytest.mark.core_model]
-
-# NOTE: tests for skipped/warned attributes verify the *behavioral*
-# outcome (attribute excluded from results) but do not assert on log
-# output. vllm's logger sets propagate=False, preventing caplog from
-# capturing records. See https://github.com/pytest-dev/pytest/issues/3697
-
-
-# ---------------------------------------------------------------------------
-# Test pipelines
-# ---------------------------------------------------------------------------
-
-
-class FallbackPipeline(nn.Module):
- """Pipeline with standard attribute names (no protocol)."""
-
- def __init__(self):
- super().__init__()
- self.transformer = nn.Linear(10, 10)
- self.text_encoder = nn.Linear(10, 10)
- self.text_encoder_2 = nn.Linear(10, 10)
- self.vae = nn.Linear(10, 10)
-
-
-class NonModuleAttrPipeline(nn.Module):
- """Pipeline where an attribute is not an nn.Module (fallback path)."""
-
- def __init__(self):
- super().__init__()
- self.transformer = nn.Linear(10, 10)
- self.text_encoder = "not_a_module"
- self.vae = nn.Linear(10, 10)
-
-
-class DuplicateAttrPipeline(nn.Module):
- """Pipeline where two encoder attrs point to the same module."""
-
- def __init__(self):
- super().__init__()
- self.transformer = nn.Linear(10, 10)
- encoder = nn.Linear(10, 10)
- self.text_encoder = encoder
- self.text_encoder_2 = encoder
- self.vae = nn.Linear(10, 10)
-
-
-class ProtocolPipeline(nn.Module, SupportsModuleOffload):
- """Pipeline with non-standard names, using the protocol."""
-
- _dit_modules: ClassVar[list[str]] = ["gen_transformer"]
- _encoder_modules: ClassVar[list[str]] = ["mllm", "vision_model"]
- _vae_modules: ClassVar[list[str]] = ["gen_vae"]
-
- def __init__(self):
- super().__init__()
- self.gen_transformer = nn.Linear(10, 10)
- self.mllm = nn.Linear(10, 10)
- self.vision_model = nn.Linear(10, 10)
- self.gen_vae = nn.Linear(10, 10)
- # Standard name present but NOT declared — should be ignored
- self.transformer = nn.Linear(10, 10)
-
-
-class MissingAttrPipeline(nn.Module, SupportsModuleOffload):
- """Pipeline that declares a non-existent attribute."""
-
- _dit_modules: ClassVar[list[str]] = ["transformer"]
- _encoder_modules: ClassVar[list[str]] = ["nonexistent_encoder"]
- _vae_modules: ClassVar[list[str]] = ["vae"]
-
- def __init__(self):
- super().__init__()
- self.transformer = nn.Linear(10, 10)
- self.vae = nn.Linear(10, 10)
-
-
-class MissingIntermediatePipeline(nn.Module, SupportsModuleOffload):
- """Pipeline with dotted path referencing non-existent intermediate."""
-
- _dit_modules: ClassVar[list[str]] = ["nonexistent.transformer"]
- _encoder_modules: ClassVar[list[str]] = []
- _vae_modules: ClassVar[list[str]] = []
-
- def __init__(self):
- super().__init__()
-
-
-class NestedPipeline(nn.Module, SupportsModuleOffload):
- """Pipeline with nested modules accessed via dotted paths."""
-
- _dit_modules: ClassVar[list[str]] = ["pipe.transformer"]
- _encoder_modules: ClassVar[list[str]] = ["pipe.text_encoder"]
- _vae_modules: ClassVar[list[str]] = ["vae"]
-
- def __init__(self):
- super().__init__()
- self.pipe = nn.Module()
- self.pipe.transformer = nn.Linear(10, 10)
- self.pipe.text_encoder = nn.Linear(10, 10)
- self.vae = nn.Linear(10, 10)
-
-
-class ResidentPipeline(nn.Module, SupportsModuleOffload):
- """Pipeline with resident modules that must stay on GPU."""
-
- _dit_modules: ClassVar[list[str]] = ["language_model.model"]
- _encoder_modules: ClassVar[list[str]] = []
- _vae_modules: ClassVar[list[str]] = ["vae"]
- _resident_modules: ClassVar[list[str]] = [
- "bagel.time_embedder",
- "bagel.vae2llm",
- ]
-
- def __init__(self):
- super().__init__()
- self.language_model = nn.Module()
- self.language_model.model = nn.Linear(10, 10)
- self.bagel = nn.Module()
- self.bagel.time_embedder = nn.Linear(10, 10)
- self.bagel.vae2llm = nn.Linear(10, 10)
- self.vae = nn.Linear(10, 10)
-
-
-class MultiVaePipeline(nn.Module, SupportsModuleOffload):
- """Pipeline with multiple VAEs."""
-
- _dit_modules: ClassVar[list[str]] = ["transformer"]
- _encoder_modules: ClassVar[list[str]] = ["text_encoder"]
- _vae_modules: ClassVar[list[str]] = ["vae", "audio_vae"]
-
- def __init__(self):
- super().__init__()
- self.transformer = nn.Linear(10, 10)
- self.text_encoder = nn.Linear(10, 10)
- self.vae = nn.Linear(10, 10)
- self.audio_vae = nn.Linear(10, 10)
-
-
-# ---------------------------------------------------------------------------
-# Tests
-# ---------------------------------------------------------------------------
-
-
-class TestFallbackDiscovery:
- """Test the fallback attribute scan (no SupportsModuleOffload)."""
-
- def test_discovers_standard_attrs(self):
- pipeline = FallbackPipeline()
- result = ModuleDiscovery.discover(pipeline)
-
- assert not isinstance(pipeline, SupportsModuleOffload)
- assert result.dit_names == ["transformer"]
- assert result.dits[0] is pipeline.transformer
- assert result.encoder_names == ["text_encoder", "text_encoder_2"]
- assert result.vaes[0] is pipeline.vae
- assert result.resident_modules == []
-
- def test_deduplicates_encoders(self):
- pipeline = DuplicateAttrPipeline()
- result = ModuleDiscovery.discover(pipeline)
-
- assert len(result.encoders) == 1
- assert result.encoder_names == ["text_encoder"]
-
- def test_skips_non_module_attr(self):
- pipeline = NonModuleAttrPipeline()
- result = ModuleDiscovery.discover(pipeline)
-
- assert len(result.encoders) == 0
-
-
-class TestProtocolDiscovery:
- """Test discovery via SupportsModuleOffload protocol."""
-
- def test_discovers_declared_attrs_and_ignores_undeclared(self):
- pipeline = ProtocolPipeline()
- result = ModuleDiscovery.discover(pipeline)
-
- assert isinstance(pipeline, SupportsModuleOffload)
- assert result.dit_names == ["gen_transformer"]
- assert result.encoder_names == ["mllm", "vision_model"]
- assert len(result.vaes) == 1
- # self.transformer exists but is NOT in _dit_modules
- assert "transformer" not in result.dit_names
- # No _resident_modules declared — defaults to empty
- assert result.resident_modules == []
-
- def test_skips_missing_attr(self):
- pipeline = MissingAttrPipeline()
- result = ModuleDiscovery.discover(pipeline)
-
- assert len(result.encoders) == 0
-
- def test_skips_missing_intermediate(self):
- result = ModuleDiscovery.discover(MissingIntermediatePipeline())
-
- assert len(result.dits) == 0
-
- def test_dotted_path_resolves_nested_modules(self):
- pipeline = NestedPipeline()
- result = ModuleDiscovery.discover(pipeline)
-
- assert result.dit_names == ["pipe.transformer"]
- assert result.dits[0] is pipeline.pipe.transformer
- assert result.encoder_names == ["pipe.text_encoder"]
- assert result.encoders[0] is pipeline.pipe.text_encoder
- assert result.vaes[0] is pipeline.vae
-
- def test_resident_modules(self):
- pipeline = ResidentPipeline()
- result = ModuleDiscovery.discover(pipeline)
-
- assert result.resident_names == [
- "bagel.time_embedder",
- "bagel.vae2llm",
- ]
- assert result.resident_modules[0] is pipeline.bagel.time_embedder
- assert result.resident_modules[1] is pipeline.bagel.vae2llm
- assert result.dits[0] is pipeline.language_model.model
-
- def test_multiple_vaes(self):
- pipeline = MultiVaePipeline()
- result = ModuleDiscovery.discover(pipeline)
-
- assert len(result.vaes) == 2
- assert result.vaes[0] is pipeline.vae
- assert result.vaes[1] is pipeline.audio_vae
diff --git a/tests/diffusion/offloader/test_sequential_backend.py b/tests/diffusion/offloader/test_sequential_backend.py
index 2539cc06895..d18637a780e 100644
--- a/tests/diffusion/offloader/test_sequential_backend.py
+++ b/tests/diffusion/offloader/test_sequential_backend.py
@@ -3,6 +3,8 @@
"""Unit tests for SequentialOffloadBackend."""
+from unittest.mock import patch
+
import pytest
import torch
from torch import nn
@@ -42,7 +44,7 @@ def mock(self):
class TestMoveParamsPinMemory:
- def test_dtensor_skips_pin_memory(self, accelerator_device, monkeypatch: pytest.MonkeyPatch):
+ def test_dtensor_skips_pin_memory(self, accelerator_device):
"""DTensor should skip pin_memory to avoid RuntimeError."""
module = _create_simple_module().to(accelerator_device)
tracker, mock_pin = _track_pin_memory_calls()
@@ -54,73 +56,73 @@ def fake_isinstance(obj, cls):
return True
return original_isinstance(obj, cls)
- monkeypatch.setattr(torch.Tensor, "pin_memory", mock_pin)
- monkeypatch.setattr("builtins.isinstance", fake_isinstance)
- hook = SequentialOffloadHook(
- offload_targets=[],
- device=accelerator_device,
- pin_memory=True,
- use_hsdp=False,
- )
- hook._move_params(
- module,
- torch.device("cpu"),
- non_blocking=False,
- pin_memory=True,
- )
- assert not tracker["called"], "pin_memory should not be called for DTensor"
-
- def test_regular_tensor_calls_pin_memory(self, accelerator_device, monkeypatch: pytest.MonkeyPatch):
+ with patch.object(torch.Tensor, "pin_memory", mock_pin):
+ with patch("builtins.isinstance", fake_isinstance):
+ hook = SequentialOffloadHook(
+ offload_targets=[],
+ device=accelerator_device,
+ pin_memory=True,
+ use_hsdp=False,
+ )
+ hook._move_params(
+ module,
+ torch.device("cpu"),
+ non_blocking=False,
+ pin_memory=True,
+ )
+ assert not tracker["called"], "pin_memory should not be called for DTensor"
+
+ def test_regular_tensor_calls_pin_memory(self, accelerator_device):
"""Regular tensor should call pin_memory when moving to CPU."""
module = _create_simple_module().to(accelerator_device)
tracker, mock_pin = _track_pin_memory_calls()
- monkeypatch.setattr(torch.Tensor, "pin_memory", mock_pin)
- hook = SequentialOffloadHook(
- offload_targets=[],
- device=accelerator_device,
- pin_memory=True,
- use_hsdp=False,
- )
- hook._move_params(
- module,
- torch.device("cpu"),
- non_blocking=False,
- pin_memory=True,
- )
- assert tracker["called"], "pin_memory should be called for regular tensors"
-
- def test_pin_memory_skipped_when_disabled(self, accelerator_device, monkeypatch: pytest.MonkeyPatch):
+ with patch.object(torch.Tensor, "pin_memory", mock_pin):
+ hook = SequentialOffloadHook(
+ offload_targets=[],
+ device=accelerator_device,
+ pin_memory=True,
+ use_hsdp=False,
+ )
+ hook._move_params(
+ module,
+ torch.device("cpu"),
+ non_blocking=False,
+ pin_memory=True,
+ )
+ assert tracker["called"], "pin_memory should be called for regular tensors"
+
+ def test_pin_memory_skipped_when_disabled(self, accelerator_device):
"""pin_memory should not be called when pin_memory=False."""
module = _create_simple_module().to(accelerator_device)
tracker, mock_pin = _track_pin_memory_calls()
- monkeypatch.setattr(torch.Tensor, "pin_memory", mock_pin)
- hook = SequentialOffloadHook(
- offload_targets=[],
- device=accelerator_device,
- pin_memory=False,
- use_hsdp=False,
- )
- hook._move_params(
- module,
- torch.device("cpu"),
- non_blocking=False,
- pin_memory=False,
- )
- assert not tracker["called"], "pin_memory should not be called when disabled"
-
- def test_pin_memory_skipped_for_non_cpu_target(self, accelerator_device, monkeypatch: pytest.MonkeyPatch):
+ with patch.object(torch.Tensor, "pin_memory", mock_pin):
+ hook = SequentialOffloadHook(
+ offload_targets=[],
+ device=accelerator_device,
+ pin_memory=False,
+ use_hsdp=False,
+ )
+ hook._move_params(
+ module,
+ torch.device("cpu"),
+ non_blocking=False,
+ pin_memory=False,
+ )
+ assert not tracker["called"], "pin_memory should not be called when disabled"
+
+ def test_pin_memory_skipped_for_non_cpu_target(self, accelerator_device):
"""pin_memory should not be called for non-CPU targets."""
module = _create_simple_module().to("cpu")
tracker, mock_pin = _track_pin_memory_calls()
- monkeypatch.setattr(torch.Tensor, "pin_memory", mock_pin)
- hook = SequentialOffloadHook(
- offload_targets=[],
- device=torch.device("cpu"),
- pin_memory=True,
- use_hsdp=False,
- )
- hook._move_params(module, accelerator_device, non_blocking=False, pin_memory=True)
- assert not tracker["called"], "pin_memory should not be called for non-CPU target"
+ with patch.object(torch.Tensor, "pin_memory", mock_pin):
+ hook = SequentialOffloadHook(
+ offload_targets=[],
+ device=torch.device("cpu"),
+ pin_memory=True,
+ use_hsdp=False,
+ )
+ hook._move_params(module, accelerator_device, non_blocking=False, pin_memory=True)
+ assert not tracker["called"], "pin_memory should not be called for non-CPU target"
diff --git a/tests/diffusion/quantization/test_component_routing.py b/tests/diffusion/quantization/test_component_routing.py
deleted file mode 100644
index c8b3837e256..00000000000
--- a/tests/diffusion/quantization/test_component_routing.py
+++ /dev/null
@@ -1,408 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Tests for component routing for quantization."""
-
-from unittest.mock import MagicMock
-
-import pytest
-import torch
-from vllm.model_executor.layers.quantization.base_config import (
- QuantizationConfig,
-)
-from vllm.model_executor.models.utils import WeightsMapper
-
-from vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_thinker import (
- PRE_QUANTIZED_METHODS,
-)
-from vllm_omni.quantization.component_config import (
- ComponentQuantizationConfig,
-)
-from vllm_omni.quantization.inc_config import OmniINCConfig
-
-pytestmark = [pytest.mark.core_model]
-
-
-# ---------------------------------------------------------------------------
-# Helpers: lightweight mock quant configs
-# ---------------------------------------------------------------------------
-
-
-class _MockQuantConfig(QuantizationConfig):
- """Minimal mock that only implements get_name()."""
-
- def __init__(self, name: str, **attrs):
- self._name = name
- for k, v in attrs.items():
- setattr(self, k, v)
-
- def get_name(self) -> str:
- return self._name
-
- def get_quant_method(self, layer, prefix):
- return MagicMock()
-
- @classmethod
- def get_supported_act_dtypes(cls):
- return [torch.bfloat16, torch.float16]
-
- def get_min_capability(self):
- return 0
-
- @classmethod
- def from_config(cls, config):
- raise NotImplementedError
-
- def get_config_filenames(self):
- return []
-
-
-def _make_inc_config(block_names="thinker.model.layers,talker.model.layers", extra_config=None):
- """Create a real OmniINCConfig with block_name_to_quantize."""
- return OmniINCConfig(
- weight_bits=4,
- group_size=128,
- sym=True,
- block_name_to_quantize=block_names,
- extra_config=extra_config or {},
- )
-
-
-THINKER_MAPPER = WeightsMapper(
- orig_to_new_prefix={
- "thinker.lm_head.": "language_model.lm_head.",
- "thinker.model.": "language_model.model.",
- "thinker.": "",
- }
-)
-
-TALKER_MAPPER = WeightsMapper(
- orig_to_new_prefix={
- "talker.codec_head.": "language_model.lm_head.",
- "talker.model.": "language_model.model.",
- "talker.thinker_to_talker_proj.": "thinker_to_talker_proj.",
- "talker.": "",
- }
-)
-
-
-# ===================================================================
-# 1. OmniINCConfig.apply_vllm_mapper
-# ===================================================================
-
-
-class TestApplyVllmMapper:
- def test_inc_csv_string_normalized_to_list(self):
- """CSV string block_name_to_quantize is split into a list."""
- cfg = _make_inc_config("thinker.model.layers,talker.model.layers")
- cfg.apply_vllm_mapper(THINKER_MAPPER)
- assert isinstance(cfg.block_name_to_quantize, list)
-
- def test_thinker_blocks_remapped(self):
- """thinker.model.layers -> language_model.model.layers after apply_vllm_mapper."""
- cfg = _make_inc_config("thinker.model.layers,talker.model.layers")
- cfg.apply_vllm_mapper(THINKER_MAPPER)
- assert any("language_model.model.layers" in b for b in cfg.block_name_to_quantize)
-
- def test_cross_stage_blocks_kept_unchanged(self):
- """Blocks not matching any mapper prefix are kept unchanged (harmless)."""
- cfg = _make_inc_config("thinker.model.layers,talker.model.layers")
- cfg.apply_vllm_mapper(THINKER_MAPPER)
- # talker.model.layers doesn't match any thinker mapper prefix → stays as-is
- assert "talker.model.layers" in cfg.block_name_to_quantize
-
- def test_talker_remap(self):
- """talker.model.layers -> language_model.model.layers with talker mapper."""
- cfg = _make_inc_config("thinker.model.layers,talker.model.layers")
- cfg.apply_vllm_mapper(TALKER_MAPPER)
- assert any("language_model.model.layers" in b for b in cfg.block_name_to_quantize)
- # thinker.model.layers doesn't match talker mapper → stays as-is
- assert "thinker.model.layers" in cfg.block_name_to_quantize
-
- def test_extra_config_keys_remapped(self):
- """Regex keys in extra_config get their escaped-dot prefixes remapped."""
- extra = {
- r".*thinker\.model\.layers\.0\.mlp\.gate.*": {"bits": 16, "data_type": "float"},
- }
- cfg = _make_inc_config("thinker.model.layers", extra_config=extra)
- cfg.apply_vllm_mapper(THINKER_MAPPER)
- # The key should now reference the vLLM runtime path
- assert any("language_model" in k for k in cfg.extra_config)
- # Original thinker\.model prefix should be replaced
- assert not any(r"thinker\.model" in k for k in cfg.extra_config)
-
- def test_single_block_name(self):
- """Only one block name (not CSV) still works."""
- cfg = _make_inc_config("thinker.model.layers")
- cfg.apply_vllm_mapper(THINKER_MAPPER)
- assert any("language_model.model.layers" in b for b in cfg.block_name_to_quantize)
-
- def test_already_list_block_names(self):
- """block_name_to_quantize already a list (not CSV string) works."""
- cfg = _make_inc_config(["thinker.model.layers", "talker.model.layers"])
- cfg.apply_vllm_mapper(THINKER_MAPPER)
- assert isinstance(cfg.block_name_to_quantize, list)
- assert any("language_model.model.layers" in b for b in cfg.block_name_to_quantize)
-
- def test_mutates_in_place(self):
- """apply_vllm_mapper mutates the config in place (same as upstream INCConfig)."""
- cfg = _make_inc_config("thinker.model.layers")
- original_id = id(cfg)
- cfg.apply_vllm_mapper(THINKER_MAPPER)
- assert id(cfg) == original_id
-
- # -- Stage prefix tests (runtime prefix = container + internal name) --
-
- def test_thinker_block_has_stage_prefix(self):
- """Mapped block name must start with 'thinker.' so runtime startswith() works."""
- cfg = _make_inc_config("thinker.model.layers,talker.model.layers")
- cfg.apply_vllm_mapper(THINKER_MAPPER)
- assert "thinker.language_model.model.layers" in cfg.block_name_to_quantize
-
- def test_talker_block_has_stage_prefix(self):
- """Mapped block name must start with 'talker.' so runtime startswith() works."""
- cfg = _make_inc_config("thinker.model.layers,talker.model.layers")
- cfg.apply_vllm_mapper(TALKER_MAPPER)
- assert "talker.language_model.model.layers" in cfg.block_name_to_quantize
-
- def test_thinker_block_matches_runtime_prefix(self):
- """Simulates get_layer_config's startswith() check for FusedMoE layers."""
- cfg = _make_inc_config("thinker.model.layers,talker.model.layers")
- cfg.apply_vllm_mapper(THINKER_MAPPER)
- runtime_prefix = "thinker.language_model.model.layers.0.mlp.experts"
- assert any(runtime_prefix.startswith(b) for b in cfg.block_name_to_quantize)
-
- def test_talker_block_matches_runtime_prefix(self):
- """Simulates get_layer_config's startswith() check for talker FusedMoE."""
- cfg = _make_inc_config("thinker.model.layers,talker.model.layers")
- cfg.apply_vllm_mapper(TALKER_MAPPER)
- runtime_prefix = "talker.language_model.model.layers.0.mlp.experts"
- assert any(runtime_prefix.startswith(b) for b in cfg.block_name_to_quantize)
-
- def test_extra_config_plain_key_has_stage_prefix(self):
- """Plain extra_config keys are remapped with stage prefix."""
- extra = {
- "talker.model.layers.0.mlp.shared_expert_gate": {"bits": 16},
- }
- cfg = _make_inc_config("talker.model.layers", extra_config=extra)
- cfg.apply_vllm_mapper(TALKER_MAPPER)
- assert "talker.language_model.model.layers.0.mlp.shared_expert_gate" in cfg.extra_config
-
- def test_extra_config_regex_key_still_works(self):
- """Regex extra_config keys use re.search so no stage prefix needed."""
- import re
-
- extra = {
- r".*thinker\.model\.layers\.0\.mlp\.gate.*": {"bits": 16},
- }
- cfg = _make_inc_config("thinker.model.layers", extra_config=extra)
- cfg.apply_vllm_mapper(THINKER_MAPPER)
- runtime_name = "thinker.language_model.model.layers.0.mlp.gate"
- matched = any(re.search(k, runtime_name) for k in cfg.extra_config)
- assert matched
-
-
-# ===================================================================
-# 2. OmniINCConfig upgrade helpers
-# ===================================================================
-
-
-class TestOmniINCConfigUpgrade:
- def test_maybe_upgrade_none(self):
- assert OmniINCConfig.maybe_upgrade(None) is None
-
- def test_maybe_upgrade_non_inc(self):
- """Non-INC configs are passed through unchanged."""
- cfg = _MockQuantConfig("fp8")
- assert OmniINCConfig.maybe_upgrade(cfg) is cfg
-
- def test_maybe_upgrade_already_omni(self):
- """Already OmniINCConfig is returned as-is."""
- cfg = _make_inc_config()
- assert OmniINCConfig.maybe_upgrade(cfg) is cfg
-
- def test_maybe_upgrade_vanilla_inc(self):
- """Vanilla INCConfig is promoted to OmniINCConfig."""
- from vllm.model_executor.layers.quantization.inc import INCConfig
-
- vanilla = INCConfig(weight_bits=4, group_size=128, sym=True)
- upgraded = OmniINCConfig.maybe_upgrade(vanilla)
- assert isinstance(upgraded, OmniINCConfig)
- assert upgraded.weight_bits == 4
- assert upgraded.group_size == 128
-
-
-# ===================================================================
-# 2. Three-branch thinker routing (simulated)
-# ===================================================================
-
-
-def _simulate_thinker_routing(quant_config):
- """Simulate the three-branch routing in thinker __init__.
-
- Returns (visual_quant_config, language_quant_config, wrapped_vllm_quant).
- """
- if isinstance(quant_config, ComponentQuantizationConfig):
- visual_quant_config = quant_config.resolve("visual")
- language_quant_config = quant_config.resolve("language_model")
- return visual_quant_config, language_quant_config, quant_config
- elif quant_config is not None:
- if quant_config.get_name() in PRE_QUANTIZED_METHODS:
- return quant_config, quant_config, quant_config
- else:
- language_quant_config = quant_config
- wrapped = ComponentQuantizationConfig(
- component_configs={"language_model": quant_config},
- default_config=None,
- )
- return None, language_quant_config, wrapped
- else:
- return None, None, None
-
-
-class TestThinkerRouting:
- def test_none(self):
- vis, lang, wrapped = _simulate_thinker_routing(None)
- assert vis is None
- assert lang is None
- assert wrapped is None
-
- @pytest.mark.parametrize("method", ["modelopt", "modelopt_fp4", "modelopt_mxfp8"])
- def test_pre_quantized_all_components(self, method):
- """Pre-quantized methods pass config to all components."""
- cfg = _MockQuantConfig(method)
- vis, lang, wrapped = _simulate_thinker_routing(cfg)
- assert vis is cfg
- assert lang is cfg
- assert wrapped is cfg
-
- def test_fp8_dynamic_language_only(self):
- """fp8 dynamic: visual=None, language gets original config."""
- cfg = _MockQuantConfig("fp8")
- vis, lang, wrapped = _simulate_thinker_routing(cfg)
- assert vis is None
- assert lang is cfg
- assert isinstance(wrapped, ComponentQuantizationConfig)
- assert wrapped.resolve("language_model") is cfg
- assert wrapped.resolve("visual") is None
-
- def test_inc_autoround_language_only(self):
- """INC/AutoRound: not in _PRE_QUANTIZED_METHODS -> wrapped like fp8."""
- cfg = _MockQuantConfig("inc")
- vis, lang, wrapped = _simulate_thinker_routing(cfg)
- assert vis is None
- assert lang is cfg
- assert isinstance(wrapped, ComponentQuantizationConfig)
-
- def test_component_config_passthrough(self):
- """Explicit ComponentQuantizationConfig is used directly."""
- inner_fp8 = _MockQuantConfig("fp8")
- inner_modelopt = _MockQuantConfig("modelopt")
- cqc = ComponentQuantizationConfig(
- component_configs={
- "visual": inner_modelopt,
- "language_model": inner_fp8,
- }
- )
- vis, lang, wrapped = _simulate_thinker_routing(cqc)
- assert vis is inner_modelopt
- assert lang is inner_fp8
- assert wrapped is cqc
-
-
-# ===================================================================
-# 3. Talker visual routing (init_multi_modal guard)
-# ===================================================================
-
-
-def _simulate_talker_visual_routing(quant_config):
- """Simulate the talker init_multi_modal visual routing."""
- if quant_config is not None and quant_config.get_name() in PRE_QUANTIZED_METHODS:
- return quant_config
- return None
-
-
-class TestTalkerVisualRouting:
- def test_none(self):
- assert _simulate_talker_visual_routing(None) is None
-
- @pytest.mark.parametrize("method", ["modelopt", "modelopt_fp4", "modelopt_mxfp8"])
- def test_pre_quantized_passes_through(self, method):
- """Pre-quantized methods pass quant config to visual."""
- cfg = _MockQuantConfig(method)
- assert _simulate_talker_visual_routing(cfg) is cfg
-
- def test_fp8_blocked(self):
- """fp8 dynamic must NOT be passed to visual."""
- cfg = _MockQuantConfig("fp8")
- assert _simulate_talker_visual_routing(cfg) is None
-
- def test_inc_blocked(self):
- """INC/AutoRound must NOT be passed to visual (not in _PRE_QUANTIZED_METHODS)."""
- cfg = _MockQuantConfig("inc")
- assert _simulate_talker_visual_routing(cfg) is None
-
-
-# ===================================================================
-# 4. ComponentQuantizationConfig.resolve
-# ===================================================================
-
-
-class TestComponentResolve:
- def test_longest_prefix_match(self):
- a = _MockQuantConfig("a")
- b = _MockQuantConfig("b")
- cqc = ComponentQuantizationConfig(component_configs={"language_model": a, "language_model.model": b})
- assert cqc.resolve("language_model.model.layers.0") is b
- assert cqc.resolve("language_model.lm_head") is a
-
- def test_no_match_returns_default(self):
- a = _MockQuantConfig("a")
- default = _MockQuantConfig("default")
- cqc = ComponentQuantizationConfig(
- component_configs={"language_model": a},
- default_config=default,
- )
- assert cqc.resolve("visual") is default
-
- def test_no_match_no_default_returns_none(self):
- a = _MockQuantConfig("a")
- cqc = ComponentQuantizationConfig(
- component_configs={"language_model": a},
- )
- assert cqc.resolve("visual") is None
-
- def test_get_name(self):
- cqc = ComponentQuantizationConfig(component_configs={})
- assert cqc.get_name() == "component"
-
- def test_get_quant_method_delegates(self):
- """get_quant_method dispatches to the resolved config."""
- inner = _MockQuantConfig("fp8")
- cqc = ComponentQuantizationConfig(
- component_configs={"language_model": inner},
- )
- layer = MagicMock()
- result = cqc.get_quant_method(layer, "language_model.model.layers.0.mlp")
- assert result is not None # delegates to inner.get_quant_method
-
- def test_get_quant_method_returns_none_for_unmatched(self):
- """get_quant_method returns None when no config matches."""
- inner = _MockQuantConfig("fp8")
- cqc = ComponentQuantizationConfig(
- component_configs={"language_model": inner},
- )
- layer = MagicMock()
- result = cqc.get_quant_method(layer, "visual.blocks.0.mlp")
- assert result is None
-
- def test_min_capability(self):
- a = _MockQuantConfig("a")
- a.get_min_capability = lambda: 80
- b = _MockQuantConfig("b")
- b.get_min_capability = lambda: 70
- cqc = ComponentQuantizationConfig(component_configs={"x": a, "y": b})
- assert cqc.get_min_capability() == 70
-
- def test_min_capability_empty(self):
- cqc = ComponentQuantizationConfig(component_configs={})
- assert cqc.get_min_capability() == 0
diff --git a/tests/diffusion/quantization/test_fp8_config.py b/tests/diffusion/quantization/test_fp8_config.py
index 574af7a6699..9c18c1f551b 100644
--- a/tests/diffusion/quantization/test_fp8_config.py
+++ b/tests/diffusion/quantization/test_fp8_config.py
@@ -5,7 +5,7 @@
import pytest
from torch import nn
-pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]
+pytestmark = [pytest.mark.core_model, pytest.mark.diffusion]
def test_build_quant_config_fp8():
diff --git a/tests/diffusion/quantization/test_int8_config.py b/tests/diffusion/quantization/test_int8_config.py
index 875277ece42..d4d5aa5a7fe 100644
--- a/tests/diffusion/quantization/test_int8_config.py
+++ b/tests/diffusion/quantization/test_int8_config.py
@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for Int8 quantization config."""
+from unittest.mock import MagicMock, patch
+
import pytest
import torch
from pytest_mock import MockerFixture
@@ -100,7 +102,7 @@ def test_quantization_config_string_and_dict_equivalent():
assert config_str.quantization_config.activation_scheme == config_dict.quantization_config.activation_scheme
-def test_get_quant_method(mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch):
+def test_get_quant_method(mocker: MockerFixture):
"""Test for get_quant_method method for GPU"""
from vllm_omni.quantization.int8_config import Int8OnlineLinearMethod
@@ -109,16 +111,18 @@ def test_get_quant_method(mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch
def _fake_init(self, quant_config):
pass
- layer = mocker.Mock(spec=LinearBase)
+ layer = MagicMock(spec=LinearBase)
mocker.patch.object(Int8OnlineLinearMethod, "__init__", _fake_init)
prefix = "test_layer"
# Mock the platform to be GPU
- monkeypatch.setattr(current_omni_platform, "is_cuda", lambda: True)
- monkeypatch.setattr(current_omni_platform, "is_npu", lambda: False)
- method = config.get_quant_method(layer, prefix)
- assert isinstance(method, Int8OnlineLinearMethod)
+ with (
+ patch("vllm_omni.platforms.current_omni_platform.is_cuda", return_value=True),
+ patch("vllm_omni.platforms.current_omni_platform.is_npu", return_value=False),
+ ):
+ method = config.get_quant_method(layer, prefix)
+ assert isinstance(method, Int8OnlineLinearMethod)
# Test skipping quantization for a layer
config.ignored_layers = [prefix]
@@ -126,20 +130,22 @@ def _fake_init(self, quant_config):
assert isinstance(method, UnquantizedLinearMethod)
-def test_get_npu_quant_method(mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch):
+def test_get_npu_quant_method():
"""Test for get_quant_method method for NPU"""
from vllm_omni.quantization.int8_config import NPUInt8OnlineLinearMethod
config = build_quant_config("int8")
- layer = mocker.Mock(spec=LinearBase)
+ layer = MagicMock(spec=LinearBase)
prefix = "test_layer"
# Mock the platform to be NPU
- monkeypatch.setattr(current_omni_platform, "is_cuda", lambda: False)
- monkeypatch.setattr(current_omni_platform, "is_npu", lambda: True)
- method = config.get_quant_method(layer, prefix)
- assert isinstance(method, NPUInt8OnlineLinearMethod)
+ with (
+ patch("vllm_omni.platforms.current_omni_platform.is_cuda", return_value=False),
+ patch("vllm_omni.platforms.current_omni_platform.is_npu", return_value=True),
+ ):
+ method = config.get_quant_method(layer, prefix)
+ assert isinstance(method, NPUInt8OnlineLinearMethod)
# Test skipping quantization for a layer
config.ignored_layers = [prefix]
@@ -239,7 +245,7 @@ class TestNPUInt8LinearMethod:
@pytest.fixture
def mock_torch_npu(self, mocker):
- torch_npu = mocker.MagicMock()
+ torch_npu = MagicMock()
mocker.patch("vllm_omni.quantization.int8_config.torch_npu", return_value=torch_npu)
mocker.patch(
diff --git a/tests/diffusion/quantization/test_quantization_quality.py b/tests/diffusion/quantization/test_quantization_quality.py
index ba6a150c4bb..3d8f1873698 100644
--- a/tests/diffusion/quantization/test_quantization_quality.py
+++ b/tests/diffusion/quantization/test_quantization_quality.py
@@ -32,7 +32,7 @@
import pytest
import torch
-from tests.helpers.mark import hardware_marks
+from tests.utils import hardware_marks
# ---------------------------------------------------------------------------
# Configuration — add new quantization methods / models here
diff --git a/tests/diffusion/test_diffusers_adapter.py b/tests/diffusion/test_diffusers_adapter.py
deleted file mode 100644
index ac2ec2e3fef..00000000000
--- a/tests/diffusion/test_diffusers_adapter.py
+++ /dev/null
@@ -1,186 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from collections import namedtuple
-from types import SimpleNamespace
-
-import pytest
-import torch
-from diffusers import DiffusionPipeline
-from PIL import Image
-
-from vllm_omni.diffusion.data import (
- DiffusionOutput,
- DiffusionParallelConfig,
- OmniDiffusionConfig,
-)
-from vllm_omni.diffusion.models.diffusers_adapter import DiffusersAdapterPipeline
-from vllm_omni.diffusion.request import OmniDiffusionRequest
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]
-
-
-def _make_od_config(**overrides) -> OmniDiffusionConfig:
- od_config = OmniDiffusionConfig(
- model="test/model",
- model_class_name="DiffusersAdapterPipeline",
- dtype=torch.float16,
- diffusion_load_format="diffusers",
- diffusers_load_kwargs={},
- diffusers_call_kwargs={},
- output_type="pil",
- parallel_config=DiffusionParallelConfig(cfg_parallel_size=1, sequence_parallel_size=1),
- cache_backend="none",
- )
- for key, value in overrides.items():
- setattr(od_config, key, value)
- return od_config
-
-
-def _make_request(**overrides) -> OmniDiffusionRequest:
- prompt = overrides.pop("prompt", "a test prompt")
- negative_prompt = overrides.pop("negative_prompt", None)
- prompt_obj: dict[str, str] = {"prompt": prompt}
- if negative_prompt is not None:
- prompt_obj["negative_prompt"] = negative_prompt
-
- defaults = {
- "prompts": [prompt_obj],
- "sampling_params": OmniDiffusionSamplingParams(
- num_inference_steps=20,
- guidance_scale=7.5,
- height=16,
- width=16,
- num_frames=1,
- num_outputs_per_prompt=1,
- seed=42,
- output_type="pil",
- generator_device="cpu",
- ),
- }
- defaults.update(overrides)
- return OmniDiffusionRequest(**defaults)
-
-
-class TestDiffusersAdapterPipeline:
- def test_adapter_forward_returns_output(self, mocker):
- od_config = _make_od_config()
- request = _make_request()
- stub_image = Image.new("RGB", (request.sampling_params.width, request.sampling_params.height)) # pyright: ignore[reportArgumentType]
-
- adapter = DiffusersAdapterPipeline(od_config=od_config)
- MockPipelineOutput = namedtuple("MockPipelineOutput", ["image"])
- MockPipeline = type("MockPipeline", (DiffusionPipeline,), {})
- adapter._pipeline = MockPipeline()
-
- mocker.patch.object(
- MockPipeline,
- "__call__",
- return_value=MockPipelineOutput(image=stub_image),
- )
- output = adapter.forward(request)
-
- assert isinstance(output, DiffusionOutput)
- assert isinstance(output.output, MockPipelineOutput)
- assert output.output.image is stub_image
-
- @pytest.mark.parametrize(
- "feature_id",
- ["cfg_parallel", "ulysses", "ring", "teacache", "cache_dit", "enforce_eager", "quantization"],
- )
- def test_adapter_guard_unsupported_feature(self, feature_id):
- if feature_id == "cfg_parallel":
- od_config = _make_od_config(
- parallel_config=DiffusionParallelConfig(cfg_parallel_size=2, sequence_parallel_size=1),
- cache_backend="none",
- )
- elif feature_id == "ulysses":
- od_config = _make_od_config(
- parallel_config=DiffusionParallelConfig(cfg_parallel_size=1, ulysses_degree=2),
- cache_backend="none",
- )
- elif feature_id == "ring":
- od_config = _make_od_config(
- parallel_config=DiffusionParallelConfig(cfg_parallel_size=1, ring_degree=2),
- cache_backend="none",
- )
- elif feature_id == "teacache":
- od_config = _make_od_config(
- parallel_config=DiffusionParallelConfig(cfg_parallel_size=1, sequence_parallel_size=1),
- cache_backend="tea_cache",
- )
- elif feature_id == "cache_dit":
- od_config = _make_od_config(
- parallel_config=DiffusionParallelConfig(cfg_parallel_size=1, sequence_parallel_size=1),
- cache_backend="cache_dit",
- )
- elif feature_id == "enforce_eager":
- od_config = _make_od_config(enforce_eager=True)
- elif feature_id == "quantization":
- od_config = _make_od_config(quantization_config=SimpleNamespace(quant_method="fp8"))
- else:
- raise ValueError(f"Unknown feature ID: {feature_id}")
-
- with pytest.raises(NotImplementedError):
- DiffusersAdapterPipeline(od_config=od_config)
-
- def test_adapter_guard_unknown_output_type(self, mocker):
- """Test that the adapter wraps an unknown output type as-is.
- This is useful when `return_dict=True` and the diffusers pipeline returns an OrderedDict subclass."""
-
- adapter = DiffusersAdapterPipeline(od_config=_make_od_config())
- raw_output = {"unexpected": "dict-output"}
-
- MockPipeline = type("MockPipeline", (DiffusionPipeline,), {})
- adapter._pipeline = MockPipeline()
-
- mocker.patch.object(
- MockPipeline,
- "__call__",
- return_value=raw_output,
- )
- output = adapter.forward(_make_request())
-
- assert isinstance(output, DiffusionOutput)
- assert output.output == raw_output
-
- def test_adapter_build_call_kwargs(self):
- adapter = DiffusersAdapterPipeline(
- od_config=_make_od_config(
- diffusers_call_kwargs={
- "guidance_scale": 1.25,
- "eta": 0.3,
- "output_type": "np",
- }
- )
- )
- req = _make_request(
- prompt="a cat on mars",
- negative_prompt="low quality",
- sampling_params=OmniDiffusionSamplingParams(
- num_inference_steps=9,
- guidance_scale=8.0,
- height=320,
- width=640,
- num_frames=8,
- num_outputs_per_prompt=2,
- seed=123,
- output_type="pil",
- ),
- )
-
- kwargs = adapter._build_call_kwargs(req)
-
- assert kwargs["prompt"] == "a cat on mars"
- assert kwargs["negative_prompt"] == "low quality"
- assert kwargs["num_inference_steps"] == 9
- assert kwargs["guidance_scale"] == 8.0
- assert kwargs["height"] == 320
- assert kwargs["width"] == 640
- assert kwargs["num_frames"] == 8
- assert kwargs["num_images_per_prompt"] == 2
- assert kwargs["output_type"] == "pil"
- assert isinstance(kwargs["generator"], torch.Generator)
- assert kwargs["generator"].device.type == "cpu"
- assert kwargs["generator"].initial_seed() == 123
diff --git a/tests/diffusion/test_diffusion_model_runner.py b/tests/diffusion/test_diffusion_model_runner.py
index b63f6d8887f..88b17147e85 100644
--- a/tests/diffusion/test_diffusion_model_runner.py
+++ b/tests/diffusion/test_diffusion_model_runner.py
@@ -8,10 +8,9 @@
import torch
import vllm_omni.diffusion.worker.diffusion_model_runner as model_runner_module
-from tests.helpers.mark import hardware_test
from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner
-pytestmark = [pytest.mark.diffusion]
+pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]
@contextmanager
@@ -65,8 +64,6 @@ def _make_runner(cache_backend, cache_backend_name: str, enable_cache_dit_summar
return runner
-@pytest.mark.core_model
-@hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_execute_model_skips_cache_summary_without_active_cache_backend(monkeypatch):
"""Guard cache diagnostics with runtime backend state to avoid stale-config crashes."""
runner = _make_runner(cache_backend=None, cache_backend_name="cache_dit")
@@ -87,8 +84,6 @@ def test_execute_model_skips_cache_summary_without_active_cache_backend(monkeypa
assert cache_summary_calls == []
-@pytest.mark.core_model
-@hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_execute_model_emits_cache_summary_with_active_cache_dit_backend(monkeypatch):
class _EnabledCacheBackend:
def is_enabled(self):
@@ -112,8 +107,6 @@ def is_enabled(self):
assert cache_summary_calls == [(runner.pipeline, True)]
-@pytest.mark.core_model
-@pytest.mark.cpu
def test_load_model_clears_cache_backend_for_unsupported_pipeline(monkeypatch):
class _DummyLoader:
def __init__(self, load_config, od_config=None):
diff --git a/tests/diffusion/test_diffusion_scheduler.py b/tests/diffusion/test_diffusion_scheduler.py
index a64d9920e03..4324ba1e630 100644
--- a/tests/diffusion/test_diffusion_scheduler.py
+++ b/tests/diffusion/test_diffusion_scheduler.py
@@ -4,10 +4,10 @@
import queue
import threading
from types import SimpleNamespace
+from unittest.mock import Mock, patch
import pytest
import torch
-from pytest_mock import MockerFixture
from vllm_omni.diffusion.data import DiffusionOutput, DiffusionRequestAbortedError
from vllm_omni.diffusion.diffusion_engine import DiffusionEngine
@@ -97,19 +97,19 @@ def initialize(self, od_config) -> None:
def add_request(self, request: OmniDiffusionRequest) -> str:
assert request is self._request
- self._state = SimpleNamespace(sched_req_id=self._sched_req_id, req=request)
+ self._state = Mock(sched_req_id=self._sched_req_id, req=request)
return self._sched_req_id
def schedule(self):
if self._scheduled or self._state is None:
- return SimpleNamespace(
+ return Mock(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
scheduled_req_ids=[],
is_empty=True,
)
self._scheduled = True
- return SimpleNamespace(
+ return Mock(
scheduled_new_reqs=[NewRequestData.from_state(self._state)],
scheduled_cached_reqs=CachedRequestData.make_empty(),
scheduled_req_ids=[self._state.sched_req_id],
@@ -153,7 +153,7 @@ def close(self) -> None:
class TestRequestScheduler:
def setup_method(self) -> None:
self.scheduler: RequestScheduler = RequestScheduler()
- self.scheduler.initialize(SimpleNamespace())
+ self.scheduler.initialize(Mock())
def test_single_request_success_lifecycle(self) -> None:
req_id = self.scheduler.add_request(_make_request("a"))
@@ -276,23 +276,23 @@ def test_request_id_mapping_lifecycle(self) -> None:
class TestDiffusionEngine:
- def test_add_req_and_wait_for_response_single_path(self, mocker: MockerFixture) -> None:
+ def test_add_req_and_wait_for_response_single_path(self) -> None:
engine = DiffusionEngine.__new__(DiffusionEngine)
engine.scheduler = RequestScheduler()
- engine.scheduler.initialize(SimpleNamespace())
+ engine.scheduler.initialize(Mock())
engine._rpc_lock = threading.RLock()
engine.abort_queue = queue.Queue()
request = _make_request("engine")
runner_output = _make_request_output("engine")
- engine.execute_fn = mocker.Mock(return_value=runner_output)
+ engine.execute_fn = Mock(return_value=runner_output)
output = engine.add_req_and_wait_for_response(request)
assert output is runner_output.result
engine.execute_fn.assert_called_once()
- def test_supports_scheduler_interface_injection(self, mocker: MockerFixture) -> None:
+ def test_supports_scheduler_interface_injection(self) -> None:
request = _make_request("engine_iface")
runner_output = _make_request_output("engine_iface")
scheduler = _StubScheduler(request, runner_output)
@@ -301,45 +301,33 @@ def test_supports_scheduler_interface_injection(self, mocker: MockerFixture) ->
engine.scheduler = scheduler
engine._rpc_lock = threading.RLock()
engine.abort_queue = queue.Queue()
- engine.execute_fn = mocker.Mock(return_value=runner_output)
+ engine.execute_fn = Mock(return_value=runner_output)
output = engine.add_req_and_wait_for_response(request)
assert output is runner_output.result
engine.execute_fn.assert_called_once()
- def test_initializes_injected_scheduler(
- self,
- monkeypatch: pytest.MonkeyPatch,
- mocker: MockerFixture,
- ) -> None:
+ def test_initializes_injected_scheduler(self) -> None:
request = _make_request("init")
scheduler = _StubScheduler(request, DiffusionOutput(output=None))
- od_config = SimpleNamespace(model_class_name="mock_model")
- fake_executor_cls = mocker.Mock(return_value=mocker.Mock())
+ od_config = Mock(model_class_name="mock_model")
+ fake_executor_cls = Mock(return_value=Mock())
- monkeypatch.setattr(
- "vllm_omni.diffusion.diffusion_engine.get_diffusion_post_process_func",
- lambda *args, **kwargs: None,
- )
- monkeypatch.setattr(
- "vllm_omni.diffusion.diffusion_engine.get_diffusion_pre_process_func",
- lambda *args, **kwargs: None,
- )
- monkeypatch.setattr(
- "vllm_omni.diffusion.diffusion_engine.DiffusionExecutor.get_class",
- lambda *args, **kwargs: fake_executor_cls,
- )
- monkeypatch.setattr(DiffusionEngine, "_dummy_run", lambda self: None)
-
- DiffusionEngine(od_config, scheduler=scheduler)
+ with (
+ patch("vllm_omni.diffusion.diffusion_engine.get_diffusion_post_process_func", return_value=None),
+ patch("vllm_omni.diffusion.diffusion_engine.get_diffusion_pre_process_func", return_value=None),
+ patch("vllm_omni.diffusion.diffusion_engine.DiffusionExecutor.get_class", return_value=fake_executor_cls),
+ patch.object(DiffusionEngine, "_dummy_run", return_value=None),
+ ):
+ DiffusionEngine(od_config, scheduler=scheduler)
assert scheduler.initialized_with is od_config
fake_executor_cls.assert_called_once_with(od_config)
def test_scheduler_alias_keeps_default_request_scheduler(self) -> None:
scheduler = Scheduler()
- scheduler.initialize(SimpleNamespace())
+ scheduler.initialize(Mock())
req_id = scheduler.add_request(_make_request("alias"))
sched_output = scheduler.schedule()
@@ -348,10 +336,10 @@ def test_scheduler_alias_keeps_default_request_scheduler(self) -> None:
assert req_id in finished
assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED
- def test_step_raises_aborted_error(self, mocker: MockerFixture) -> None:
+ def test_step_raises_aborted_error(self) -> None:
engine = DiffusionEngine.__new__(DiffusionEngine)
engine.pre_process_func = None
- engine.add_req_and_wait_for_response = mocker.Mock(
+ engine.add_req_and_wait_for_response = Mock(
return_value=DiffusionOutput(aborted=True, abort_message="Request req-abort aborted.")
)
@@ -361,7 +349,7 @@ def test_step_raises_aborted_error(self, mocker: MockerFixture) -> None:
def test_abort_queue_marks_request_finished_aborted(self) -> None:
engine = DiffusionEngine.__new__(DiffusionEngine)
engine.scheduler = RequestScheduler()
- engine.scheduler.initialize(SimpleNamespace())
+ engine.scheduler.initialize(Mock())
engine.abort_queue = queue.Queue()
req_id = engine.scheduler.add_request(_make_request("req-abort"))
@@ -373,7 +361,7 @@ def test_abort_queue_marks_request_finished_aborted(self) -> None:
def test_finalize_finished_request_returns_aborted_output(self) -> None:
engine = DiffusionEngine.__new__(DiffusionEngine)
engine.scheduler = RequestScheduler()
- engine.scheduler.initialize(SimpleNamespace())
+ engine.scheduler.initialize(Mock())
req_id = engine.scheduler.add_request(_make_request("req-finalize"))
engine.scheduler.finish_requests(req_id, DiffusionRequestStatus.FINISHED_ABORTED)
@@ -383,40 +371,29 @@ def test_finalize_finished_request_returns_aborted_output(self) -> None:
assert output.aborted is True
assert output.abort_message == "Request req-finalize aborted."
- def test_initializes_step_scheduler_when_step_execution_enabled(
- self,
- monkeypatch: pytest.MonkeyPatch,
- mocker: MockerFixture,
- ) -> None:
- od_config = SimpleNamespace(model_class_name="mock_model")
+ def test_initializes_step_scheduler_when_step_execution_enabled(self) -> None:
+ od_config = Mock(model_class_name="mock_model")
od_config.step_execution = True
- fake_executor = mocker.Mock()
- fake_executor_cls = mocker.Mock(return_value=fake_executor)
+ fake_executor = Mock()
+ fake_executor_cls = Mock(return_value=fake_executor)
- monkeypatch.setattr(
- "vllm_omni.diffusion.diffusion_engine.get_diffusion_post_process_func",
- lambda *args, **kwargs: None,
- )
- monkeypatch.setattr(
- "vllm_omni.diffusion.diffusion_engine.get_diffusion_pre_process_func",
- lambda *args, **kwargs: None,
- )
- monkeypatch.setattr(
- "vllm_omni.diffusion.diffusion_engine.DiffusionExecutor.get_class",
- lambda *args, **kwargs: fake_executor_cls,
- )
- monkeypatch.setattr(DiffusionEngine, "_dummy_run", lambda self: None)
- engine = DiffusionEngine(od_config)
+ with (
+ patch("vllm_omni.diffusion.diffusion_engine.get_diffusion_post_process_func", return_value=None),
+ patch("vllm_omni.diffusion.diffusion_engine.get_diffusion_pre_process_func", return_value=None),
+ patch("vllm_omni.diffusion.diffusion_engine.DiffusionExecutor.get_class", return_value=fake_executor_cls),
+ patch.object(DiffusionEngine, "_dummy_run", return_value=None),
+ ):
+ engine = DiffusionEngine(od_config)
assert isinstance(engine.scheduler, StepScheduler)
assert engine.execute_fn is fake_executor.execute_step
fake_executor_cls.assert_called_once_with(od_config)
- def test_dummy_run_raises_on_output_error(self, mocker: MockerFixture) -> None:
+ def test_dummy_run_raises_on_output_error(self) -> None:
engine = DiffusionEngine.__new__(DiffusionEngine)
- engine.od_config = SimpleNamespace(model_class_name="mock_model")
+ engine.od_config = Mock(model_class_name="mock_model")
engine.pre_process_func = None
- engine.add_req_and_wait_for_response = mocker.Mock(return_value=DiffusionOutput(error="boom"))
+ engine.add_req_and_wait_for_response = Mock(return_value=DiffusionOutput(error="boom"))
with pytest.raises(RuntimeError, match="Dummy run failed: boom"):
engine._dummy_run()
@@ -425,7 +402,7 @@ def test_dummy_run_raises_on_output_error(self, mocker: MockerFixture) -> None:
class TestStepScheduler:
def setup_method(self) -> None:
self.scheduler: StepScheduler = StepScheduler()
- self.scheduler.initialize(SimpleNamespace())
+ self.scheduler.initialize(Mock())
def test_single_request_step_lifecycle(self) -> None:
request = _make_step_request("step", num_inference_steps=3)
diff --git a/tests/diffusion/test_diffusion_step_pipeline.py b/tests/diffusion/test_diffusion_step_pipeline.py
index 06f8cd14dc8..68aba9ba3bf 100644
--- a/tests/diffusion/test_diffusion_step_pipeline.py
+++ b/tests/diffusion/test_diffusion_step_pipeline.py
@@ -7,13 +7,13 @@
import threading
from contextlib import contextmanager
from types import SimpleNamespace
+from unittest.mock import Mock
import pytest
import torch
-from pytest_mock import MockerFixture
import vllm_omni.diffusion.worker.diffusion_model_runner as model_runner_module
-from tests.helpers.mark import hardware_test
+from tests.utils import hardware_test
from vllm_omni.diffusion.data import DiffusionOutput
from vllm_omni.diffusion.diffusion_engine import DiffusionEngine
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
@@ -542,11 +542,11 @@ def test_rejects_lora_requests_in_step_mode(self):
class TestExecutor:
"""MultiprocDiffusionExecutor.execute_step"""
- def test_execute_step_passes_through_runner_output(self, mocker: MockerFixture):
+ def test_execute_step_passes_through_runner_output(self):
executor = object.__new__(MultiprocDiffusionExecutor)
executor._ensure_open = lambda: None
expected = RunnerOutput(req_id="req-step", step_index=1, finished=False, result=None)
- executor.collective_rpc = mocker.Mock(return_value=expected)
+ executor.collective_rpc = Mock(return_value=expected)
request = _make_engine_request("req-step", num_inference_steps=2)
scheduler_output = _make_scheduler_output(request, sched_req_id="req-step")
@@ -578,9 +578,9 @@ class TestEngine:
),
],
)
- def test_step_engine_returns_error(self, execute_fn, expected_error, mocker: MockerFixture):
+ def test_step_engine_returns_error(self, execute_fn, expected_error):
scheduler = StepScheduler()
- scheduler.initialize(mocker.Mock())
+ scheduler.initialize(Mock())
engine = _make_engine(scheduler, execute_fn=execute_fn)
output = engine.add_req_and_wait_for_response(_make_engine_request("req-error", num_inference_steps=2))
@@ -588,9 +588,9 @@ def test_step_engine_returns_error(self, execute_fn, expected_error, mocker: Moc
assert output.output is None
assert expected_error in output.error
- def test_step_execution_completes(self, mocker: MockerFixture):
+ def test_step_execution_completes(self):
scheduler = StepScheduler()
- scheduler.initialize(mocker.Mock())
+ scheduler.initialize(Mock())
engine = _make_engine(scheduler)
request = _make_engine_request("req-step", num_inference_steps=2)
@@ -614,9 +614,9 @@ def execute_fn(_):
assert output.error is None
assert torch.equal(output.output, torch.tensor([2.0]))
- def test_step_abort_stops_rescheduling_after_first_step(self, mocker: MockerFixture):
+ def test_step_abort_stops_rescheduling_after_first_step(self):
scheduler = StepScheduler()
- scheduler.initialize(mocker.Mock())
+ scheduler.initialize(Mock())
engine = _make_engine(scheduler)
request = _make_engine_request("req-stop", num_inference_steps=4)
@@ -639,9 +639,9 @@ def execute_fn(_):
assert step["n"] == 1
_assert_aborted_output(output, "req-stop")
- def test_step_abort_after_reschedule_returns_aborted_output(self, mocker: MockerFixture):
+ def test_step_abort_after_reschedule_returns_aborted_output(self):
scheduler = StepScheduler()
- scheduler.initialize(mocker.Mock())
+ scheduler.initialize(Mock())
engine = _make_engine(scheduler)
request = _make_engine_request("req-mid", num_inference_steps=4)
@@ -666,9 +666,9 @@ def execute_fn(sched_output):
assert step["n"] == 2
_assert_aborted_output(output, "req-mid")
- def test_finished_step_without_result_returns_error(self, mocker: MockerFixture):
+ def test_finished_step_without_result_returns_error(self):
scheduler = StepScheduler()
- scheduler.initialize(mocker.Mock())
+ scheduler.initialize(Mock())
engine = _make_engine(
scheduler,
execute_fn=lambda _: RunnerOutput(
diff --git a/tests/diffusion/test_diffusion_worker.py b/tests/diffusion/test_diffusion_worker.py
index fc08c5f7f03..e2bd7ef8a32 100644
--- a/tests/diffusion/test_diffusion_worker.py
+++ b/tests/diffusion/test_diffusion_worker.py
@@ -16,7 +16,7 @@
from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker
-pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.gpu]
+pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]
@pytest.fixture
@@ -81,31 +81,17 @@ def test_load_weights_empty_iterable(self, mocker: MockerFixture, mock_gpu_worke
class TestDiffusionWorkerSleep:
"""Test DiffusionWorker.sleep method."""
- @pytest.fixture(autouse=True)
- def setup_allocator(self, mocker: MockerFixture):
- """
- Unified interception of Allocators, and provision of default security values.
- """
- self.mock_allocator_class = mocker.patch("vllm.device_allocator.cumem.CuMemAllocator")
- self.mock_allocator = mocker.Mock()
- self.mock_allocator_class.get_instance.return_value = self.mock_allocator
- self.mock_allocator.get_current_usage.return_value = 4 * 1024**3
- self.mock_allocator.sleep = mocker.Mock()
-
def test_sleep_level_1(self, mocker: MockerFixture, mock_gpu_worker):
"""Test sleep mode level 1 (offload weights only)."""
mock_allocator_class = mocker.patch("vllm.device_allocator.cumem.CuMemAllocator")
- mock_platform = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform")
- mock_platform.get_free_memory.side_effect = [10 * 1024**3, 12 * 1024**3]
- mock_platform.get_device_total_memory.return_value = 80 * 1024**3
+ mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform")
mock_get_process_memory = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.get_process_gpu_memory")
# Setup process-scoped memory mocks
# Before sleep: 3GB used
# After sleep: 1GB used (freed 2GB)
- initial_usage = 3 * 1024**3
mock_get_process_memory.side_effect = [
- initial_usage,
+ 3 * 1024**3,
1 * 1024**3,
]
@@ -113,29 +99,25 @@ def test_sleep_level_1(self, mocker: MockerFixture, mock_gpu_worker):
mock_allocator = mocker.Mock()
mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator)
mock_allocator.sleep = mocker.Mock()
- mock_allocator.get_current_usage.return_value = initial_usage
# Call sleep with level 1
result = mock_gpu_worker.sleep(level=1)
# Verify sleep was called with correct tags
mock_allocator.sleep.assert_called_once_with(offload_tags=("weights",))
- assert bool(result) is True
+ assert result is True
# Verify buffers were NOT saved (level 1 doesn't save buffers)
assert len(mock_gpu_worker._sleep_saved_buffers) == 0
def test_sleep_level_2(self, mocker: MockerFixture, mock_gpu_worker):
"""Test sleep mode level 2 (offload all, save buffers)."""
mock_allocator_class = mocker.patch("vllm.device_allocator.cumem.CuMemAllocator")
- mock_platform = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform")
- mock_platform.get_free_memory.side_effect = [5 * 1024**3, 10 * 1024**3]
- mock_platform.get_device_total_memory.return_value = 80 * 1024**3
+ mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform")
mock_get_process_memory = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.get_process_gpu_memory")
# Setup process-scoped memory mocks
- initial_usage = 5 * 1024**3
mock_get_process_memory.side_effect = [
- initial_usage, # Before sleep
+ 5 * 1024**3, # Before sleep
1 * 1024**3, # After sleep (freed 4GB)
]
@@ -143,7 +125,6 @@ def test_sleep_level_2(self, mocker: MockerFixture, mock_gpu_worker):
mock_allocator = mocker.Mock()
mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator)
mock_allocator.sleep = mocker.Mock()
- mock_allocator.get_current_usage.return_value = initial_usage
# Mock pipeline buffers
mock_buffer1 = torch.randn(10, 10)
@@ -160,7 +141,7 @@ def test_sleep_level_2(self, mocker: MockerFixture, mock_gpu_worker):
# Verify sleep was called with empty tags (offload all)
mock_allocator.sleep.assert_called_once_with(offload_tags=tuple())
- assert bool(result) is True
+ assert result is True
# Verify buffers were saved
assert len(mock_gpu_worker._sleep_saved_buffers) == 2
@@ -170,26 +151,22 @@ def test_sleep_level_2(self, mocker: MockerFixture, mock_gpu_worker):
def test_sleep_memory_freed_validation(self, mocker: MockerFixture, mock_gpu_worker):
"""Test that sleep validates memory was actually freed."""
mock_allocator_class = mocker.patch("vllm.device_allocator.cumem.CuMemAllocator")
- mock_platform = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform")
- mock_platform.get_free_memory.return_value = 10 * 1024**3
- mock_platform.get_device_total_memory.return_value = 80 * 1024**3
+ mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform")
mock_get_process_memory = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.get_process_gpu_memory")
# Simulate process memory increase (should trigger assertion error)
- initial_usage = 1 * 1024**3
mock_get_process_memory.side_effect = [
- initial_usage, # Before sleep: 1GB used
+ 1 * 1024**3, # Before sleep: 1GB used
3 * 1024**3, # After sleep: 3GB used (negative freed)
]
mock_allocator = mocker.Mock()
mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator)
mock_allocator.sleep = mocker.Mock()
- mock_allocator.get_current_usage.return_value = initial_usage
# This should raise an assertion error
- result = mock_gpu_worker.sleep(level=1)
- assert result == initial_usage
+ with pytest.raises(AssertionError, match="Memory usage increased after sleeping"):
+ mock_gpu_worker.sleep(level=1)
def test_sleep_falls_back_to_device_memory_when_nvml_unavailable(self, mocker: MockerFixture, mock_gpu_worker):
"""Test sleep uses device-scoped fallback when NVML is unavailable."""
@@ -207,12 +184,11 @@ def test_sleep_falls_back_to_device_memory_when_nvml_unavailable(self, mocker: M
mock_allocator = mocker.Mock()
mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator)
mock_allocator.sleep = mocker.Mock()
- mock_allocator.get_current_usage.return_value = 2 * 1024**3
result = mock_gpu_worker.sleep(level=1)
mock_allocator.sleep.assert_called_once_with(offload_tags=("weights",))
- assert bool(result) is True
+ assert result is True
class TestDiffusionWorkerWakeUp:
@@ -226,7 +202,6 @@ def test_wake_up_without_buffers(self, mocker: MockerFixture, mock_gpu_worker):
mock_allocator = mocker.Mock()
mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator)
mock_allocator.wake_up = mocker.Mock()
- mock_allocator.get_current_usage.return_value = 10 * 1024**3
# Ensure no saved buffers
mock_gpu_worker._sleep_saved_buffers = {}
@@ -236,7 +211,7 @@ def test_wake_up_without_buffers(self, mocker: MockerFixture, mock_gpu_worker):
# Verify allocator.wake_up was called
mock_allocator.wake_up.assert_called_once_with(["weights"])
- assert bool(result) is True
+ assert result is True
def test_wake_up_with_buffers(self, mocker: MockerFixture, mock_gpu_worker):
"""Test wake_up with saved buffers (level 2 sleep)."""
@@ -246,7 +221,6 @@ def test_wake_up_with_buffers(self, mocker: MockerFixture, mock_gpu_worker):
mock_allocator = mocker.Mock()
mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator)
mock_allocator.wake_up = mocker.Mock()
- mock_allocator.get_current_usage.return_value = 10 * 1024**3
# Create saved buffers
saved_buffer1 = torch.randn(10, 10)
@@ -281,7 +255,7 @@ def test_wake_up_with_buffers(self, mocker: MockerFixture, mock_gpu_worker):
# Verify saved buffers were cleared
assert len(mock_gpu_worker._sleep_saved_buffers) == 0
- assert bool(result) is True
+ assert result is True
def test_wake_up_partial_buffer_restore(self, mocker: MockerFixture, mock_gpu_worker):
"""Test wake_up only restores buffers that were saved."""
@@ -291,7 +265,6 @@ def test_wake_up_partial_buffer_restore(self, mocker: MockerFixture, mock_gpu_wo
mock_allocator = mocker.Mock()
mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator)
mock_allocator.wake_up = mocker.Mock()
- mock_allocator.get_current_usage.return_value = 10 * 1024**3
# Only save buffer1, not buffer2
saved_buffer1 = torch.randn(10, 10)
@@ -320,4 +293,4 @@ def test_wake_up_partial_buffer_restore(self, mocker: MockerFixture, mock_gpu_wo
# buffer2 should NOT be restored since it wasn't saved
mock_buffer2.data.copy_.assert_not_called()
- assert bool(result) is True
+ assert result is True
diff --git a/tests/diffusion/test_diffusion_worker_cuda_profiler.py b/tests/diffusion/test_diffusion_worker_cuda_profiler.py
deleted file mode 100644
index 4a3b22c212e..00000000000
--- a/tests/diffusion/test_diffusion_worker_cuda_profiler.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import pytest
-from pytest_mock import MockerFixture
-
-from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker
-
-pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]
-
-
-@pytest.fixture
-def mock_od_config(mocker: MockerFixture):
- """Create a mock OmniDiffusionConfig with a CUDA profiler backend."""
- config = mocker.Mock()
- config.profiler_config = mocker.Mock()
- config.profiler_config.profiler = "cuda"
- config.diffusion_load_format = "default"
- return config
-
-
-@pytest.fixture
-def mock_diffusion_worker_dependencies(mocker: MockerFixture):
- """Patch heavy worker dependencies for focused profiler tests."""
- mocker.patch.object(DiffusionWorker, "init_device")
- mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.DiffusionModelRunner")
-
-
-class TestDiffusionWorkerCudaProfiler:
- def test_creates_cuda_profiler_wrapper(
- self,
- mocker: MockerFixture,
- mock_od_config,
- mock_diffusion_worker_dependencies,
- ):
- fake_profiler = mocker.Mock()
- cuda_profiler = mocker.patch(
- "vllm_omni.diffusion.worker.diffusion_worker.CudaProfilerWrapper",
- return_value=fake_profiler,
- )
- create_omni_profiler = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.create_omni_profiler")
-
- worker = DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config, skip_load_model=True)
-
- cuda_profiler.assert_called_once_with(mock_od_config.profiler_config)
- create_omni_profiler.assert_not_called()
- assert worker.profiler is fake_profiler
-
- def test_profile_start_stop_delegates_to_cuda_profiler(
- self,
- mocker: MockerFixture,
- mock_od_config,
- mock_diffusion_worker_dependencies,
- ):
- fake_profiler = mocker.Mock()
- fake_profiler.start = mocker.Mock()
- fake_profiler.stop = mocker.Mock()
- mocker.patch(
- "vllm_omni.diffusion.worker.diffusion_worker.CudaProfilerWrapper",
- return_value=fake_profiler,
- )
-
- worker = DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config, skip_load_model=True)
-
- assert worker.profile(is_start=True) is None
- assert worker.profile(is_start=False) is None
-
- fake_profiler.start.assert_called_once_with()
- fake_profiler.stop.assert_called_once_with()
-
- def test_returns_none_when_profiler_config_is_missing(
- self,
- mocker: MockerFixture,
- mock_od_config,
- mock_diffusion_worker_dependencies,
- ):
- mock_od_config.profiler_config = None
- cuda_profiler = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.CudaProfilerWrapper")
- create_omni_profiler = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.create_omni_profiler")
-
- worker = DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config, skip_load_model=True)
-
- cuda_profiler.assert_not_called()
- create_omni_profiler.assert_not_called()
- assert worker.profiler is None
-
- def test_cuda_backend_does_not_use_torch_profiler_factory(
- self,
- mocker: MockerFixture,
- mock_od_config,
- mock_diffusion_worker_dependencies,
- ):
- mocker.patch(
- "vllm_omni.diffusion.worker.diffusion_worker.CudaProfilerWrapper",
- return_value=mocker.Mock(),
- )
- create_omni_profiler = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.create_omni_profiler")
-
- DiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config, skip_load_model=True)
-
- create_omni_profiler.assert_not_called()
diff --git a/tests/diffusion/test_inline_stage_diffusion_client.py b/tests/diffusion/test_inline_stage_diffusion_client.py
deleted file mode 100644
index 385f39b1240..00000000000
--- a/tests/diffusion/test_inline_stage_diffusion_client.py
+++ /dev/null
@@ -1,96 +0,0 @@
-from __future__ import annotations
-
-import asyncio
-from unittest.mock import MagicMock, patch
-
-import pytest
-
-from vllm_omni.diffusion.data import OmniDiffusionConfig
-from vllm_omni.diffusion.inline_stage_diffusion_client import InlineStageDiffusionClient
-from vllm_omni.engine.stage_init_utils import StageMetadata
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-from vllm_omni.outputs import OmniRequestOutput
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-@pytest.fixture
-def mock_engine():
- with patch("vllm_omni.diffusion.inline_stage_diffusion_client.DiffusionEngine") as mock:
- engine_instance = MagicMock()
- mock.make_engine.return_value = engine_instance
- yield engine_instance
-
-
-@pytest.fixture
-def client(mock_engine):
- metadata = StageMetadata(
- stage_id=0,
- stage_type="diffusion",
- engine_output_type="image",
- is_comprehension=False,
- requires_multimodal_data=False,
- engine_input_source="prompt",
- final_output=True,
- final_output_type="image",
- default_sampling_params={},
- custom_process_input_func=None,
- model_stage=None,
- runtime_cfg=None,
- )
- with patch.object(InlineStageDiffusionClient, "_enrich_config"):
- od_config = MagicMock(spec=OmniDiffusionConfig)
- c = InlineStageDiffusionClient(model="test_model", od_config=od_config, metadata=metadata, batch_size=1)
- yield c
- c.shutdown()
-
-
-@pytest.mark.asyncio
-async def test_inline_dispatch_request_success(client, mock_engine):
- # Setup mock engine step to return a successful result
- mock_result = OmniRequestOutput.from_diffusion(request_id="req-1", images=[MagicMock()])
- mock_engine.step.return_value = [mock_result]
-
- sampling_params = OmniDiffusionSamplingParams()
- await client.add_request_async("req-1", "A test prompt", sampling_params)
-
- # Wait for the task to be processed
- for _ in range(10):
- output = client.get_diffusion_output_nowait()
- if output is not None:
- break
- await asyncio.sleep(0.01)
-
- assert output is not None
- assert output.request_id == "req-1"
- mock_engine.step.assert_called_once()
-
-
-@pytest.mark.asyncio
-async def test_inline_dispatch_request_error(client, mock_engine):
- # Setup mock engine step to raise an exception
- mock_engine.step.side_effect = RuntimeError("Engine failure")
-
- sampling_params = OmniDiffusionSamplingParams()
- await client.add_request_async("req-err", "A test prompt", sampling_params)
-
- for _ in range(10):
- output = client.get_diffusion_output_nowait()
- if output is not None:
- break
- await asyncio.sleep(0.01)
-
- assert output is not None
- assert output.request_id == "req-err"
- assert output.error == "Engine failure"
- assert not output.images
-
-
-def test_inline_shutdown(client, mock_engine):
- assert not client._shutting_down
-
- # Shutting down should cleanly cancel anything queued and close engine
- client.shutdown()
-
- assert client._shutting_down
- mock_engine.close.assert_called_once()
diff --git a/tests/diffusion/test_multiproc_engine_concurrency.py b/tests/diffusion/test_multiproc_engine_concurrency.py
index 9ec06e8107d..517f98ddaa9 100644
--- a/tests/diffusion/test_multiproc_engine_concurrency.py
+++ b/tests/diffusion/test_multiproc_engine_concurrency.py
@@ -1,25 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import asyncio
-import multiprocessing as mp
import queue
import threading
-import time
-from types import SimpleNamespace
-from unittest.mock import MagicMock, Mock
+from unittest.mock import Mock, patch
import pytest
import torch
-import zmq
-from vllm.v1.engine.exceptions import EngineDeadError
from vllm_omni.diffusion.data import DiffusionOutput
from vllm_omni.diffusion.diffusion_engine import DiffusionEngine
from vllm_omni.diffusion.executor.multiproc_executor import MultiprocDiffusionExecutor
from vllm_omni.diffusion.sched import RequestScheduler
-from vllm_omni.diffusion.stage_diffusion_proc import StageDiffusionProc
-from vllm_omni.outputs import OmniRequestOutput
pytestmark = [pytest.mark.diffusion, pytest.mark.core_model, pytest.mark.cpu]
@@ -32,9 +24,11 @@ def _tagged_output(tag: str) -> DiffusionOutput:
return DiffusionOutput(output=torch.tensor([0]), error=tag)
-def _mock_request(tag: str):
- """Return a lightweight request object identifiable by *tag*."""
- return SimpleNamespace(request_ids=[tag])
+def _mock_request(tag: str) -> Mock:
+ """Return a mock ``OmniDiffusionRequest`` identifiable by *tag*."""
+ req = Mock()
+ req.request_ids = [tag]
+ return req
def _make_executor(num_gpus: int = 1):
@@ -42,25 +36,25 @@ def _make_executor(num_gpus: int = 1):
Returns ``(executor, request_queue, result_queue)``.
"""
- od_cfg = SimpleNamespace(num_gpus=num_gpus)
- monkeypatch = pytest.MonkeyPatch()
- monkeypatch.setattr(MultiprocDiffusionExecutor, "_init_executor", lambda self: None)
- executor = MultiprocDiffusionExecutor(od_cfg)
- monkeypatch.undo()
+ od_cfg = Mock()
+ od_cfg.num_gpus = num_gpus
+
+ with patch.object(MultiprocDiffusionExecutor, "_init_executor"):
+ executor = MultiprocDiffusionExecutor(od_cfg)
req_q: queue.Queue = queue.Queue()
res_q: queue.Queue = queue.Queue()
- mock_broadcast_mq = SimpleNamespace(enqueue=req_q.put)
+ mock_broadcast_mq = Mock()
+ mock_broadcast_mq.enqueue = req_q.put
- mock_rmq = SimpleNamespace(dequeue=lambda timeout=None: res_q.get(timeout=timeout if timeout is not None else 10))
+ mock_rmq = Mock()
+ mock_rmq.dequeue = lambda timeout=None: res_q.get(timeout=timeout if timeout is not None else 10)
executor._broadcast_mq = mock_broadcast_mq
executor._result_mq = mock_rmq
executor._closed = False
executor._processes = []
- executor.is_failed = False
- executor._failure_callbacks = []
return executor, req_q, res_q
@@ -69,7 +63,7 @@ def _make_engine(num_gpus: int = 1):
executor, req_q, res_q = _make_executor(num_gpus)
engine = DiffusionEngine.__new__(DiffusionEngine)
sched = RequestScheduler()
- sched.initialize(SimpleNamespace())
+ sched.initialize(Mock())
engine.scheduler = sched
engine.executor = executor
engine._rpc_lock = threading.RLock()
@@ -344,9 +338,8 @@ def test_collective_rpc_closed_executor_raises(self):
class TestCollectiveRpcTimeoutWhileLockHeld:
"""``collective_rpc(timeout=...)`` must honour its timeout even when
- another thread holds ``engine._rpc_lock`` indefinitely (e.g. request
- execution stalled on ``add_req_and_wait_for_response`` → ``execute_fn``
- → ``collective_rpc`` while blocked on an unresponsive worker).
+ another thread holds ``engine._rpc_lock`` indefinitely (e.g. a stalled
+ ``add_req`` waiting on an unresponsive worker).
"""
def test_rpc_times_out_when_lock_held_directly(self):
@@ -370,10 +363,10 @@ def _hold_lock():
with pytest.raises(TimeoutError):
engine.collective_rpc("health", timeout=0.5)
- def test_rpc_times_out_when_request_execution_stalled_on_worker(self):
+ def test_rpc_times_out_when_add_req_stalled_on_worker(self):
"""Real-world scenario the bot flagged:
- The scheduler/execute path holds ``_rpc_lock`` while blocked on
+ ``add_req`` holds ``_rpc_lock`` while blocked on
``executor._result_mq.dequeue()`` because the worker never replies.
A concurrent ``collective_rpc(timeout=...)`` must still time out
instead of hanging forever waiting for the lock.
@@ -439,353 +432,3 @@ def _hold_and_release():
t.join(5)
assert result.error == "ok"
-
-
-# ───────── error handling: EngineDeadError propagation through layers ─────
-
-
-class TestMultiprocExecutorRaisesEngineDeadError:
- """``collective_rpc`` raises ``EngineDeadError`` when the engine is failed."""
-
- def test_collective_rpc_raises_when_is_failed(self):
- executor = object.__new__(MultiprocDiffusionExecutor)
- executor._closed = False
- executor._broadcast_mq = MagicMock()
- executor._result_mq = MagicMock()
- executor._result_mq.dequeue = MagicMock(side_effect=TimeoutError)
- executor.is_failed = True
-
- with pytest.raises(EngineDeadError):
- executor.collective_rpc(
- "generate",
- args=(MagicMock(),),
- unique_reply_rank=0,
- exec_all_ranks=True,
- )
-
- def test_collective_rpc_raises_mid_dequeue_when_is_failed(self):
- """Worker dies while we are polling the dequeue loop."""
- executor, _, res_q = _make_executor()
-
- call_count = 0
- orig_dequeue = executor._result_mq.dequeue
-
- def _dying_dequeue(timeout=None):
- nonlocal call_count
- call_count += 1
- if call_count == 1:
- executor.is_failed = True
- raise TimeoutError
- return orig_dequeue(timeout=timeout)
-
- executor._result_mq.dequeue = _dying_dequeue
-
- with pytest.raises(EngineDeadError):
- executor.collective_rpc(
- "generate",
- args=(MagicMock(),),
- unique_reply_rank=0,
- exec_all_ranks=True,
- )
-
-
-class TestDiffusionEngineDeadErrorPassthrough:
- """``DiffusionEngine.add_req_and_wait_for_response`` re-raises
- ``EngineDeadError`` from executor and wraps other errors."""
-
- def test_engine_dead_error_propagates(self):
- engine, executor, _, _ = _make_engine()
- engine.execute_fn = Mock(side_effect=EngineDeadError())
-
- with pytest.raises(EngineDeadError):
- engine.add_req_and_wait_for_response(_mock_request("dead"))
-
- def test_runtime_error_wrapped_in_output(self):
- engine, executor, _, _ = _make_engine()
- engine.execute_fn = Mock(side_effect=RuntimeError("gpu fault"))
-
- out = engine.add_req_and_wait_for_response(_mock_request("fault"))
- assert isinstance(out, DiffusionOutput)
- assert "gpu fault" in out.error
-
-
-class TestStageDiffusionClientErrorPropagation:
- """Error surface behaviour of ``StageDiffusionClient``.
-
- Uses ``object.__new__`` to construct a client without spawning a real
- subprocess, then manually sets the fields needed for each test.
- """
-
- def _make_client(self, *, engine_dead=False, proc_alive=True):
- from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient
-
- client = object.__new__(StageDiffusionClient)
- client.stage_id = 0
- client.final_output = True
- client.final_output_type = "image"
- client.default_sampling_params = None
- client.custom_process_input_func = None
- client.engine_input_source = None
-
- client._output_queue = asyncio.Queue()
- client._rpc_results = {}
- client._pending_rpcs = set()
- client._tasks = {}
- client._shutting_down = False
- client._engine_dead = engine_dead
- client._owns_process = True
- client._proc = MagicMock(
- is_alive=MagicMock(return_value=proc_alive),
- exitcode=1,
- )
- client._request_socket = MagicMock()
- client._response_socket = MagicMock()
- client._encoder = MagicMock()
- client._decoder = MagicMock()
-
- return client
-
- @pytest.mark.asyncio
- async def test_add_request_raises_when_dead(self):
- client = self._make_client(engine_dead=True)
-
- with pytest.raises(EngineDeadError):
- await client.add_request_async("req-3", "test prompt", None)
-
- def test_check_health_raises_when_dead(self):
- client = self._make_client(engine_dead=True)
-
- with pytest.raises(EngineDeadError):
- client.check_health()
-
- def test_check_health_ok_when_alive(self):
- client = self._make_client()
- client.check_health()
-
- def test_get_output_raises_engine_dead_when_dead(self):
- """When ``_engine_dead`` is True and the output queue is empty,
- ``get_diffusion_output_nowait`` must raise ``EngineDeadError``."""
- client = self._make_client(engine_dead=True)
- # Simulate _drain_responses as a no-op (no ZMQ socket)
- client._response_socket.recv.side_effect = zmq.Again
-
- with pytest.raises(EngineDeadError):
- client.get_diffusion_output_nowait()
-
- def test_get_output_returns_none_when_alive_and_empty(self):
- """When the engine is alive and the queue is empty, return None."""
- client = self._make_client()
- client._response_socket.recv.side_effect = zmq.Again
-
- assert client.get_diffusion_output_nowait() is None
-
- def test_check_health_raises_when_proc_dead(self):
- """``check_health`` detects a dead subprocess via ``_proc.is_alive()``
- and raises ``EngineDeadError``, setting ``_engine_dead`` as a
- side effect."""
- client = self._make_client(proc_alive=False)
-
- with pytest.raises(EngineDeadError, match="not alive"):
- client.check_health()
-
- assert client._engine_dead is True
-
- def test_get_output_raises_when_proc_dead(self):
- """When the subprocess has died (non-signal exit) and the output
- queue is empty, ``get_diffusion_output_nowait`` must raise
- ``EngineDeadError`` with the exit code."""
- client = self._make_client(proc_alive=False)
- client._response_socket.recv.side_effect = zmq.Again
-
- with pytest.raises(EngineDeadError, match="exit code"):
- client.get_diffusion_output_nowait()
-
- assert client._engine_dead is True
-
- def test_get_output_returns_none_on_signal_death(self):
- """When the subprocess was killed by a signal (exit code > 128),
- ``get_diffusion_output_nowait`` returns ``None`` and sets
- ``_shutting_down`` instead of raising."""
- client = self._make_client(proc_alive=False)
- client._proc.exitcode = 137 # SIGKILL (128 + 9)
- client._response_socket.recv.side_effect = zmq.Again
-
- result = client.get_diffusion_output_nowait()
-
- assert result is None
- assert client._shutting_down is True
- assert client._engine_dead is True
-
-
-# ───────── monitor thread & death sentinel integration tests ─────────
-
-
-def _poll_flag(get_flag, *, timeout=5.0, interval=0.05) -> bool:
- """Poll until ``get_flag()`` returns True or *timeout* elapses."""
- deadline = time.monotonic() + timeout
- while time.monotonic() < deadline:
- if get_flag():
- return True
- time.sleep(interval)
- return False
-
-
-def _make_short_lived_process() -> mp.Process:
- """Spawn a real subprocess that exits immediately.
-
- The process must be started with ``"fork"`` (or the platform default)
- so that it can use a plain ``lambda`` as its target — ``"spawn"`` would
- fail to pickle it.
- """
- ctx = mp.get_context("fork")
- p = ctx.Process(target=lambda: None, name="ShortLivedWorker-0")
- p.start()
- return p
-
-
-class TestMultiprocExecutorWorkerMonitor:
- """Integration tests for ``start_worker_monitor``.
-
- Uses real short-lived subprocesses so that OS-level sentinel fd
- readiness is exercised end-to-end.
- """
-
- def test_worker_monitor_sets_is_failed_and_calls_callbacks_on_death(self):
- """When a worker process dies, the monitor thread must:
- 1. Set ``is_failed = True``
- 2. Call ``shutdown()`` (which sets ``_closed = True``)
- 3. Invoke all registered failure callbacks
- """
- executor = object.__new__(MultiprocDiffusionExecutor)
- executor._closed = False
- executor.is_failed = False
- executor._failure_callbacks = []
- executor._broadcast_mq = None
- executor._result_mq = None
- executor.resources = None
- # Use a no-op so shutdown() doesn't crash on None resources.
- executor._finalizer = lambda: None
-
- proc = _make_short_lived_process()
- executor._processes = [proc]
-
- callback_called = threading.Event()
- executor.register_failure_callback(callback_called.set)
-
- executor.start_worker_monitor()
-
- # Wait for the process to exit and the monitor to react.
- proc.join(5)
- assert _poll_flag(lambda: executor.is_failed), "is_failed was not set"
- assert executor._closed, "shutdown() was not called"
- assert callback_called.wait(timeout=2), "failure callback was not invoked"
-
- def test_worker_monitor_noop_when_already_closed(self):
- """If ``_closed`` is already True when the process dies (orderly
- shutdown), the monitor must *not* set ``is_failed``."""
- executor = object.__new__(MultiprocDiffusionExecutor)
- executor._closed = True # already shut down
- executor.is_failed = False
- executor._failure_callbacks = []
- executor._broadcast_mq = None
- executor._result_mq = None
- executor.resources = None
- executor._finalizer = lambda: None
-
- proc = _make_short_lived_process()
- executor._processes = [proc]
-
- executor.start_worker_monitor()
- proc.join(5)
-
- # Give the monitor thread a chance to run (it should early-return).
- time.sleep(0.3)
- assert not executor.is_failed, "is_failed should remain False on orderly shutdown"
-
-
-class TestStageDiffusionClientProcMonitor:
- """Integration test for ``StageDiffusionClient._start_proc_monitor``.
-
- Uses a real short-lived subprocess to verify the sentinel-based
- detection pipeline.
- """
-
- def test_proc_monitor_sets_engine_dead_on_process_death(self):
- """When the subprocess dies, the monitor thread must set
- ``_engine_dead = True``."""
- from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient
-
- client = object.__new__(StageDiffusionClient)
- client.stage_id = 0
- client._shutting_down = False
- client._engine_dead = False
-
- proc = _make_short_lived_process()
- client._proc = proc
-
- client._start_proc_monitor()
- proc.join(5)
-
- assert _poll_flag(lambda: client._engine_dead), "_engine_dead was not set"
-
-
-class TestDrainResponsesDeathSentinel:
- """Tests for death sentinel and error routing in
- ``StageDiffusionClient._drain_responses()``.
- """
-
- def _make_client(self):
- from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient
-
- client = object.__new__(StageDiffusionClient)
- client.stage_id = 0
- client._engine_dead = False
- client._shutting_down = False
- client._output_queue = asyncio.Queue()
- client._rpc_results = {}
- client._pending_rpcs = set()
- client._response_socket = MagicMock()
- client._decoder = MagicMock()
- return client
-
- def test_drain_responses_sets_engine_dead_on_death_sentinel(self):
- """When ``_drain_responses`` receives the ``DIFFUSION_PROC_DEAD``
- sentinel, it must set ``_engine_dead = True`` and stop draining
- (decoder is never called)."""
- client = self._make_client()
-
- # First recv returns the death sentinel, second would be a normal
- # message but should never be reached.
- client._response_socket.recv.side_effect = [
- StageDiffusionProc.DIFFUSION_PROC_DEAD,
- b"should-not-be-reached",
- ]
-
- client._drain_responses()
-
- assert client._engine_dead is True
- client._decoder.decode.assert_not_called()
-
- def test_drain_responses_routes_error_as_omni_request_output(self):
- """When ``_drain_responses`` receives a ``{"type": "error"}`` message
- with a ``request_id``, it must place an ``OmniRequestOutput`` with
- the error on ``_output_queue``."""
- client = self._make_client()
-
- error_msg = {
- "type": "error",
- "request_id": "req-fail",
- "error": "gpu fault",
- }
- # First recv returns the encoded error, second raises zmq.Again.
- client._response_socket.recv.side_effect = [b"encoded-error", zmq.Again]
- client._decoder.decode.return_value = error_msg
-
- client._drain_responses()
-
- assert not client._output_queue.empty()
- output = client._output_queue.get_nowait()
- assert isinstance(output, OmniRequestOutput)
- assert output.request_id == "req-fail"
- assert output.error == "gpu fault"
- assert output.finished is True
diff --git a/tests/diffusion/test_stage_diffusion_proc.py b/tests/diffusion/test_stage_diffusion_proc.py
deleted file mode 100644
index f1cf4f9b7d1..00000000000
--- a/tests/diffusion/test_stage_diffusion_proc.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import asyncio
-from concurrent.futures import ThreadPoolExecutor
-from dataclasses import asdict
-from types import SimpleNamespace
-
-import pytest
-
-from vllm_omni.diffusion.stage_diffusion_proc import StageDiffusionProc
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]
-
-
-def test_process_batch_request_preserves_parent_request_id_and_kv_sender_info():
- async def run_test():
- captured = {}
-
- def step(request):
- captured["request"] = request
- return [
- SimpleNamespace(
- images=["img-1"],
- _multimodal_output={},
- _custom_output={},
- metrics={},
- stage_durations={},
- peak_memory_mb=0.0,
- latents=None,
- trajectory_latents=None,
- trajectory_timesteps=None,
- trajectory_log_probs=None,
- trajectory_decoded=None,
- final_output_type="image",
- ),
- SimpleNamespace(
- images=["img-2"],
- _multimodal_output={},
- _custom_output={},
- metrics={},
- stage_durations={},
- peak_memory_mb=0.0,
- latents=None,
- trajectory_latents=None,
- trajectory_timesteps=None,
- trajectory_log_probs=None,
- trajectory_decoded=None,
- final_output_type="image",
- ),
- ]
-
- proc = object.__new__(StageDiffusionProc)
- proc._engine = SimpleNamespace(step=step)
- proc._executor = ThreadPoolExecutor(max_workers=1)
-
- try:
- result = await proc._process_batch_request(
- request_id="req-parent",
- prompts=["hello", "world"],
- sampling_params_dict=asdict(OmniDiffusionSamplingParams()),
- kv_sender_info={0: {"host": "10.0.0.2", "zmq_port": 50151}},
- )
- finally:
- proc._executor.shutdown(wait=True)
-
- request = captured["request"]
- assert request.request_id == "req-parent"
- assert request.request_ids == ["req-parent-0", "req-parent-1"]
- assert request.kv_sender_info == {0: {"host": "10.0.0.2", "zmq_port": 50151}}
- assert result.request_id == "req-parent"
- assert result.images == ["img-1", "img-2"]
-
- asyncio.run(run_test())
diff --git a/tests/distributed/omni_connectors/test_basic_connectors.py b/tests/distributed/omni_connectors/test_basic_connectors.py
index 662d41fe01e..1b1965355e9 100644
--- a/tests/distributed/omni_connectors/test_basic_connectors.py
+++ b/tests/distributed/omni_connectors/test_basic_connectors.py
@@ -9,7 +9,7 @@
from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec
from vllm_omni.distributed.omni_connectors.utils.serialization import OmniSerializer
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+# pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
def test_basic_serialization():
@@ -120,61 +120,3 @@ def test_get_invalid_metadata(shm_connector):
result = shm_connector.get("stage_0", "stage_1", "req_3", {"unknown": "format"})
assert result is None
-
-
-def test_mooncake_connector_defaults_missing_host_to_detected_ip(monkeypatch: pytest.MonkeyPatch):
- import vllm_omni.distributed.omni_connectors.connectors.mooncake_transfer_engine_connector as mooncake_module
-
- class _FakePool:
- is_cuda = False
-
- def pin_memory(self):
- return self
-
- def data_ptr(self):
- return 1234
-
- class _FakeTransferEngine:
- def initialize(self, host, mode, protocol, device_name):
- self.host = host
- self.mode = mode
- self.protocol = protocol
- self.device_name = device_name
- return 0
-
- def get_rpc_port(self):
- return 23456
-
- def register_memory(self, base_ptr, pool_size):
- del base_ptr, pool_size
- return 0
-
- def unregister_memory(self, base_ptr):
- del base_ptr
- return 0
-
- monkeypatch.setattr(mooncake_module, "TransferEngine", _FakeTransferEngine)
- monkeypatch.setattr(mooncake_module.torch, "empty", lambda *args, **kwargs: _FakePool())
- monkeypatch.setattr(
- mooncake_module.MooncakeTransferEngineConnector,
- "_get_local_ip",
- lambda self: "10.20.30.40",
- )
- monkeypatch.setattr(
- mooncake_module.MooncakeTransferEngineConnector,
- "_zmq_listener_loop",
- lambda self: self._listener_ready.set(),
- )
-
- connector = mooncake_module.MooncakeTransferEngineConnector(
- {
- "zmq_port": 50051,
- "memory_pool_size": 4096,
- }
- )
- try:
- assert connector.host == "10.20.30.40"
- assert connector.engine.host == "10.20.30.40"
- assert connector.get_connection_info()["host"] == "10.20.30.40"
- finally:
- connector.close()
diff --git a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py
index 22f7c268be2..dddf49a05de 100644
--- a/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py
+++ b/tests/distributed/omni_connectors/test_chunk_transfer_adapter.py
@@ -4,15 +4,12 @@
import threading
from collections import deque
from types import SimpleNamespace
-from unittest.mock import patch
import pytest
import torch
from pytest_mock import MockerFixture
-from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler
from vllm.v1.request import RequestStatus
-from vllm_omni.data_entry_keys import OmniPayload
from vllm_omni.distributed.omni_connectors.transfer_adapter.base import OmniTransferAdapterBase
from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import (
OmniChunkTransferAdapter,
@@ -112,11 +109,7 @@ def test_load_poll(build_adapter):
request = _req("req-1", RequestStatus.WAITING, external_req_id="external-1")
adapter.load_async(request)
- payload: OmniPayload = {
- "codes": {"audio": [[1]]},
- "hidden_states": {"output": torch.tensor([[2.0]])},
- "meta": {"finished": torch.tensor(True, dtype=torch.bool)},
- }
+ payload = {"code_predictor_codes": [[1]], "hidden_states": torch.tensor([[2.0]]), "finished": True}
connector.get.return_value = (payload, 16)
adapter._poll_single_request(request)
@@ -140,68 +133,15 @@ def test_save_async(build_adapter):
assert task["is_finished"] is False
-def test_send_single_request_cleans_up_after_finished_payload(build_adapter, monkeypatch):
- adapter, _ = build_adapter(stage_id=1)
- request = _req("req-finished", RequestStatus.FINISHED_STOPPED, external_req_id="ext-finished")
-
- adapter.custom_process_next_stage_input_func = lambda **kwargs: {"x": [1], "finished": True}
- cleanup_calls = []
- monkeypatch.setattr(adapter, "cleanup", lambda *a, **kw: cleanup_calls.append((a, kw)))
-
- adapter._send_single_request({"pooling_output": None, "request": request, "is_finished": True})
-
- assert len(cleanup_calls) == 1
- args, _ = cleanup_calls[0]
- assert args[0] == "req-finished"
- assert args[1] == "ext-finished"
-
-
def test_update_request_payload(build_adapter):
adapter, _ = build_adapter()
- first: OmniPayload = {
- "hidden_states": {"output": torch.tensor([[1.0]])},
- "codes": {"audio": [1]},
- "meta": {"finished": torch.tensor(False, dtype=torch.bool)},
- }
- adapter._update_request_payload("ext", first)
- second: OmniPayload = {
- "hidden_states": {"output": torch.tensor([[2.0]])},
- "codes": {"audio": [2]},
- "meta": {"finished": torch.tensor(True, dtype=torch.bool)},
- }
- merged = adapter._update_request_payload("ext", second)
-
- assert torch.equal(merged["hidden_states"]["output"], torch.tensor([[1.0], [2.0]]))
- assert merged["codes"]["audio"] == [1, 2]
- assert merged["meta"]["finished"].item() is True
-
-
-def test_load_poll_ar_request_additional_information_concats_tensors(build_adapter):
- adapter, connector = build_adapter(stage_id=2, model_mode="ar")
- request = _req("req-merged", RequestStatus.WAITING, external_req_id="ext-merged")
-
- adapter.request_ids_mapping["req-merged"] = "ext-merged"
- adapter.request_payload["ext-merged"] = {
- "hidden_states": {"output": torch.tensor([[1.0]])},
- "ids": {"prompt": [11, 12]},
- "meta": {"finished": torch.tensor(False, dtype=torch.bool)},
- }
- payload: OmniPayload = {
- "hidden_states": {"output": torch.tensor([[2.0]])},
- "meta": {"finished": torch.tensor(True, dtype=torch.bool)},
- }
- connector.get.return_value = (payload, 8)
-
- adapter._poll_single_request(request)
+ adapter._update_request_payload("ext", {"h": torch.tensor([[1.0]]), "codes": [1], "finished": False})
+ merged = adapter._update_request_payload("ext", {"h": torch.tensor([[2.0]]), "codes": [2], "finished": True})
- assert torch.equal(
- request.additional_information["hidden_states"]["output"],
- torch.tensor([[1.0], [2.0]]),
- )
- # Keys absent from the new chunk are dropped (matches main's behavior).
- assert "ids" not in request.additional_information
- assert request.additional_information["meta"]["finished"].item() is True
+ assert torch.equal(merged["h"], torch.tensor([[1.0], [2.0]]))
+ assert merged["codes"] == [1, 2]
+ assert merged["finished"] is True
def test_process_and_restore_queues(build_adapter):
@@ -363,10 +303,7 @@ def test_cleanup_after_poll_flow(build_adapter):
adapter.load_async(request)
adapter.request_ids_mapping["req-flow"] = "ext-flow"
- payload: OmniPayload = {
- "hidden_states": {"output": torch.tensor([[1.0]])},
- "meta": {"finished": torch.tensor(True, dtype=torch.bool)},
- }
+ payload = {"hidden_states": torch.tensor([[1.0]]), "finished": True}
connector.get.return_value = (payload, 8)
adapter._poll_single_request(request)
@@ -382,27 +319,6 @@ def test_cleanup_after_poll_flow(build_adapter):
assert "ext-flow" not in adapter.request_payload
-def test_finish_requests_restores_status(build_adapter):
- """Abort path must pop ``requests_origin_status`` and restore pre-wait status.
-
- While ``process_pending_chunks`` holds a request off the scheduler queues, the
- adapter records the prior status (WAITING or RUNNING). ``finish_requests`` must
- put that status back on the live ``Request`` so base ``Scheduler.finish_requests``
- can finish bookkeeping without inconsistent state / crashes.
- """
- adapter, _ = build_adapter(stage_id=1)
- req_id = "req-abort-during-chunk"
- prior = RequestStatus.RUNNING
- request = _req(req_id, RequestStatus.WAITING_FOR_CHUNK)
- adapter.requests_origin_status[req_id] = prior
- requests_map = {req_id: request}
-
- adapter.finish_requests([req_id], RequestStatus.FINISHED_ABORTED, requests_map)
-
- assert request.status == prior
- assert req_id not in adapter.requests_origin_status
-
-
# ---------------------------------------------------------------
# Scheduler trigger tests
# ---------------------------------------------------------------
@@ -493,114 +409,3 @@ def test_generation_scheduler_calls_cleanup_on_finished(monkeypatch, mocker: Moc
args, _ = cleanup_calls[0]
assert args[0] == "req-s1"
assert args[1] == "ext-s1"
-
-
-def test_ar_scheduler_defers_cleanup_and_queues_save_on_finished(mocker: MockerFixture):
- """OmniARScheduler should enqueue save; adapter cleanup is handled in save thread."""
- cleanup_calls = []
- save_calls = []
-
- adapter_mock = mocker.MagicMock()
- adapter_mock.cleanup = lambda *a, **kw: cleanup_calls.append((a, kw))
- adapter_mock.save_async = lambda *a, **kw: save_calls.append((a, kw))
-
- from vllm_omni.core.sched.omni_ar_scheduler import OmniARScheduler
-
- scheduler = mocker.MagicMock()
- scheduler.chunk_transfer_adapter = adapter_mock
- scheduler.connector = None
- scheduler.perf_metrics = None
- scheduler.log_stats = False
- scheduler.recompute_kv_load_failures = False
- scheduler.structured_output_manager = mocker.MagicMock()
- scheduler.structured_output_manager.should_advance.return_value = False
- scheduler.finished_req_ids_dict = {}
- scheduler.kv_cache_manager = mocker.MagicMock()
- scheduler.kv_cache_manager.take_events.return_value = None
- scheduler.kv_event_publisher = mocker.MagicMock()
- scheduler.waiting_for_transfer_free = set()
- scheduler.transfer_triggered_requests = set()
- scheduler.active_kv_transfers = set()
-
- request = _HashableRequest(
- request_id="req-ar",
- external_req_id="ext-ar",
- status=RequestStatus.RUNNING,
- is_finished=lambda: False,
- num_computed_tokens=1,
- num_prompt_tokens=1,
- prompt_token_ids=[1],
- num_output_placeholders=0,
- sampling_params=None,
- pooling_params=None,
- stop_reason=None,
- client_index=0,
- take_events=lambda: [],
- trace_headers=None,
- num_cached_tokens=0,
- num_external_computed_tokens=0,
- num_nans_in_logits=0,
- get_finished_reason=lambda: "stop",
- )
- scheduler.requests = {"req-ar": request}
-
- scheduler._update_request_with_output = mocker.MagicMock(return_value=([], True))
- scheduler._process_kv_transfer_trigger = mocker.MagicMock(return_value=False)
- scheduler._handle_stopped_request = mocker.MagicMock(return_value=True)
- scheduler._free_request = mocker.MagicMock(return_value=None)
- scheduler._get_routed_experts = mocker.MagicMock(return_value=None)
- scheduler.running = [request]
- scheduler.waiting = mocker.MagicMock()
- scheduler.waiting.remove_requests = mocker.MagicMock()
- scheduler.make_spec_decoding_stats = mocker.MagicMock(return_value=None)
- scheduler.make_stats = mocker.MagicMock(return_value=None)
-
- scheduler_output = SimpleNamespace(
- num_scheduled_tokens={"req-ar": 1},
- scheduled_spec_decode_tokens={},
- num_invalid_spec_tokens=0,
- )
- model_runner_output = SimpleNamespace(
- sampled_token_ids=[[123]],
- logprobs=None,
- prompt_logprobs_dict={},
- pooler_output=None,
- num_nans_in_logits=None,
- kv_connector_output=None,
- cudagraph_stats=None,
- req_id_to_index={"req-ar": 0},
- kv_extracted_req_ids=None,
- )
-
- OmniARScheduler.update_from_output(scheduler, scheduler_output, model_runner_output)
-
- assert len(cleanup_calls) == 0
- assert len(save_calls) == 1
-
-
-def test_omni_ar_scheduler_finish_requests(mocker: MockerFixture):
- """``OmniARScheduler.finish_requests`` must run chunk adapter hook before vLLM base."""
- from vllm_omni.core.sched.omni_ar_scheduler import OmniARScheduler
-
- order: list[str] = []
-
- adapter = mocker.MagicMock()
-
- def _adapter_finish(request_ids, finished_status, requests):
- order.append("adapter")
- return []
-
- adapter.finish_requests.side_effect = _adapter_finish
-
- def _super_finish(_self, request_ids, finished_status):
- order.append("super")
- return []
-
- sched = OmniARScheduler.__new__(OmniARScheduler)
- sched.chunk_transfer_adapter = adapter
- sched.requests = {}
-
- with patch.object(VLLMScheduler, "finish_requests", _super_finish):
- OmniARScheduler.finish_requests(sched, ["r1"], RequestStatus.FINISHED_ABORTED)
-
- assert order == ["adapter", "super"]
diff --git a/tests/distributed/omni_connectors/test_kv_flow.py b/tests/distributed/omni_connectors/test_kv_flow.py
index cea18601932..b12fc013b7f 100644
--- a/tests/distributed/omni_connectors/test_kv_flow.py
+++ b/tests/distributed/omni_connectors/test_kv_flow.py
@@ -1,14 +1,8 @@
-import json
-import struct
-
-import numpy as np
import pytest
import torch
-import vllm_omni.distributed.omni_connectors.kv_transfer_manager as kv_transfer_manager_module
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.distributed.omni_connectors.kv_transfer_manager import (
- KVCacheTransferData,
OmniKVCacheConfig,
OmniKVTransferManager,
)
@@ -66,35 +60,6 @@ def common_constants():
}
-def _decode_stored_payload(data):
- if isinstance(data, torch.Tensor) and data.dtype == torch.uint8 and data.dim() == 1:
- return KVCacheTransferData.from_bytes(data.cpu().numpy().tobytes())
-
- if isinstance(data, (bytes, bytearray, memoryview)):
- return KVCacheTransferData.from_bytes(data)
-
- return data
-
-
-def _make_serialized_payload() -> tuple[bytes, torch.Tensor]:
- key_tensor = torch.arange(12, dtype=torch.float32).reshape(3, 4)
- payload = KVCacheTransferData(
- request_id="req-payload",
- layer_blocks={"key_cache": [key_tensor], "value_cache": [None]},
- block_ids=[1],
- metadata={"seq_len": 3},
- ).to_bytes()
- return payload, key_tensor
-
-
-def _rewrite_serialized_header(payload: bytes, mutate_header) -> bytes:
- header_len = struct.unpack(">I", payload[:4])[0]
- header = json.loads(payload[4 : 4 + header_len])
- mutate_header(header)
- new_header = json.dumps(header, separators=(",", ":")).encode("utf-8")
- return struct.pack(">I", len(new_header)) + new_header + payload[4 + header_len :]
-
-
def test_manager_extraction(kv_config, mock_connector, common_constants):
"""Test extraction and sending logic in OmniKVTransferManager."""
num_layers = common_constants["num_layers"]
@@ -130,7 +95,7 @@ def test_manager_extraction(kv_config, mock_connector, common_constants):
expected_key = f"stage1->stage2:{full_request_id}"
assert expected_key in mock_connector.store
- data = _decode_stored_payload(mock_connector.store[expected_key])
+ data = mock_connector.store[expected_key]
assert data["request_id"] == req_id
assert "layer_blocks" in data
assert len(data["layer_blocks"]["key_cache"]) == num_layers
@@ -141,116 +106,6 @@ def test_manager_extraction(kv_config, mock_connector, common_constants):
assert data["layer_blocks"]["key_cache"][0].shape == expected_shape
-def test_from_bytes_rejects_out_of_bounds_header_len():
- payload, _ = _make_serialized_payload()
- bad_payload = struct.pack(">I", len(payload)) + payload[4:]
-
- with pytest.raises(ValueError, match="header_len"):
- KVCacheTransferData.from_bytes(bad_payload)
-
- with pytest.raises(ValueError, match="header_len"):
- KVCacheTransferData.from_bytes_gpu(torch.tensor(list(bad_payload), dtype=torch.uint8))
-
-
-def test_from_bytes_rejects_out_of_bounds_tensor_span():
- payload, _ = _make_serialized_payload()
- bad_payload = _rewrite_serialized_header(payload, lambda header: header["td"][0].update({"o": 4096}))
-
- with pytest.raises(ValueError, match="tensor span"):
- KVCacheTransferData.from_bytes(bad_payload)
-
- with pytest.raises(ValueError, match="tensor span"):
- KVCacheTransferData.from_bytes_gpu(torch.tensor(list(bad_payload), dtype=torch.uint8))
-
-
-def test_from_bytes_rejects_unsupported_dtype():
- payload, _ = _make_serialized_payload()
- bad_payload = _rewrite_serialized_header(payload, lambda header: header["td"][0].update({"d": "cuda"}))
-
- with pytest.raises(ValueError, match="Unsupported dtype"):
- KVCacheTransferData.from_bytes(bad_payload)
-
- with pytest.raises(ValueError, match="Unsupported dtype"):
- KVCacheTransferData.from_bytes_gpu(torch.tensor(list(bad_payload), dtype=torch.uint8))
-
-
-def test_from_bytes_uses_explicit_layer_index_descriptor():
- payload, key_tensor = _make_serialized_payload()
- payload_with_explicit_index = _rewrite_serialized_header(
- payload,
- lambda header: header["td"][0].update({"n": "key_cache_extra_suffix", "i": 0}),
- )
-
- data = KVCacheTransferData.from_bytes(payload_with_explicit_index)
-
- assert torch.equal(data["layer_blocks"]["key_cache"][0], key_tensor)
-
-
-def test_update_sender_info_uses_configured_source_stage():
- config = OmniKVCacheConfig(
- connector_config={"type": "mock"},
- stage_id=2,
- engine_input_source=[1],
- need_recv_cache=True,
- )
- manager = OmniKVTransferManager(config)
-
- manager.update_sender_info(
- {
- 0: {"host": "10.0.0.1", "zmq_port": 50151},
- 1: {"host": "10.0.0.2", "zmq_port": 50152},
- }
- )
-
- assert manager.config.connector_config["sender_host"] == "10.0.0.2"
- assert manager.config.connector_config["sender_zmq_port"] == 50152
-
-
-def test_clone_received_payload_tensors_breaks_buffer_alias():
- payload, key_tensor = _make_serialized_payload()
- raw = np.frombuffer(bytearray(payload), dtype=np.uint8)
- data = KVCacheTransferData.from_bytes(memoryview(raw))
-
- OmniKVTransferManager._clone_received_payload_tensors(data)
- raw[:] = 0
-
- assert torch.equal(data["layer_blocks"]["key_cache"][0], key_tensor)
-
-
-def test_receive_kv_cache_uses_exponential_backoff(monkeypatch):
- config = OmniKVCacheConfig(
- connector_config={"type": "mock"},
- from_stage="sender",
- stage_id="receiver",
- need_recv_cache=True,
- recv_timeout=0.3,
- )
- manager = OmniKVTransferManager(config)
-
- class _NeverReadyConnector:
- def get(self, **kwargs):
- del kwargs
- return None
-
- manager._connector = _NeverReadyConnector()
-
- now = {"value": 0.0}
- sleep_intervals = []
-
- monkeypatch.setattr(kv_transfer_manager_module.time, "time", lambda: now["value"])
-
- def _fake_sleep(interval: float) -> None:
- sleep_intervals.append(interval)
- now["value"] += interval
-
- monkeypatch.setattr(kv_transfer_manager_module.time, "sleep", _fake_sleep)
-
- data, size = manager.receive_kv_cache_for_request("req-backoff")
-
- assert (data, size) == (None, 0)
- assert sleep_intervals == pytest.approx([0.01, 0.02, 0.04, 0.08, 0.16])
-
-
def test_manager_extraction_tuple_layout(kv_config, mock_connector, common_constants):
"""Test extraction with tuple layout."""
num_layers = common_constants["num_layers"]
@@ -280,7 +135,7 @@ def test_manager_extraction_tuple_layout(kv_config, mock_connector, common_const
expected_key = f"stage1->stage2:{full_request_id}"
assert expected_key in mock_connector.store
- data = _decode_stored_payload(mock_connector.store[expected_key])
+ data = mock_connector.store[expected_key]
expected_shape = (seq_len, num_heads, head_dim)
for idx in range(len(kv_caches)):
assert data["layer_blocks"]["key_cache"][idx].shape == expected_shape
@@ -310,7 +165,7 @@ def test_manager_extraction_mismatched_kv_block_counts(kv_config, mock_connector
expected_key = f"stage1->stage2:{full_request_id}"
assert expected_key in mock_connector.store
- data = _decode_stored_payload(mock_connector.store[expected_key])
+ data = mock_connector.store[expected_key]
expected_shape = (2 * block_size, num_heads, head_dim)
assert data["layer_blocks"]["key_cache"][0].shape == expected_shape
assert data["layer_blocks"]["value_cache"][0].shape == expected_shape
@@ -399,82 +254,6 @@ def test_manager_reception(kv_config, mock_connector, common_constants):
assert req.kv_metadata["seq_len"] == seq_len
-def test_manager_reception_prefers_parent_request_id_for_batched_request(kv_config, mock_connector, common_constants):
- """Batched diffusion requests must fetch KV using the parent/global request ID."""
- num_layers = common_constants["num_layers"]
- num_heads = common_constants["num_heads"]
- head_dim = common_constants["head_dim"]
- seq_len = common_constants["seq_len"]
- parent_req_id = common_constants["req_id"]
-
- expected_shape = (seq_len, num_heads, head_dim)
- key_cache = [torch.randn(expected_shape) for _ in range(num_layers)]
- value_cache = [torch.randn(expected_shape) for _ in range(num_layers)]
-
- data_to_receive = {
- "request_id": parent_req_id,
- "layer_blocks": {"key_cache": key_cache, "value_cache": value_cache},
- "metadata": {"seq_len": seq_len},
- "block_ids": [],
- }
-
- manager = OmniKVTransferManager(kv_config)
- manager._connector = mock_connector
-
- full_request_id = f"omni_stage1_to_stage2_kv_cache_{parent_req_id}"
- store_key = f"stage1->stage2:{full_request_id}"
- mock_connector.store[store_key] = data_to_receive
-
- req = OmniDiffusionRequest(
- prompts=["prompt-a", "prompt-b"],
- sampling_params=OmniDiffusionSamplingParams(),
- request_ids=[f"{parent_req_id}-0", f"{parent_req_id}-1"],
- request_id=parent_req_id,
- )
-
- success = manager.receive_kv_cache(req, target_device=torch.device("cpu"))
-
- assert success
- assert req.kv_metadata["seq_len"] == seq_len
- assert torch.allclose(req.past_key_values.key_cache[0], key_cache[0])
-
-
-def test_receive_multi_kv_cache_uses_parent_request_id_for_cfg_collection(kv_config):
- manager = OmniKVTransferManager(kv_config)
-
- seen = {}
-
- def collect_cfg(request_id, cfg_request_ids, kv_transfer_manager, target_device):
- seen["request_id"] = request_id
- seen["cfg_request_ids"] = cfg_request_ids
- seen["kv_transfer_manager"] = kv_transfer_manager
- seen["target_device"] = target_device
- return {"cfg_text_kv_metadata": {"ok": True}}
-
- req = OmniDiffusionRequest(
- prompts=["prompt-a", "prompt-b"],
- sampling_params=OmniDiffusionSamplingParams(),
- request_ids=["req-parent-0", "req-parent-1"],
- request_id="req-parent",
- )
- req.sampling_params.cfg_kv_request_ids = {"cfg_text": "req-parent__cfg_text"}
-
- manager.receive_kv_cache = lambda request, target_device=None: request is req
-
- success = manager.receive_multi_kv_cache(
- req,
- cfg_kv_collect_func=collect_cfg,
- target_device=torch.device("cpu"),
- )
-
- assert success
- assert seen["request_id"] == "req-parent"
- assert seen["cfg_request_ids"] == {"cfg_text": "req-parent__cfg_text"}
- assert seen["kv_transfer_manager"] is manager
- assert seen["target_device"] == torch.device("cpu")
- assert req.sampling_params.cfg_text_kv_metadata == {"ok": True}
-
-
def test_integration_flow(common_constants):
"""Simulate extraction -> connector -> reception."""
num_layers = common_constants["num_layers"]
diff --git a/tests/distributed/omni_connectors/test_shm_connector.py b/tests/distributed/omni_connectors/test_shm_connector.py
deleted file mode 100644
index e702318e3f3..00000000000
--- a/tests/distributed/omni_connectors/test_shm_connector.py
+++ /dev/null
@@ -1,184 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for SharedMemoryConnector focusing on TP / CFG / metadata fallback."""
-
-import pytest
-
-from vllm_omni.distributed.omni_connectors.connectors.shm_connector import (
- SharedMemoryConnector,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-@pytest.fixture()
-def connector():
- c = SharedMemoryConnector({"shm_threshold_bytes": 64})
- yield c
- c.close()
-
-
-# ── Key-based read (the fundamental SHM path) ────────────────────────
-
-
-class TestKeyBasedReadWrite:
- def test_put_then_get_by_key(self, connector):
- data = {"hello": "world", "n": 42}
- ok, size, meta = connector.put("s0", "s1", "test_key_1", data)
- assert ok
- assert size > 0
- assert "shm" in meta
- assert "test_key_1" in connector._pending_keys
-
- result = connector.get("s0", "s1", "test_key_1", metadata=None)
- assert result is not None
- obj, rsize = result
- assert obj == data
- assert rsize == size
- assert "test_key_1" not in connector._pending_keys
-
- def test_get_nonexistent_key_returns_none(self, connector):
- result = connector.get("s0", "s1", "no_such_key_xyz", metadata=None)
- assert result is None
-
- def test_rank_aware_keys_independent(self, connector):
- """Each TP rank writes/reads its own key — simulates homogeneous TP."""
- payloads = {}
- for rank in range(4):
- key = f"req1_s0_0_{rank}_{rank}"
- data = {"rank": rank, "values": list(range(rank, rank + 3))}
- ok, _, _ = connector.put("s0", "s1", key, data)
- assert ok
- payloads[rank] = data
-
- for rank in range(4):
- key = f"req1_s0_0_{rank}_{rank}"
- result = connector.get("s0", "s1", key, metadata=None)
- assert result is not None
- obj, _ = result
- assert obj == payloads[rank]
-
-
-# ── Metadata fallback behaviour ──────────────────────────────────────
-
-
-class TestMetadataFallback:
- def test_rdma_style_metadata_falls_back_to_key(self, connector):
- """source_host/source_port metadata should be ignored; key read used."""
- data = {"payload": True}
- connector.put("s0", "s1", "fb_key_1", data)
-
- rdma_meta = {"source_host": "10.0.0.1", "source_port": 12345}
- result = connector.get("s0", "s1", "fb_key_1", metadata=rdma_meta)
- assert result is not None
- obj, _ = result
- assert obj == data
-
- def test_non_dict_metadata_falls_back_to_key(self, connector):
- data = {"val": 99}
- connector.put("s0", "s1", "fb_key_2", data)
-
- result = connector.get("s0", "s1", "fb_key_2", metadata="not_a_dict")
- assert result is not None
- obj, _ = result
- assert obj == data
-
- def test_empty_dict_metadata_falls_back_to_key(self, connector):
- data = {"x": 1}
- connector.put("s0", "s1", "fb_key_3", data)
-
- result = connector.get("s0", "s1", "fb_key_3", metadata={})
- assert result is not None
- obj, _ = result
- assert obj == data
-
- def test_shm_handle_metadata_still_works(self, connector):
- """When metadata contains a proper 'shm' handle, use it directly."""
- data = {"direct": True}
- ok, size, meta = connector.put("s0", "s1", "shm_direct_1", data)
- assert ok
- result = connector.get("s0", "s1", "shm_direct_1", metadata=meta)
- assert result is not None
- obj, _ = result
- assert obj == data
-
- def test_metadata_keyed_by_request_id(self, connector):
- """Metadata wrapped as {get_key: actual_meta} should be unwrapped."""
- data = {"wrapped": True}
- ok, size, meta = connector.put("s0", "s1", "wrap_key", data)
- assert ok
- wrapped = {"wrap_key": meta}
- result = connector.get("s0", "s1", "wrap_key", metadata=wrapped)
- assert result is not None
- obj, _ = result
- assert obj == data
-
-
-# ── Heterogeneous TP multi-key read ──────────────────────────────────
-
-
-class TestHeteroTPMultiKey:
- def test_receiver_reads_multiple_sender_keys(self, connector):
- """Simulates from_tp=2 -> to_tp=1: receiver reads 2 keys and merges."""
- for sender_rank in range(2):
- key = f"req1_s0_0_{sender_rank}_0"
- data = {"sender": sender_rank, "shard": [sender_rank * 10]}
- connector.put("s0", "s1", key, data)
-
- shards = []
- for sender_rank in range(2):
- key = f"req1_s0_0_{sender_rank}_0"
- result = connector.get("s0", "s1", key, metadata=None)
- assert result is not None
- obj, _ = result
- shards.append(obj)
-
- assert len(shards) == 2
- assert shards[0]["sender"] == 0
- assert shards[1]["sender"] == 1
-
- def test_sender_writes_multiple_receiver_keys(self, connector):
- """Simulates from_tp=1 -> to_tp=2: sender writes 2 sliced keys."""
- for recv_rank in range(2):
- key = f"req1_s0_0_0_{recv_rank}"
- data = {"target": recv_rank, "slice": list(range(recv_rank, recv_rank + 2))}
- connector.put("s0", "s1", key, data)
-
- for recv_rank in range(2):
- key = f"req1_s0_0_0_{recv_rank}"
- result = connector.get("s0", "s1", key, metadata=None)
- assert result is not None
- obj, _ = result
- assert obj["target"] == recv_rank
-
-
-# ── Cleanup ──────────────────────────────────────────────────────────
-
-
-class TestCleanup:
- def test_cleanup_removes_unconsumed_segment(self, connector):
- data = {"leak": True}
- connector.put("s0", "s1", "cleanup_req_42", data)
- assert "cleanup_req_42" in connector._pending_keys
-
- connector.cleanup("req_42")
- assert "cleanup_req_42" not in connector._pending_keys
-
- result = connector.get("s0", "s1", "cleanup_req_42", metadata=None)
- assert result is None
-
- def test_cleanup_noop_for_consumed_segment(self, connector):
- data = {"consumed": True}
- connector.put("s0", "s1", "consumed_req_99", data)
- connector.get("s0", "s1", "consumed_req_99", metadata=None)
-
- connector.cleanup("req_99")
- assert "consumed_req_99" not in connector._pending_keys
-
- def test_close_cleans_all_pending(self, connector):
- for i in range(3):
- connector.put("s0", "s1", f"close_test_{i}", {"i": i})
-
- assert len(connector._pending_keys) == 3
- connector.close()
- assert len(connector._pending_keys) == 0
diff --git a/tests/distributed/omni_connectors/test_tp_rank_aware.py b/tests/distributed/omni_connectors/test_tp_rank_aware.py
deleted file mode 100644
index d4793479aaf..00000000000
--- a/tests/distributed/omni_connectors/test_tp_rank_aware.py
+++ /dev/null
@@ -1,716 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for rank-aware KV transfer (TP > 1) and heterogeneous TP support.
-
-Covers:
-- _build_rank_aware_send_keys / _build_rank_aware_recv_keys
-- _get_kv_source_ranks / _get_kv_target_ranks / get_kv_connector_key
-- update_sender_info storing base host/port
-- receive path constructing per-rank metadata for connector.get()
-- Mooncake connector _query_metadata_at and partial-metadata get() path
-"""
-
-from types import SimpleNamespace
-from unittest.mock import MagicMock, patch
-
-import pytest
-import torch
-
-from vllm_omni.distributed.omni_connectors.kv_transfer_manager import (
- KVCacheTransferData,
- OmniKVCacheConfig,
- OmniKVTransferManager,
-)
-from vllm_omni.distributed.omni_connectors.utils.initialization import (
- KV_RANK_PORT_STRIDE,
-)
-from vllm_omni.distributed.omni_connectors.utils.kv_utils import (
- KVTPTopology,
- build_rank_aware_recv_keys,
- build_rank_aware_send_keys,
- get_kv_connector_key,
- get_kv_source_ranks,
- get_kv_target_ranks,
- merge_received_rank_shards,
- slice_received_rank_shard,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def _make_manager(
- from_tp: int = 1,
- to_tp: int = 1,
- local_rank: int = 0,
- from_stage: str = "stage0",
- to_stage: str = "stage1",
- stage_id: str = "stage1",
- need_recv: bool = True,
- need_send: bool = False,
- recv_timeout: float = 0.3,
-) -> OmniKVTransferManager:
- """Build a manager with TP params injected, bypassing torch.distributed."""
- config = OmniKVCacheConfig(
- connector_config={"type": "mock"},
- from_stage=from_stage,
- to_stage=to_stage,
- stage_id=stage_id,
- need_recv_cache=need_recv,
- need_send_cache=need_send,
- recv_timeout=recv_timeout,
- from_tp=from_tp,
- to_tp=to_tp,
- )
- with (
- patch("vllm_omni.distributed.omni_connectors.kv_transfer_manager.get_local_tp_rank", return_value=local_rank),
- patch(
- "vllm_omni.distributed.omni_connectors.kv_transfer_manager.get_tp_world_size",
- return_value=max(from_tp, to_tp),
- ),
- ):
- mgr = OmniKVTransferManager(config)
- return mgr
-
-
-def _make_payload(head_values: list[float], request_id: str = "req-1") -> dict:
- head_tensor = torch.tensor(head_values, dtype=torch.float32).view(1, len(head_values), 1).repeat(2, 1, 1)
- return {
- "request_id": request_id,
- "layer_blocks": {
- "key_cache": [head_tensor.clone()],
- "value_cache": [(head_tensor + 100).clone()],
- },
- "block_ids": [0],
- "metadata": {"seq_len": 2},
- }
-
-
-def _make_transfer_data(head_values: list[float], request_id: str = "req-1") -> KVCacheTransferData:
- payload = _make_payload(head_values, request_id=request_id)
- return KVCacheTransferData(
- request_id=request_id,
- layer_blocks=payload["layer_blocks"],
- block_ids=payload["block_ids"],
- metadata=payload["metadata"],
- )
-
-
-# ── Key format helper ────────────────────────────────────────────────
-
-
-class TestConnectorKeyFormat:
- def test_key_format_matches_pr2677(self):
- key = get_kv_connector_key("req-1", "stage0", 0, 1, 2)
- assert key == "req-1_stage0_0_1_2"
-
- def test_key_fields_are_positional(self):
- key = get_kv_connector_key("r", "s", 5, 3, 7)
- parts = key.split("_")
- assert parts == ["r", "s", "5", "3", "7"]
-
-
-# ── Source / target rank mapping ─────────────────────────────────────
-
-
-class TestRankMapping:
- """Verify get_kv_target_ranks and get_kv_source_ranks for various TP configs."""
-
- def test_homogeneous_tp2_rank0(self):
- topo = KVTPTopology(source_tp_size=2, target_tp_size=2, local_rank=0)
- assert get_kv_target_ranks(topo) == [0]
- assert get_kv_source_ranks(topo) == [0]
-
- def test_homogeneous_tp2_rank1(self):
- topo = KVTPTopology(source_tp_size=2, target_tp_size=2, local_rank=1)
- assert get_kv_target_ranks(topo) == [1]
- assert get_kv_source_ranks(topo) == [1]
-
- def test_homogeneous_tp4_rank3(self):
- topo = KVTPTopology(source_tp_size=4, target_tp_size=4, local_rank=3)
- assert get_kv_target_ranks(topo) == [3]
- assert get_kv_source_ranks(topo) == [3]
-
- def test_sender_gt_receiver_tp4_to_tp2_rank0(self):
- """Receiver rank 0 should receive from sender rank 0 and 1."""
- topo = KVTPTopology(source_tp_size=4, target_tp_size=2, local_rank=0)
- assert get_kv_source_ranks(topo) == [0, 1]
-
- def test_sender_gt_receiver_tp4_to_tp2_rank1(self):
- """Receiver rank 1 should receive from sender rank 2 and 3."""
- topo = KVTPTopology(source_tp_size=4, target_tp_size=2, local_rank=1)
- assert get_kv_source_ranks(topo) == [2, 3]
-
- def test_sender_lt_receiver_tp2_to_tp4_rank0(self):
- """Sender rank 0 should send to receiver ranks 0 and 1."""
- topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=0)
- assert get_kv_target_ranks(topo) == [0, 1]
-
- def test_sender_lt_receiver_tp2_to_tp4_rank1(self):
- topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=1)
- assert get_kv_target_ranks(topo) == [2, 3]
-
- def test_receiver_lt_sender_source_ranks(self):
- """Receiver rank 0 with tp2_to_tp4 should source from rank 0 only."""
- topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=0)
- assert get_kv_source_ranks(topo) == [0]
-
- def test_invalid_topology_raises(self):
- topo = KVTPTopology(source_tp_size=3, target_tp_size=2, local_rank=0)
- with pytest.raises(ValueError, match="divisible"):
- get_kv_source_ranks(topo)
-
-
-# ── _build_rank_aware_recv_keys ──────────────────────────────────────
-
-
-class TestBuildRankAwareRecvKeys:
- """Verify build_rank_aware_recv_keys returns (key, from_rank) tuples."""
-
- def test_tp1_returns_legacy_key_with_none_rank(self):
- topo = KVTPTopology(source_tp_size=1, target_tp_size=1, local_rank=0)
- pairs = build_rank_aware_recv_keys("req-1", "stage0", "stage1", topo)
- assert len(pairs) == 1
- key, rank = pairs[0]
- assert key == "omni_stage0_to_stage1_kv_cache_req-1"
- assert rank is None
-
- def test_homogeneous_tp2_rank0(self):
- topo = KVTPTopology(source_tp_size=2, target_tp_size=2, local_rank=0)
- pairs = build_rank_aware_recv_keys("req-1", "stage0", "stage1", topo)
- assert len(pairs) == 1
- key, rank = pairs[0]
- assert key == "req-1_stage0_0_0_0"
- assert rank == 0
-
- def test_homogeneous_tp2_rank1(self):
- topo = KVTPTopology(source_tp_size=2, target_tp_size=2, local_rank=1)
- pairs = build_rank_aware_recv_keys("req-1", "stage0", "stage1", topo)
- assert len(pairs) == 1
- key, rank = pairs[0]
- assert key == "req-1_stage0_0_1_1"
- assert rank == 1
-
- def test_heterogeneous_tp4_to_tp2_rank0_gets_two_keys(self):
- """Receiver rank 0 with source_tp=4, target_tp=2 should get 2 keys."""
- topo = KVTPTopology(source_tp_size=4, target_tp_size=2, local_rank=0)
- pairs = build_rank_aware_recv_keys("req-1", "stage0", "stage1", topo)
- assert len(pairs) == 2
-
- keys = [k for k, _ in pairs]
- ranks = [r for _, r in pairs]
- assert keys == ["req-1_stage0_0_0_0", "req-1_stage0_0_1_0"]
- assert ranks == [0, 1]
-
- def test_heterogeneous_tp4_to_tp2_rank1_gets_two_keys(self):
- topo = KVTPTopology(source_tp_size=4, target_tp_size=2, local_rank=1)
- pairs = build_rank_aware_recv_keys("req-1", "stage0", "stage1", topo)
- assert len(pairs) == 2
-
- ranks = [r for _, r in pairs]
- assert ranks == [2, 3]
-
- def test_heterogeneous_tp2_to_tp4_rank2_gets_one_key(self):
- """Receiver rank 2 with source_tp=2, target_tp=4 should get 1 key from sender rank 1."""
- topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=2)
- pairs = build_rank_aware_recv_keys("req-1", "stage0", "stage1", topo)
- assert len(pairs) == 1
- key, rank = pairs[0]
- assert rank == 1
- assert key == "req-1_stage0_0_1_2"
-
-
-# ── _build_rank_aware_send_keys ──────────────────────────────────────
-
-
-class TestBuildRankAwareSendKeys:
- def test_tp1_returns_legacy_key(self):
- topo = KVTPTopology(source_tp_size=1, target_tp_size=1, local_rank=0)
- keys = build_rank_aware_send_keys("req-1", "stage0", "stage1", topo)
- assert keys == ["omni_stage0_to_stage1_kv_cache_req-1"]
-
- def test_homogeneous_tp2_rank0(self):
- topo = KVTPTopology(source_tp_size=2, target_tp_size=2, local_rank=0)
- keys = build_rank_aware_send_keys("req-1", "stage0", "stage1", topo)
- assert keys == ["req-1_stage0_0_0_0"]
-
- def test_sender_lt_receiver_tp2_to_tp4_rank0_sends_two_keys(self):
- topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=0)
- keys = build_rank_aware_send_keys("req-1", "stage0", "stage1", topo)
- assert len(keys) == 2
- assert keys == ["req-1_stage0_0_0_0", "req-1_stage0_0_0_1"]
-
-
-# ── update_sender_info stores base host/port ─────────────────────────
-
-
-class TestUpdateSenderInfoBase:
- def test_stores_base_host_and_port(self):
- mgr = _make_manager(from_tp=2, to_tp=2, local_rank=0)
- mgr.update_sender_info({"host": "10.0.0.1", "zmq_port": 50151})
-
- assert mgr._sender_base_host == "10.0.0.1"
- assert mgr._sender_base_zmq_port == 50151
-
- def test_rank1_adjusts_default_port_but_preserves_base(self):
- mgr = _make_manager(from_tp=2, to_tp=2, local_rank=1)
- mgr.update_sender_info({"host": "10.0.0.1", "zmq_port": 50151})
-
- assert mgr._sender_base_host == "10.0.0.1"
- assert mgr._sender_base_zmq_port == 50151
- expected_adjusted = 50151 + 1 * KV_RANK_PORT_STRIDE
- assert mgr.config.connector_config["sender_zmq_port"] == expected_adjusted
-
- def test_nested_sender_info_resolves_correctly(self):
- """Nested sender_info keyed by integer stage id should resolve
- using recv_stages (engine_input_source → recv_from)."""
- config = OmniKVCacheConfig(
- connector_config={"type": "mock"},
- stage_id=2,
- engine_input_source=[1],
- need_recv_cache=True,
- from_tp=2,
- to_tp=2,
- )
- with (
- patch("vllm_omni.distributed.omni_connectors.kv_transfer_manager.get_local_tp_rank", return_value=0),
- patch("vllm_omni.distributed.omni_connectors.kv_transfer_manager.get_tp_world_size", return_value=2),
- ):
- mgr = OmniKVTransferManager(config)
- mgr.update_sender_info(
- {
- 0: {"host": "10.0.0.1", "zmq_port": 50151},
- 1: {"host": "10.0.0.2", "zmq_port": 50152},
- }
- )
- assert mgr._sender_base_host == "10.0.0.2"
- assert mgr._sender_base_zmq_port == 50152
-
-
-# ── receive path constructs per-rank metadata ────────────────────────
-
-
-class TestReceiveConstructsMetadata:
- """Verify that receive_kv_cache_for_request passes metadata with
- correct (host, port) to connector.get() for heterogeneous TP."""
-
- def test_tp1_no_metadata_passed(self):
- """TP=1: connector.get() should be called WITHOUT metadata."""
- mgr = _make_manager(from_tp=1, to_tp=1, local_rank=0, recv_timeout=0.05)
- mgr.update_sender_info({"host": "10.0.0.1", "zmq_port": 50151})
-
- calls = []
-
- class _Connector:
- def get(self, from_stage, to_stage, get_key, metadata=None):
- calls.append({"key": get_key, "metadata": metadata})
- return None
-
- mgr._connector = _Connector()
- mgr.receive_kv_cache_for_request("req-1")
-
- assert len(calls) > 0
- assert calls[0]["metadata"] is None
-
- def test_homogeneous_tp2_rank0_passes_metadata(self):
- """TP=2 rank 0: metadata should point to sender rank 0's port."""
- mgr = _make_manager(from_tp=2, to_tp=2, local_rank=0, recv_timeout=0.05)
- mgr.update_sender_info({"host": "10.0.0.1", "zmq_port": 50151})
-
- calls = []
-
- class _Connector:
- def get(self, from_stage, to_stage, get_key, metadata=None):
- calls.append({"key": get_key, "metadata": metadata})
- return None
-
- mgr._connector = _Connector()
- mgr.receive_kv_cache_for_request("req-1")
-
- assert len(calls) > 0
- meta = calls[0]["metadata"]
- assert meta is not None
- assert meta["source_host"] == "10.0.0.1"
- assert meta["source_port"] == 50151 + 0 * KV_RANK_PORT_STRIDE
-
- def test_homogeneous_tp2_rank1_passes_metadata_with_offset(self):
- mgr = _make_manager(from_tp=2, to_tp=2, local_rank=1, recv_timeout=0.05)
- mgr.update_sender_info({"host": "10.0.0.1", "zmq_port": 50151})
-
- calls = []
-
- class _Connector:
- def get(self, from_stage, to_stage, get_key, metadata=None):
- calls.append({"key": get_key, "metadata": metadata})
- return None
-
- mgr._connector = _Connector()
- mgr.receive_kv_cache_for_request("req-1")
-
- meta = calls[0]["metadata"]
- assert meta["source_port"] == 50151 + 1 * KV_RANK_PORT_STRIDE
-
- def test_heterogeneous_tp4_to_tp2_rank0_multiple_metadata(self):
- """Receiver rank 0 with source_tp=4, target_tp=2 should call get() with
- two different metadata entries for sender ranks 0 and 1."""
- mgr = _make_manager(from_tp=4, to_tp=2, local_rank=0, recv_timeout=0.05)
- mgr.update_sender_info({"host": "10.0.0.1", "zmq_port": 50151})
-
- calls = []
-
- class _Connector:
- def get(self, from_stage, to_stage, get_key, metadata=None):
- calls.append({"key": get_key, "metadata": metadata})
- return None
-
- mgr._connector = _Connector()
- mgr.receive_kv_cache_for_request("req-1")
-
- seen_ports = set()
- for c in calls:
- if c["metadata"]:
- seen_ports.add(c["metadata"]["source_port"])
- expected_ports = {
- 50151 + 0 * KV_RANK_PORT_STRIDE,
- 50151 + 1 * KV_RANK_PORT_STRIDE,
- }
- assert expected_ports.issubset(seen_ports)
-
-
-# ── Mooncake connector _query_metadata_at ────────────────────────────
-
-
-class TestMooncakeQueryMetadataAt:
- """Test the connector's _query_metadata_at method and partial-metadata
- path in get() without requiring real RDMA/Mooncake."""
-
- def test_query_metadata_at_returns_full_metadata(self):
- """Mock the ZMQ interaction to verify _query_metadata_at returns
- complete metadata including data_size."""
-
- try:
- from vllm_omni.distributed.omni_connectors.connectors.mooncake_transfer_engine_connector import (
- MooncakeTransferEngineConnector,
- QueryResponse,
- )
- except ImportError:
- pytest.skip("Mooncake not available")
-
- import msgspec
-
- connector = MagicMock(spec=MooncakeTransferEngineConnector)
- connector._get_req_socket = MagicMock()
-
- mock_socket = MagicMock()
- resp = QueryResponse(request_id="test_key@s0_s1", data_size=4096, is_fast_path=True)
- mock_socket.recv.return_value = msgspec.msgpack.encode(resp)
- connector._get_req_socket.return_value = mock_socket
-
- result = MooncakeTransferEngineConnector._query_metadata_at(
- connector,
- "test_key@s0_s1",
- "10.0.0.1",
- 50151,
- )
-
- assert result is not None
- assert result["source_host"] == "10.0.0.1"
- assert result["source_port"] == 50151
- assert result["data_size"] == 4096
- assert result["is_fast_path"] is True
-
- def test_query_metadata_at_returns_none_on_not_found(self):
- try:
- from vllm_omni.distributed.omni_connectors.connectors.mooncake_transfer_engine_connector import (
- INFO_NOT_FOUND,
- MooncakeTransferEngineConnector,
- )
- except ImportError:
- pytest.skip("Mooncake not available")
-
- connector = MagicMock(spec=MooncakeTransferEngineConnector)
- mock_socket = MagicMock()
- mock_socket.recv.return_value = INFO_NOT_FOUND
- connector._get_req_socket.return_value = mock_socket
-
- result = MooncakeTransferEngineConnector._query_metadata_at(
- connector,
- "test_key@s0_s1",
- "10.0.0.1",
- 50151,
- )
- assert result is None
-
-
-# ── Merge / slice hooks ──────────────────────────────────────────────
-
-
-class TestMergeSliceHooks:
- def test_single_shard_passes_through(self):
- payload = {"layer_blocks": {"key_cache": [1]}}
- assert merge_received_rank_shards([payload]) == payload
-
- def test_default_merger_concats_head_dim(self):
- p0 = _make_payload([0.0])
- p1 = _make_payload([1.0])
- result = merge_received_rank_shards([p0, p1])
- key_cache = result["layer_blocks"]["key_cache"][0]
- value_cache = result["layer_blocks"]["value_cache"][0]
- assert key_cache.shape == (2, 2, 1)
- assert value_cache.shape == (2, 2, 1)
- assert torch.equal(key_cache[:, :, 0], torch.tensor([[0.0, 1.0], [0.0, 1.0]]))
- assert torch.equal(value_cache[:, :, 0], torch.tensor([[100.0, 101.0], [100.0, 101.0]]))
-
- def test_custom_merger_hook_called(self):
- merged = {"merged": True}
- assert merge_received_rank_shards([{}, {}], merger=lambda payloads: merged) == merged
-
- def test_slicer_hook_called(self):
- topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=0)
- sliced = {"sliced": True}
- assert slice_received_rank_shard({"full": True}, topo, slicer=lambda payload: sliced) == sliced
-
- def test_default_slicer_extracts_rank_local_heads(self):
- topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=1)
- payload = _make_payload([0.0, 1.0])
- result = slice_received_rank_shard(payload, topo)
- key_cache = result["layer_blocks"]["key_cache"][0]
- value_cache = result["layer_blocks"]["value_cache"][0]
- assert key_cache.shape == (2, 1, 1)
- assert value_cache.shape == (2, 1, 1)
- assert torch.equal(key_cache[:, :, 0], torch.tensor([[1.0], [1.0]]))
- assert torch.equal(value_cache[:, :, 0], torch.tensor([[101.0], [101.0]]))
-
- def test_presliced_payload_is_not_sliced_twice(self):
- topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=1)
- payload = _make_payload([1.0])
- payload["metadata"]["tp_head_slice"] = {"applied": True, "target_rank": 1}
- result = slice_received_rank_shard(payload, topo)
- assert result is payload
-
- def test_round_trip_merge_from_tp4_to_tp2(self):
- topo = KVTPTopology(source_tp_size=4, target_tp_size=2, local_rank=1)
- source_ranks = get_kv_source_ranks(topo)
- payloads = [_make_payload([float(rank)]) for rank in source_ranks]
- result = merge_received_rank_shards(payloads)
- key_cache = result["layer_blocks"]["key_cache"][0]
- assert torch.equal(key_cache[:, :, 0], torch.tensor([[2.0, 3.0], [2.0, 3.0]]))
-
- def test_round_trip_slice_from_tp2_to_tp4(self):
- topo = KVTPTopology(source_tp_size=2, target_tp_size=4, local_rank=3)
- payload = _make_payload([2.0, 3.0])
- result = slice_received_rank_shard(payload, topo)
- key_cache = result["layer_blocks"]["key_cache"][0]
- assert torch.equal(key_cache[:, :, 0], torch.tensor([[3.0], [3.0]]))
-
-
-class TestSenderSideSlicing:
- def test_transfer_slices_before_sending_to_multiple_targets(self):
- mgr = _make_manager(
- from_tp=2,
- to_tp=4,
- local_rank=0,
- need_send=True,
- need_recv=False,
- )
- sent_payloads = []
-
- class _Connector:
- supports_raw_data = False
-
- def put(self, from_stage, to_stage, put_key, data):
- sent_payloads.append((put_key, KVCacheTransferData.from_bytes(data)))
- return True, len(data), {}
-
- mgr._connector = _Connector()
- mgr._transfer_kv_cache(_make_transfer_data([0.0, 1.0]), "req-1")
-
- assert [key for key, _ in sent_payloads] == ["req-1_stage0_0_0_0", "req-1_stage0_0_0_1"]
- assert sent_payloads[0][1]["layer_blocks"]["key_cache"][0].shape == (2, 1, 1)
- assert sent_payloads[1][1]["layer_blocks"]["key_cache"][0].shape == (2, 1, 1)
- assert torch.equal(
- sent_payloads[0][1]["layer_blocks"]["key_cache"][0][:, :, 0],
- torch.tensor([[0.0], [0.0]]),
- )
- assert torch.equal(
- sent_payloads[1][1]["layer_blocks"]["key_cache"][0][:, :, 0],
- torch.tensor([[1.0], [1.0]]),
- )
- assert sent_payloads[0][1]["metadata"]["tp_head_slice"]["target_rank"] == 0
- assert sent_payloads[1][1]["metadata"]["tp_head_slice"]["target_rank"] == 1
-
-
-class _MockBroadcastGroup:
- def __init__(self, world_size: int, rank_in_group: int, broadcast_value=None, recv_value=None):
- self.world_size = world_size
- self.rank_in_group = rank_in_group
- self.broadcast_value = broadcast_value
- self.recv_value = recv_value
- self.broadcast_calls = []
- self.send_calls = []
- self.recv_calls = []
- self.shm_broadcaster = None
-
- def broadcast_object(self, obj=None, src: int = 0):
- self.broadcast_calls.append((obj, src))
- return self.broadcast_value if self.broadcast_value is not None else obj
-
- def send_object(self, obj, dst: int):
- self.send_calls.append((dst, obj))
-
- def recv_object(self, src: int):
- self.recv_calls.append(src)
- return self.recv_value
-
-
-class TestDistributedReceive:
- def test_tp_cfg_leader_receives_then_sends_branch_local_payloads(self):
- mgr = _make_manager(from_tp=2, to_tp=4, local_rank=0)
- req = SimpleNamespace(request_id="req-1", sampling_params=SimpleNamespace())
- world_group = _MockBroadcastGroup(world_size=4, rank_in_group=2)
- cfg_group = _MockBroadcastGroup(world_size=3, rank_in_group=0)
-
- def _receive(req_obj, cfg_func, target_device):
- req_obj.past_key_values = SimpleNamespace(key_cache=[torch.tensor([1.0])])
- req_obj.kv_metadata = {"source": "leader"}
- req_obj.sampling_params.past_key_values = req_obj.past_key_values
- req_obj.sampling_params.kv_metadata = req_obj.kv_metadata
- req_obj.sampling_params.cfg_text_past_key_values = SimpleNamespace(key_cache=[torch.tensor([2.0])])
- req_obj.sampling_params.cfg_text_kv_metadata = {"source": "cfg_text"}
- req_obj.sampling_params.cfg_img_past_key_values = SimpleNamespace(key_cache=[torch.tensor([3.0])])
- req_obj.sampling_params.cfg_img_kv_metadata = {"source": "cfg_img"}
- return True
-
- mgr.receive_multi_kv_cache = MagicMock(side_effect=_receive)
- with (
- patch("vllm_omni.diffusion.distributed.parallel_state.get_world_group", return_value=world_group),
- patch(
- "vllm_omni.diffusion.distributed.parallel_state.get_classifier_free_guidance_world_size",
- return_value=3,
- ),
- patch(
- "vllm_omni.diffusion.distributed.parallel_state.get_classifier_free_guidance_rank",
- return_value=0,
- ),
- patch("vllm_omni.diffusion.distributed.parallel_state.get_cfg_group", return_value=cfg_group),
- ):
- assert mgr.receive_multi_kv_cache_distributed(req) is True
-
- mgr.receive_multi_kv_cache.assert_called_once()
- assert mgr.receive_multi_kv_cache.call_args.args[2] == torch.device("cpu")
- assert req.kv_metadata == {"source": "leader"}
- assert cfg_group.broadcast_calls == []
- assert [dst for dst, _ in cfg_group.send_calls] == [1, 2]
- rank1_payload = cfg_group.send_calls[0][1]
- rank2_payload = cfg_group.send_calls[1][1]
- assert torch.equal(rank1_payload["past_key_values"].key_cache[0], torch.tensor([1.0]))
- assert torch.equal(rank2_payload["past_key_values"].key_cache[0], torch.tensor([1.0]))
- assert rank1_payload["sp.cfg_active_branch"] == "cfg_text"
- assert rank2_payload["sp.cfg_active_branch"] == "cfg_img"
- assert rank1_payload["sp.cfg_branch_roles"] == ["cfg_text", "cfg_img"]
- assert rank2_payload["sp.cfg_branch_roles"] == ["cfg_text", "cfg_img"]
- assert "sp.cfg_branch_past_key_values" in rank1_payload
- assert "sp.cfg_branch_past_key_values" in rank2_payload
- assert list(rank1_payload["sp.cfg_branch_past_key_values"].keys()) == ["cfg_text"]
- assert list(rank2_payload["sp.cfg_branch_past_key_values"].keys()) == ["cfg_img"]
- assert "sp.cfg_text_past_key_values" in rank1_payload
- assert "sp.cfg_img_past_key_values" not in rank1_payload
- assert "sp.cfg_img_past_key_values" in rank2_payload
- assert "sp.cfg_text_past_key_values" not in rank2_payload
-
- def test_tp_cfg_follower_receives_local_payload_without_receiving(self):
- mgr = _make_manager(from_tp=2, to_tp=4, local_rank=1)
- req = SimpleNamespace(request_id="req-1", sampling_params=SimpleNamespace())
- world_group = _MockBroadcastGroup(world_size=4, rank_in_group=3)
- cfg_payload = {
- "past_key_values": SimpleNamespace(key_cache=[torch.tensor([1.0])]),
- "kv_metadata": {"source": "main"},
- "sp.past_key_values": SimpleNamespace(key_cache=[torch.tensor([1.0])]),
- "sp.kv_metadata": {"source": "main"},
- "sp.cfg_active_branch": "cfg_text",
- "sp.cfg_branch_roles": ["cfg_text", "cfg_img"],
- "sp.cfg_branch_past_key_values": {
- "cfg_text": SimpleNamespace(key_cache=[torch.tensor([2.0])]),
- },
- "sp.cfg_branch_kv_metadata": {"cfg_text": {"source": "cfg-text"}},
- "sp.cfg_text_past_key_values": SimpleNamespace(key_cache=[torch.tensor([2.0])]),
- }
- cfg_group = _MockBroadcastGroup(world_size=2, rank_in_group=1, recv_value=cfg_payload)
-
- mgr.receive_multi_kv_cache = MagicMock(return_value=True)
- with (
- patch("vllm_omni.diffusion.distributed.parallel_state.get_world_group", return_value=world_group),
- patch(
- "vllm_omni.diffusion.distributed.parallel_state.get_classifier_free_guidance_world_size",
- return_value=2,
- ),
- patch(
- "vllm_omni.diffusion.distributed.parallel_state.get_classifier_free_guidance_rank",
- return_value=1,
- ),
- patch("vllm_omni.diffusion.distributed.parallel_state.get_cfg_group", return_value=cfg_group),
- ):
- assert mgr.receive_multi_kv_cache_distributed(req) is True
-
- mgr.receive_multi_kv_cache.assert_not_called()
- assert req.kv_metadata == {"source": "main"}
- assert torch.equal(req.past_key_values.key_cache[0], torch.tensor([1.0]))
- assert torch.equal(req.sampling_params.past_key_values.key_cache[0], torch.tensor([1.0]))
- assert req.sampling_params.cfg_active_branch == "cfg_text"
- assert req.sampling_params.cfg_branch_roles == ["cfg_text", "cfg_img"]
- assert torch.equal(
- req.sampling_params.cfg_branch_past_key_values["cfg_text"].key_cache[0],
- torch.tensor([2.0]),
- )
- assert req.sampling_params.cfg_branch_kv_metadata == {"cfg_text": {"source": "cfg-text"}}
- assert torch.equal(req.sampling_params.cfg_text_past_key_values.key_cache[0], torch.tensor([2.0]))
- assert cfg_group.broadcast_calls == []
- assert cfg_group.recv_calls == [0]
-
- def test_tp_without_cfg_keeps_independent_receive_path(self):
- mgr = _make_manager(from_tp=2, to_tp=2, local_rank=1)
- req = SimpleNamespace(request_id="req-1", sampling_params=SimpleNamespace())
- world_group = _MockBroadcastGroup(world_size=2, rank_in_group=1)
- mgr.receive_multi_kv_cache = MagicMock(return_value=True)
-
- with patch("vllm_omni.diffusion.distributed.parallel_state.get_world_group", return_value=world_group):
- assert mgr.receive_multi_kv_cache_distributed(req, target_device=torch.device("cpu")) is True
-
- mgr.receive_multi_kv_cache.assert_called_once_with(req, None, torch.device("cpu"))
-
-
-# ── TP auto-detect ───────────────────────────────────────────────────
-
-
-class TestAutoDetectTP:
- def test_auto_detect_when_config_defaults(self):
- """When config from_tp/to_tp == 1 (default), manager should auto-detect."""
- config = OmniKVCacheConfig(
- connector_config={"type": "mock"},
- from_stage="s0",
- stage_id="s1",
- need_recv_cache=True,
- )
- with (
- patch("vllm_omni.distributed.omni_connectors.kv_transfer_manager.get_local_tp_rank", return_value=0),
- patch("vllm_omni.distributed.omni_connectors.kv_transfer_manager.get_tp_world_size", return_value=4),
- ):
- mgr = OmniKVTransferManager(config)
- assert mgr._tp_topo.source_tp_size == 4
- assert mgr._tp_topo.target_tp_size == 4
-
- def test_explicit_tp_overrides_auto_detect(self):
- config = OmniKVCacheConfig(
- connector_config={"type": "mock"},
- from_stage="s0",
- stage_id="s1",
- need_recv_cache=True,
- from_tp=2,
- to_tp=4,
- )
- with (
- patch("vllm_omni.distributed.omni_connectors.kv_transfer_manager.get_local_tp_rank", return_value=0),
- patch("vllm_omni.distributed.omni_connectors.kv_transfer_manager.get_tp_world_size", return_value=8),
- ):
- mgr = OmniKVTransferManager(config)
- assert mgr._tp_topo.source_tp_size == 2
- assert mgr._tp_topo.target_tp_size == 4
diff --git a/tests/distributed/omni_coordinator/test_load_balancer.py b/tests/distributed/omni_coordinator/test_load_balancer.py
index 8350b33d396..c54d2489402 100644
--- a/tests/distributed/omni_coordinator/test_load_balancer.py
+++ b/tests/distributed/omni_coordinator/test_load_balancer.py
@@ -3,18 +3,12 @@
from time import time
-import pytest
-
from vllm_omni.distributed.omni_coordinator import (
InstanceInfo,
- LeastQueueLengthBalancer,
RandomBalancer,
- RoundRobinBalancer,
StageStatus,
)
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
def test_load_balancer_select_returns_valid_index():
"""Verify RandomBalancer.select() returns a valid index for instances."""
@@ -62,173 +56,3 @@ def test_load_balancer_select_returns_valid_index():
assert isinstance(index, int)
assert 0 <= index < len(instances)
-
-
-def test_round_robin_balancer_cycles_instances():
- now = time()
- instances = [
- InstanceInfo(
- input_addr="tcp://host:10001",
- output_addr="tcp://host:10001-out",
- stage_id=0,
- status=StageStatus.UP,
- queue_length=2,
- last_heartbeat=now,
- registered_at=now,
- ),
- InstanceInfo(
- input_addr="tcp://host:10002",
- output_addr="tcp://host:10002-out",
- stage_id=0,
- status=StageStatus.UP,
- queue_length=1,
- last_heartbeat=now,
- registered_at=now,
- ),
- InstanceInfo(
- input_addr="tcp://host:10003",
- output_addr="tcp://host:10003-out",
- stage_id=1,
- status=StageStatus.UP,
- queue_length=0,
- last_heartbeat=now,
- registered_at=now,
- ),
- ]
-
- balancer = RoundRobinBalancer()
- results = [balancer.select({}, instances) for _ in range(5)]
-
- # Default start_index=0 => 0,1,2,0,1
- assert results == [0, 1, 2, 0, 1]
-
-
-def test_round_robin_balancer_empty_instances_raises():
- with pytest.raises(ValueError, match="instances must not be empty"):
- RoundRobinBalancer().select({}, [])
-
-
-def test_round_robin_balancer_after_large_index_and_shorter_list():
- """Large start_index % len(instances) then counter wraps with shorter list."""
- now = time()
- two = [
- InstanceInfo(
- input_addr="tcp://host:10001",
- output_addr="tcp://host:10001-out",
- stage_id=0,
- status=StageStatus.UP,
- queue_length=0,
- last_heartbeat=now,
- registered_at=now,
- ),
- InstanceInfo(
- input_addr="tcp://host:10002",
- output_addr="tcp://host:10002-out",
- stage_id=0,
- status=StageStatus.UP,
- queue_length=0,
- last_heartbeat=now,
- registered_at=now,
- ),
- ]
- balancer = RoundRobinBalancer(start_index=7)
- assert balancer.select({}, two) == 1 # 7 % 2
- assert balancer.select({}, two) == 0 # next index wrapped to 0
-
-
-def test_least_queue_length_balancer_picks_min_queue():
- now = time()
- instances = [
- InstanceInfo(
- input_addr="tcp://host:10001",
- output_addr="tcp://host:10001-out",
- stage_id=0,
- status=StageStatus.UP,
- queue_length=2,
- last_heartbeat=now,
- registered_at=now,
- ),
- InstanceInfo(
- input_addr="tcp://host:10002",
- output_addr="tcp://host:10002-out",
- stage_id=0,
- status=StageStatus.UP,
- queue_length=0,
- last_heartbeat=now,
- registered_at=now,
- ),
- InstanceInfo(
- input_addr="tcp://host:10003",
- output_addr="tcp://host:10003-out",
- stage_id=1,
- status=StageStatus.UP,
- queue_length=5,
- last_heartbeat=now,
- registered_at=now,
- ),
- ]
-
- balancer = LeastQueueLengthBalancer()
- index = balancer.select({}, instances)
- assert index == 1
-
-
-def test_least_queue_length_balancer_empty_instances_raises():
- with pytest.raises(ValueError, match="instances must not be empty"):
- LeastQueueLengthBalancer().select({}, [])
-
-
-def test_least_queue_length_balancer_equal_queues_uses_choice(mocker):
- now = time()
- instances = [
- InstanceInfo(
- input_addr="tcp://host:10001",
- output_addr="tcp://host:10001-out",
- stage_id=0,
- status=StageStatus.UP,
- queue_length=3,
- last_heartbeat=now,
- registered_at=now,
- ),
- InstanceInfo(
- input_addr="tcp://host:10002",
- output_addr="tcp://host:10002-out",
- stage_id=0,
- status=StageStatus.UP,
- queue_length=3,
- last_heartbeat=now,
- registered_at=now,
- ),
- InstanceInfo(
- input_addr="tcp://host:10003",
- output_addr="tcp://host:10003-out",
- stage_id=1,
- status=StageStatus.UP,
- queue_length=3,
- last_heartbeat=now,
- registered_at=now,
- ),
- ]
- balancer = LeastQueueLengthBalancer()
- mocker.patch(
- "vllm_omni.distributed.omni_coordinator.load_balancer.random.choice",
- return_value=2,
- )
- assert balancer.select({}, instances) == 2
-
-
-def test_least_queue_length_balancer_negative_queue_raises():
- now = time()
- instances = [
- InstanceInfo(
- input_addr="tcp://host:10001",
- output_addr="tcp://host:10001-out",
- stage_id=0,
- status=StageStatus.UP,
- queue_length=-1,
- last_heartbeat=now,
- registered_at=now,
- ),
- ]
- with pytest.raises(ValueError, match="queue_length must be non-negative"):
- LeastQueueLengthBalancer().select({}, instances)
diff --git a/tests/distributed/omni_coordinator/test_omni_coord_client_for_hub.py b/tests/distributed/omni_coordinator/test_omni_coord_client_for_hub.py
index 2fbd7c85bf8..24b3319232d 100644
--- a/tests/distributed/omni_coordinator/test_omni_coord_client_for_hub.py
+++ b/tests/distributed/omni_coordinator/test_omni_coord_client_for_hub.py
@@ -12,8 +12,6 @@
OmniCoordClientForHub,
)
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
def _bind_pub() -> tuple[zmq.Context, zmq.Socket, str]:
ctx = zmq.Context.instance()
diff --git a/tests/distributed/omni_coordinator/test_omni_coord_client_for_stage.py b/tests/distributed/omni_coordinator/test_omni_coord_client_for_stage.py
index 0ba19c7fff7..b74a48f49cd 100644
--- a/tests/distributed/omni_coordinator/test_omni_coord_client_for_stage.py
+++ b/tests/distributed/omni_coordinator/test_omni_coord_client_for_stage.py
@@ -2,20 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
-import threading
-import pytest
import zmq
from vllm_omni.distributed.omni_coordinator import (
OmniCoordClientForStage,
StageStatus,
)
-from vllm_omni.distributed.omni_coordinator import (
- omni_coord_client_for_stage as stage_client_module,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
def _bind_router() -> tuple[zmq.Context, zmq.Socket, str]:
@@ -26,8 +19,7 @@ def _bind_router() -> tuple[zmq.Context, zmq.Socket, str]:
return ctx, router, endpoint
-def _recv_event(router: zmq.Socket, timeout_ms: int = 2000) -> dict:
- assert router.poll(timeout=timeout_ms) != 0, "Timed out waiting for coordinator event"
+def _recv_event(router: zmq.Socket) -> dict:
frames = router.recv_multipart()
# ROUTER adds identity frame; the last frame is the payload.
payload = frames[-1]
@@ -116,197 +108,3 @@ def test_stage_client_close_sends_down_status():
router.close(0)
ctx.term()
-
-
-def test_stage_client_reconnects_after_send_failure(mocker):
- """Verify send failure path invokes reconnect before retrying send."""
- ctx, router, endpoint = _bind_router()
-
- client = OmniCoordClientForStage(
- endpoint,
- "tcp://stage:reconnect-in",
- "tcp://stage:reconnect-out",
- 0,
- )
-
- # Discard initial registration event from the real socket.
- _recv_event(router)
-
- class _FlakySocket:
- def __init__(self):
- self.send_calls = 0
- self.closed = False
-
- def send(self, *_args, **_kwargs):
- self.send_calls += 1
- if self.send_calls == 1:
- raise RuntimeError("simulated send failure")
-
- def close(self, *_args, **_kwargs):
- self.closed = True
-
- flaky_socket = _FlakySocket()
- client._socket = flaky_socket
- client._reconnect = mocker.Mock(return_value=True)
-
- client.update_info(queue_length=1)
-
- client._reconnect.assert_called_once_with(max_retries=3)
- assert flaky_socket.send_calls == 2
-
- client.close()
- router.close(0)
- ctx.term()
-
-
-def test_stage_client_raises_when_reconnect_fails(mocker):
- """Verify send failure is propagated when reconnect cannot recover."""
- ctx, router, endpoint = _bind_router()
-
- client = OmniCoordClientForStage(
- endpoint,
- "tcp://stage:reconnect-fail-in",
- "tcp://stage:reconnect-fail-out",
- 0,
- )
-
- # Discard initial registration event from the real socket.
- _recv_event(router)
-
- class _AlwaysFailSocket:
- def send(self, *_args, **_kwargs):
- raise RuntimeError("simulated send failure")
-
- def close(self, *_args, **_kwargs):
- pass
-
- client._socket = _AlwaysFailSocket()
- client._reconnect = mocker.Mock(return_value=False)
-
- with pytest.raises(RuntimeError, match="simulated send failure"):
- client.update_info(queue_length=2)
-
- client._reconnect.assert_called_once_with(max_retries=3)
- client.close()
- router.close(0)
- ctx.term()
-
-
-def test_stage_client_close_handles_runtime_error_in_final_update(mocker):
- """Verify close() still releases resources when final update raises RuntimeError."""
- ctx, router, endpoint = _bind_router()
-
- client = OmniCoordClientForStage(
- endpoint,
- "tcp://stage:close-runtime-in",
- "tcp://stage:close-runtime-out",
- 0,
- )
-
- # Discard initial registration event from the real socket.
- _recv_event(router)
-
- client._send_event = mocker.Mock(side_effect=RuntimeError("simulated close-time failure"))
- client.close()
-
- assert client._closed
- assert client._socket.closed
-
- router.close(0)
- ctx.term()
-
-
-def test_reconnect_respects_retry_limit(monkeypatch):
- """Verify _reconnect stops after max_retries on repeated failures."""
- attempts = {"connect": 0}
-
- class _FailSocket:
- def close(self, *_args, **_kwargs):
- pass
-
- def connect(self, *_args, **_kwargs):
- attempts["connect"] += 1
- raise zmq.ZMQError("simulated reconnect failure")
-
- class _FailContext:
- def socket(self, *_args, **_kwargs):
- return _FailSocket()
-
- def term(self):
- pass
-
- client = OmniCoordClientForStage.__new__(OmniCoordClientForStage)
- client._closed = False
- client._coord_zmq_addr = "tcp://127.0.0.1:9999"
- client._stop_event = threading.Event()
- client._send_lock = threading.RLock()
- client._socket = _FailSocket()
- client._ctx = _FailContext()
-
- monkeypatch.setattr(stage_client_module.zmq, "Context", lambda: _FailContext())
- monkeypatch.setattr(stage_client_module.time, "sleep", lambda *_args, **_kwargs: None)
-
- assert client._reconnect(max_retries=3, retry_interval=5.0) is False
- assert attempts["connect"] == 3
-
-
-def test_heartbeat_loop_retries_after_transient_send_failure():
- """Verify heartbeat loop continues after one transient send failure."""
-
- class _FakeStopEvent:
- def __init__(self):
- self.wait_calls = 0
- self._set = False
-
- def wait(self, timeout=None):
- _ = timeout
- self.wait_calls += 1
- # Run two loop iterations, then stop.
- return self._set or self.wait_calls >= 3
-
- def is_set(self):
- return self._set
-
- def set(self):
- self._set = True
-
- client = OmniCoordClientForStage.__new__(OmniCoordClientForStage)
- client._closed = False
- client._heartbeat_interval = 0.0
- client._stop_event = _FakeStopEvent()
-
- calls = {"count": 0}
-
- def _fake_send(event_type):
- assert event_type == "heartbeat"
- calls["count"] += 1
- if calls["count"] == 1:
- raise RuntimeError("transient heartbeat failure")
-
- client._send_event = _fake_send
-
- client._heartbeat_loop()
-
- assert calls["count"] == 2
-
-
-def test_update_info_rejected_while_closing():
- """Verify update_info is rejected once client enters closing state."""
- ctx, router, endpoint = _bind_router()
-
- client = OmniCoordClientForStage(
- endpoint,
- "tcp://stage:closing-in",
- "tcp://stage:closing-out",
- 0,
- )
- _recv_event(router)
-
- client._closing = True
- with pytest.raises(RuntimeError, match="closing"):
- client.update_info(queue_length=3)
-
- client._closing = False
- client.close()
- router.close(0)
- ctx.term()
diff --git a/tests/distributed/omni_coordinator/test_omni_coordinator.py b/tests/distributed/omni_coordinator/test_omni_coordinator.py
index eff3d429e40..0c68e61bb11 100644
--- a/tests/distributed/omni_coordinator/test_omni_coordinator.py
+++ b/tests/distributed/omni_coordinator/test_omni_coordinator.py
@@ -4,7 +4,6 @@
import json
import time
-import pytest
import zmq
from vllm.v1.utils import get_engine_client_zmq_addr
@@ -14,8 +13,6 @@
StageStatus,
)
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
def _recv_instance_list(sub: zmq.Socket, timeout_ms: int = 2000) -> dict | None:
"""Receive InstanceList JSON from SUB socket. Returns None on timeout."""
@@ -41,74 +38,6 @@ def _wait_for_instance_list(
return None
-def _drain_sub_messages(sub: zmq.Socket, max_seconds: float = 0.4) -> None:
- """Drain queued SUB messages for a short window."""
- deadline = time.time() + max_seconds
- while time.time() < deadline:
- _recv_instance_list(sub, timeout_ms=50)
-
-
-def test_omni_coordinator_pub_coalescing_on_rapid_queue_updates():
- """Rapid updates should be coalesced into fewer PUB messages."""
- router_addr = get_engine_client_zmq_addr(
- local_only=False,
- host="127.0.0.1",
- port=0,
- )
- pub_addr = get_engine_client_zmq_addr(
- local_only=False,
- host="127.0.0.1",
- port=0,
- )
- coordinator = OmniCoordinator(
- router_zmq_addr=router_addr,
- pub_zmq_addr=pub_addr,
- heartbeat_timeout=1000.0,
- )
-
- sub_ctx = zmq.Context.instance()
- sub = sub_ctx.socket(zmq.SUB)
- sub.connect(pub_addr)
- sub.setsockopt(zmq.SUBSCRIBE, b"")
-
- time.sleep(0.3) # PUB/SUB slow-joiner
-
- client = OmniCoordClientForStage(
- router_addr,
- "tcp://stage:coalesce",
- "tcp://stage:coalesce-out",
- 0,
- )
-
- # Wait for initial registration broadcast and clear any queued messages.
- msg = _wait_for_instance_list(sub, expected_count=1)
- assert msg is not None
- _drain_sub_messages(sub)
-
- # Burst many queue updates in a short period.
- update_count = 80
- for i in range(update_count):
- client.update_info(queue_length=i)
-
- # With publish_min_interval=0.1s, received messages over ~1s should be
- # much smaller than update_count (coalescing effect).
- window_s = 1.1
- deadline = time.time() + window_s
- recv_count = 0
- while time.time() < deadline:
- if _recv_instance_list(sub, timeout_ms=100) is not None:
- recv_count += 1
-
- assert recv_count < update_count // 2, (
- f"expected coalesced PUB traffic, got {recv_count} for {update_count} updates"
- )
-
- client.close()
- coordinator.close()
- sub.close(0)
- sub_ctx.term()
-
-
def test_omni_coordinator_registration_broadcast():
"""Verify that after multiple OmniCoordClientForStage instances register,
OmniCoordinator publishes an InstanceList containing all registered instances.
diff --git a/tests/e2e/accuracy/conftest.py b/tests/e2e/accuracy/conftest.py
index 709fdf345ec..0a81b02075b 100644
--- a/tests/e2e/accuracy/conftest.py
+++ b/tests/e2e/accuracy/conftest.py
@@ -1,18 +1,16 @@
from __future__ import annotations
import os
+import shutil
import subprocess
from contextlib import contextmanager
from dataclasses import dataclass
-from io import BytesIO
from pathlib import Path
import pytest
-import requests
import torch
-from PIL import Image
-from tests.helpers.runtime import OmniServer, OmniServerParams
+from tests.conftest import OmniServer, OmniServerParams
def pytest_addoption(parser):
@@ -116,8 +114,8 @@ def generate_server(self):
params = self.generate_params
model = self.model_prefix + params.model
server_args = params.server_args or []
- if params.use_omni and params.stage_init_timeout is not None:
- server_args = ["--stage-init-timeout", str(params.stage_init_timeout), *server_args]
+ if params.use_omni:
+ server_args = ["--stage-init-timeout", "120", *server_args]
with OmniServer(
model,
server_args,
@@ -185,26 +183,16 @@ def accuracy_artifact_root() -> Path:
return root
-@pytest.fixture(scope="session")
-def qwen_bear_image(accuracy_artifact_root: Path) -> Image.Image:
- """Download the Qwen bear image from the URL and save it to the accuracy artifact root."""
- QWEN_BEAR_IMAGE_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/omni-assets/qwen-bear.png"
- response = requests.get(QWEN_BEAR_IMAGE_URL, timeout=60)
- response.raise_for_status()
- image = Image.open(BytesIO(response.content)).convert("RGB")
- image.save(accuracy_artifact_root / "qwen_bear.png")
- return image
+def reset_artifact_dir(path: Path) -> Path:
+ if path.exists():
+ shutil.rmtree(path)
+ path.mkdir(parents=True, exist_ok=True)
+ return path
-@pytest.fixture(scope="session")
-def rabbit_image(accuracy_artifact_root: Path) -> Image.Image:
- """Download the rabbit image from the URL and save it to the accuracy artifact root."""
- RABBIT_IMAGE_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/omni-assets/rabbit.png"
- response = requests.get(RABBIT_IMAGE_URL, timeout=60)
- response.raise_for_status()
- image = Image.open(BytesIO(response.content)).convert("RGB")
- image.save(accuracy_artifact_root / "rabbit.png")
- return image
+def infer_model_label(model: str) -> str:
+ label = Path(model.rstrip("/\\")).name or "model"
+ return "".join(char if char.isalnum() or char in {"-", "_"} else "_" for char in label)
def _build_accuracy_server_config(
@@ -238,7 +226,6 @@ def _build_accuracy_server_config(
server_args=generate_server_args,
env_dict={"CUDA_VISIBLE_DEVICES": shared_gpu},
use_omni=True,
- stage_init_timeout=300,
),
judge_params=OmniServerParams(
model=judge_model,
diff --git a/tests/e2e/accuracy/helpers.py b/tests/e2e/accuracy/helpers.py
deleted file mode 100644
index 382d3ea9b5f..00000000000
--- a/tests/e2e/accuracy/helpers.py
+++ /dev/null
@@ -1,115 +0,0 @@
-from pathlib import Path
-
-import numpy as np
-import pytest
-import torch
-from PIL import Image
-from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
-
-
-def reset_artifact_dir(path: Path) -> Path:
- import shutil
-
- if path.exists():
- shutil.rmtree(path)
- path.mkdir(parents=True, exist_ok=True)
- return path
-
-
-def infer_model_label(model: str) -> str:
- label = Path(model.rstrip("/\\")).name or "model"
- return "".join(char if char.isalnum() or char in {"-", "_"} else "_" for char in label)
-
-
-def model_output_dir(parent_dir: Path, model: str) -> Path:
- safe_model_name = model.split("/")[-1].replace(".", "_")
- path = parent_dir / safe_model_name
- path.mkdir(parents=True, exist_ok=True)
- return path
-
-
-def assert_similarity(
- *,
- model_name: str,
- vllm_image: Image.Image,
- diffusers_image: Image.Image,
- ssim_threshold: float,
- psnr_threshold: float,
- width: int | None = None,
- height: int | None = None,
- compare_mode: str = "RGB",
-) -> None:
- requested_size = (width, height) if width is not None and height is not None else None
- if requested_size is not None and diffusers_image.size != requested_size:
- pytest.skip(
- "Skipping as diffusers baseline output is corrupt and not comparable: "
- f"dimensions do not match requested size; requested={requested_size}, got={diffusers_image.size}."
- )
-
- assert vllm_image.size == diffusers_image.size, (
- f"Online and diffusers output sizes mismatch: online={vllm_image.size}, diffusers={diffusers_image.size}"
- )
-
- ssim_score, psnr_score = compute_image_ssim_psnr(
- prediction=vllm_image,
- reference=diffusers_image,
- compare_mode=compare_mode,
- )
- print(f"{model_name} similarity metrics:")
- print(f" SSIM: value={ssim_score:.6f}, threshold>={ssim_threshold:.6f}, range=[-1, 1], higher_is_better=True")
- print(
- f" PSNR: value={psnr_score:.6f} dB, threshold>={psnr_threshold:.6f} dB, range=[0, +inf), higher_is_better=True"
- )
-
- assert ssim_score >= ssim_threshold, (
- f"SSIM below threshold for {model_name}: got {ssim_score:.6f}, expected >= {ssim_threshold:.6f}."
- )
- assert psnr_score >= psnr_threshold, (
- f"PSNR below threshold for {model_name}: got {psnr_score:.6f}, expected >= {psnr_threshold:.6f}."
- )
-
-
-def assert_image_sequence_similarity(
- *,
- model_name: str,
- vllm_images: list[Image.Image],
- diffusers_images: list[Image.Image],
- ssim_threshold: float,
- psnr_threshold: float,
- compare_mode: str = "RGB",
-) -> None:
- assert len(vllm_images) == len(diffusers_images), (
- f"Output image count mismatch for {model_name}: online={len(vllm_images)}, diffusers={len(diffusers_images)}"
- )
- for index, (vllm_image, diffusers_image) in enumerate(zip(vllm_images, diffusers_images, strict=True), start=1):
- assert_similarity(
- model_name=f"{model_name}[layer={index}]",
- vllm_image=vllm_image,
- diffusers_image=diffusers_image,
- ssim_threshold=ssim_threshold,
- psnr_threshold=psnr_threshold,
- compare_mode=compare_mode,
- )
-
-
-def compute_image_ssim_psnr(
- *,
- prediction: Image.Image,
- reference: Image.Image,
- compare_mode: str = "RGB",
-) -> tuple[float, float]:
- pred_tensor = _pil_to_batched_tensor(prediction, compare_mode=compare_mode)
- ref_tensor = _pil_to_batched_tensor(reference, compare_mode=compare_mode)
-
- ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0)
- psnr_metric = PeakSignalNoiseRatio(data_range=1.0)
-
- ssim_value = float(ssim_metric(pred_tensor, ref_tensor).item())
- psnr_value = float(psnr_metric(pred_tensor, ref_tensor).item())
- return ssim_value, psnr_value
-
-
-def _pil_to_batched_tensor(image: Image.Image, *, compare_mode: str) -> torch.Tensor:
- array = np.asarray(image.convert(compare_mode), dtype=np.float32) / 255.0
- tensor = torch.from_numpy(array).permute(2, 0, 1).unsqueeze(0)
- return tensor
diff --git a/tests/e2e/accuracy/qwen3_omni/__init__.py b/tests/e2e/accuracy/qwen3_omni/__init__.py
deleted file mode 100644
index 79a31c4f100..00000000000
--- a/tests/e2e/accuracy/qwen3_omni/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-"""Qwen3-Omni accuracy benchmarks (Daily-Omni / Seed-TTS ``vllm bench serve --omni``)."""
diff --git a/tests/e2e/accuracy/qwen3_omni/qwen3_omni_acc_bench_core.py b/tests/e2e/accuracy/qwen3_omni/qwen3_omni_acc_bench_core.py
deleted file mode 100644
index 2ce86d504f0..00000000000
--- a/tests/e2e/accuracy/qwen3_omni/qwen3_omni_acc_bench_core.py
+++ /dev/null
@@ -1,201 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-"""Shared helpers for Qwen3-Omni Daily-Omni / Seed-TTS ``vllm bench serve --omni`` accuracy runs.
-
-Local dataset paths are **optional**. When ``VLLM_DAILY_OMNI_QA_JSON`` + ``VLLM_DAILY_OMNI_VIDEO_DIR``
-point to existing files, those are used with inline video. Otherwise the benchmark falls back to
-the HuggingFace dataset id (``liarliar/Daily-Omni``); QA loads via ``datasets``, and the first
-bench request that needs media downloads ``Videos.tar`` from the Hub when no video dir is set.
-
-Similarly for Seed-TTS: a local directory wins; otherwise ``--dataset-path`` uses the Hub id
-and ``huggingface_hub.snapshot_download`` inside ``resolve_seed_tts_root`` pulls files on demand.
-
-Use :func:`build_acc_benchmark_cli_argv` to assemble ``argv`` for a live Omni server (host/port/model
-and small bench defaults) before ``parse_args`` / ``run_acc_benchmark`` in the accuracy driver.
-"""
-
-from __future__ import annotations
-
-import json
-import os
-import shutil
-import subprocess
-from pathlib import Path
-from typing import Any, Protocol
-
-DEFAULT_DAILY_OMNI_HF_REPO = "liarliar/Daily-Omni"
-DEFAULT_SEED_TTS_HF_REPO = "zhaochenyang20/seed-tts-eval"
-
-
-class OmniBenchServerEndpoint(Protocol):
- """Anything with ``host`` / ``port`` / ``model`` (e.g. :class:`tests.conftest.OmniServer`)."""
-
- host: str
- port: int
- model: str
-
-
-def build_acc_benchmark_cli_argv(
- server: OmniBenchServerEndpoint,
- *,
- skip_seed: bool,
- skip_daily: bool,
- num_prompts: int | None = None,
- max_concurrency: int | None = None,
-) -> list[str]:
- """Prefix argv for :func:`run_qwen_omni_acc_benchmark.parse_acc_benchmark_args` + :func:`run_acc_benchmark`.
-
- Wires ``--host`` / ``--port`` / ``--model`` to a running Omni OpenAI server, sets small
- ``--num-prompts`` / ``--max-concurrency`` defaults (overridable via ``ACC_BENCH_NUM_PROMPTS`` /
- ``ACC_BENCH_MAX_CONCURRENCY``), and when Daily-Omni runs adds ``--daily-omni-repo`` so Hub QA
- matches :func:`daily_omni_bench_argv` once ``run_acc_benchmark`` mirrors ``--daily-omni-repo`` into env.
- """
- n_prompts = int(os.environ.get("ACC_BENCH_NUM_PROMPTS", "2000")) if num_prompts is None else int(num_prompts)
- n_conc = int(os.environ.get("ACC_BENCH_MAX_CONCURRENCY", "10")) if max_concurrency is None else int(max_concurrency)
- argv = [
- "--host",
- server.host,
- "--port",
- str(server.port),
- "--model",
- server.model,
- "--num-prompts",
- str(n_prompts),
- "--max-concurrency",
- str(n_conc),
- ]
- if not skip_daily:
- repo = os.environ.get("VLLM_DAILY_OMNI_REPO", DEFAULT_DAILY_OMNI_HF_REPO).strip() or DEFAULT_DAILY_OMNI_HF_REPO
- argv.extend(["--daily-omni-repo", repo])
- if skip_seed:
- argv.append("--skip-seed-tts")
- if skip_daily:
- argv.append("--skip-daily-omni")
- return argv
-
-
-def daily_omni_bench_argv() -> list[str]:
- """CLI args for Daily-Omni (after ``vllm bench serve --omni``)."""
- qa = os.environ.get("VLLM_DAILY_OMNI_QA_JSON", "").strip()
- vd = os.environ.get("VLLM_DAILY_OMNI_VIDEO_DIR", "").strip()
- if qa and vd:
- qap = Path(qa).expanduser()
- vdp = Path(vd).expanduser()
- if qap.is_file() and vdp.is_dir():
- return [
- "--dataset-name",
- "daily-omni",
- "--daily-omni-qa-json",
- str(qap.resolve()),
- "--daily-omni-video-dir",
- str(vdp.resolve()),
- "--daily-omni-inline-local-video",
- ]
- repo = os.environ.get("VLLM_DAILY_OMNI_REPO", DEFAULT_DAILY_OMNI_HF_REPO).strip() or DEFAULT_DAILY_OMNI_HF_REPO
- return [
- "--dataset-name",
- "daily-omni",
- "--dataset-path",
- repo,
- ]
-
-
-def seed_tts_bench_argv(*, locale: str = "en") -> list[str]:
- """CLI args for Seed-TTS (after ``vllm bench serve --omni``)."""
- dp = os.environ.get("VLLM_SEED_TTS_DATASET_PATH", "").strip()
- if dp:
- p = Path(dp).expanduser()
- # Preserve Hugging Face repo ids verbatim. Only canonicalize to an
- # absolute path when the value actually exists as a local directory.
- dataset_path = str(p.resolve()) if p.exists() and p.is_dir() else dp
- else:
- dataset_path = (
- os.environ.get("VLLM_SEED_TTS_REPO", DEFAULT_SEED_TTS_HF_REPO).strip() or DEFAULT_SEED_TTS_HF_REPO
- )
- out = ["--dataset-name", "seed-tts", "--dataset-path", dataset_path]
- root = os.environ.get("SEED_TTS_ROOT", "").strip()
- if root:
- out.extend(["--seed-tts-root", str(Path(root).expanduser().resolve())])
- out.extend(["--seed-tts-locale", locale])
- return out
-
-
-def find_vllm_cli() -> str:
- exe = shutil.which("vllm")
- if not exe:
- raise FileNotFoundError("Could not find `vllm` on PATH (install vLLM-Omni with CLI entrypoints).")
- return exe
-
-
-def run_vllm_bench_subprocess(vllm: str, argv: list[str], *, extra_env: dict[str, str] | None = None) -> None:
- env = os.environ.copy()
- if extra_env:
- env.update(extra_env)
- subprocess.run([vllm, *argv], env=env, check=True)
-
-
-def load_benchmark_result(path: Path) -> dict[str, Any]:
- with path.open(encoding="utf-8") as f:
- return json.load(f)
-
-
-def build_serve_common_argv(
- *,
- host: str,
- port: int,
- model: str,
- num_prompts: int,
- max_concurrency: int,
- num_warmups: int,
- percentile_metrics: str,
- result_dir: Path,
- result_filename: str,
- ready_check_timeout_sec: int | None = None,
-) -> list[str]:
- out = [
- "bench",
- "serve",
- "--omni",
- "--host",
- host,
- "--port",
- str(port),
- "--model",
- model,
- "--endpoint",
- "/v1/chat/completions",
- "--backend",
- "openai-chat-omni",
- "--request-rate",
- "inf",
- "--num-prompts",
- str(num_prompts),
- "--max-concurrency",
- str(max_concurrency),
- "--no-oversample",
- "--num-warmups",
- str(num_warmups),
- "--percentile-metrics",
- percentile_metrics,
- "--save-result",
- "--result-dir",
- str(result_dir),
- "--result-filename",
- result_filename,
- ]
- if ready_check_timeout_sec is not None:
- out.extend(["--ready-check-timeout-sec", str(int(ready_check_timeout_sec))])
- return out
-
-
-def assert_daily_omni_scored(result: dict[str, Any]) -> None:
- acc = result.get("daily_omni_accuracy")
- assert acc is not None, "daily_omni_accuracy missing — wrong dataset or benchmark wiring"
- assert int(result.get("daily_omni_evaluated_ok", 0) or 0) > 0, "no successful MCQ rows (daily_omni_evaluated_ok==0)"
-
-
-def assert_seed_tts_scored(result: dict[str, Any]) -> None:
- err = result.get("seed_tts_eval_setup_error")
- assert not err, f"Seed-TTS eval deps/setup failed: {err}"
- assert int(result.get("seed_tts_content_evaluated", 0) or 0) > 0, (
- "seed_tts_content_evaluated==0 — enable WER eval and check PCM capture / modalities"
- )
diff --git a/tests/e2e/accuracy/qwen3_omni/run_qwen_omni_acc_benchmark.py b/tests/e2e/accuracy/qwen3_omni/run_qwen_omni_acc_benchmark.py
deleted file mode 100644
index 7fb71b28d77..00000000000
--- a/tests/e2e/accuracy/qwen3_omni/run_qwen_omni_acc_benchmark.py
+++ /dev/null
@@ -1,428 +0,0 @@
-#!/usr/bin/env python3
-# SPDX-License-Identifier: Apache-2.0
-"""Accuracy (and light perf) checks for Qwen3-Omni via ``vllm bench serve --omni``.
-
-The standalone CLI uses small ``--num-prompts`` / ``--max-concurrency`` defaults suitable for
-L4-style smoke runs against an already-running server. The pytest wrappers in
-``tests/e2e/accuracy/qwen3_omni/test_qwen3_omni.py`` may still require larger GPUs (currently
-H100 / MI325) because they launch the live Omni server inside the test.
-
-1. **Daily-Omni** — MCQ accuracy fields in the saved JSON (``daily_omni_accuracy``, …); by default the
- run **fails** if accuracy is strictly below **0.69** (``--min-daily-omni-accuracy`` / ``ACC_BENCH_MIN_DAILY_OMNI_ACCURACY``).
-2. **Seed-TTS** — ``seed-tts-eval``-style metrics when ``--seed-tts-wer-eval`` is used
- (WER / SIM / UTMOS keys from :func:`compute_seed_tts_wer_metrics`).
-
-Prerequisites
--------------
-* A running Omni OpenAI-compatible server (same machine or reachable host), e.g.::
-
- vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8000
-
- On L4 you may need a smaller checkpoint, quantization, or tighter engine flags; this script
- only drives the **client** benchmark.
-
-* ``vllm`` CLI from **vLLM-Omni** (so ``bench serve`` registers ``daily-omni`` / ``seed-tts``).
-
-* **Daily-Omni** — if local ``qa.json`` + ``Videos/`` are not both provided (CLI or matching env),
- the client passes ``--dataset-path`` with a Hub id (default ``liarliar/Daily-Omni``). The **child**
- ``vllm bench serve`` process then loads QA via ``datasets.load_dataset`` (needs ``pip install datasets``,
- network or HF cache). Without ``--daily-omni-video-dir``, the benchmark **lazily** downloads and
- extracts ``Videos.tar`` from the Hub (``huggingface_hub``) on first multimodal request. Override
- the dataset repo with ``--daily-omni-repo`` or ``VLLM_DAILY_OMNI_REPO``; override the tar repo
- with ``VLLM_DAILY_OMNI_MEDIA_REPO`` if needed.
-
-* **Seed-TTS** optional extras for WER/SIM/UTMOS::
-
- pip install 'vllm-omni[seed-tts-eval]'
-
-Examples
---------
-Pytest (same checks; needs a running server)::
-
- pytest -sv tests/e2e/accuracy/qwen3_omni/test_qwen3_omni.py
-
-Smoke on localhost (server already up)::
-
- python tests/e2e/accuracy/qwen3_omni/run_qwen_omni_acc_benchmark.py \\
- --model Qwen/Qwen3-Omni-30B-A3B-Instruct \\
- --daily-omni-qa-json ./qa.json \\
- --daily-omni-video-dir ./Videos \\
- --seed-tts-dataset-path ./seed-tts-eval
-
-Skip one suite, tighten gates::
-
- python tests/e2e/accuracy/qwen3_omni/run_qwen_omni_acc_benchmark.py \\
- --skip-daily-omni \\
- --max-seed-tts-mean-wer 0.35 \\
- --min-seed-tts-mean-sim 0.75
-"""
-
-from __future__ import annotations
-
-import argparse
-import contextlib
-import json
-import os
-import sys
-from datetime import datetime
-from pathlib import Path
-from typing import Any
-
-from tests.e2e.accuracy.qwen3_omni.qwen3_omni_acc_bench_core import (
- build_serve_common_argv,
- daily_omni_bench_argv,
- find_vllm_cli,
- load_benchmark_result,
- run_vllm_bench_subprocess,
- seed_tts_bench_argv,
-)
-
-_REPO_ROOT = Path(__file__).resolve().parents[4]
-
-
-def _repo_root() -> Path:
- return _REPO_ROOT
-
-
-def _default_result_dir() -> Path:
- return Path(__file__).resolve().parent / "results" / "qwen_omni_acc"
-
-
-def _validate_daily_omni(result: dict[str, Any], *, min_accuracy: float | None) -> list[str]:
- errs: list[str] = []
- acc = result.get("daily_omni_accuracy")
- if acc is None:
- errs.append("Missing daily_omni_accuracy (wrong dataset or no gold-evaluated rows).")
- return errs
- ev = int(result.get("daily_omni_evaluated_ok", 0) or 0)
- if ev <= 0:
- errs.append("daily_omni_evaluated_ok is 0; no successful MCQ rows to score.")
- if min_accuracy is not None and float(acc) + 1e-12 < float(min_accuracy):
- errs.append(f"daily_omni_accuracy={acc:.6f} < --min-daily-omni-accuracy={min_accuracy}")
- return errs
-
-
-def _validate_seed_tts(
- result: dict[str, Any],
- *,
- max_mean_wer: float | None,
- min_mean_sim: float | None,
- min_mean_utmos: float | None,
-) -> list[str]:
- errs: list[str] = []
- setup = result.get("seed_tts_eval_setup_error")
- if setup:
- errs.append(f"Seed-TTS eval setup failed: {setup}")
- return errs
- n = int(result.get("seed_tts_content_evaluated", 0) or 0)
- if n <= 0:
- errs.append("seed_tts_content_evaluated is 0 (enable --seed-tts-wer-eval and check PCM capture).")
- mean_wer = result.get("seed_tts_content_error_mean")
- if mean_wer is not None and max_mean_wer is not None and float(mean_wer) > float(max_mean_wer) + 1e-12:
- errs.append(f"seed_tts_content_error_mean (WER)={mean_wer:.6f} > --max-seed-tts-mean-wer={max_mean_wer}")
- sim_m = result.get("seed_tts_sim_mean")
- if sim_m is not None and min_mean_sim is not None and float(sim_m) + 1e-12 < float(min_mean_sim):
- errs.append(f"seed_tts_sim_mean={sim_m:.6f} < --min-seed-tts-mean-sim={min_mean_sim}")
- ut_m = result.get("seed_tts_utmos_mean")
- if ut_m is not None and min_mean_utmos is not None and float(ut_m) + 1e-12 < float(min_mean_utmos):
- errs.append(f"seed_tts_utmos_mean={ut_m:.6f} < --min-seed-tts-mean-utmos={min_mean_utmos}")
- return errs
-
-
-def sync_dataset_env_from_ns(ns: argparse.Namespace) -> None:
- """Mirror CLI path flags into env vars read by ``daily_omni_bench_argv`` / ``seed_tts_bench_argv``."""
- repo = getattr(ns, "daily_omni_repo", None)
- if repo is not None and str(repo).strip():
- os.environ["VLLM_DAILY_OMNI_REPO"] = str(repo).strip()
- if ns.daily_omni_qa_json is not None:
- os.environ["VLLM_DAILY_OMNI_QA_JSON"] = str(Path(ns.daily_omni_qa_json).expanduser().resolve())
- if ns.daily_omni_video_dir is not None:
- os.environ["VLLM_DAILY_OMNI_VIDEO_DIR"] = str(Path(ns.daily_omni_video_dir).expanduser().resolve())
- if ns.seed_tts_dataset_path is not None:
- # ``--seed-tts-dataset-path`` accepts either a local directory or a
- # Hugging Face repo id. Only resolve to an absolute filesystem path
- # when the value actually exists locally; otherwise preserve the repo
- # string verbatim so downstream code can pass it to snapshot_download.
- raw = str(ns.seed_tts_dataset_path).strip()
- p = Path(raw).expanduser()
- os.environ["VLLM_SEED_TTS_DATASET_PATH"] = str(p.resolve()) if p.exists() and p.is_dir() else raw
- if ns.seed_tts_root is not None:
- os.environ["SEED_TTS_ROOT"] = str(Path(ns.seed_tts_root).expanduser().resolve())
-
-
-@contextlib.contextmanager
-def _preserve_benchmark_dataset_env() -> Any:
- """Save/restore dataset-related env vars so benchmark tests don't leak state."""
- keys = (
- "VLLM_DAILY_OMNI_REPO",
- "VLLM_DAILY_OMNI_QA_JSON",
- "VLLM_DAILY_OMNI_VIDEO_DIR",
- "VLLM_SEED_TTS_DATASET_PATH",
- "SEED_TTS_ROOT",
- )
- original = {k: os.environ.get(k) for k in keys}
- try:
- yield
- finally:
- for key, value in original.items():
- if value is None:
- os.environ.pop(key, None)
- else:
- os.environ[key] = value
-
-
-def _build_common_args(ns: argparse.Namespace, *, result_filename: str) -> list[str]:
- return build_serve_common_argv(
- host=ns.host,
- port=ns.port,
- model=ns.model,
- num_prompts=ns.num_prompts,
- max_concurrency=ns.max_concurrency,
- num_warmups=ns.num_warmups,
- percentile_metrics=ns.percentile_metrics,
- result_dir=ns.result_dir,
- result_filename=result_filename,
- ready_check_timeout_sec=ns.ready_check_timeout_sec,
- )
-
-
-def run_daily_omni(ns: argparse.Namespace, vllm: str) -> Path:
- ns.result_dir.mkdir(parents=True, exist_ok=True)
- tag = datetime.now().strftime("%Y%m%d-%H%M%S")
- result_filename = f"qwen_omni_acc_daily_omni_{tag}.json"
- extra = json.loads(ns.daily_extra_body_json)
- argv = (
- _build_common_args(ns, result_filename=result_filename)
- + daily_omni_bench_argv()
- + [
- "--daily-omni-input-mode",
- ns.daily_omni_input_mode,
- "--extra-body",
- json.dumps(extra, ensure_ascii=False, separators=(",", ":")),
- ]
- )
- if ns.daily_omni_save_eval_items:
- argv.append("--daily-omni-save-eval-items")
- print("\n$", vllm, *argv, "\n", flush=True)
- run_vllm_bench_subprocess(vllm, argv)
- out = Path(ns.result_dir) / result_filename
- if not out.is_file():
- raise FileNotFoundError(f"Expected result JSON at {out}")
- return out
-
-
-def run_seed_tts(ns: argparse.Namespace, vllm: str) -> Path:
- ns.result_dir.mkdir(parents=True, exist_ok=True)
- tag = datetime.now().strftime("%Y%m%d-%H%M%S")
- result_filename = f"qwen_omni_acc_seed_tts_{tag}.json"
- extra = json.loads(ns.seed_extra_body_json)
- argv = (
- _build_common_args(ns, result_filename=result_filename)
- + seed_tts_bench_argv(locale=ns.seed_tts_locale)
- + [
- "--seed-tts-wer-eval",
- "--extra-body",
- json.dumps(extra, ensure_ascii=False, separators=(",", ":")),
- ]
- )
- if ns.seed_tts_wer_save_items:
- argv.append("--seed-tts-wer-save-items")
- if ns.seed_tts_file_ref_audio:
- argv.append("--seed-tts-file-ref-audio")
- extra_env: dict[str, str] = {"SEED_TTS_WER_EVAL": "1"}
- if ns.seed_tts_eval_device:
- extra_env["SEED_TTS_EVAL_DEVICE"] = ns.seed_tts_eval_device
- print("\n$", vllm, *argv, "\n", flush=True)
- run_vllm_bench_subprocess(vllm, argv, extra_env=extra_env)
- out = Path(ns.result_dir) / result_filename
- if not out.is_file():
- raise FileNotFoundError(f"Expected result JSON at {out}")
- return out
-
-
-def build_arg_parser() -> argparse.ArgumentParser:
- p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
- p.add_argument("--host", default=os.environ.get("ACC_BENCH_HOST", "127.0.0.1"))
- p.add_argument("--port", type=int, default=int(os.environ.get("ACC_BENCH_PORT", "8000")))
- p.add_argument(
- "--model",
- default=os.environ.get(
- "ACC_BENCH_MODEL",
- "Qwen/Qwen3-Omni-30B-A3B-Instruct",
- ),
- help="Model id passed to ``vllm bench serve`` (must match the running server).",
- )
- p.add_argument("--num-prompts", type=int, default=int(os.environ.get("ACC_BENCH_NUM_PROMPTS", "2000")))
- p.add_argument("--max-concurrency", type=int, default=int(os.environ.get("ACC_BENCH_MAX_CONCURRENCY", "10")))
- p.add_argument("--num-warmups", type=int, default=int(os.environ.get("ACC_BENCH_NUM_WARMUPS", "0")))
- p.add_argument(
- "--percentile-metrics",
- default=os.environ.get("ACC_BENCH_PERCENTILE_METRICS", "ttft,tpot,itl,e2el,audio_ttfp,audio_rtf"),
- )
- p.add_argument(
- "--ready-check-timeout-sec",
- type=int,
- default=None,
- help="If set, forwarded to ``vllm bench serve`` (probe first request until success). "
- "Omit to use upstream default (typically skip).",
- )
- p.add_argument(
- "--result-dir",
- type=Path,
- default=Path(os.environ.get("ACC_BENCH_RESULT_DIR", str(_default_result_dir()))),
- )
-
- p.add_argument("--skip-daily-omni", action="store_true")
- p.add_argument("--skip-seed-tts", action="store_true")
-
- p.add_argument(
- "--daily-omni-repo",
- type=str,
- default=None,
- help="Hugging Face dataset id for Daily-Omni Hub mode (sets VLLM_DAILY_OMNI_REPO). "
- "Ignored when local qa.json + video dir are used.",
- )
- p.add_argument(
- "--daily-omni-qa-json",
- type=Path,
- default=None,
- help="Optional local qa.json; if omitted with no env, uses Hub liarliar/Daily-Omni.",
- )
- p.add_argument(
- "--daily-omni-video-dir",
- type=Path,
- default=None,
- help="Optional local Videos root; if omitted, media is fetched lazily from Hub Videos.tar.",
- )
- p.add_argument("--daily-omni-input-mode", choices=("all", "visual", "audio"), default="all")
- p.add_argument(
- "--daily-extra-body-json",
- default='{"modalities":["text"]}',
- help="JSON merged into each chat request for Daily-Omni (default matches common L4 / text-output runs).",
- )
- p.add_argument(
- "--daily-omni-save-eval-items",
- action="store_true",
- help="Sets env via CLI flag so per-item rows are stored in the result JSON.",
- )
- p.add_argument(
- "--min-daily-omni-accuracy",
- type=float,
- default=float((os.environ.get("ACC_BENCH_MIN_DAILY_OMNI_ACCURACY") or "0.69").strip() or "0.69"),
- help="Fail when daily_omni_accuracy is strictly below this threshold (0–1). "
- "Default baseline 0.69; override with env ACC_BENCH_MIN_DAILY_OMNI_ACCURACY or pass 0 to disable the floor.",
- )
-
- p.add_argument(
- "--seed-tts-dataset-path",
- type=str,
- default=None,
- help="Optional local root or Hub id; if omitted, uses zhaochenyang20/seed-tts-eval.",
- )
- p.add_argument("--seed-tts-root", type=Path, default=None, help="Optional override for Seed-TTS filesystem root.")
- p.add_argument("--seed-tts-locale", choices=("en", "zh"), default="en")
- p.add_argument(
- "--seed-extra-body-json",
- default='{"modalities":["text","audio"]}',
- help="JSON for Seed-TTS chat requests (must include audio for synthesis + PCM capture).",
- )
- p.add_argument("--seed-tts-wer-save-items", action="store_true")
- p.add_argument(
- "--seed-tts-file-ref-audio",
- action="store_true",
- help="Use file:// ref_audio; server must allow local media paths.",
- )
- p.add_argument(
- "--seed-tts-eval-device",
- default=os.environ.get("SEED_TTS_EVAL_DEVICE"),
- help="Sets SEED_TTS_EVAL_DEVICE for Whisper / WavLM / UTMOS (e.g. cuda:0).",
- )
- p.add_argument(
- "--max-seed-tts-mean-wer",
- type=float,
- default=0.5,
- help="If set, fail when seed_tts_content_error_mean is strictly above this value.",
- )
- p.add_argument(
- "--min-seed-tts-mean-sim",
- type=float,
- default=None,
- help="If set, fail when seed_tts_sim_mean is strictly below this value.",
- )
- p.add_argument(
- "--min-seed-tts-mean-utmos",
- type=float,
- default=None,
- help="If set, fail when seed_tts_utmos_mean is strictly below this value.",
- )
- return p
-
-
-def parse_acc_benchmark_args(argv: list[str] | None = None) -> argparse.Namespace:
- """Parse CLI args; when ``argv`` is ``None``, use ``sys.argv[1:]`` (standalone script)."""
- if argv is None:
- argv = sys.argv[1:]
- return build_arg_parser().parse_args(argv)
-
-
-def run_acc_benchmark(ns: argparse.Namespace) -> int:
- """Run Daily-Omni and/or Seed-TTS client benches against a running server; return 0 on success."""
- failed: list[str] = []
-
- with _preserve_benchmark_dataset_env():
- sync_dataset_env_from_ns(ns)
-
- vllm = find_vllm_cli()
- print(f"Using vLLM CLI: {vllm}", flush=True)
- print(f"Repo root (for cwd reference): {_repo_root()}", flush=True)
-
- if not ns.skip_daily_omni:
- path = run_daily_omni(ns, vllm)
- print(f"\n[Daily-Omni] result JSON: {path}", flush=True)
- data = load_benchmark_result(path)
- errs = _validate_daily_omni(data, min_accuracy=ns.min_daily_omni_accuracy)
- if errs:
- failed.extend([f"[Daily-Omni] {e}" for e in errs])
- else:
- print(
- f"[Daily-Omni] daily_omni_accuracy={data.get('daily_omni_accuracy')} "
- f"evaluated_ok={data.get('daily_omni_evaluated_ok')}",
- flush=True,
- )
-
- if not ns.skip_seed_tts:
- path = run_seed_tts(ns, vllm)
- print(f"\n[Seed-TTS] result JSON: {path}", flush=True)
- data = load_benchmark_result(path)
- errs = _validate_seed_tts(
- data,
- max_mean_wer=ns.max_seed_tts_mean_wer,
- min_mean_sim=ns.min_seed_tts_mean_sim,
- min_mean_utmos=ns.min_seed_tts_mean_utmos,
- )
- if errs:
- failed.extend([f"[Seed-TTS] {e}" for e in errs])
- else:
- print(
- f"[Seed-TTS] mean_wer={data.get('seed_tts_content_error_mean')} "
- f"mean_sim={data.get('seed_tts_sim_mean')} mean_utmos={data.get('seed_tts_utmos_mean')} "
- f"evaluated={data.get('seed_tts_content_evaluated')}",
- flush=True,
- )
-
- if failed:
- print("\nACCURACY CHECK FAILED:", file=sys.stderr)
- for line in failed:
- print(f" - {line}", file=sys.stderr)
- return 1
-
- print("\nAll configured accuracy checks passed.", flush=True)
- return 0
-
-
-def main() -> int:
- return run_acc_benchmark(parse_acc_benchmark_args())
-
-
-if __name__ == "__main__":
- raise SystemExit(main())
diff --git a/tests/e2e/accuracy/qwen3_omni/test_qwen3_omni.py b/tests/e2e/accuracy/qwen3_omni/test_qwen3_omni.py
deleted file mode 100644
index 773f7c1108c..00000000000
--- a/tests/e2e/accuracy/qwen3_omni/test_qwen3_omni.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-"""Qwen3-Omni accuracy benchmarks (Daily-Omni MCQ + Seed-TTS WER) via ``vllm bench serve --omni``.
-
-Starts a **module-scoped** Omni OpenAI-compatible server (same pattern as ``tests/dfx/perf`` and
-``tests/e2e/online_serving/test_qwen3_omni.py``), then runs the client benches against
-``omni_server.host`` / ``omni_server.port`` / ``omni_server.model``.
-
-**Daily-Omni from Hugging Face:** unless ``VLLM_DAILY_OMNI_QA_JSON`` and ``VLLM_DAILY_OMNI_VIDEO_DIR``
-point at a full local tree, the bench uses ``--dataset-path`` (default ``liarliar/Daily-Omni`` via
-``VLLM_DAILY_OMNI_REPO`` / ``--daily-omni-repo``). QA loads through ``datasets``; ``Videos.tar`` is
-downloaded and extracted under ``HF_HOME`` on demand. The tests patch in
-``--daily-omni-inline-local-video`` so multimodal payloads use data URLs (no
-``--allowed-local-media-path`` on the server). Use small ``--num-prompts`` defaults suitable for CI
-(override with ``ACC_BENCH_NUM_PROMPTS`` / ``ACC_BENCH_MAX_CONCURRENCY``; see
-:func:`tests.e2e.accuracy.qwen3_omni.qwen3_omni_acc_bench_core.build_acc_benchmark_cli_argv`).
-
-This package lives under ``tests/e2e/accuracy/qwen3_omni/``, so pytest still loads
-``tests/e2e/accuracy/conftest.py``, which imports ``tests.conftest`` (heavy deps: ``vllm``, ``torch``, …).
-A broken or partial install can therefore **fail during collection** before these tests run.
-
-If ``vllm`` is not on ``PATH``, the tests **skip** instead of erroring. Without
-``VLLM_SKIP_ACC_BENCH=1``, a failed bench still yields a **failed** run (non-zero subprocess exit).
-
-Run::
-
- pytest -sv tests/e2e/accuracy/qwen3_omni/test_qwen3_omni.py
-
-Only the subprocess accuracy marker::
-
- pytest -sv tests/e2e/accuracy/qwen3_omni/test_qwen3_omni.py -m qwen3_omni_acc
-
-Skip when you do not have GPUs, a server, or datasets (CI opt-out)::
-
- VLLM_SKIP_ACC_BENCH=1 pytest -sv tests/e2e/accuracy/qwen3_omni/test_qwen3_omni.py
-
-Standalone CLI (expects a server already up; uses ``ACC_BENCH_*`` env defaults)::
-
- python tests/e2e/accuracy/qwen3_omni/run_qwen_omni_acc_benchmark.py --help
-"""
-
-from __future__ import annotations
-
-from pathlib import Path
-
-import pytest
-
-from tests.e2e.accuracy.qwen3_omni import run_qwen_omni_acc_benchmark as _acc_bench
-from tests.e2e.accuracy.qwen3_omni.qwen3_omni_acc_bench_core import (
- build_acc_benchmark_cli_argv,
- find_vllm_cli,
-)
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServerParams
-from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
-from vllm_omni.platforms import current_omni_platform
-
-_E2E_ROOT = Path(__file__).resolve().parent.parent.parent
-
-models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
-
-pytestmark = [pytest.mark.full_model, pytest.mark.omni]
-
-_CI_DEPLOY = get_deploy_config_path("ci/qwen3_omni_moe.yaml")
-
-
-def get_chunk_config(config_path: str | None = None):
- """Load the qwen3_omni CI deploy yaml with async_chunk modifications for streaming mode."""
- if config_path is None:
- config_path = _CI_DEPLOY
- # TODO: remove this workaround once legacy `stage_args` path is deleted.
- # The pipeline (qwen3_omni/pipeline.py) already wires
- # thinker2talker_async_chunk / talker2code2wav_async_chunk on stage 0/1,
- # so only async_chunk needs flipping. Writing nested `engine_args:` into
- # the new-schema overlay trips _parse_stage_deploy's legacy branch and
- # drops flat fields (load_format, max_num_seqs, ...).
- return modify_stage_config(config_path, updates={"async_chunk": True})
-
-
-if current_omni_platform.is_xpu():
- stage_configs = [_CI_DEPLOY]
-else: # CUDA + ROCm MI325 share the same deploy config
- stage_configs = [get_chunk_config()]
-
-test_params = [
- OmniServerParams(model=model, stage_config_path=stage_config) for model in models for stage_config in stage_configs
-]
-
-
-def _require_vllm_cli() -> None:
- try:
- find_vllm_cli()
- except FileNotFoundError as exc:
- pytest.skip(str(exc))
-
-
-@pytest.fixture(autouse=True)
-def _daily_omni_hub_inline_media(monkeypatch: pytest.MonkeyPatch) -> None:
- """Hub / lazy-cache mode uses local files → default ``file://`` needs server allowlist.
-
- ``run_qwen_omni_acc_benchmark`` binds ``daily_omni_bench_argv`` at import time; patch that copy
- so we append ``--daily-omni-inline-local-video`` whenever the core helper did not already set it
- (local qa.json + video-dir mode already passes the flag).
- """
- orig = _acc_bench.daily_omni_bench_argv
-
- def _wrapped() -> list[str]:
- out = list(orig())
- if "--daily-omni-inline-local-video" not in out:
- out.append("--daily-omni-inline-local-video")
- return out
-
- monkeypatch.setattr(_acc_bench, "daily_omni_bench_argv", _wrapped)
- monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
- monkeypatch.setenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0")
-
-
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-def test_qwen3_omni_daily_omni_accuracy_bench(omni_server) -> None:
- _require_vllm_cli()
- pytest.importorskip("datasets")
- pytest.importorskip("huggingface_hub")
- ns = _acc_bench.parse_acc_benchmark_args(
- build_acc_benchmark_cli_argv(omni_server, skip_seed=True, skip_daily=False)
- )
- assert _acc_bench.run_acc_benchmark(ns) == 0
-
-
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-def test_qwen3_omni_seed_tts_wer_bench(omni_server) -> None:
- _require_vllm_cli()
- pytest.importorskip("huggingface_hub")
- ns = _acc_bench.parse_acc_benchmark_args(
- build_acc_benchmark_cli_argv(omni_server, skip_seed=False, skip_daily=True)
- )
- assert _acc_bench.run_acc_benchmark(ns) == 0
diff --git a/tests/e2e/accuracy/test_gebench_h100_smoke.py b/tests/e2e/accuracy/test_gebench_h100_smoke.py
index 2702710e4a2..b4b83187135 100644
--- a/tests/e2e/accuracy/test_gebench_h100_smoke.py
+++ b/tests/e2e/accuracy/test_gebench_h100_smoke.py
@@ -6,13 +6,13 @@
import pytest
from benchmarks.accuracy.text_to_image.gbench import main as gbench_main
-from tests.e2e.accuracy.helpers import infer_model_label, reset_artifact_dir
-from tests.helpers.mark import hardware_test
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.e2e.accuracy.conftest import infer_model_label, reset_artifact_dir
+from tests.utils import hardware_test
+@pytest.mark.advanced_model
@pytest.mark.benchmark
+@pytest.mark.diffusion
@hardware_test(res={"cuda": "H100"}, num_cards=1)
def test_gebench_h100_smoke(
gebench_accuracy_servers,
diff --git a/tests/e2e/accuracy/test_gedit_bench_h100_smoke.py b/tests/e2e/accuracy/test_gedit_bench_h100_smoke.py
index 789f7ec939b..ac5f2cb3cfd 100644
--- a/tests/e2e/accuracy/test_gedit_bench_h100_smoke.py
+++ b/tests/e2e/accuracy/test_gedit_bench_h100_smoke.py
@@ -7,13 +7,13 @@
from benchmarks.accuracy.image_to_image.gedit_bench import GROUPS
from benchmarks.accuracy.image_to_image.gedit_bench import main as gedit_main
-from tests.e2e.accuracy.helpers import infer_model_label, reset_artifact_dir
-from tests.helpers.mark import hardware_test
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.e2e.accuracy.conftest import infer_model_label, reset_artifact_dir
+from tests.utils import hardware_test
+@pytest.mark.advanced_model
@pytest.mark.benchmark
+@pytest.mark.diffusion
@hardware_test(res={"cuda": "H100"}, num_cards=1)
def test_gedit_bench_h100_smoke(
gedit_accuracy_servers,
@@ -106,9 +106,9 @@ def test_gedit_bench_h100_smoke(
group_summary = language_summary["by_group"][group]
assert set(group_summary) == {"count", "Q_SC", "Q_PQ", "Q_O"}
- assert summary["languages"]["en"]["overall"]["Q_SC"] >= 6.95
+ assert summary["languages"]["en"]["overall"]["Q_SC"] >= 7.0
assert summary["languages"]["en"]["overall"]["Q_PQ"] >= 5.8
- assert summary["languages"]["en"]["overall"]["Q_O"] >= 6.15
+ assert summary["languages"]["en"]["overall"]["Q_O"] >= 6.2
assert summary["languages"]["cn"]["overall"]["Q_SC"] >= 6.9
assert summary["languages"]["cn"]["overall"]["Q_PQ"] >= 5.7
assert summary["languages"]["cn"]["overall"]["Q_O"] >= 6.1
diff --git a/tests/e2e/accuracy/test_ltx2_3_video_similarity.py b/tests/e2e/accuracy/test_ltx2_3_video_similarity.py
deleted file mode 100644
index dec533d58ae..00000000000
--- a/tests/e2e/accuracy/test_ltx2_3_video_similarity.py
+++ /dev/null
@@ -1,410 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""
-SSIM/PSNR accuracy tests for LTX-2.3.
-
-1. **Transformer parity** (``test_ltx2_3_transformer_matches_diffusers``):
- Swaps our custom transformer into diffusers' ``LTX2Pipeline`` to measure
- numerical parity in isolation. Thresholds: SSIM >= 0.95, PSNR >= 28 dB.
- Result: SSIM 0.999987 (bit-identical).
-
-2. **Full pipeline** (``test_ltx2_3_pipeline_matches_diffusers``):
- Runs the full vLLM-Omni serving stack (``OmniServer`` -> HTTP API) and
- compares per-frame against stock diffusers. Currently skipped because
- the OmniServer subprocess creates a different RNG state than in-process
- diffusers, producing different initial latents from the same seed.
- This is a test infrastructure limitation, not a model accuracy issue.
-"""
-
-from __future__ import annotations
-
-import gc
-import os
-import tempfile
-from pathlib import Path
-
-import diffusers
-import numpy as np
-import pytest
-import requests
-import torch
-from PIL import Image
-
-from tests.e2e.accuracy.helpers import compute_image_ssim_psnr, model_output_dir
-from tests.helpers.env import run_post_test_cleanup, run_pre_test_cleanup
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServer
-
-# Parse diffusers version for compatibility check
-_DIFFUSERS_VERSION = tuple(int(x) for x in diffusers.__version__.split(".")[:2] if x.isdigit())
-_DIFFUSERS_038 = _DIFFUSERS_VERSION >= (0, 38)
-
-MODEL_ID = "dg845/LTX-2.3-Diffusers"
-MODEL_ENV_VAR = "VLLM_TEST_LTX23_MODEL"
-PROMPT = "A lighthouse on a rocky cliff at sunset, waves crashing below, golden hour lighting"
-NEGATIVE_PROMPT = "blurry, low quality, distorted, watermark"
-WIDTH = 512
-HEIGHT = 384
-NUM_FRAMES = 25 # ~1 second at 24fps
-NUM_INFERENCE_STEPS = 20
-GUIDANCE_SCALE = 4.0
-SEED = 42
-
-# Transformer-swap test: near-identical output expected
-TRANSFORMER_SSIM_THRESHOLD = 0.95
-TRANSFORMER_PSNR_THRESHOLD = 28.0
-
-# Full-pipeline test: allows minor divergence from RNG / pipeline differences
-PIPELINE_SSIM_THRESHOLD = 0.94
-PIPELINE_PSNR_THRESHOLD = 28.0
-
-
-def _model_name() -> str:
- return os.environ.get(MODEL_ENV_VAR, MODEL_ID)
-
-
-def _local_files_only(model: str) -> bool:
- return Path(model).exists()
-
-
-# ---------------------------------------------------------------------------
-# Frame extraction helpers
-# ---------------------------------------------------------------------------
-
-
-def _video_to_frames(video_np: np.ndarray) -> list[Image.Image]:
- """Convert numpy video to list of PIL Images."""
- while video_np.ndim > 4:
- video_np = video_np[0]
- if video_np.dtype in (np.float32, np.float64, np.float16):
- video_np = np.clip(video_np * 255, 0, 255).astype(np.uint8)
- return [Image.fromarray(video_np[t]) for t in range(video_np.shape[0])]
-
-
-def _extract_diffusers_frames(result) -> list[Image.Image]:
- """Extract frames from diffusers pipeline output."""
- video = result.frames
- if isinstance(video, np.ndarray):
- return _video_to_frames(video)
- if isinstance(video, list):
- if isinstance(video[0], list):
- return [img.convert("RGB") for img in video[0]]
- if isinstance(video[0], Image.Image):
- return [img.convert("RGB") for img in video]
- raise ValueError(f"Unexpected output type: {type(video)}")
-
-
-def _extract_mp4_frames(mp4_bytes: bytes) -> list[Image.Image]:
- """Extract frames from an MP4 video using ffmpeg."""
- import subprocess
-
- with tempfile.TemporaryDirectory() as tmpdir:
- mp4_path = os.path.join(tmpdir, "video.mp4")
- with open(mp4_path, "wb") as f:
- f.write(mp4_bytes)
-
- # Extract video frames as PNG files using ffmpeg
- frame_pattern = os.path.join(tmpdir, "frame_%04d.png")
- subprocess.run(
- ["ffmpeg", "-i", mp4_path, "-vsync", "0", frame_pattern],
- capture_output=True,
- check=True,
- )
-
- # Load frames in order
- frames = []
- i = 1
- while True:
- fpath = os.path.join(tmpdir, f"frame_{i:04d}.png")
- if not os.path.exists(fpath):
- break
- frames.append(Image.open(fpath).convert("RGB").copy())
- i += 1
- return frames
-
-
-# ---------------------------------------------------------------------------
-# Comparison helper
-# ---------------------------------------------------------------------------
-
-
-def _assert_video_similarity(
- *,
- model_name: str,
- vllm_frames: list[Image.Image],
- diffusers_frames: list[Image.Image],
- ssim_threshold: float,
- psnr_threshold: float,
-) -> tuple[float, float]:
- """Compare video frames and assert SSIM/PSNR meet thresholds."""
- min_frames = min(len(vllm_frames), len(diffusers_frames))
- assert min_frames > 0, "No frames to compare"
-
- ssim_scores = []
- psnr_scores = []
- for i in range(min_frames):
- ssim_val, psnr_val = compute_image_ssim_psnr(
- prediction=vllm_frames[i],
- reference=diffusers_frames[i],
- )
- ssim_scores.append(ssim_val)
- psnr_scores.append(psnr_val)
-
- avg_ssim = sum(ssim_scores) / len(ssim_scores)
- avg_psnr = sum(psnr_scores) / len(psnr_scores)
-
- print(f"\n{model_name} video similarity ({min_frames} frames):")
- print(f" SSIM: avg={avg_ssim:.6f}, min={min(ssim_scores):.6f}, threshold>={ssim_threshold:.6f}")
- print(f" PSNR: avg={avg_psnr:.6f} dB, min={min(psnr_scores):.6f} dB, threshold>={psnr_threshold:.6f} dB")
-
- assert avg_ssim >= ssim_threshold, f"SSIM below threshold: got {avg_ssim:.6f}, expected >= {ssim_threshold:.6f}."
- assert avg_psnr >= psnr_threshold, f"PSNR below threshold: got {avg_psnr:.6f}, expected >= {psnr_threshold:.6f}."
- return avg_ssim, avg_psnr
-
-
-# ---------------------------------------------------------------------------
-# Diffusers baseline (shared by both tests)
-# ---------------------------------------------------------------------------
-
-
-def _run_diffusers_baseline(model: str, output_dir: Path) -> list[Image.Image]:
- """Generate video using stock diffusers LTX2Pipeline."""
- from diffusers import LTX2Pipeline
-
- run_pre_test_cleanup(enable_force=True)
- pipe = None
- try:
- pipe = LTX2Pipeline.from_pretrained(
- model, torch_dtype=torch.bfloat16, local_files_only=_local_files_only(model)
- ).to("cuda")
-
- generator = torch.Generator(device="cuda").manual_seed(SEED)
- result = pipe(
- prompt=PROMPT,
- negative_prompt=NEGATIVE_PROMPT,
- width=WIDTH,
- height=HEIGHT,
- num_frames=NUM_FRAMES,
- num_inference_steps=NUM_INFERENCE_STEPS,
- guidance_scale=GUIDANCE_SCALE,
- generator=generator,
- output_type="np",
- )
- frames = _extract_diffusers_frames(result)
- for i, f in enumerate(frames):
- f.save(output_dir / f"diffusers_frame_{i:04d}.png")
- return frames
- finally:
- del pipe
- gc.collect()
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- run_post_test_cleanup(enable_force=True)
-
-
-# ---------------------------------------------------------------------------
-# Test 1: Transformer-swap parity
-# ---------------------------------------------------------------------------
-
-
-def _run_with_custom_transformer(model: str, output_dir: Path) -> list[Image.Image]:
- """Run diffusers pipeline with our custom transformer swapped in."""
- from contextlib import nullcontext
-
- from diffusers import LTX2Pipeline
- from vllm.config import VllmConfig, set_current_vllm_config
- from vllm.distributed.parallel_state import init_distributed_environment, initialize_model_parallel
-
- from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import create_transformer_from_config, load_transformer_config
-
- vllm_config = VllmConfig()
- ctx = set_current_vllm_config(vllm_config)
- ctx.__enter__()
-
- if not torch.distributed.is_initialized():
- os.environ.setdefault("MASTER_ADDR", "localhost")
- os.environ.setdefault("MASTER_PORT", "29503")
- os.environ.setdefault("RANK", "0")
- os.environ.setdefault("WORLD_SIZE", "1")
- init_distributed_environment(world_size=1, rank=0, local_rank=0)
- initialize_model_parallel(tensor_model_parallel_size=1)
-
- local = _local_files_only(model)
- pipe = LTX2Pipeline.from_pretrained(model, torch_dtype=torch.bfloat16, local_files_only=local)
-
- transformer_config = load_transformer_config(model, "transformer", local)
- our_transformer = create_transformer_from_config(transformer_config)
-
- diffusers_state = dict(pipe.transformer.named_parameters())
-
- def _weight_iter():
- for name, param in diffusers_state.items():
- yield name, param.data
-
- our_transformer.load_weights(_weight_iter())
- our_transformer = our_transformer.to(dtype=torch.bfloat16, device="cuda").eval()
-
- # Compatibility shims for diffusers pipeline
- our_transformer.dtype = torch.bfloat16
- if not hasattr(our_transformer, "cache_context"):
- our_transformer.cache_context = lambda name: nullcontext()
-
- del pipe.transformer
- pipe.transformer = our_transformer
- for name, component in pipe.components.items():
- if name != "transformer" and hasattr(component, "to"):
- try:
- component.to("cuda")
- except Exception:
- pass
-
- generator = torch.Generator(device="cuda").manual_seed(SEED)
- result = pipe(
- prompt=PROMPT,
- negative_prompt=NEGATIVE_PROMPT,
- width=WIDTH,
- height=HEIGHT,
- num_frames=NUM_FRAMES,
- num_inference_steps=NUM_INFERENCE_STEPS,
- guidance_scale=GUIDANCE_SCALE,
- generator=generator,
- output_type="np",
- )
- frames = _extract_diffusers_frames(result)
- for i, f in enumerate(frames):
- f.save(output_dir / f"vllm_transformer_frame_{i:04d}.png")
-
- del pipe, result, our_transformer
- gc.collect()
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- return frames
-
-
-@pytest.mark.advanced_model
-@pytest.mark.benchmark
-@pytest.mark.diffusion
-@pytest.mark.skipif(
- not _DIFFUSERS_038, reason="LTX-2.3 requires diffusers >= 0.38.0 for cross_attn_mod and BWE vocoder"
-)
-@hardware_test(res={"cuda": "H100"}, num_cards=1)
-def test_ltx2_3_transformer_matches_diffusers(accuracy_artifact_root: Path) -> None:
- """Transformer-level parity: swap our transformer into diffusers pipeline.
-
- Isolates transformer numerical accuracy from pipeline-level differences.
- Both runs use diffusers' denoising loop, CFG, scheduler, and RNG.
- """
- model = _model_name()
- output_dir = model_output_dir(accuracy_artifact_root, MODEL_ID)
-
- diffusers_frames = _run_diffusers_baseline(model=model, output_dir=output_dir)
- vllm_frames = _run_with_custom_transformer(model=model, output_dir=output_dir)
-
- _assert_video_similarity(
- model_name=f"{MODEL_ID} (transformer-swap)",
- vllm_frames=vllm_frames,
- diffusers_frames=diffusers_frames,
- ssim_threshold=TRANSFORMER_SSIM_THRESHOLD,
- psnr_threshold=TRANSFORMER_PSNR_THRESHOLD,
- )
-
-
-# ---------------------------------------------------------------------------
-# Test 2: Full pipeline (OmniServer → HTTP API vs diffusers)
-# ---------------------------------------------------------------------------
-
-
-def _run_vllm_omni_serving(model: str, output_dir: Path) -> list[Image.Image]:
- """Generate video via the full vLLM-Omni serving stack."""
- server_args = [
- "--model-class-name",
- "LTX23Pipeline",
- "--stage-init-timeout",
- "600",
- ]
- with OmniServer(model, server_args, use_omni=True) as server:
- # Submit generation request
- response = requests.post(
- f"http://{server.host}:{server.port}/v1/videos",
- files={
- "prompt": (None, PROMPT),
- "negative_prompt": (None, NEGATIVE_PROMPT),
- "model": (None, server.model),
- "num_frames": (None, str(NUM_FRAMES)),
- "fps": (None, "24"),
- "size": (None, f"{WIDTH}x{HEIGHT}"),
- "num_inference_steps": (None, str(NUM_INFERENCE_STEPS)),
- "guidance_scale": (None, str(GUIDANCE_SCALE)),
- "seed": (None, str(SEED)),
- },
- timeout=120,
- )
- response.raise_for_status()
- video_id = response.json()["id"]
-
- # Poll for completion
- import time
-
- for _ in range(120):
- status_resp = requests.get(
- f"http://{server.host}:{server.port}/v1/videos/{video_id}",
- timeout=30,
- )
- status_resp.raise_for_status()
- status = status_resp.json()["status"]
- if status == "completed":
- break
- if status in ("error", "failed"):
- raise RuntimeError(f"Video generation failed: {status_resp.json()}")
- time.sleep(5)
- else:
- raise TimeoutError(f"Video generation timed out after 600s (id={video_id})")
-
- # Download video content
- content_resp = requests.get(
- f"http://{server.host}:{server.port}/v1/videos/{video_id}/content",
- timeout=120,
- )
- content_resp.raise_for_status()
- mp4_bytes = content_resp.content
-
- # Save MP4
- mp4_path = output_dir / "vllm_omni_pipeline.mp4"
- with open(mp4_path, "wb") as f:
- f.write(mp4_bytes)
-
- # Extract frames
- frames = _extract_mp4_frames(mp4_bytes)
- for i, frame in enumerate(frames):
- frame.save(output_dir / f"vllm_pipeline_frame_{i:04d}.png")
- return frames
-
-
-@pytest.mark.advanced_model
-@pytest.mark.benchmark
-@pytest.mark.diffusion
-@pytest.mark.skipif(
- not _DIFFUSERS_038, reason="LTX-2.3 requires diffusers >= 0.38.0 for cross_attn_mod and BWE vocoder"
-)
-@hardware_test(res={"cuda": "H100"}, num_cards=1)
-def test_ltx2_3_pipeline_matches_diffusers(accuracy_artifact_root: Path) -> None:
- """Full-pipeline parity: vLLM-Omni serving stack vs diffusers.
-
- Runs the complete vLLM-Omni OmniServer (subprocess, HTTP API, video
- encoding) and compares per-frame against stock diffusers output.
- Follows the Wan2.2 / Qwen Image pattern with seed-based determinism.
- """
- model = _model_name()
- output_dir = model_output_dir(accuracy_artifact_root, MODEL_ID)
-
- diffusers_frames = _run_diffusers_baseline(model=model, output_dir=output_dir)
- vllm_frames = _run_vllm_omni_serving(model=model, output_dir=output_dir)
-
- _assert_video_similarity(
- model_name=f"{MODEL_ID} (full-pipeline)",
- vllm_frames=vllm_frames,
- diffusers_frames=diffusers_frames,
- ssim_threshold=PIPELINE_SSIM_THRESHOLD,
- psnr_threshold=PIPELINE_PSNR_THRESHOLD,
- )
diff --git a/tests/e2e/accuracy/test_qwen_image.py b/tests/e2e/accuracy/test_qwen_image.py
deleted file mode 100644
index 4b8215d54b5..00000000000
--- a/tests/e2e/accuracy/test_qwen_image.py
+++ /dev/null
@@ -1,122 +0,0 @@
-from __future__ import annotations
-
-import base64
-import gc
-import io
-import os
-from pathlib import Path
-
-import pytest
-import requests
-import torch
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline
-from PIL import Image
-
-from tests.e2e.accuracy.helpers import assert_similarity, model_output_dir
-from tests.helpers.env import run_post_test_cleanup, run_pre_test_cleanup
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServer
-
-pytestmark = [pytest.mark.full_model, pytest.mark.diffusion]
-
-
-MODEL_ID = "Qwen/Qwen-Image"
-MODEL_ENV_VAR = "QWEN_IMAGE_MODEL"
-PROMPT = "A photo of a cat sitting on a laptop keyboard, digital art style."
-NEGATIVE_PROMPT = "blurry, low quality"
-WIDTH = 512
-HEIGHT = 512
-NUM_INFERENCE_STEPS = 20
-TRUE_CFG_SCALE = 4.0
-SEED = 42
-SSIM_THRESHOLD = 0.97
-PSNR_THRESHOLD = 30.0
-
-
-def _model_name() -> str:
- return os.environ.get(MODEL_ENV_VAR, MODEL_ID)
-
-
-def _local_files_only(model: str) -> bool:
- return Path(model).exists()
-
-
-def _run_vllm_omni_qwen_image(*, model: str, output_path: Path) -> Image.Image:
- server_args = ["--num-gpus", "1", "--stage-init-timeout", "300", "--init-timeout", "900"]
- with OmniServer(model, server_args, use_omni=True) as omni_server:
- response = requests.post(
- f"http://{omni_server.host}:{omni_server.port}/v1/images/generations",
- json={
- "model": omni_server.model,
- "prompt": PROMPT,
- "size": f"{WIDTH}x{HEIGHT}",
- "n": 1,
- "response_format": "b64_json",
- "negative_prompt": NEGATIVE_PROMPT,
- "num_inference_steps": NUM_INFERENCE_STEPS,
- "true_cfg_scale": TRUE_CFG_SCALE,
- "seed": SEED,
- },
- timeout=600,
- )
- response.raise_for_status()
- payload = response.json()
- assert len(payload["data"]) == 1
- image_bytes = base64.b64decode(payload["data"][0]["b64_json"])
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
- image.load()
- image.save(output_path)
- return image
-
-
-def _run_diffusers_qwen_image(*, model: str, output_path: Path) -> Image.Image:
- run_pre_test_cleanup(enable_force=True)
- pipe: DiffusionPipeline | None = None
- try:
- pipe = DiffusionPipeline.from_pretrained(
- model,
- torch_dtype=torch.bfloat16,
- trust_remote_code=True,
- local_files_only=_local_files_only(model),
- ).to("cuda")
- generator = torch.Generator(device="cuda").manual_seed(SEED)
- result = pipe( # pyright: ignore[reportCallIssue]
- prompt=PROMPT,
- negative_prompt=NEGATIVE_PROMPT,
- width=WIDTH,
- height=HEIGHT,
- num_inference_steps=NUM_INFERENCE_STEPS,
- true_cfg_scale=TRUE_CFG_SCALE,
- generator=generator,
- )
- output_image = result.images[0].convert("RGB")
- output_image.save(output_path)
- return output_image
- finally:
- if pipe is not None and hasattr(pipe, "maybe_free_model_hooks"):
- pipe.maybe_free_model_hooks()
- del pipe
- gc.collect()
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- run_post_test_cleanup(enable_force=True)
-
-
-@pytest.mark.benchmark
-@hardware_test(res={"cuda": "H100"}, num_cards=1)
-def test_qwen_image_matches_diffusers(accuracy_artifact_root: Path) -> None:
- model = _model_name()
- output_dir = model_output_dir(accuracy_artifact_root, MODEL_ID)
-
- vllm_output = _run_vllm_omni_qwen_image(model=model, output_path=output_dir / "vllm_omni.png")
- diffusers_output = _run_diffusers_qwen_image(model=model, output_path=output_dir / "diffusers.png")
-
- assert_similarity(
- model_name=MODEL_ID,
- vllm_image=vllm_output,
- diffusers_image=diffusers_output,
- width=WIDTH,
- height=HEIGHT,
- ssim_threshold=SSIM_THRESHOLD,
- psnr_threshold=PSNR_THRESHOLD,
- )
diff --git a/tests/e2e/accuracy/test_qwen_image_edit.py b/tests/e2e/accuracy/test_qwen_image_edit.py
deleted file mode 100644
index 07deecca976..00000000000
--- a/tests/e2e/accuracy/test_qwen_image_edit.py
+++ /dev/null
@@ -1,228 +0,0 @@
-from __future__ import annotations
-
-import gc
-from pathlib import Path
-
-import pytest
-import requests
-import torch
-from diffusers import QwenImageEditPipeline, QwenImageEditPlusPipeline
-from PIL import Image
-
-from benchmarks.accuracy.common import decode_base64_image, pil_to_png_bytes
-from tests.e2e.accuracy.helpers import assert_similarity, model_output_dir
-from tests.helpers.env import run_post_test_cleanup, run_pre_test_cleanup
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServer
-
-pytestmark = [pytest.mark.full_model, pytest.mark.diffusion]
-
-
-SINGLE_MODEL = "Qwen/Qwen-Image-Edit"
-MULTIPLE_MODEL = "Qwen/Qwen-Image-Edit-2509"
-WIDTH = 512
-HEIGHT = 512
-NUM_INFERENCE_STEPS = 20
-TRUE_CFG_SCALE = 4.0
-SEED = 42
-SSIM_THRESHOLD = 0.94
-PSNR_THRESHOLD = 28.0
-
-PROMPT_SINGLE_IMAGE = "The input is a 2D cartoon bear mascot. Restyle it into a painterly oil artwork with warm colors while preserving the main structure."
-PROMPT_MULTIPLE_IMAGE = "Put the cartoon bear mascot and the furry rabbit into one coherent scene with a painterly oil artwork style and consistent lighting."
-NEGATIVE_PROMPT = "low quality, blurry, artifacts, distortion"
-SERVER_ARGS = ["--num-gpus", "1", "--stage-init-timeout", "300", "--init-timeout", "900"]
-
-
-def _run_vllm_omni_image_edit(
- *,
- omni_server: OmniServer,
- prompt: str,
- input_images: list[Image.Image],
- output_path: Path,
-) -> Image.Image:
- response = requests.post(
- f"http://{omni_server.host}:{omni_server.port}/v1/images/edits",
- data={
- "model": omni_server.model,
- "prompt": prompt,
- "size": f"{WIDTH}x{HEIGHT}",
- "n": 1,
- "response_format": "b64_json",
- "negative_prompt": NEGATIVE_PROMPT,
- "num_inference_steps": NUM_INFERENCE_STEPS,
- "true_cfg_scale": TRUE_CFG_SCALE,
- "seed": SEED,
- },
- files=[
- ("image", (f"image_{index}.png", pil_to_png_bytes(image), "image/png"))
- for index, image in enumerate(input_images)
- ],
- timeout=600,
- )
- response.raise_for_status()
- payload = response.json()
- assert len(payload["data"]) == 1
- image = decode_base64_image(payload["data"][0]["b64_json"])
- image.load()
- image.save(output_path)
- return image
-
-
-def _run_diffusers_image_edit(
- *,
- model: str,
- pipeline_class: type[QwenImageEditPipeline] | type[QwenImageEditPlusPipeline],
- prompt: str,
- input_images: list[Image.Image],
- output_path: Path,
-) -> Image.Image:
- run_pre_test_cleanup(enable_force=True)
- pipe: QwenImageEditPipeline | QwenImageEditPlusPipeline | None = None
- device = torch.device("cuda:0")
- torch.cuda.set_device(device)
- try:
- images = input_images[0] if len(input_images) == 1 else input_images
- pipe = pipeline_class.from_pretrained(
- model,
- torch_dtype=torch.bfloat16,
- trust_remote_code=True,
- ).to(device)
- pipe.set_progress_bar_config(disable=False)
- generator = torch.Generator(device=device).manual_seed(SEED)
- result = pipe( # pyright: ignore[reportCallIssue]
- prompt=prompt,
- image=images,
- negative_prompt=NEGATIVE_PROMPT,
- num_inference_steps=NUM_INFERENCE_STEPS,
- true_cfg_scale=TRUE_CFG_SCALE,
- width=WIDTH,
- height=HEIGHT,
- generator=generator,
- )
- output_image = result.images[0].convert("RGB") # pyright: ignore[reportAttributeAccessIssue]
- output_image.save(output_path)
- return output_image
- finally:
- if pipe is not None and hasattr(pipe, "maybe_free_model_hooks"):
- pipe.maybe_free_model_hooks()
- del pipe
- gc.collect()
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- run_post_test_cleanup(enable_force=True)
-
-
-def _vllm_omni_output_single_image(
- accuracy_artifact_root: Path,
- qwen_bear_image: Image.Image,
-) -> Image.Image:
- output_dir = model_output_dir(accuracy_artifact_root, SINGLE_MODEL)
- output_path = output_dir / "vllm_omni_single.png"
- with OmniServer(model=SINGLE_MODEL, serve_args=SERVER_ARGS) as server:
- output = _run_vllm_omni_image_edit(
- omni_server=server,
- prompt=PROMPT_SINGLE_IMAGE,
- input_images=[qwen_bear_image],
- output_path=output_path,
- )
- return output
-
-
-def _diffusers_output_single_image(accuracy_artifact_root: Path, qwen_bear_image: Image.Image) -> Image.Image:
- output_dir = model_output_dir(accuracy_artifact_root, SINGLE_MODEL)
- output_path = output_dir / "diffusers_single.png"
- return _run_diffusers_image_edit(
- model=SINGLE_MODEL,
- pipeline_class=QwenImageEditPipeline,
- prompt=PROMPT_SINGLE_IMAGE,
- input_images=[qwen_bear_image],
- output_path=output_path,
- )
-
-
-def _vllm_omni_output_multiple_image(
- accuracy_artifact_root: Path,
- qwen_bear_image: Image.Image,
- rabbit_image: Image.Image,
-) -> Image.Image:
- output_dir = model_output_dir(accuracy_artifact_root, MULTIPLE_MODEL)
- output_path = output_dir / "vllm_omni_multiple.png"
- with OmniServer(model=MULTIPLE_MODEL, serve_args=SERVER_ARGS) as server:
- output = _run_vllm_omni_image_edit(
- omni_server=server,
- prompt=PROMPT_MULTIPLE_IMAGE,
- input_images=[qwen_bear_image, rabbit_image],
- output_path=output_path,
- )
- return output
-
-
-def _diffusers_output_multiple_image(
- accuracy_artifact_root: Path, qwen_bear_image: Image.Image, rabbit_image: Image.Image
-) -> Image.Image:
- output_dir = model_output_dir(accuracy_artifact_root, MULTIPLE_MODEL)
- output_path = output_dir / "diffusers_multiple.png"
- return _run_diffusers_image_edit(
- model=MULTIPLE_MODEL,
- pipeline_class=QwenImageEditPlusPipeline,
- prompt=PROMPT_MULTIPLE_IMAGE,
- input_images=[qwen_bear_image, rabbit_image],
- output_path=output_path,
- )
-
-
-@pytest.mark.benchmark
-@hardware_test(res={"cuda": "H100"}, num_cards=1)
-def test_qwen_image_edit_single_matches_diffusers(
- accuracy_artifact_root: Path,
- qwen_bear_image: Image.Image,
-) -> None:
- vllm_image = _vllm_omni_output_single_image(
- accuracy_artifact_root=accuracy_artifact_root,
- qwen_bear_image=qwen_bear_image,
- )
- diffusers_image = _diffusers_output_single_image(
- accuracy_artifact_root=accuracy_artifact_root,
- qwen_bear_image=qwen_bear_image,
- )
- assert_similarity(
- model_name=SINGLE_MODEL,
- vllm_image=vllm_image,
- diffusers_image=diffusers_image,
- width=WIDTH,
- height=HEIGHT,
- ssim_threshold=SSIM_THRESHOLD,
- psnr_threshold=PSNR_THRESHOLD,
- )
-
-
-@pytest.mark.benchmark
-@hardware_test(res={"cuda": "H100"}, num_cards=1)
-@pytest.mark.skip(
- reason="Skipping as the second image seems to be ignored by the API. Will come back to this later after #2772 is merged."
-)
-def test_qwen_image_edit_multiple_matches_diffusers(
- accuracy_artifact_root: Path,
- qwen_bear_image: Image.Image,
- rabbit_image: Image.Image,
-) -> None:
- vllm_image = _vllm_omni_output_multiple_image(
- accuracy_artifact_root=accuracy_artifact_root,
- qwen_bear_image=qwen_bear_image,
- rabbit_image=rabbit_image,
- )
- diffusers_image = _diffusers_output_multiple_image(
- accuracy_artifact_root=accuracy_artifact_root,
- qwen_bear_image=qwen_bear_image,
- rabbit_image=rabbit_image,
- )
- assert_similarity(
- model_name=MULTIPLE_MODEL,
- vllm_image=vllm_image,
- diffusers_image=diffusers_image,
- width=WIDTH,
- height=HEIGHT,
- ssim_threshold=SSIM_THRESHOLD,
- psnr_threshold=PSNR_THRESHOLD,
- )
diff --git a/tests/e2e/accuracy/test_qwen_image_layered.py b/tests/e2e/accuracy/test_qwen_image_layered.py
deleted file mode 100644
index 30ad2966ff6..00000000000
--- a/tests/e2e/accuracy/test_qwen_image_layered.py
+++ /dev/null
@@ -1,149 +0,0 @@
-from __future__ import annotations
-
-import base64
-import gc
-import io
-import os
-from pathlib import Path
-
-import pytest
-import requests
-import torch
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline
-from PIL import Image
-
-from tests.e2e.accuracy.helpers import assert_image_sequence_similarity, model_output_dir
-from tests.helpers.env import run_post_test_cleanup, run_pre_test_cleanup
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServer
-
-pytestmark = [pytest.mark.full_model, pytest.mark.diffusion]
-
-
-MODEL_ID = "Qwen/Qwen-Image-Layered"
-MODEL_ENV_VAR = "QWEN_IMAGE_LAYERED_MODEL"
-PROMPT = "decompose into layers"
-NEGATIVE_PROMPT = " "
-NUM_INFERENCE_STEPS = 20
-TRUE_CFG_SCALE = 4.0
-SEED = 777
-LAYERS = 3
-RESOLUTION = 640
-SSIM_THRESHOLD = 0.97
-PSNR_THRESHOLD = 30.0
-
-
-def _model_name() -> str:
- return os.environ.get(MODEL_ENV_VAR, MODEL_ID)
-
-
-def _local_files_only(model: str) -> bool:
- return Path(model).exists()
-
-
-def _normalize_layered_images(images: object) -> list[Image.Image]:
- if not isinstance(images, list) or not images:
- raise AssertionError(f"Unexpected layered output container: {type(images).__name__}")
-
- first_item = images[0]
- if isinstance(first_item, Image.Image):
- return [image.convert("RGBA") for image in images if isinstance(image, Image.Image)]
- if isinstance(first_item, (list, tuple)):
- return [image.convert("RGBA") for image in first_item if isinstance(image, Image.Image)]
- raise AssertionError(f"Unexpected layered image element type: {type(first_item).__name__}")
-
-
-def _run_vllm_omni_qwen_image_layered(*, model: str, input_image: Image.Image, output_dir: Path) -> list[Image.Image]:
- input_image.save(output_dir / "input.png")
- server_args = ["--num-gpus", "1", "--stage-init-timeout", "300", "--init-timeout", "900"]
- with OmniServer(model, server_args, use_omni=True) as omni_server:
- buffer = io.BytesIO()
- input_image.save(buffer, format="PNG")
- buffer.seek(0)
- response = requests.post(
- f"http://{omni_server.host}:{omni_server.port}/v1/images/edits",
- data={
- "model": omni_server.model,
- "prompt": PROMPT,
- "size": "auto",
- "n": 1,
- "response_format": "b64_json",
- "negative_prompt": NEGATIVE_PROMPT,
- "num_inference_steps": NUM_INFERENCE_STEPS,
- "true_cfg_scale": TRUE_CFG_SCALE,
- "seed": SEED,
- "layers": LAYERS,
- "resolution": RESOLUTION,
- },
- files=[("image", ("input.png", buffer, "image/png"))],
- timeout=600,
- )
- response.raise_for_status()
- payload = response.json()
- assert len(payload["data"]) == LAYERS
- output_images = []
- for item in payload["data"]:
- image_bytes = base64.b64decode(item["b64_json"])
- image = Image.open(io.BytesIO(image_bytes)).convert("RGBA")
- image.load()
- output_images.append(image)
- for index, image in enumerate(output_images, start=1):
- image.save(output_dir / f"vllm_omni_layer_{index}.png")
- return output_images
-
-
-def _run_diffusers_qwen_image_layered(*, model: str, input_image: Image.Image, output_dir: Path) -> list[Image.Image]:
- run_pre_test_cleanup(enable_force=True)
- pipe: DiffusionPipeline | None = None
- try:
- pipe = DiffusionPipeline.from_pretrained(
- model,
- torch_dtype=torch.bfloat16,
- trust_remote_code=True,
- local_files_only=_local_files_only(model),
- ).to("cuda")
- generator = torch.Generator(device="cuda").manual_seed(SEED)
- result = pipe( # pyright: ignore[reportCallIssue]
- image=input_image,
- prompt=PROMPT,
- negative_prompt=NEGATIVE_PROMPT,
- num_inference_steps=NUM_INFERENCE_STEPS,
- true_cfg_scale=TRUE_CFG_SCALE,
- generator=generator,
- num_images_per_prompt=1,
- layers=LAYERS,
- resolution=RESOLUTION,
- )
- output_images = _normalize_layered_images(result.images)
- assert len(output_images) == LAYERS, f"Expected {LAYERS} diffusers layers, got {len(output_images)}"
- for index, image in enumerate(output_images, start=1):
- image.save(output_dir / f"diffusers_layer_{index}.png")
- return output_images
- finally:
- if pipe is not None and hasattr(pipe, "maybe_free_model_hooks"):
- pipe.maybe_free_model_hooks()
- del pipe
- gc.collect()
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- run_post_test_cleanup(enable_force=True)
-
-
-@pytest.mark.benchmark
-@hardware_test(res={"cuda": "H100"}, num_cards=1)
-def test_qwen_image_layered_matches_diffusers(accuracy_artifact_root: Path, qwen_bear_image: Image.Image) -> None:
- model = _model_name()
- output_dir = model_output_dir(accuracy_artifact_root, MODEL_ID)
- input_image = qwen_bear_image.convert("RGBA")
-
- vllm_outputs = _run_vllm_omni_qwen_image_layered(model=model, input_image=input_image, output_dir=output_dir)
- diffusers_outputs = _run_diffusers_qwen_image_layered(model=model, input_image=input_image, output_dir=output_dir)
-
- assert_image_sequence_similarity(
- model_name=MODEL_ID,
- vllm_images=vllm_outputs,
- diffusers_images=diffusers_outputs,
- ssim_threshold=SSIM_THRESHOLD,
- psnr_threshold=PSNR_THRESHOLD,
- compare_mode="RGBA",
- )
diff --git a/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py b/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py
index 1caef3bff54..3cdda1f9ffa 100644
--- a/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py
+++ b/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py
@@ -22,6 +22,7 @@
from diffusers import UniPCMultistepScheduler
from PIL import Image
+from tests.conftest import OmniServerParams
from tests.e2e.accuracy.wan22_i2v.run_wan22_i2v_diffusers_cp import (
_configure_scheduler,
_ensure_wan_ftfy_fallback,
@@ -47,10 +48,7 @@
SSIM_THRESHOLD,
WIDTH,
)
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServerParams
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.utils import hardware_test
def test_parse_video_metadata_extracts_dimensions_and_fps() -> None:
@@ -539,7 +537,9 @@ def _generate_offline_video(*, image_source: str) -> tuple[Path, Path]:
return offline_path, offline_metadata_path
+@pytest.mark.advanced_model
@pytest.mark.benchmark
+@pytest.mark.diffusion
@hardware_test(res={"cuda": "H100"}, num_cards=1)
def test_wan22_i2v_diffusers_offline_generates_video(
wan22_i2v_image_source: str | None,
@@ -563,7 +563,9 @@ def test_wan22_i2v_diffusers_offline_generates_video(
assert offline_metadata["frame_count"] == NUM_FRAMES
+@pytest.mark.advanced_model
@pytest.mark.benchmark
+@pytest.mark.diffusion
@hardware_test(res={"cuda": "H100"}, num_cards=2)
@pytest.mark.parametrize("omni_server", SERVER_CASES, indirect=True)
def test_wan22_i2v_online_serving_generates_video(
@@ -592,7 +594,9 @@ def test_wan22_i2v_online_serving_generates_video(
assert online_metadata["frame_count"] == NUM_FRAMES
+@pytest.mark.advanced_model
@pytest.mark.benchmark
+@pytest.mark.diffusion
@hardware_test(res={"cuda": "H100"}, num_cards=2)
def test_wan22_i2v_serving_matches_diffusers_video_similarity(
wan22_i2v_image_source: str | None,
diff --git a/tests/e2e/offline_inference/custom_pipeline/qwen_image_pipeline_with_logprob.py b/tests/e2e/offline_inference/custom_pipeline/qwen_image_pipeline_with_logprob.py
index 709c6655565..ed5b219f80f 100644
--- a/tests/e2e/offline_inference/custom_pipeline/qwen_image_pipeline_with_logprob.py
+++ b/tests/e2e/offline_inference/custom_pipeline/qwen_image_pipeline_with_logprob.py
@@ -6,8 +6,7 @@
This pipeline follows the structure of the user's reference implementation:
- supports pre-tokenized prompt IDs via OmniCustomPrompt-style dict input
- uses an SDE scheduler that can return step logprobs
-- returns structured trajectory_* fields (latents, timesteps, log_probs)
- consistent with the BAGEL trajectory recording design
+- returns rich custom_output fields for testing
"""
from __future__ import annotations
@@ -394,10 +393,10 @@ def forward(
return DiffusionOutput(
output=_maybe_to_cpu(image),
- trajectory_latents=_maybe_to_cpu(all_latents),
- trajectory_log_probs=_maybe_to_cpu(all_log_probs),
- trajectory_timesteps=_maybe_to_cpu(all_timesteps),
custom_output={
+ "all_latents": _maybe_to_cpu(all_latents),
+ "all_log_probs": _maybe_to_cpu(all_log_probs),
+ "all_timesteps": _maybe_to_cpu(all_timesteps),
"prompt_embeds": _maybe_to_cpu(prompt_embeds),
"prompt_embeds_mask": _maybe_to_cpu(prompt_embeds_mask),
"negative_prompt_embeds": _maybe_to_cpu(negative_prompt_embeds),
diff --git a/tests/e2e/offline_inference/custom_pipeline/test_async_omni_collective_rpc.py b/tests/e2e/offline_inference/custom_pipeline/test_async_omni_collective_rpc.py
index bd3f2e09975..57743d62bf6 100644
--- a/tests/e2e/offline_inference/custom_pipeline/test_async_omni_collective_rpc.py
+++ b/tests/e2e/offline_inference/custom_pipeline/test_async_omni_collective_rpc.py
@@ -26,7 +26,7 @@
import pytest
-from tests.helpers.mark import hardware_test
+from tests.utils import hardware_test
from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
diff --git a/tests/e2e/offline_inference/custom_pipeline/test_async_omni_qwen_image_generate.py b/tests/e2e/offline_inference/custom_pipeline/test_async_omni_qwen_image_generate.py
index 0681687fe73..f1b4595c9df 100644
--- a/tests/e2e/offline_inference/custom_pipeline/test_async_omni_qwen_image_generate.py
+++ b/tests/e2e/offline_inference/custom_pipeline/test_async_omni_qwen_image_generate.py
@@ -1,12 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""E2E tests for AsyncOmni Qwen-Image generation with trajectory_* fields.
-
-Validates that the custom Qwen-Image pipeline returns structured trajectory
-outputs (latents, timesteps, log_probs) via OmniRequestOutput's trajectory_*
-fields instead of the legacy custom_output dict.
-"""
+"""E2E tests for AsyncOmni Qwen-Image generation flow (no Ray, no HTTP server)."""
from __future__ import annotations
@@ -19,7 +14,7 @@
import pytest
from transformers import AutoTokenizer
-from tests.helpers.mark import hardware_test
+from tests.utils import hardware_test
from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
@@ -196,17 +191,10 @@ async def test_async_omni_generate_with_logprobs():
_assert_valid_image_output(output)
- assert output.trajectory_latents is not None, "trajectory_latents should be present"
- assert hasattr(output.trajectory_latents, "shape")
- assert output.trajectory_latents.numel() > 0
-
- assert output.trajectory_timesteps is not None, "trajectory_timesteps should be present"
- assert hasattr(output.trajectory_timesteps, "shape")
- assert output.trajectory_timesteps.numel() > 0
-
- assert output.trajectory_log_probs is not None, "trajectory_log_probs should be present when logprobs=True"
- assert hasattr(output.trajectory_log_probs, "shape")
- assert output.trajectory_log_probs.numel() > 0
+ all_log_probs = output.custom_output.get("all_log_probs")
+ assert all_log_probs is not None, "all_log_probs should be present when logprobs=True"
+ assert hasattr(all_log_probs, "shape")
+ assert all_log_probs.numel() > 0
@pytest.mark.core_model
diff --git a/tests/e2e/offline_inference/custom_pipeline/test_worker_extension.py b/tests/e2e/offline_inference/custom_pipeline/test_worker_extension.py
index 653b35d7e2f..ffbe703ca78 100644
--- a/tests/e2e/offline_inference/custom_pipeline/test_worker_extension.py
+++ b/tests/e2e/offline_inference/custom_pipeline/test_worker_extension.py
@@ -10,7 +10,7 @@
from tests.e2e.offline_inference.custom_pipeline.worker_extension import (
vLLMOmniColocateWorkerExtensionForTest,
)
-from tests.helpers.mark import hardware_test
+from tests.utils import hardware_test
from vllm_omni.diffusion.worker.diffusion_worker import CustomPipelineWorkerExtension
from vllm_omni.entrypoints.async_omni import AsyncOmni
diff --git a/tests/e2e/offline_inference/stage_configs/bagel_mooncake_ci.yaml b/tests/e2e/offline_inference/stage_configs/bagel_mooncake_ci.yaml
new file mode 100644
index 00000000000..590244acd26
--- /dev/null
+++ b/tests/e2e/offline_inference/stage_configs/bagel_mooncake_ci.yaml
@@ -0,0 +1,89 @@
+# stage config for running BAGEL with Mooncake connector for CI e2e tests.
+# This config is optimized for single GPU tests with Mooncake inter-stage communication.
+
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 1
+ model_arch: BagelForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.45
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: text
+ distributed_executor_backend: mp
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 1
+ load_format: dummy
+ omni_kv_config:
+ need_send_cache: true
+ kv_transfer_criteria:
+ type: prefill_finished
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 52
+ detokenize: true
+ repetition_penalty: 1.05
+ output_connectors:
+ to_stage_1: mooncake_connector
+ - stage_id: 1
+ stage_type: diffusion
+ cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: dit
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.45
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: image
+ distributed_executor_backend: mp
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 1
+ load_format: dummy
+ omni_kv_config:
+ need_recv_cache: true
+ engine_input_source: [0]
+ final_output: true
+ final_output_type: image
+ is_comprehension: false
+ default_sampling_params:
+ seed: 52
+ input_connectors:
+ from_stage_0: mooncake_connector
+
+# Top-level runtime config with Mooncake connector
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
+ connectors:
+ mooncake_connector:
+ name: MooncakeConnector
+ extra:
+ host: "${MOONCAKE_HOST}"
+ metadata_server: "http://${MOONCAKE_HOST}:${MOONCAKE_HTTP_PORT}/metadata"
+ master: "${MOONCAKE_HOST}:${MOONCAKE_RPC_PORT}"
+ segment: 64000000
+ localbuf: 64000000
+ proto: tcp
+ edges:
+ - from: 0
+ to: 1
+ window_size: -1
diff --git a/tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_ci.yaml b/tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_ci.yaml
new file mode 100644
index 00000000000..b7999652e23
--- /dev/null
+++ b/tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_ci.yaml
@@ -0,0 +1,87 @@
+# stage config for running BAGEL with SharedMemory connector for CI e2e tests.
+# This config is optimized for single GPU tests with SharedMemory inter-stage communication.
+
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 1
+ model_arch: OmniBagelForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.45
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: text
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 1
+ load_format: dummy
+ omni_kv_config:
+ need_send_cache: true
+ kv_transfer_criteria:
+ type: prefill_finished #or special token generated
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 52
+ detokenize: True
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ stage_type: diffusion
+ cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: dit
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.45
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: image
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 1
+ load_format: dummy
+ omni_kv_config:
+ need_recv_cache: true
+ engine_input_source: [0]
+
+ final_output: true
+ final_output_type: image
+ is_comprehension: false
+ default_sampling_params:
+ seed: 52
+
+# Runtime edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
+
+ # Distributed connectors configuration (optional)
+ # More connectors will be supported in the future.
+ connectors:
+ shared_memory_connector:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536 # 64KB threshold
+
+
+ edges:
+ - from: 0
+ to: 1
+ window_size: -1
diff --git a/tests/e2e/offline_inference/stage_configs/npu/qwen2_5_omni_ci.yaml b/tests/e2e/offline_inference/stage_configs/npu/qwen2_5_omni_ci.yaml
new file mode 100644
index 00000000000..f93a6c71473
--- /dev/null
+++ b/tests/e2e/offline_inference/stage_configs/npu/qwen2_5_omni_ci.yaml
@@ -0,0 +1,103 @@
+# stage config for running qwen2.5-omni for multi-stage omni runtime.
+
+# This config is optimized for CI e2e tests.
+stage_args:
+ - stage_id: 0
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "0"
+ engine_args:
+ model_stage: thinker
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ max_model_len: 896
+ max_num_batched_tokens: 896
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.8
+ skip_mm_profiling: true
+ enforce_eager: true # Now we only support eager mode
+ trust_remote_code: true
+ engine_output_type: latent
+ enable_prefix_caching: false
+ mm_processor_cache_gb: 0
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 128
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+ - stage_id: 1
+ runtime:
+ process: true
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ max_model_len: 896
+ max_num_batched_tokens: 896
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.8
+ skip_mm_profiling: true
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: latent
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
+ default_sampling_params:
+ temperature: 0.9
+ top_p: 0.8
+ top_k: 40
+ max_tokens: 128
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+ stop_token_ids: [8294]
+ - stage_id: 2
+ runtime:
+ process: true
+ devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ gpu_memory_utilization: 0.15
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: audio
+ engine_input_source: [1]
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 128
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+
+# Top-level runtime config (concise): default windows and stage edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1 # Simplified: trigger downstream only after full upstream completion
+ max_inflight: 1 # Simplified: process serially within each stage
+ edges:
+ - from: 0 # thinker → talker: trigger only after receiving full input (-1)
+ to: 1
+ window_size: -1
+ - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
+ to: 2
+ window_size: -1
diff --git a/tests/e2e/offline_inference/test_bagel_img2img.py b/tests/e2e/offline_inference/test_bagel_img2img.py
index b4de059f2d0..a0c3f6cc9fc 100644
--- a/tests/e2e/offline_inference/test_bagel_img2img.py
+++ b/tests/e2e/offline_inference/test_bagel_img2img.py
@@ -15,49 +15,47 @@
"""
import socket
+from pathlib import Path
from typing import Any
import pytest
from PIL import Image
from vllm.assets.image import ImageAsset
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
-from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
+from tests.conftest import modify_stage_config
+from tests.utils import hardware_test
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.platforms import current_omni_platform
-BAGEL_CI_DEPLOY = get_deploy_config_path("ci/bagel.yaml")
-
# Reference pixel data extracted from the known-good output image
# Generated with seed=52, num_inference_steps=15,
# prompt='Change the grass color to red',
# input image: 2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg
REFERENCE_PIXELS = [
- {"position": (100, 100), "rgb": (156, 172, 217)},
- {"position": (400, 50), "rgb": (105, 144, 217)},
- {"position": (700, 100), "rgb": (118, 159, 232)},
- {"position": (150, 400), "rgb": (180, 22, 52)},
- {"position": (512, 336), "rgb": (221, 211, 194)},
- {"position": (700, 400), "rgb": (192, 10, 46)},
- {"position": (100, 600), "rgb": (102, 12, 22)},
- {"position": (400, 600), "rgb": (161, 28, 47)},
- {"position": (700, 600), "rgb": (100, 87, 94)},
- {"position": (256, 256), "rgb": (181, 201, 221)},
+ {"position": (100, 100), "rgb": (157, 172, 217)},
+ {"position": (400, 50), "rgb": (105, 144, 218)},
+ {"position": (700, 100), "rgb": (118, 159, 233)},
+ {"position": (150, 400), "rgb": (195, 34, 60)},
+ {"position": (512, 336), "rgb": (222, 214, 193)},
+ {"position": (700, 400), "rgb": (197, 15, 43)},
+ {"position": (100, 600), "rgb": (105, 13, 18)},
+ {"position": (400, 600), "rgb": (169, 33, 44)},
+ {"position": (700, 600), "rgb": (101, 86, 93)},
+ {"position": (256, 256), "rgb": (181, 202, 222)},
]
if current_omni_platform.is_rocm():
REFERENCE_PIXELS = [
- {"position": (100, 100), "rgb": (156, 172, 217)},
- {"position": (400, 50), "rgb": (105, 144, 217)},
- {"position": (700, 100), "rgb": (118, 159, 232)},
- {"position": (150, 400), "rgb": (180, 22, 52)},
- {"position": (512, 336), "rgb": (221, 211, 194)},
- {"position": (700, 400), "rgb": (192, 10, 46)},
- {"position": (100, 600), "rgb": (102, 12, 22)},
- {"position": (400, 600), "rgb": (161, 28, 47)},
- {"position": (700, 600), "rgb": (100, 87, 94)},
- {"position": (256, 256), "rgb": (181, 201, 221)},
+ {"position": (100, 100), "rgb": (156, 172, 215)},
+ {"position": (400, 50), "rgb": (106, 144, 216)},
+ {"position": (700, 100), "rgb": (118, 158, 231)},
+ {"position": (150, 400), "rgb": (183, 23, 48)},
+ {"position": (512, 336), "rgb": (218, 215, 191)},
+ {"position": (700, 400), "rgb": (194, 14, 42)},
+ {"position": (100, 600), "rgb": (105, 10, 16)},
+ {"position": (400, 600), "rgb": (167, 33, 46)},
+ {"position": (700, 600), "rgb": (102, 86, 92)},
+ {"position": (256, 256), "rgb": (181, 201, 220)},
]
PIXEL_TOLERANCE = 10
@@ -184,8 +182,8 @@ def _generate_bagel_img2img(
return generated_image
-def _resolve_deploy_config(config_path: str, run_level: str) -> str:
- """Resolve deploy config based on run level.
+def _resolve_stage_config(config_path: str, run_level: str) -> str:
+ """Resolve stage config based on run level.
For advanced_model (real weights), strip load_format: dummy so the model
falls back to loading real weights from HuggingFace.
@@ -194,9 +192,9 @@ def _resolve_deploy_config(config_path: str, run_level: str) -> str:
return modify_stage_config(
config_path,
deletes={
- "stages": {
- 0: ["load_format"],
- 1: ["load_format"],
+ "stage_args": {
+ 0: ["engine_args.load_format"],
+ 1: ["engine_args.load_format"],
}
},
)
@@ -210,11 +208,13 @@ def _resolve_deploy_config(config_path: str, run_level: str) -> str:
def test_bagel_img2img_shared_memory_connector(run_level):
"""Test Bagel img2img with shared memory connector."""
input_image = _load_input_image()
- config_path = _resolve_deploy_config(BAGEL_CI_DEPLOY, run_level)
- with OmniRunner(
- "ByteDance-Seed/BAGEL-7B-MoT",
- stage_configs_path=config_path,
- ) as runner:
- generated_image = _generate_bagel_img2img(runner.omni, input_image)
+ config_path = str(Path(__file__).parent / "stage_configs" / "bagel_sharedmemory_ci.yaml")
+ config_path = _resolve_stage_config(config_path, run_level)
+ omni = Omni(model="ByteDance-Seed/BAGEL-7B-MoT", stage_configs_path=config_path, stage_init_timeout=300)
+
+ try:
+ generated_image = _generate_bagel_img2img(omni, input_image)
if run_level == "advanced_model":
_validate_pixels(generated_image)
+ finally:
+ omni.close()
diff --git a/tests/e2e/offline_inference/test_bagel_lora.py b/tests/e2e/offline_inference/test_bagel_lora.py
deleted file mode 100644
index 785d0c7fb8f..00000000000
--- a/tests/e2e/offline_inference/test_bagel_lora.py
+++ /dev/null
@@ -1,195 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""
-End-to-end test for BAGEL LoRA support (Stage 1 / DiT).
-
-Validates that LoRA adapters are correctly loaded, applied with controllable
-scale, and cleanly deactivated. Uses a synthetic rank-1 adapter targeting the
-first decoder layer's QKV projection.
-
-Assertions:
- (a) LoRA at scale=1.0 visibly changes the output (diff > 0.5)
- (b) scale=2.0 produces a larger delta than scale=1.0 (linearity)
- (c) The delta is bounded (diff < 80, not corrupted)
- (d) Deactivating LoRA exactly restores the baseline (diff == 0)
-"""
-
-import json
-import os
-from pathlib import Path
-
-from vllm_omni.inputs.data import OmniSamplingParams
-from vllm_omni.outputs import OmniRequestOutput
-
-os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-
-import numpy as np
-import pytest
-import torch
-from PIL import Image
-from safetensors.torch import save_file
-
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
-from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
-from vllm_omni.entrypoints.omni import Omni
-from vllm_omni.lora.request import LoRARequest
-from vllm_omni.lora.utils import stable_lora_int_id
-
-MODEL = "ByteDance-Seed/BAGEL-7B-MoT"
-BAGEL_STAGE_CONFIG = get_deploy_config_path("ci/bagel.yaml")
-DEFAULT_PROMPT = "<|im_start|>A cute cat<|im_end|>"
-
-
-# ---------------------------------------------------------------------------
-# Helpers (reused from test_bagel_text2img.py patterns)
-# ---------------------------------------------------------------------------
-
-
-def _resolve_deploy_config(config_path: str, run_level: str) -> str:
- if run_level == "advanced_model":
- return modify_stage_config(
- config_path,
- deletes={
- "stages": {
- 0: ["load_format"],
- 1: ["load_format"],
- }
- },
- )
- return config_path
-
-
-def _configure_sampling_params(omni: Omni, num_inference_steps: int = 10) -> list[OmniSamplingParams]:
- params_list = omni.default_sampling_params_list
- if len(params_list) > 1:
- params_list[1].num_inference_steps = num_inference_steps
- params_list[1].extra_args = {
- "cfg_text_scale": 4.0,
- "cfg_img_scale": 1.5,
- }
- return params_list
-
-
-def _extract_generated_image(omni_outputs: list[OmniRequestOutput]) -> Image.Image | None:
- for req_output in omni_outputs:
- if req_output.images:
- return req_output.images[0]
- return None
-
-
-def _generate_bagel_image(omni: Omni) -> Image.Image:
- params_list = _configure_sampling_params(omni)
- params_list[1].lora_request = None
- outputs = list(
- omni.generate(
- prompts=[{"prompt": DEFAULT_PROMPT, "modalities": ["image"]}],
- sampling_params_list=params_list,
- )
- )
- img = _extract_generated_image(outputs)
- assert img is not None, "No image generated"
- return img
-
-
-def _generate_bagel_image_with_lora(
- omni: Omni,
- lora_request: LoRARequest,
- lora_scale: float = 1.0,
-) -> Image.Image:
- params_list = _configure_sampling_params(omni)
- params_list[1].lora_request = lora_request
- params_list[1].lora_scale = lora_scale
- outputs = list(
- omni.generate(
- prompts=[{"prompt": DEFAULT_PROMPT, "modalities": ["image"]}],
- sampling_params_list=params_list,
- )
- )
- img = _extract_generated_image(outputs)
- assert img is not None, "No image generated with LoRA"
- return img
-
-
-# BAGEL uses GQA: hidden_size=3584, 28 Q heads, 4 KV heads, head_dim=128
-# QKV packed dim = 28*128 + 4*128 + 4*128 = 3584 + 512 + 512 = 4608
-_LORA_DIM = 3584
-_LORA_QKV_DIM = 4608
-_LORA_MODULE = "bagel.language_model.model.layers.0.self_attn.qkv_proj"
-_LORA_RANK = 4
-
-
-def _make_file_lora_request(adapter_dir: Path) -> LoRARequest:
- """Write synthetic adapter to disk and return a file-backed LoRARequest."""
- adapter_dir.mkdir(parents=True, exist_ok=True)
- gen = torch.Generator().manual_seed(42)
- lora_a = torch.randn((_LORA_RANK, _LORA_DIM), dtype=torch.float32, generator=gen) * 0.1
- lora_b = torch.randn((_LORA_QKV_DIM, _LORA_RANK), dtype=torch.float32, generator=gen) * 0.5
- save_file(
- {
- f"base_model.model.{_LORA_MODULE}.lora_A.weight": lora_a,
- f"base_model.model.{_LORA_MODULE}.lora_B.weight": lora_b,
- },
- str(adapter_dir / "adapter_model.safetensors"),
- )
- (adapter_dir / "adapter_config.json").write_text(
- json.dumps({"r": _LORA_RANK, "lora_alpha": _LORA_RANK, "target_modules": [_LORA_MODULE]}),
- encoding="utf-8",
- )
- lora_dir = str(adapter_dir)
- return LoRARequest(lora_name="test_file", lora_int_id=stable_lora_int_id(lora_dir), lora_path=lora_dir)
-
-
-# ---------------------------------------------------------------------------
-# Test
-# ---------------------------------------------------------------------------
-
-
-@pytest.mark.core_model
-@pytest.mark.advanced_model
-@pytest.mark.diffusion
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"})
-def test_bagel_lora_scale_and_deactivation(run_level, tmp_path):
- """Validate LoRA effect, bounded perturbation, and clean deactivation."""
- config_path = _resolve_deploy_config(BAGEL_STAGE_CONFIG, run_level)
- with OmniRunner(MODEL, stage_configs_path=config_path) as runner:
- omni = runner.omni
- lora_request = _make_file_lora_request(tmp_path / "bagel_lora")
-
- # 1) Baseline (no LoRA)
- baseline = _generate_bagel_image(omni)
-
- # 2) LoRA with scale=1.0
- img_1x = _generate_bagel_image_with_lora(omni, lora_request, lora_scale=1.0)
-
- # 3) LoRA with scale=2.0
- img_2x = _generate_bagel_image_with_lora(omni, lora_request, lora_scale=2.0)
-
- # 4) No LoRA again (deactivation)
- restored = _generate_bagel_image(omni)
-
- baseline_arr = np.array(baseline, dtype=np.int16)
- img_1x_arr = np.array(img_1x, dtype=np.int16)
- img_2x_arr = np.array(img_2x, dtype=np.int16)
- restored_arr = np.array(restored, dtype=np.int16)
-
- diff_1x = np.abs(baseline_arr - img_1x_arr).mean()
- diff_2x = np.abs(baseline_arr - img_2x_arr).mean()
- diff_restored = np.abs(baseline_arr - restored_arr).mean()
-
- # (a) Adapter has visible effect at both scales
- assert diff_1x > 0.5, f"LoRA scale=1.0 had no visible effect: diff={diff_1x}"
- assert diff_2x > 0.5, f"LoRA scale=2.0 had no visible effect: diff={diff_2x}"
-
- # (b) Different scales produce different outputs
- assert not np.isclose(diff_1x, diff_2x, atol=1.0), (
- f"LoRA scale has no effect: diff_1x={diff_1x:.2f}, diff_2x={diff_2x:.2f}"
- )
-
- # (c) Output is not corrupted (scale=2.0 can produce ~2x the diff of scale=1.0)
- assert diff_1x < 80, f"LoRA output looks corrupted: diff_1x={diff_1x}"
- assert diff_2x < 120, f"LoRA output looks corrupted: diff_2x={diff_2x}"
-
- # (d) Deactivation fully restores base model
- assert diff_restored == 0.0, f"Base model not restored after LoRA deactivation: diff={diff_restored}"
diff --git a/tests/e2e/offline_inference/test_bagel_text2img.py b/tests/e2e/offline_inference/test_bagel_text2img.py
index 65cd8425cd0..7cce8da3a73 100644
--- a/tests/e2e/offline_inference/test_bagel_text2img.py
+++ b/tests/e2e/offline_inference/test_bagel_text2img.py
@@ -16,54 +16,52 @@
import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
import signal
import socket
import subprocess
import tempfile
import time
+from pathlib import Path
from typing import Any
import pytest
from PIL import Image
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
-from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
+from tests.conftest import modify_stage_config
+from tests.utils import hardware_test
from vllm_omni.entrypoints.omni import Omni
from vllm_omni.platforms import current_omni_platform
-BAGEL_CI_DEPLOY = get_deploy_config_path("ci/bagel.yaml")
-BAGEL_MOONCAKE_CI_DEPLOY = get_deploy_config_path("ci/bagel_mooncake.yaml")
-
# Reference pixel data extracted from the known-good output image
# Each entry contains (x, y) position and expected (R, G, B) values
# "Generated with seed=52, num_inference_steps=15,
# prompt='A futuristic city skyline at twilight, cyberpunk style'"
REFERENCE_PIXELS = [
- {"position": (100, 100), "rgb": (115, 113, 94)},
- {"position": (400, 50), "rgb": (159, 160, 144)},
- {"position": (700, 100), "rgb": (164, 151, 123)},
- {"position": (150, 400), "rgb": (120, 121, 107)},
- {"position": (512, 512), "rgb": (165, 133, 127)},
- {"position": (700, 400), "rgb": (217, 130, 66)},
- {"position": (100, 700), "rgb": (191, 168, 152)},
- {"position": (400, 700), "rgb": (130, 96, 77)},
- {"position": (700, 700), "rgb": (247, 203, 140)},
- {"position": (256, 256), "rgb": (167, 156, 150)},
+ {"position": (100, 100), "rgb": (121, 118, 100)},
+ {"position": (400, 50), "rgb": (163, 162, 143)},
+ {"position": (700, 100), "rgb": (170, 156, 127)},
+ {"position": (150, 400), "rgb": (129, 127, 112)},
+ {"position": (512, 512), "rgb": (135, 61, 59)},
+ {"position": (700, 400), "rgb": (205, 107, 43)},
+ {"position": (100, 700), "rgb": (197, 177, 157)},
+ {"position": (400, 700), "rgb": (139, 107, 86)},
+ {"position": (700, 700), "rgb": (247, 205, 146)},
+ {"position": (256, 256), "rgb": (171, 160, 153)},
]
if current_omni_platform.is_rocm():
REFERENCE_PIXELS = [
- {"position": (100, 100), "rgb": (115, 113, 94)},
- {"position": (400, 50), "rgb": (159, 160, 144)},
- {"position": (700, 100), "rgb": (164, 151, 123)},
- {"position": (150, 400), "rgb": (120, 121, 107)},
- {"position": (512, 512), "rgb": (165, 133, 127)},
- {"position": (700, 400), "rgb": (217, 130, 66)},
- {"position": (100, 700), "rgb": (191, 168, 152)},
- {"position": (400, 700), "rgb": (130, 96, 77)},
- {"position": (700, 700), "rgb": (247, 203, 140)},
- {"position": (256, 256), "rgb": (167, 156, 150)},
+ {"position": (100, 100), "rgb": (123, 119, 100)},
+ {"position": (400, 50), "rgb": (162, 161, 142)},
+ {"position": (700, 100), "rgb": (171, 156, 127)},
+ {"position": (150, 400), "rgb": (131, 128, 112)},
+ {"position": (512, 512), "rgb": (134, 61, 59)},
+ {"position": (700, 400), "rgb": (204, 107, 43)},
+ {"position": (100, 700), "rgb": (201, 180, 165)},
+ {"position": (400, 700), "rgb": (140, 108, 87)},
+ {"position": (700, 700), "rgb": (247, 205, 145)},
+ {"position": (256, 256), "rgb": (171, 160, 153)},
]
# Maximum allowed difference per color channel
@@ -174,8 +172,8 @@ def _generate_bagel_image(omni: Omni, prompt: str = DEFAULT_PROMPT) -> Image.Ima
return generated_image
-def _resolve_deploy_config(config_path: str, run_level: str) -> str:
- """Resolve deploy config based on run level.
+def _resolve_stage_config(config_path: str, run_level: str) -> str:
+ """Resolve stage config based on run level.
For advanced_model (real weights), strip load_format: dummy so the model
falls back to loading real weights from HuggingFace.
@@ -184,9 +182,9 @@ def _resolve_deploy_config(config_path: str, run_level: str) -> str:
return modify_stage_config(
config_path,
deletes={
- "stages": {
- 0: ["load_format"],
- 1: ["load_format"],
+ "stage_args": {
+ 0: ["engine_args.load_format"],
+ 1: ["engine_args.load_format"],
}
},
)
@@ -199,14 +197,16 @@ def _resolve_deploy_config(config_path: str, run_level: str) -> str:
@hardware_test(res={"cuda": "H100", "rocm": "MI325"})
def test_bagel_text2img_shared_memory_connector(run_level):
"""Test Bagel text2img with shared memory connector."""
- config_path = _resolve_deploy_config(BAGEL_CI_DEPLOY, run_level)
- with OmniRunner(
- "ByteDance-Seed/BAGEL-7B-MoT",
- stage_configs_path=config_path,
- ) as runner:
- generated_image = _generate_bagel_image(runner.omni)
+ config_path = str(Path(__file__).parent / "stage_configs" / "bagel_sharedmemory_ci.yaml")
+ config_path = _resolve_stage_config(config_path, run_level)
+ omni = Omni(model="ByteDance-Seed/BAGEL-7B-MoT", stage_configs_path=config_path, stage_init_timeout=300)
+
+ try:
+ generated_image = _generate_bagel_image(omni)
if run_level == "advanced_model":
_validate_pixels(generated_image)
+ finally:
+ omni.close()
def _wait_for_port(host: str, port: int, timeout: int = 30) -> bool:
@@ -278,7 +278,7 @@ def _cleanup_mooncake_processes(timeout_secs: int = 5) -> None:
def _load_mooncake_config(host: str, rpc_port: int, http_port: int) -> str:
- """Load Mooncake config from CI overlay and substitute placeholders.
+ """Load Mooncake config from YAML and substitute placeholders.
Args:
host: Mooncake host address.
@@ -288,13 +288,16 @@ def _load_mooncake_config(host: str, rpc_port: int, http_port: int) -> str:
Returns:
Path to the temporary config file with substituted values.
"""
- with open(BAGEL_MOONCAKE_CI_DEPLOY) as f:
+ config_path = str(Path(__file__).parent / "stage_configs" / "bagel_mooncake_ci.yaml")
+ with open(config_path) as f:
config_content = f.read()
+ # Substitute placeholders
config_content = config_content.replace("${MOONCAKE_HOST}", host)
config_content = config_content.replace("${MOONCAKE_RPC_PORT}", str(rpc_port))
config_content = config_content.replace("${MOONCAKE_HTTP_PORT}", str(http_port))
+ # Write to temp file
temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False)
temp_file.write(config_content)
temp_file.close()
@@ -316,6 +319,7 @@ def test_bagel_text2img_mooncake_connector(run_level):
mooncake_master_proc = None
temp_config_file = None
+ omni = None
try:
_cleanup_mooncake_processes()
@@ -344,17 +348,16 @@ def test_bagel_text2img_mooncake_connector(run_level):
http_port=MOONCAKE_HTTP_PORT,
)
- temp_config_file = _resolve_deploy_config(temp_config_file, run_level)
- with OmniRunner(
- "ByteDance-Seed/BAGEL-7B-MoT",
- stage_configs_path=temp_config_file,
- stage_init_timeout=300,
- ) as runner:
- generated_image = _generate_bagel_image(runner.omni)
- if run_level == "advanced_model":
- _validate_pixels(generated_image)
+ temp_config_file = _resolve_stage_config(temp_config_file, run_level)
+ omni = Omni(model="ByteDance-Seed/BAGEL-7B-MoT", stage_configs_path=temp_config_file, stage_init_timeout=300)
+
+ generated_image = _generate_bagel_image(omni)
+ if run_level == "advanced_model":
+ _validate_pixels(generated_image)
finally:
+ if omni:
+ omni.close()
if temp_config_file:
try:
os.unlink(temp_config_file)
diff --git a/tests/e2e/offline_inference/test_bagel_understanding.py b/tests/e2e/offline_inference/test_bagel_understanding.py
deleted file mode 100644
index e342152fc02..00000000000
--- a/tests/e2e/offline_inference/test_bagel_understanding.py
+++ /dev/null
@@ -1,135 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""
-End-to-end tests for Bagel text2text and img2text (understanding) tasks.
-
-These tests validate that the Bagel multistage pipeline correctly generates
-text output for understanding tasks, matching reference results.
-
-Equivalent to running:
- python3 examples/offline_inference/bagel/end2end.py \
- --modality text2text \
- --prompts "Where is the capital of France?"
-
- python3 examples/offline_inference/bagel/end2end.py \
- --modality img2text \
- --prompts "Please describe this image" \
- --image-path 2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg
-"""
-
-import os
-
-os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-
-import pytest
-from vllm.assets.image import ImageAsset
-
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
-from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
-
-MODEL_NAME = "ByteDance-Seed/BAGEL-7B-MoT"
-STAGE_CONFIG = get_deploy_config_path("ci/bagel.yaml")
-
-REFERENCE_TEXT_TEXT2TEXT = "The capital of France is Paris."
-
-REFERENCE_TEXT_IMG2TEXT = (
- "This is a photo of a wooden boardwalk or pathway that leads through "
- "tall green grass. The path appears to be in a natural setting, possibly "
- "a wetland or marsh area. The sky above is blue with some scattered "
- "clouds, suggesting it might be a sunny day. The overall scene looks "
- "peaceful and serene."
-)
-
-
-def _resolve_deploy_config(config_path: str, run_level: str) -> str:
- """Strip load_format: dummy for advanced_model (real weights)."""
- if run_level == "advanced_model":
- return modify_stage_config(
- config_path,
- deletes={
- "stages": {
- 0: ["load_format"],
- 1: ["load_format"],
- }
- },
- )
- return config_path
-
-
-def _extract_text(omni_outputs: list) -> str:
- """Extract generated text from OmniRequestOutput list."""
- for req_output in omni_outputs:
- ro = getattr(req_output, "request_output", None)
- if ro and getattr(ro, "outputs", None):
- return "".join(getattr(o, "text", "") or "" for o in ro.outputs)
- return ""
-
-
-@pytest.mark.core_model
-@pytest.mark.advanced_model
-@pytest.mark.diffusion
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"})
-def test_bagel_text2text(run_level):
- """Test Bagel text2text produces correct text output."""
- config_path = _resolve_deploy_config(STAGE_CONFIG, run_level)
- with OmniRunner(
- MODEL_NAME,
- stage_configs_path=config_path,
- ) as runner:
- omni = runner.omni
- prompt = "<|im_start|>user\nWhere is the capital of France?<|im_end|>\n<|im_start|>assistant\n"
- params_list = omni.default_sampling_params_list
- omni_outputs = list(
- omni.generate(
- prompts=[{"prompt": prompt, "modalities": ["text"]}],
- sampling_params_list=params_list,
- )
- )
-
- assert len(omni_outputs) > 0, "No outputs returned"
- text = _extract_text(omni_outputs)
- assert len(text) > 0, "Generated text is empty"
-
- if run_level == "advanced_model":
- assert text == REFERENCE_TEXT_TEXT2TEXT, (
- f"Text mismatch: expected {REFERENCE_TEXT_TEXT2TEXT!r}, got {text!r}"
- )
-
-
-@pytest.mark.core_model
-@pytest.mark.advanced_model
-@pytest.mark.diffusion
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"})
-def test_bagel_img2text(run_level):
- """Test Bagel img2text produces correct text output."""
- input_image = ImageAsset("2560px-Gfp-wisconsin-madison-the-nature-boardwalk").pil_image.convert("RGB")
- config_path = _resolve_deploy_config(STAGE_CONFIG, run_level)
- with OmniRunner(
- MODEL_NAME,
- stage_configs_path=config_path,
- stage_init_timeout=300,
- ) as runner:
- omni = runner.omni
- prompt = "<|im_start|>user\n<|image_pad|>\nPlease describe this image<|im_end|>\n<|im_start|>assistant\n"
- params_list = omni.default_sampling_params_list
- omni_outputs = list(
- omni.generate(
- prompts=[
- {
- "prompt": prompt,
- "multi_modal_data": {"image": input_image},
- "modalities": ["text"],
- }
- ],
- sampling_params_list=params_list,
- )
- )
-
- assert len(omni_outputs) > 0, "No outputs returned"
- text = _extract_text(omni_outputs)
- assert len(text) > 0, "Generated text is empty"
-
- if run_level == "advanced_model":
- assert text == REFERENCE_TEXT_IMG2TEXT, f"Text mismatch: expected {REFERENCE_TEXT_IMG2TEXT!r}, got {text!r}"
diff --git a/tests/e2e/offline_inference/test_cache_dit.py b/tests/e2e/offline_inference/test_cache_dit.py
index 1577dd9f6db..0e31413dc07 100644
--- a/tests/e2e/offline_inference/test_cache_dit.py
+++ b/tests/e2e/offline_inference/test_cache_dit.py
@@ -8,15 +8,27 @@
It uses minimal settings to keep test time short for CI.
"""
+import os
+import sys
+from pathlib import Path
+
import pytest
import torch
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
+from tests.utils import hardware_test
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+
+# ruff: noqa: E402
+REPO_ROOT = Path(__file__).resolve().parents[2]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.insert(0, str(REPO_ROOT))
+
+from vllm_omni import Omni
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
+
# Use random weights model for testing
models = ["riverclouds/qwen_image_random"]
@@ -36,17 +48,20 @@ def test_cache_dit(model_name: str):
"residual_diff_threshold": 0.24,
"max_continuous_cached_steps": 3,
}
- with OmniRunner(
- model_name,
- cache_backend="cache_dit",
- cache_config=cache_config,
- ) as runner:
+ m = None
+ try:
+ m = Omni(
+ model=model_name,
+ cache_backend="cache_dit",
+ cache_config=cache_config,
+ )
+
# Use minimal settings for fast testing
height = 256
width = 256
num_inference_steps = 4 # Minimal steps for fast test
- outputs = runner.omni.generate(
+ outputs = m.generate(
"a photo of a cat sitting on a laptop keyboard",
OmniDiffusionSamplingParams(
height=height,
@@ -75,3 +90,9 @@ def test_cache_dit(model_name: str):
# Check image size
assert images[0].width == width
assert images[0].height == height
+ except Exception as e:
+ print(f"Test failed with error: {e}")
+ raise
+ finally:
+ if m is not None and hasattr(m, "close"):
+ m.close()
diff --git a/tests/e2e/offline_inference/test_cosyvoice3.py b/tests/e2e/offline_inference/test_cosyvoice3.py
deleted file mode 100644
index 7206f1e7b0c..00000000000
--- a/tests/e2e/offline_inference/test_cosyvoice3.py
+++ /dev/null
@@ -1,200 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""
-Offline E2E smoke test for CosyVoice3 zero-shot reference inference.
-
-This test uses the official upstream zero-shot prompt text/audio pair and
-verifies a stable reference recipe:
-- config-derived top_p/top_k and token-length ratios
-- model EOS token as the stop token
-- a conservative repetition penalty to avoid degenerate loops
-"""
-
-from __future__ import annotations
-
-import functools
-import io
-import os
-from pathlib import Path
-from urllib.request import urlopen
-
-import numpy as np
-import pytest
-import soundfile as sf
-from huggingface_hub import snapshot_download
-from vllm.sampling_params import SamplingParams
-
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
-from tests.helpers.stage_config import get_deploy_config_path
-from vllm_omni.model_executor.models.cosyvoice3.config import CosyVoice3Config
-from vllm_omni.model_executor.models.cosyvoice3.tokenizer import get_qwen_tokenizer
-
-MODEL = "FunAudioLLM/Fun-CosyVoice3-0.5B-2512"
-MODEL_DIR_ENV = "VLLM_OMNI_COSYVOICE3_MODEL_DIR"
-
-REFERENCE_PROMPT_WAV_URL = "https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/main/asset/zero_shot_prompt.wav"
-REFERENCE_PROMPT_TEXT = "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。"
-REFERENCE_SYNTH_TEXT = (
- "CosyVoice is undergoing a comprehensive upgrade, providing more accurate, "
- "stable, faster, and better voice generation capabilities."
-)
-REFERENCE_STAGE0_TEMPERATURE = 1.0
-REFERENCE_STAGE0_REPETITION_PENALTY = 2.0
-
-
-ASYNC_CHUNK_MODES = [
- pytest.param(False, id="sync"),
- pytest.param(True, id="async_chunk"),
-]
-
-
-@functools.lru_cache(maxsize=1)
-def _load_reference_prompt_wav() -> tuple[np.ndarray, int]:
- with urlopen(REFERENCE_PROMPT_WAV_URL, timeout=30) as resp:
- data = resp.read()
- audio, sr = sf.read(io.BytesIO(data), dtype="float32", always_2d=False)
- if isinstance(audio, np.ndarray) and audio.ndim > 1:
- audio = np.mean(audio, axis=-1)
- return np.asarray(audio, dtype=np.float32), int(sr)
-
-
-@functools.lru_cache(maxsize=1)
-def _resolve_model_dir() -> Path:
- override = os.environ.get(MODEL_DIR_ENV)
- if override:
- return Path(override).expanduser().resolve()
- return Path(snapshot_download(MODEL, allow_patterns=["*"]))
-
-
-def _reference_zero_shot_stage0_sampling(*, text: str) -> SamplingParams:
- config = CosyVoice3Config()
- sampling_cfg = config.llm.get("sampling", {})
- eos_token_id = int(config.llm["eos_token_id"])
- model_dir = _resolve_model_dir()
- tokenizer = get_qwen_tokenizer(
- token_path=str(model_dir / config.qwen_pretrain_path),
- skip_special_tokens=config.skip_special_tokens,
- version=config.version,
- )
- text_len = max(1, len(tokenizer.encode(text, allowed_special=config.allowed_special)))
- return SamplingParams(
- temperature=REFERENCE_STAGE0_TEMPERATURE,
- top_p=float(sampling_cfg.get("top_p", 0.8)),
- top_k=int(sampling_cfg.get("top_k", 25)),
- repetition_penalty=REFERENCE_STAGE0_REPETITION_PENALTY,
- stop_token_ids=[eos_token_id],
- min_tokens=int(text_len * config.min_token_text_ratio),
- max_tokens=int(text_len * config.max_token_text_ratio),
- )
-
-
-def _concat_audio(audio_val) -> np.ndarray:
- import torch
-
- if isinstance(audio_val, list):
- tensors = []
- for t in audio_val:
- if t is None:
- continue
- if hasattr(t, "detach"):
- t = t.detach()
- if hasattr(t, "cpu"):
- t = t.cpu()
- if hasattr(t, "float"):
- t = t.float()
- if isinstance(t, torch.Tensor):
- tensors.append(t.reshape(-1))
- if not tensors:
- return np.zeros((0,), dtype=np.float32)
- return torch.cat(tensors, dim=-1).numpy().astype(np.float32, copy=False)
-
- if hasattr(audio_val, "detach"):
- audio_val = audio_val.detach()
- if hasattr(audio_val, "cpu"):
- audio_val = audio_val.cpu()
- if hasattr(audio_val, "float"):
- audio_val = audio_val.float()
- if hasattr(audio_val, "numpy"):
- audio_val = audio_val.numpy()
- audio_np = np.asarray(audio_val, dtype=np.float32)
- return audio_np.reshape(-1)
-
-
-def _get_stage_engine_outputs(omni_runner: OmniRunner, stage_id: int):
- stage_list = getattr(omni_runner.omni, "stage_list", None)
- if stage_list is not None:
- return getattr(stage_list[stage_id], "engine_outputs", None) or []
-
- stage_clients = getattr(getattr(omni_runner.omni, "engine", None), "stage_clients", None)
- if stage_clients is not None:
- return getattr(stage_clients[stage_id], "engine_outputs", None) or []
-
- raise AttributeError("Unable to locate stage outputs on Omni runner")
-
-
-def _build_reference_inputs(prompt_audio: tuple[np.ndarray, int]) -> list[dict[str, object]]:
- return [
- {
- "prompt": REFERENCE_SYNTH_TEXT,
- "multi_modal_data": {"audio": prompt_audio},
- "modalities": ["audio"],
- "mm_processor_kwargs": {"prompt_text": REFERENCE_PROMPT_TEXT},
- }
- ]
-
-
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "L4"}, num_cards=1)
-@pytest.mark.parametrize("async_chunk", ASYNC_CHUNK_MODES)
-def test_cosyvoice3_offline_reference_zero_shot(async_chunk: bool) -> None:
- """CosyVoice3 zero-shot reference inference should stop cleanly and produce sane audio."""
- prompt_audio, prompt_sr = _load_reference_prompt_wav()
- model_dir = _resolve_model_dir()
- expected_stop_token = int(CosyVoice3Config().llm["eos_token_id"])
-
- with OmniRunner(
- str(model_dir),
- seed=42,
- stage_configs_path=get_deploy_config_path("cosyvoice3.yaml"),
- async_chunk=async_chunk,
- stage_init_timeout=300,
- ) as omni_runner:
- sampling_params_list = omni_runner.get_default_sampling_params_list()
- sampling_params_list[0] = _reference_zero_shot_stage0_sampling(text=REFERENCE_SYNTH_TEXT)
-
- outputs = omni_runner.omni.generate(_build_reference_inputs((prompt_audio, prompt_sr)), sampling_params_list)
-
- assert outputs, "No outputs returned"
- audio_mm = outputs[0].multimodal_output
- assert "audio" in audio_mm, "No audio output found"
-
- audio = _concat_audio(audio_mm["audio"])
- assert audio.size > 0, "Generated audio is empty"
-
- sr_val = audio_mm.get("sr", 24000)
- if isinstance(sr_val, list) and sr_val:
- sr_val = sr_val[-1]
- if hasattr(sr_val, "item"):
- sr_val = sr_val.item()
- sr = int(sr_val)
- assert sr == 24000, f"Unexpected sample_rate={sr}"
-
- duration_s = audio.size / sr
- assert 2.8 <= duration_s <= 8.8, f"Unexpected duration={duration_s:.3f}s (samples={audio.size}, sr={sr})"
-
- stage0_outputs = _get_stage_engine_outputs(omni_runner, 0)
- if stage0_outputs:
- completion = stage0_outputs[0].outputs[0]
- finish_reason = getattr(completion, "finish_reason", None)
- stop_reason = getattr(completion, "stop_reason", None)
- num_tokens = len(getattr(completion, "token_ids", []) or [])
-
- assert finish_reason == "stop", f"Stage-0 finish_reason={finish_reason}, expected 'stop'"
- assert int(stop_reason) == expected_stop_token, (
- f"Stage-0 stop_reason={stop_reason}, expected {expected_stop_token}"
- )
- assert 80 <= num_tokens <= 220, f"Stage-0 num_tokens={num_tokens}, expected sane stop-bound range"
- else:
- assert async_chunk, "Stage-0 produced no engine outputs"
diff --git a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py
index d7fd6f72f5b..f3830f02e97 100644
--- a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py
+++ b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py
@@ -1,15 +1,22 @@
import gc
+import sys
+from pathlib import Path
import pytest
import torch
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
-from tests.helpers.env import DeviceMemoryMonitor
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
+from tests.utils import DeviceMemoryMonitor, hardware_test
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
+# ruff: noqa: E402
+REPO_ROOT = Path(__file__).resolve().parents[2]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.insert(0, str(REPO_ROOT))
+
+from vllm_omni import Omni
+
models = ["riverclouds/qwen_image_random"]
@@ -20,29 +27,30 @@ def inference(model_name: str, offload: bool = True):
current_omni_platform.reset_peak_memory_stats()
monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
- with OmniRunner(
- model_name,
+ m = Omni(
+ model=model_name,
# TODO: we might want to add overlapped feature e2e tests
# cache_backend="cache_dit",
enable_cpu_offload=offload,
- ) as runner:
- current_omni_platform.reset_peak_memory_stats()
- height = 256
- width = 256
+ )
+ current_omni_platform.reset_peak_memory_stats()
+ height = 256
+ width = 256
- runner.omni.generate(
- "a photo of a cat sitting on a laptop keyboard",
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_inference_steps=9,
- guidance_scale=0.0,
- generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
- ),
- )
+ m.generate(
+ "a photo of a cat sitting on a laptop keyboard",
+ OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_inference_steps=9,
+ guidance_scale=0.0,
+ generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
+ ),
+ )
peak = monitor.peak_used_mb
monitor.stop()
+ del m
gc.collect()
current_omni_platform.empty_cache()
diff --git a/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py b/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py
index 4f19c100476..6132f1bd0eb 100644
--- a/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py
+++ b/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py
@@ -1,12 +1,21 @@
+import sys
+from pathlib import Path
+
import pytest
import torch
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
-from tests.helpers.env import DeviceMemoryMonitor
-from tests.helpers.runtime import OmniRunner
+from tests.utils import DeviceMemoryMonitor
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
+# ruff: noqa: E402
+REPO_ROOT = Path(__file__).resolve().parents[2]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.insert(0, str(REPO_ROOT))
+
+from vllm_omni import Omni
+
# Models to test and expected saved memory in MB, correspondingly
MODELS_SAVED_MEMORY_MB = {
"riverclouds/qwen_image_random": 4500,
@@ -24,33 +33,34 @@ def run_inference(
monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
- with OmniRunner(
- model_name,
+ m = Omni(
+ model=model_name,
enable_layerwise_offload=layerwise_offload,
# TODO: we might want to add overlapped feature e2e tests
# cache_backend="cache_dit",
boundary_ratio=0.875,
flow_shift=5.0,
- ) as runner:
- current_omni_platform.reset_peak_memory_stats()
-
- # Refer to tests/e2e/offline_inference/test_t2v_model.py
- # Use minimal settings for testing
- height = 480
- width = 640
- num_frames = 5
-
- runner.omni.generate(
- "A cat sitting on a table",
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
- guidance_scale=1.0,
- num_inference_steps=num_inference_steps,
- num_frames=num_frames,
- ),
- )
+ )
+
+ current_omni_platform.reset_peak_memory_stats()
+
+ # Refer to tests/e2e/offline_inference/test_t2v_model.py
+ # Use minimal settings for testing
+ height = 480
+ width = 640
+ num_frames = 5
+
+ m.generate(
+ "A cat sitting on a table",
+ OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
+ guidance_scale=1.0,
+ num_inference_steps=num_inference_steps,
+ num_frames=num_frames,
+ ),
+ )
peak = monitor.peak_used_mb
monitor.stop()
diff --git a/tests/e2e/offline_inference/test_diffusion_lora.py b/tests/e2e/offline_inference/test_diffusion_lora.py
index 027dadb3f4e..b414fe30eeb 100644
--- a/tests/e2e/offline_inference/test_diffusion_lora.py
+++ b/tests/e2e/offline_inference/test_diffusion_lora.py
@@ -7,7 +7,6 @@
import torch
from safetensors.torch import save_file
-from tests.helpers.runtime import OmniRunner
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
@@ -17,12 +16,15 @@
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
+from vllm_omni import Omni
+
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
# This test is specific to Z-Image LoRA behavior. Keep it focused on a single
# model to reduce runtime and avoid extra downloads.
models = ["Tongyi-MAI/Z-Image-Turbo"]
+DIFFUSION_INIT_TIMEOUT_S = 600
@pytest.mark.parametrize("model_name", models)
@@ -75,8 +77,12 @@ def _write_zimage_lora(adapter_dir: Path) -> str:
)
return str(adapter_dir)
- with OmniRunner(model_name) as runner:
- m = runner.omni
+ m = Omni(
+ model=model_name,
+ stage_init_timeout=DIFFUSION_INIT_TIMEOUT_S,
+ init_timeout=DIFFUSION_INIT_TIMEOUT_S,
+ )
+ try:
# high resolution may cause OOM on L4
height = 256
width = 256
@@ -134,3 +140,5 @@ def _write_zimage_lora(adapter_dir: Path) -> str:
diff = np.abs(np.array(images[0], dtype=np.int16) - np.array(images_lora[0], dtype=np.int16)).mean()
assert diff > 0.0
+ finally:
+ m.close()
diff --git a/tests/e2e/offline_inference/test_dynin_omni.py b/tests/e2e/offline_inference/test_dynin_omni.py
deleted file mode 100644
index f891fc4f12e..00000000000
--- a/tests/e2e/offline_inference/test_dynin_omni.py
+++ /dev/null
@@ -1,374 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""
-E2E offline smoke tests for Dynin-Omni.
-
-- model: "snu-aidas/Dynin-Omni"
-- stage config: tests/e2e/stage_configs/dynin_omni_ci.yaml
-"""
-
-from __future__ import annotations
-
-import os
-from pathlib import Path
-from typing import Any
-
-import numpy as np
-import pytest
-import torch
-from transformers import AutoTokenizer
-
-from tests.helpers.mark import hardware_test
-
-os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-
-_REPO_ROOT = Path(__file__).resolve().parents[3]
-_DEFAULT_DYNIN_CONFIG_PATH: Path | None = None
-_DEFAULT_STAGE_CONFIG_PATH = _REPO_ROOT / "tests" / "e2e" / "stage_configs" / "dynin_omni_ci.yaml"
-
-models = ["snu-aidas/Dynin-Omni"]
-stage_configs = [str(_DEFAULT_STAGE_CONFIG_PATH)]
-test_params = [(model, stage_config) for model in models for stage_config in stage_configs]
-
-DYNIN_CONFIG_PATH = str(_DEFAULT_DYNIN_CONFIG_PATH) if _DEFAULT_DYNIN_CONFIG_PATH is not None else None
-
-pytestmark = [
- pytest.mark.core_model,
- pytest.mark.omni,
- pytest.mark.parametrize("omni_runner", test_params, indirect=True),
-]
-
-
-# prompting util
-def _build_mmu_prompt(tokenizer: Any, question: str, dynin_config_path: str | None) -> dict[str, Any]:
- encoded = tokenizer(question, return_tensors="pt", add_special_tokens=True)
- token_ids = [int(v) for v in encoded["input_ids"][0].tolist()]
- attention_mask = [int(v) for v in encoded["attention_mask"][0].tolist()]
- additional_information: dict[str, Any] = {
- "task": ["mmu"],
- "detok_id": [0],
- "prompt_length": [len(token_ids)],
- "attention_mask": [attention_mask],
- "max_new_tokens": [64],
- "steps": [64],
- "block_length": [16],
- "temperature": [0.0],
- }
- if dynin_config_path:
- additional_information["dynin_config_path"] = [str(dynin_config_path)]
- return {
- "prompt_token_ids": token_ids,
- "additional_information": additional_information,
- "modalities": ["text"],
- }
-
-
-def _build_mmu_multimodal_prompt(
- tokenizer: Any,
- question: str,
- dynin_config_path: str | None,
- *,
- image: Any | None = None,
- audio: tuple[np.ndarray, int] | None = None,
-) -> dict[str, Any]:
- if image is None and audio is None:
- raise ValueError("At least one multimodal input (image or audio) must be provided.")
-
- prefix_chunks: list[str] = []
- mm_data: dict[str, Any] = {}
- if image is not None:
- prefix_chunks.append("<|soi|><|image|><|eoi|>")
- mm_data["image"] = image
- if audio is not None:
- prefix_chunks.append("<|soa|><|audio|><|eoa|>")
- mm_data["audio"] = audio
-
- prefixed_question = " ".join(prefix_chunks + [question]).strip()
- prompt = _build_mmu_prompt(
- tokenizer=tokenizer,
- question=prefixed_question,
- dynin_config_path=dynin_config_path,
- )
- prompt["multi_modal_data"] = mm_data
- prompt["modalities"] = ["text"]
- return prompt
-
-
-def _generate_synthetic_image(width: int = 224, height: int = 224) -> np.ndarray:
- x = np.linspace(0, 255, width, dtype=np.uint8)
- y = np.linspace(0, 255, height, dtype=np.uint8)[:, None]
- red = np.tile(x, (height, 1))
- green = np.tile(y, (1, width))
- blue = ((red.astype(np.uint16) + green.astype(np.uint16)) // 2).astype(np.uint8)
- return np.stack([red, green, blue], axis=-1)
-
-
-def _generate_synthetic_audio(duration_s: int = 5, sample_rate: int = 48_000) -> tuple[np.ndarray, int]:
- t = np.linspace(0, duration_s, int(sample_rate * duration_s), endpoint=False, dtype=np.float32)
- waveform = 0.1 * np.sin(2.0 * np.pi * 440.0 * t)
- return waveform.astype(np.float32), sample_rate
-
-
-# prompting util
-def _build_t2s_decode_prompt(dynin_config_path: str | None) -> dict[str, Any]:
- # Bypass stage-0 generation and directly validate token->audio decode path.
- generated_audio_token_ids = [int(v) for v in ([10, 11, 12, 13, 14] * 32)]
- additional_information: dict[str, Any] = {
- "task": ["t2s"],
- "detok_id": [1],
- "generated_token_ids": [generated_audio_token_ids],
- "audio_codebook_size": [4096],
- }
- if dynin_config_path:
- additional_information["dynin_config_path"] = [str(dynin_config_path)]
- return {
- "prompt_token_ids": [0],
- "additional_information": additional_information,
- "modalities": ["audio"],
- }
-
-
-# prompting util
-def _build_t2i_decode_prompt(dynin_config_path: str | None) -> dict[str, Any]:
- # Bypass stage-0 generation and directly validate token->image decode path.
- # MAGVIT decode path expects a square token grid; 1024 tokens -> 32x32.
- generated_image_token_ids = [int(v) for v in ([10, 11, 12, 13, 14, 15, 16, 17] * 128)]
- additional_information: dict[str, Any] = {
- "task": ["t2i"],
- "detok_id": [2],
- "generated_token_ids": [generated_image_token_ids],
- "codebook_size": [8192],
- }
- if dynin_config_path:
- additional_information["dynin_config_path"] = [str(dynin_config_path)]
- return {
- "prompt_token_ids": [0],
- "additional_information": additional_information,
- "modalities": ["image"],
- }
-
-
-def _configure_dynin_config_env() -> None:
- if DYNIN_CONFIG_PATH:
- os.environ["DYNIN_CONFIG_PATH"] = str(DYNIN_CONFIG_PATH)
- else:
- os.environ.pop("DYNIN_CONFIG_PATH", None)
-
-
-def _is_finished_request_output(request_output: Any) -> bool:
- if request_output is None:
- return False
- req_list = request_output if isinstance(request_output, list) else [request_output]
- for req in req_list:
- if req is not None and bool(getattr(req, "finished", False)):
- return True
- return False
-
-
-def _find_stage_output(outputs: list[Any], output_type: str) -> Any | None:
- matched = [
- stage_output for stage_output in outputs if getattr(stage_output, "final_output_type", None) == output_type
- ]
- if not matched:
- return None
-
- # Prefer the latest finished chunk to avoid picking an intermediate stream output.
- for stage_output in reversed(matched):
- if _is_finished_request_output(getattr(stage_output, "request_output", None)):
- return stage_output
- return matched[-1]
-
-
-def _to_token_list(value: Any) -> list[int]:
- if value is None:
- return []
- if hasattr(value, "detach"):
- value = value.detach()
- if hasattr(value, "cpu"):
- value = value.cpu()
- if hasattr(value, "flatten"):
- value = value.flatten().tolist()
- if isinstance(value, tuple):
- value = list(value)
- if not isinstance(value, list):
- return []
- out: list[int] = []
- for token in value:
- if isinstance(token, bool):
- continue
- try:
- out.append(int(token))
- except Exception:
- continue
- return out
-
-
-def _extract_text(stage_output: Any, tokenizer: Any | None = None) -> str:
- request_output = getattr(stage_output, "request_output", None)
- if request_output is None:
- return ""
- req_list = request_output if isinstance(request_output, list) else [request_output]
- for req in req_list:
- completions = getattr(req, "outputs", None) or []
- if not completions:
- continue
- completion = completions[0]
- mm_out = (
- getattr(completion, "multimodal_output", None)
- or getattr(req, "multimodal_output", None)
- or getattr(stage_output, "multimodal_output", None)
- or {}
- )
- text = mm_out.get("text")
- if isinstance(text, list) and text:
- text = text[-1]
- if isinstance(text, str) and text.strip():
- return text.strip()
- if tokenizer is not None:
- for key in ("text_tokens", "token_ids"):
- token_ids = _to_token_list(mm_out.get(key))
- if not token_ids:
- continue
- decoded = tokenizer.decode(token_ids, skip_special_tokens=True)
- if isinstance(decoded, str) and decoded.strip():
- return decoded.strip()
- fallback = getattr(completion, "text", None)
- if isinstance(fallback, str) and fallback.strip():
- return fallback.strip()
- return ""
-
-
-def _extract_audio(stage_output: Any) -> Any | None:
- request_output = getattr(stage_output, "request_output", None)
- if request_output is None:
- return None
- req_list = request_output if isinstance(request_output, list) else [request_output]
- for req in req_list:
- completions = getattr(req, "outputs", None) or []
- if not completions:
- continue
- completion = completions[0]
- mm_out = getattr(completion, "multimodal_output", None) or {}
- if "audio" in mm_out:
- return mm_out["audio"]
- return None
-
-
-def _extract_image(stage_output: Any) -> Any | None:
- request_output = getattr(stage_output, "request_output", None)
- if request_output is None:
- return None
- req_list = request_output if isinstance(request_output, list) else [request_output]
- for req in req_list:
- completions = getattr(req, "outputs", None) or []
- if not completions:
- continue
- completion = completions[0]
- mm_out = getattr(completion, "multimodal_output", None) or {}
- if "image" in mm_out:
- return mm_out["image"]
- return None
-
-
-def _numel(value: Any) -> int:
- if value is None:
- return 0
- if isinstance(value, torch.Tensor):
- return int(value.numel())
- shape = getattr(value, "shape", None)
- if shape is not None:
- try:
- total = 1
- for dim in shape:
- total *= int(dim)
- return int(total)
- except Exception:
- pass
- if isinstance(value, (list, tuple)):
- return len(value)
- return 0
-
-
-@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
-def test_dynin_t2i_decode_to_image(omni_runner) -> None:
- _configure_dynin_config_env()
- prompt = _build_t2i_decode_prompt(dynin_config_path=DYNIN_CONFIG_PATH)
-
- outputs = omni_runner.generate([prompt])
-
- image_output = _find_stage_output(outputs, "image")
- assert image_output is not None
- image_value = _extract_image(image_output)
- assert image_value is not None
- assert _numel(image_value) > 0
-
-
-@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
-def test_dynin_mmu_to_text(omni_runner) -> None:
- _configure_dynin_config_env()
- tokenizer = AutoTokenizer.from_pretrained(omni_runner.model_name, trust_remote_code=True)
- prompt = _build_mmu_prompt(
- tokenizer=tokenizer,
- question="What is 2 + 2? Answer in one short sentence.",
- dynin_config_path=DYNIN_CONFIG_PATH,
- )
-
- outputs = omni_runner.generate([prompt])
-
- text_output = _find_stage_output(outputs, "text")
- assert text_output is not None
- text_content = _extract_text(text_output, tokenizer=tokenizer)
- assert text_content
-
-
-@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
-def test_dynin_image_to_text(omni_runner) -> None:
- _configure_dynin_config_env()
- tokenizer = AutoTokenizer.from_pretrained(omni_runner.model_name, trust_remote_code=True)
- prompt = _build_mmu_multimodal_prompt(
- tokenizer=tokenizer,
- question="Describe the image briefly in one sentence.",
- dynin_config_path=DYNIN_CONFIG_PATH,
- image=_generate_synthetic_image(),
- )
-
- outputs = omni_runner.generate([prompt])
-
- text_output = _find_stage_output(outputs, "text")
- assert text_output is not None
- text_content = _extract_text(text_output, tokenizer=tokenizer)
- assert text_content
-
-
-@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
-def test_dynin_speech_to_text(omni_runner) -> None:
- _configure_dynin_config_env()
- tokenizer = AutoTokenizer.from_pretrained(omni_runner.model_name, trust_remote_code=True)
- prompt = _build_mmu_multimodal_prompt(
- tokenizer=tokenizer,
- question="Transcribe the audio briefly in one sentence.",
- dynin_config_path=DYNIN_CONFIG_PATH,
- audio=_generate_synthetic_audio(),
- )
-
- outputs = omni_runner.generate([prompt])
-
- text_output = _find_stage_output(outputs, "text")
- assert text_output is not None
- text_content = _extract_text(text_output, tokenizer=tokenizer)
- assert text_content
-
-
-@hardware_test(res={"cuda": "L4", "rocm": "MI325"})
-def test_dynin_t2s_decode_to_audio(omni_runner) -> None:
- _configure_dynin_config_env()
- prompt = _build_t2s_decode_prompt(dynin_config_path=DYNIN_CONFIG_PATH)
-
- outputs = omni_runner.generate([prompt])
-
- audio_output = _find_stage_output(outputs, "audio")
- assert audio_output is not None
- audio_value = _extract_audio(audio_output)
- assert audio_value is not None
- assert _numel(audio_value) > 0
diff --git a/tests/e2e/offline_inference/test_expert_parallel.py b/tests/e2e/offline_inference/test_expert_parallel.py
index f11646b300d..ba126986ec7 100644
--- a/tests/e2e/offline_inference/test_expert_parallel.py
+++ b/tests/e2e/offline_inference/test_expert_parallel.py
@@ -18,8 +18,8 @@
import torch.distributed as dist
from PIL import Image
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
+from tests.utils import hardware_test
+from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
@@ -96,26 +96,12 @@ def _run_inference(
tensor_parallel_size=tensor_parallel_size,
enable_expert_parallel=enable_expert_parallel,
)
+ omni = Omni(model=model_name, parallel_config=parallel_config)
+
try:
- with OmniRunner(model_name, parallel_config=parallel_config) as runner:
- omni = runner.omni
- # Warmup run (not timed)
- if warmup:
- _ = omni.generate(
- PROMPT,
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_inference_steps=DEFAULT_STEPS,
- guidance_scale=guidance_scale,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed),
- num_outputs_per_prompt=1,
- ),
- )
-
- # Timed run
- start = time.time()
- outputs = omni.generate(
+ # Warmup run (not timed)
+ if warmup:
+ _ = omni.generate(
PROMPT,
OmniDiffusionSamplingParams(
height=height,
@@ -126,13 +112,28 @@ def _run_inference(
num_outputs_per_prompt=1,
),
)
- elapsed_ms = (time.time() - start) * 1000
- return InferenceResult(
- images=outputs[0].images,
- elapsed_ms=elapsed_ms,
- )
+ # Timed run
+ start = time.time()
+ outputs = omni.generate(
+ PROMPT,
+ OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_inference_steps=DEFAULT_STEPS,
+ guidance_scale=guidance_scale,
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed),
+ num_outputs_per_prompt=1,
+ ),
+ )
+ elapsed_ms = (time.time() - start) * 1000
+
+ return InferenceResult(
+ images=outputs[0].images,
+ elapsed_ms=elapsed_ms,
+ )
finally:
+ omni.close()
_cleanup_distributed()
diff --git a/tests/e2e/offline_inference/test_flux.py b/tests/e2e/offline_inference/test_flux.py
deleted file mode 100644
index 02c6787be2b..00000000000
--- a/tests/e2e/offline_inference/test_flux.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""Tests for Flux1 Schnell."""
-
-import pytest
-from PIL import Image
-
-from vllm_omni.entrypoints.omni import Omni
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-MODEL = "black-forest-labs/FLUX.1-schnell"
-
-
-@pytest.mark.core_model
-@pytest.mark.diffusion
-def test_flux_schnell_text_to_image():
- """Test FLUX.1-schnell text-to-image generation."""
- omni = Omni(model=MODEL)
-
- omni_outputs = list(
- omni.generate(
- prompts=["A photo of a cat sitting on a laptop"],
- sampling_params_list=OmniDiffusionSamplingParams(
- height=512,
- width=512,
- num_inference_steps=2,
- seed=42,
- ),
- )
- )
-
- assert len(omni_outputs) > 0
- images = omni_outputs[0].images
- assert len(images) == 1
- assert isinstance(images[0], Image.Image)
- assert images[0].size == (512, 512)
diff --git a/tests/e2e/offline_inference/test_flux2_klein.py b/tests/e2e/offline_inference/test_flux2_klein.py
deleted file mode 100644
index a1376753467..00000000000
--- a/tests/e2e/offline_inference/test_flux2_klein.py
+++ /dev/null
@@ -1,227 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""
-End-to-end test for Flux2 Klein inpainting.
-
-"""
-
-# ruff: noqa: E402
-
-import os
-import sys
-from pathlib import Path
-
-import pytest
-import torch
-from PIL import Image, ImageDraw
-
-from vllm_omni.entrypoints.omni import Omni
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-from vllm_omni.outputs import OmniRequestOutput
-from vllm_omni.platforms import current_omni_platform
-
-REPO_ROOT = Path(__file__).resolve().parents[2]
-if str(REPO_ROOT) not in sys.path:
- sys.path.insert(0, str(REPO_ROOT))
-
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
-
-MODEL = "black-forest-labs/FLUX.2-klein-4B"
-
-_HEIGHT = 512
-_WIDTH = 512
-_NUM_INFERENCE_STEPS = 4
-
-
-def _create_test_image(width: int = _WIDTH, height: int = _HEIGHT, color: tuple = (128, 128, 128)) -> Image.Image:
- return Image.new("RGB", (width, height), color)
-
-
-def _create_test_mask(width: int = _WIDTH, height: int = _HEIGHT) -> Image.Image:
- mask = Image.new("L", (width, height), 0)
- draw = ImageDraw.Draw(mask)
- draw.rectangle([width // 4, height // 4, width * 3 // 4, height * 3 // 4], fill=255)
- return mask
-
-
-def _create_test_inputs(color: tuple = (100, 150, 200)):
- return _create_test_image(_WIDTH, _HEIGHT, color), _create_test_mask(_WIDTH, _HEIGHT)
-
-
-def _extract_images_from_output(outputs: list) -> list[Image.Image]:
- images = []
- for req_output in outputs:
- if hasattr(req_output, "images") and req_output.images:
- images.extend(req_output.images)
- elif hasattr(req_output, "request_output") and req_output.request_output:
- stage_out = req_output.request_output
- if isinstance(stage_out, OmniRequestOutput) and hasattr(stage_out, "images"):
- images.extend(stage_out.images)
- elif isinstance(stage_out, list):
- for s in stage_out:
- if hasattr(s, "images") and s.images:
- images.extend(s.images)
- return images
-
-
-# Regression test for https://github.com/vllm-project/vllm-omni/issues/3097
-@pytest.mark.core_model
-@pytest.mark.diffusion
-def test_flux2_klein_can_accept_text_inputs():
- model = Omni(model=MODEL)
- outputs = model.generate(
- "a cup of coffee on the table",
- OmniDiffusionSamplingParams(num_inference_steps=2, seed=42),
- )
- assert len(outputs[0].images) == 1
-
-
-@pytest.mark.core_model
-@pytest.mark.diffusion
-def test_flux2_klein_inpaint_basic():
- m = None
- try:
- m = Omni(model=MODEL)
- input_image, mask_image = _create_test_inputs()
-
- outputs = m.generate(
- prompts=[
- {
- "prompt": "Fill in the masked area with a beautiful garden",
- "multi_modal_data": {"image": input_image, "mask_image": mask_image},
- }
- ],
- sampling_params_list=OmniDiffusionSamplingParams(
- height=_HEIGHT,
- width=_WIDTH,
- num_inference_steps=_NUM_INFERENCE_STEPS,
- guidance_scale=0.0,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
- num_outputs_per_prompt=1,
- ),
- )
-
- images = _extract_images_from_output(list(outputs))
- assert len(images) == 1
- assert images[0].size == (_WIDTH, _HEIGHT)
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
-
-
-@pytest.mark.diffusion
-def test_flux2_klein_inpaint_deterministic():
- m = None
- try:
- m = Omni(model=MODEL)
- input_image, mask_image = _create_test_inputs()
- seed = 12345
-
- gen1 = torch.Generator(current_omni_platform.device_type).manual_seed(seed)
- gen2 = torch.Generator(current_omni_platform.device_type).manual_seed(seed)
-
- outputs1 = m.generate(
- prompts=[
- {
- "prompt": "A red flower in a field",
- "multi_modal_data": {"image": input_image, "mask_image": mask_image},
- }
- ],
- sampling_params_list=OmniDiffusionSamplingParams(
- height=_HEIGHT,
- width=_WIDTH,
- num_inference_steps=_NUM_INFERENCE_STEPS,
- guidance_scale=0.0,
- generator=gen1,
- num_outputs_per_prompt=1,
- ),
- )
-
- outputs2 = m.generate(
- prompts=[
- {
- "prompt": "A red flower in a field",
- "multi_modal_data": {"image": input_image, "mask_image": mask_image},
- }
- ],
- sampling_params_list=OmniDiffusionSamplingParams(
- height=_HEIGHT,
- width=_WIDTH,
- num_inference_steps=_NUM_INFERENCE_STEPS,
- guidance_scale=0.0,
- generator=gen2,
- num_outputs_per_prompt=1,
- ),
- )
-
- images1 = _extract_images_from_output(list(outputs1))
- images2 = _extract_images_from_output(list(outputs2))
-
- assert len(images1) == 1
- assert len(images2) == 1
-
- assert list(images1[0].getdata()) == list(images2[0].getdata()), (
- "Same input with same seed should produce identical output. "
- "This is critical for offline/online consistency."
- )
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
-
-
-@pytest.mark.diffusion
-def test_flux2_klein_inpaint_different_seeds_different_output():
- m = None
- try:
- m = Omni(model=MODEL)
- input_image, mask_image = _create_test_inputs()
-
- gen1 = torch.Generator(current_omni_platform.device_type).manual_seed(42)
- gen2 = torch.Generator(current_omni_platform.device_type).manual_seed(99999)
-
- outputs1 = m.generate(
- prompts=[
- {
- "prompt": "A beautiful landscape",
- "multi_modal_data": {"image": input_image, "mask_image": mask_image},
- }
- ],
- sampling_params_list=OmniDiffusionSamplingParams(
- height=_HEIGHT,
- width=_WIDTH,
- num_inference_steps=_NUM_INFERENCE_STEPS,
- guidance_scale=0.0,
- generator=gen1,
- num_outputs_per_prompt=1,
- ),
- )
-
- outputs2 = m.generate(
- prompts=[
- {
- "prompt": "A beautiful landscape",
- "multi_modal_data": {"image": input_image, "mask_image": mask_image},
- }
- ],
- sampling_params_list=OmniDiffusionSamplingParams(
- height=_HEIGHT,
- width=_WIDTH,
- num_inference_steps=_NUM_INFERENCE_STEPS,
- guidance_scale=0.0,
- generator=gen2,
- num_outputs_per_prompt=1,
- ),
- )
-
- images1 = _extract_images_from_output(list(outputs1))
- images2 = _extract_images_from_output(list(outputs2))
-
- assert len(images1) == 1
- assert len(images2) == 1
-
- different_pixel_count = sum(1 for p1, p2 in zip(images1[0].getdata(), images2[0].getdata()) if p1 != p2)
- assert different_pixel_count > 0, "Different seeds should produce different outputs"
- finally:
- if m is not None and hasattr(m, "close"):
- m.close()
diff --git a/tests/e2e/offline_inference/test_flux_autoround_w4a16.py b/tests/e2e/offline_inference/test_flux_autoround_w4a16.py
index ef5d6f9e051..42aab7f26a8 100644
--- a/tests/e2e/offline_inference/test_flux_autoround_w4a16.py
+++ b/tests/e2e/offline_inference/test_flux_autoround_w4a16.py
@@ -8,22 +8,31 @@
"""
import gc
-import os as _os
+import sys
+from pathlib import Path
import pytest
import torch
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
-from tests.helpers.env import DeviceMemoryMonitor
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
+from tests.utils import DeviceMemoryMonitor, hardware_test
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
+# ruff: noqa: E402
+REPO_ROOT = Path(__file__).resolve().parents[2]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.insert(0, str(REPO_ROOT))
+
+from vllm_omni import Omni
+
QUANTIZED_MODEL = "vllm-project-org/FLUX.1-dev-AutoRound-w4a16"
BASELINE_MODEL = "black-forest-labs/FLUX.1-dev"
+# Allow overriding via environment for local testing
+import os as _os
+
QUANTIZED_MODEL = _os.environ.get("FLUX_AUTOROUND_MODEL", QUANTIZED_MODEL)
BASELINE_MODEL = _os.environ.get("FLUX_BASELINE_MODEL", BASELINE_MODEL)
@@ -42,18 +51,19 @@ def _generate_image(model_name: str, **extra_kwargs) -> tuple[list, float]:
monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
- with OmniRunner(model_name, enforce_eager=True, **extra_kwargs) as runner:
- current_omni_platform.reset_peak_memory_stats()
- outputs = runner.omni.generate(
- "a photo of a cat sitting on a laptop keyboard",
- OmniDiffusionSamplingParams(
- height=HEIGHT,
- width=WIDTH,
- num_inference_steps=NUM_STEPS,
- guidance_scale=0.0,
- generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
- ),
- )
+ m = Omni(model=model_name, enforce_eager=True, **extra_kwargs)
+
+ current_omni_platform.reset_peak_memory_stats()
+ outputs = m.generate(
+ "a photo of a cat sitting on a laptop keyboard",
+ OmniDiffusionSamplingParams(
+ height=HEIGHT,
+ width=WIDTH,
+ num_inference_steps=NUM_STEPS,
+ guidance_scale=0.0,
+ generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42),
+ ),
+ )
peak = monitor.peak_used_mb
monitor.stop()
@@ -64,6 +74,7 @@ def _generate_image(model_name: str, **extra_kwargs) -> tuple[list, float]:
assert isinstance(req_out, OmniRequestOutput) and hasattr(req_out, "images")
images = req_out.images
+ del m
gc.collect()
current_omni_platform.empty_cache()
diff --git a/tests/e2e/offline_inference/test_flux_kontext.py b/tests/e2e/offline_inference/test_flux_kontext.py
index 057319c855f..93dca21c9ad 100644
--- a/tests/e2e/offline_inference/test_flux_kontext.py
+++ b/tests/e2e/offline_inference/test_flux_kontext.py
@@ -9,14 +9,23 @@
- Image editing with text guidance
"""
+import os
+import sys
+from pathlib import Path
+
import pytest
from PIL import Image
-from vllm.assets.image import ImageAsset
-from tests.helpers.runtime import OmniRunner
from vllm_omni.diffusion.data import DiffusionParallelConfig
+from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+REPO_ROOT = Path(__file__).resolve().parents[2]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.insert(0, str(REPO_ROOT))
+
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
+
MODEL = "black-forest-labs/FLUX.1-Kontext-dev"
@@ -24,15 +33,17 @@
@pytest.mark.diffusion
def test_flux_kontext_text_to_image():
"""Test FluxKontext text-to-image generation with real model."""
- with OmniRunner(
- MODEL,
+ omni = Omni(
+ model=MODEL,
parallel_config=DiffusionParallelConfig(
tensor_parallel_size=2,
),
enable_cpu_offload=False,
- ) as runner:
+ )
+
+ try:
omni_outputs = list(
- runner.omni.generate(
+ omni.generate(
prompts=["A photo of a cat sitting on a laptop"],
sampling_params_list=OmniDiffusionSamplingParams(
height=512,
@@ -43,37 +54,43 @@ def test_flux_kontext_text_to_image():
)
)
- assert len(omni_outputs) > 0
- output = omni_outputs[0]
- images = None
- if output.images:
- images = output.images
- elif hasattr(output, "request_output") and output.request_output:
- for stage_out in output.request_output:
- if hasattr(stage_out, "images") and stage_out.images:
- images = stage_out.images
- break
+ assert len(omni_outputs) > 0
+ output = omni_outputs[0]
+ images = None
+ if output.images:
+ images = output.images
+ elif hasattr(output, "request_output") and output.request_output:
+ for stage_out in output.request_output:
+ if hasattr(stage_out, "images") and stage_out.images:
+ images = stage_out.images
+ break
- assert images is not None
- assert len(images) > 0
- assert isinstance(images[0], Image.Image)
- assert images[0].size == (512, 512)
+ assert images is not None
+ assert len(images) > 0
+ assert isinstance(images[0], Image.Image)
+ assert images[0].size == (512, 512)
+ finally:
+ omni.close()
@pytest.mark.core_model
@pytest.mark.diffusion
def test_flux_kontext_image_edit():
"""Test FluxKontext image-to-image editing with real model."""
+ from vllm.assets.image import ImageAsset
+
input_image = ImageAsset("2560px-Gfp-wisconsin-madison-the-nature-boardwalk").pil_image.convert("RGB")
- with OmniRunner(
- MODEL,
+ omni = Omni(
+ model=MODEL,
parallel_config=DiffusionParallelConfig(
tensor_parallel_size=2,
),
enable_cpu_offload=False,
- ) as runner:
+ )
+
+ try:
omni_outputs = list(
- runner.omni.generate(
+ omni.generate(
prompts=[
{
"prompt": "Transform this image into a Vincent van Gogh style painting",
@@ -90,18 +107,20 @@ def test_flux_kontext_image_edit():
)
)
- assert len(omni_outputs) > 0
- output = omni_outputs[0]
- images = None
- if output.images:
- images = output.images
- elif hasattr(output, "request_output") and output.request_output:
- for stage_out in output.request_output:
- if hasattr(stage_out, "images") and stage_out.images:
- images = stage_out.images
- break
-
- assert images is not None
- assert len(images) > 0
- assert isinstance(images[0], Image.Image)
- assert images[0].size == (512, 512)
+ assert len(omni_outputs) > 0
+ output = omni_outputs[0]
+ images = None
+ if output.images:
+ images = output.images
+ elif hasattr(output, "request_output") and output.request_output:
+ for stage_out in output.request_output:
+ if hasattr(stage_out, "images") and stage_out.images:
+ images = stage_out.images
+ break
+
+ assert images is not None
+ assert len(images) > 0
+ assert isinstance(images[0], Image.Image)
+ assert images[0].size == (512, 512)
+ finally:
+ omni.close()
diff --git a/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py b/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py
deleted file mode 100644
index bd0d132d093..00000000000
--- a/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py
+++ /dev/null
@@ -1,343 +0,0 @@
-# ruff: noqa: E501
-from collections.abc import Generator
-from pathlib import Path
-
-import pytest
-import torch
-import torch.nn.functional as F
-from PIL import Image
-from transformers import CLIPModel, CLIPProcessor
-
-from tests.helpers.runtime import OmniRunner
-from vllm_omni import Omni
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-from vllm_omni.platforms import current_omni_platform
-
-PROMPT = "A brown and white dog is running on the grass"
-MODEL_NAME = "tencent/HunyuanImage-3.0"
-LOCAL_CLIP_PATH = "openai/clip-vit-base-patch32"
-REPO_ROOT = Path(__file__).resolve().parents[3]
-STAGE_CONFIG_PATH = REPO_ROOT / "vllm_omni" / "model_executor" / "stage_configs" / "hunyuan_image3_t2i.yaml"
-
-pytestmark = [pytest.mark.advanced_model, pytest.mark.diffusion]
-
-# System prompt type. Options: None, dynamic, en_vanilla, en_recaption, en_think_recaption, en_unified
-# Below are the CLIP embedding tensors from the official HunyuanImage model (seed=1234, prompt: "A brown and white dog is running on the grass").
-# SEED_1234 denotes the output without system prompt, while the remaining entries correspond to outputs generated with different system prompts.
-# fmt: off
-SEED_1234 = torch.tensor(
- [
- 0.027797, 0.028964, -0.005051, 0.001059, 0.017021, -0.034029, 0.021989, 0.033318, -0.000308, 0.016179, 0.010504, -0.034201, 0.050230, -0.021170, 0.083530, -0.003621,
- 0.040758, 0.039913, 0.044305, -0.019285, -0.058387, -0.001099, 0.042782, -0.036136, -0.014955, 0.002147, 0.009439, 0.012943, -0.028732, -0.018349, 0.002861, 0.013019,
- 0.014362, -0.038833, 0.029413, 0.020724, 0.002714, 0.010416, -0.020527, 0.050266, -0.081026, -0.006814, -0.007457, -0.032333, 0.008417, -0.122455, -0.006085, -0.025610,
- 0.012614, 0.025817, -0.005419, 0.038657, 0.000789, 0.067111, 0.002818, 0.028696, 0.047305, -0.009993, -0.019508, 0.038604, 0.099657, 0.026728, 0.012361, 0.013626,
- 0.023164, -0.037186, 0.007535, 0.054645, -0.009012, -0.019383, -0.005234, -0.018715, -0.000346, 0.051317, -0.028744, 0.029933, -0.006382, -0.018414, -0.033906, -0.028892,
- -0.015301, -0.004276, 0.014626, -0.008505, 0.013717, -0.027323, -0.001332, -0.040227, 0.047021, -0.019082, -0.037260, -0.029780, -0.594026, 0.016573, -0.010523, 0.042616,
- -0.013136, 0.030540, -0.151685, -0.005367, 0.016209, -0.034183, 0.009852, 0.038452, 0.005494, -0.017887, -0.007167, 0.017262, -0.038980, 0.011995, 0.021952, -0.031660,
- 0.020507, -0.035880, 0.035183, -0.026975, -0.050788, -0.002553, 0.037774, -0.020082, -0.015403, 0.045022, 0.072167, -0.029237, 0.003895, -0.051250, 0.008581, 0.023545,
- -0.026827, 0.020895, 0.041780, -0.040766, -0.008146, 0.080630, 0.000404, 0.032003, -0.005279, -0.090707, -0.013813, 0.010204, -0.001513, 0.016394, -0.001321, 0.020535,
- -0.038645, 0.024858, 0.024378, 0.018717, -0.056314, 0.024402, 0.018694, 0.029009, -0.008502, -0.014694, -0.028345, 0.005202, 0.046116, -0.032166, -0.030706, -0.038738,
- -0.031356, -0.009683, 0.040069, 0.001596, -0.012621, 0.018590, -0.024138, 0.035330, 0.011546, 0.015791, -0.026932, 0.004531, 0.022455, -0.012871, 0.013915, -0.009567,
- -0.010976, 0.013497, 0.042590, 0.002072, -0.052718, -0.045494, 0.013036, -0.005403, -0.005947, -0.003437, 0.016653, -0.016805, -0.040291, 0.007927, 0.001296, -0.008319,
- 0.021514, -0.001452, -0.121998, 0.015396, -0.022594, -0.006977, -0.040108, -0.035550, -0.021872, -0.014721, 0.019799, 0.036556, 0.015072, -0.057988, -0.011684, -0.045220,
- -0.026295, 0.052647, 0.013741, -0.013428, 0.061794, 0.021431, -0.011316, -0.009963, 0.008198, 0.027746, 0.074219, -0.019499, 0.042673, 0.016028, 0.007214, -0.010650,
- -0.019682, 0.001902, 0.038867, -0.007333, 0.031749, 0.004391, 0.018688, 0.044654, 0.030615, -0.027816, 0.031711, -0.056952, -0.033499, -0.039368, 0.025801, -0.027610,
- -0.009329, -0.001799, 0.024061, -0.012593, -0.050266, -0.012512, 0.019528, -0.083434, 0.018238, 0.034138, -0.020120, -0.009910, -0.002280, 0.035325, 0.034440, -0.055205,
- -0.017698, -0.000439, -0.034703, 0.013356, -0.037287, 0.048494, -0.018570, 0.028069, 0.019269, -0.007263, -0.008521, 0.000426, -0.016677, 0.056162, -0.011944, 0.017322,
- 0.022219, -0.014266, -0.009292, -0.009979, 0.014973, 0.011623, -0.017799, 0.032925, -0.024668, 0.007312, -0.025035, -0.008967, -0.026827, 0.011889, -0.138517, -0.009608,
- -0.020592, -0.001272, 0.015676, -0.025706, 0.031775, -0.004195, 0.026876, -0.014748, -0.025966, -0.008741, 0.035437, 0.017139, -0.005140, -0.007101, -0.012510, -0.023600,
- 0.032969, -0.005510, 0.020010, 0.032567, 0.015558, 0.004265, -0.036300, 0.048210, 0.080424, -0.052820, -0.002063, -0.020875, 0.052530, -0.001638, -0.020299, -0.035202,
- 0.087818, 0.034614, -0.032735, 0.033201, -0.001751, 0.029574, 0.009926, 0.011619, -0.001267, -0.020149, -0.003826, -0.029860, 0.011437, -0.051276, 0.024344, 0.003096,
- -0.011573, 0.038228, -0.005730, -0.052328, 0.001909, -0.025877, 0.019976, -0.010160, 0.023892, 0.049161, -0.028978, 0.018700, -0.026460, 0.001090, -0.072128, -0.008406,
- 0.010828, 0.020621, -0.005706, 0.023797, 0.036231, -0.112069, 0.017601, 0.007496, 0.045999, 0.016771, 0.021977, 0.022305, 0.018377, 0.002036, -0.029815, -0.082922,
- -0.012710, -0.026355, 0.003790, 0.017472, -0.023148, -0.002901, -0.057854, 0.028393, 0.230866, -0.023486, 0.051094, 0.047508, 0.018957, -0.037130, 0.001054, -0.026126,
- 0.021970, -0.046915, -0.019419, -0.014077, 0.002502, -0.079454, -0.057149, -0.081701, 0.041979, -0.043074, -0.009425, -0.035776, -0.021794, -0.004826, -0.057263, -0.072940,
- 0.037651, -0.013991, -0.043863, -0.020581, 0.034319, -0.052566, -0.010355, -0.022963, 0.027144, -0.017339, 0.088930, -0.000670, -0.026547, -0.026586, -0.032531, 0.040314,
- 0.010148, 0.021104, 0.009228, -0.073227, 0.036650, -0.019337, 0.010211, -0.089620, -0.024676, -0.020729, -0.004070, 0.000784, -0.110561, 0.015390, 0.027151, -0.003228,
- -0.066704, -0.004797, -0.026117, -0.018131, -0.090114, 0.020659, -0.007157, 0.013608, -0.022324, 0.027487, 0.018873, 0.027854, 0.045085, -0.039992, -0.017829, 0.011071,
- -0.011393, -0.004454, -0.037189, -0.030299, 0.059668, 0.005064, 0.024655, -0.037239, 0.046882, -0.010356, -0.009690, 0.061909, -0.024736, 0.016849, 0.000784, 0.000201,
- 0.066165, 0.010234, -0.012134, -0.002823, -0.060847, 0.008953, 0.010348, 0.022292, -0.044602, -0.020981, 0.038839, 0.006616, -0.016836, -0.043995, -0.005463, -0.036413,
- 0.034895, -0.018008, -0.009543, -0.025080, -0.035243, 0.042696, -0.028911, -0.030676, -0.038542, -0.027798, -0.026607, 0.019467, 0.070629, -0.037356, -0.042648, -0.000284,
- 0.033095, 0.077781, -0.052930, 0.022515, -0.029926, -0.033821, -0.003277, -0.000038, -0.026871, 0.018223, -0.004221, 0.023454, -0.030611, -0.006396, -0.009873, -0.008402,
- ],
- dtype=torch.float32,
-)
-SYSTEM_PROMPT_DYNAMIC = torch.tensor(
- [
- 0.010809, 0.021177, -0.017600, -0.016814, 0.012351, -0.024554, 0.018299, 0.039305, 0.003331, 0.030473, 0.005557, -0.040898, 0.047294, -0.016136, 0.076989, -0.002723,
- 0.017622, 0.042330, 0.058266, -0.016232, -0.029502, 0.004529, 0.033543, -0.041481, -0.017631, 0.002727, 0.018874, 0.019932, -0.030052, -0.009997, 0.004582, 0.002135,
- -0.003720, -0.030923, 0.021174, 0.034033, -0.007096, 0.011522, -0.009518, 0.055688, -0.092351, -0.003914, 0.004589, -0.032635, 0.012479, -0.140607, -0.014141, -0.031821,
- 0.001396, 0.026780, -0.007623, 0.039957, 0.006434, 0.047516, 0.014377, 0.015237, 0.034212, 0.003576, -0.027357, 0.038888, 0.087272, 0.020248, 0.015165, 0.016002,
- 0.020781, -0.040509, -0.008929, 0.080857, -0.002642, -0.009738, -0.005683, -0.000615, -0.012801, 0.046457, -0.045004, 0.024689, 0.002498, -0.017333, -0.027366, -0.023231,
- -0.006064, -0.021505, 0.007405, -0.021249, 0.026252, -0.018690, 0.020093, -0.036954, 0.037510, -0.032027, -0.030871, -0.011173, -0.618627, 0.021213, -0.004366, 0.029555,
- -0.004324, 0.020221, -0.143832, -0.021386, 0.010482, -0.042113, 0.016164, 0.040350, 0.014627, -0.011778, -0.018102, 0.035380, -0.020305, 0.010590, 0.009227, -0.011415,
- 0.018623, -0.036384, 0.031003, -0.017073, -0.056456, -0.010423, 0.033029, -0.023511, -0.008717, 0.045716, 0.068273, -0.027886, 0.009665, -0.039801, 0.001465, 0.024361,
- -0.015039, 0.022903, 0.033362, -0.022804, 0.008631, 0.076518, 0.000619, 0.022786, -0.015435, -0.095242, -0.006092, 0.015496, -0.009081, 0.015740, 0.004280, 0.013103,
- -0.031836, 0.034241, 0.031836, 0.032636, -0.053721, 0.034370, 0.019172, 0.018383, 0.006907, -0.036039, -0.027927, 0.008646, 0.040496, -0.060314, -0.039116, -0.021488,
- -0.031682, -0.005077, 0.034920, 0.002148, -0.008087, 0.002024, -0.008480, 0.041096, 0.011401, 0.020380, -0.025078, 0.005002, 0.022252, -0.014577, 0.008051, -0.014476,
- -0.007078, 0.021075, 0.036965, 0.005343, -0.038671, -0.037222, 0.014052, -0.009952, -0.003958, -0.001878, 0.017848, -0.016608, -0.030813, 0.010921, 0.001068, 0.003095,
- 0.007076, -0.001936, -0.102996, 0.006838, -0.005243, -0.009140, -0.043796, -0.027227, -0.008426, -0.013177, 0.015602, 0.021036, 0.025484, -0.064836, -0.003593, -0.038036,
- -0.023102, 0.064053, 0.007850, 0.000771, 0.039297, 0.011903, -0.015866, -0.017612, 0.006308, 0.024342, 0.086761, -0.016705, 0.039239, 0.025079, -0.006452, 0.003174,
- -0.010146, 0.010787, 0.035932, -0.015346, 0.037191, 0.010990, 0.011573, 0.044958, 0.035560, -0.017339, 0.018878, -0.025394, -0.044339, -0.029852, 0.015951, -0.032248,
- -0.012019, 0.013497, 0.012224, -0.001284, -0.034041, -0.015768, 0.000230, -0.086076, 0.024878, 0.031929, -0.016668, -0.019815, -0.001325, 0.007944, 0.017674, -0.036097,
- -0.019651, -0.001272, -0.032842, 0.002056, -0.037140, 0.043191, -0.003710, 0.011767, 0.020313, -0.018396, -0.015935, 0.010228, -0.017349, 0.049363, -0.010007, 0.019533,
- 0.018076, 0.016608, -0.005523, -0.007793, 0.016868, 0.019341, -0.008236, 0.026765, -0.025324, -0.007849, -0.023648, -0.007791, -0.018508, 0.015357, -0.166499, -0.003718,
- -0.035447, -0.005229, 0.019327, -0.014207, 0.028433, -0.002619, 0.013888, -0.033146, -0.017015, 0.004677, 0.039554, 0.003803, -0.014592, -0.018886, -0.023868, -0.022708,
- 0.033661, 0.008626, 0.015687, 0.046395, 0.014173, 0.015083, -0.025994, 0.039120, 0.076334, -0.061165, 0.001791, -0.017579, 0.067567, -0.002415, -0.032495, -0.025576,
- 0.079027, 0.036370, -0.013303, 0.030510, -0.009061, 0.019135, 0.015627, 0.024864, 0.015093, -0.017066, -0.014075, -0.021907, 0.017388, -0.033492, 0.013317, -0.000040,
- 0.003396, 0.044030, -0.009194, -0.049524, -0.005015, -0.040007, 0.009104, 0.000580, 0.005603, 0.035891, -0.038913, 0.023239, -0.017022, -0.002695, -0.095759, 0.018503,
- 0.017365, 0.011104, -0.003433, 0.024113, 0.052609, -0.085274, 0.027565, -0.005833, 0.020700, 0.015842, 0.019148, 0.020203, -0.000698, -0.005337, -0.037400, -0.060144,
- -0.031893, -0.038396, -0.001949, 0.018901, -0.014268, -0.004721, -0.055913, 0.013814, 0.215024, -0.011357, 0.057530, 0.050092, 0.016513, -0.059254, 0.001494, -0.031472,
- 0.032190, -0.047512, -0.020501, -0.002571, 0.007844, -0.063630, -0.043938, -0.079595, 0.032820, -0.021659, -0.003738, -0.035267, -0.013794, -0.021172, -0.046356, -0.077079,
- 0.021526, -0.007447, -0.050276, -0.029743, 0.022208, -0.039137, -0.021426, -0.029825, 0.029390, -0.002943, 0.073158, -0.000435, -0.032029, -0.038524, -0.029886, 0.017473,
- 0.013513, 0.022738, 0.000632, -0.073718, 0.029219, -0.018896, 0.007302, -0.116122, -0.013324, -0.012214, -0.005960, -0.003720, -0.155869, 0.019896, 0.016919, -0.021133,
- -0.066911, -0.000926, -0.020871, -0.015295, -0.086108, 0.014918, -0.009284, 0.001689, -0.038155, 0.039163, 0.015988, 0.014413, 0.034205, -0.053273, 0.001687, 0.012227,
- -0.007341, -0.006123, -0.005731, -0.026863, 0.060196, 0.028929, 0.019328, -0.033709, 0.038789, -0.015624, 0.013323, 0.053821, -0.015538, -0.001610, 0.012959, -0.013897,
- 0.082010, 0.012866, -0.017269, 0.000017, -0.059458, 0.015870, 0.028455, 0.025234, -0.051163, -0.022976, 0.011866, -0.005613, -0.008738, -0.047658, -0.002155, -0.029432,
- 0.039242, -0.013491, -0.001641, -0.024210, -0.019187, 0.026716, -0.025698, -0.027591, -0.034678, -0.002473, -0.019391, 0.017597, 0.064385, -0.029104, -0.034501, -0.004955,
- 0.015008, 0.060749, -0.051693, 0.020279, -0.027170, -0.027003, 0.000254, 0.011352, -0.028116, 0.028938, -0.007224, 0.019978, -0.025379, -0.004874, -0.019361, -0.020278,
- ],
- dtype=torch.float32,
-)
-SYSTEM_EN_RECAPTION = torch.tensor(
- [
- 0.007721, 0.015421, -0.019305, -0.000920, 0.016031, -0.019730, 0.029683, 0.026810, -0.010510, 0.021463, 0.008833, -0.040851, 0.043260, -0.007042, 0.057224, 0.011995,
- 0.007818, 0.046369, 0.059838, -0.028548, -0.047399, -0.000983, 0.024343, -0.052259, -0.013638, 0.006856, 0.009186, 0.014235, -0.031497, -0.008644, -0.009349, 0.018900,
- 0.002913, -0.022475, 0.039518, 0.019052, -0.007600, 0.010634, -0.011830, 0.075675, -0.071738, -0.014947, 0.004995, -0.025804, -0.002553, -0.093262, 0.002881, -0.033744,
- -0.007234, 0.013659, 0.009897, 0.039185, -0.005366, 0.041534, -0.005924, 0.019786, 0.048566, -0.009356, -0.027360, 0.042557, 0.091286, 0.009286, 0.015410, 0.028166,
- 0.022476, -0.025162, 0.012144, 0.084603, -0.003150, -0.008549, -0.002099, -0.014987, -0.019480, 0.046843, -0.030613, 0.015557, -0.008965, -0.008798, -0.027032, -0.014112,
- 0.018703, -0.014749, -0.000928, -0.024660, 0.024004, 0.004560, 0.028156, -0.028467, 0.025444, -0.038699, -0.014927, -0.031593, -0.648498, 0.018529, 0.003378, 0.030188,
- -0.002314, 0.014950, -0.146615, -0.009005, 0.016579, -0.037867, 0.020907, 0.033160, 0.007877, -0.026345, -0.056428, 0.031255, -0.018404, 0.013334, 0.009988, -0.022790,
- 0.020803, -0.036862, 0.036222, -0.006646, -0.058084, -0.012036, 0.044199, -0.027665, -0.015779, 0.051554, 0.059970, -0.025977, 0.003967, -0.035247, -0.000488, 0.023182,
- 0.000468, 0.019190, 0.047268, -0.032279, -0.005302, 0.078669, -0.001915, 0.024918, -0.014952, -0.078905, -0.018333, 0.001362, -0.015115, 0.005435, 0.002313, 0.018766,
- -0.032773, 0.037344, 0.024061, 0.012143, -0.057106, 0.029490, 0.019537, 0.009099, 0.026064, -0.015927, -0.037047, 0.006002, 0.025191, -0.035318, -0.032245, -0.047822,
- -0.023568, -0.004533, 0.025100, 0.002758, -0.002649, -0.012287, -0.012139, 0.043080, 0.003295, 0.024667, -0.021050, 0.006752, 0.025315, -0.011127, 0.009800, -0.021343,
- -0.024866, 0.010098, 0.026954, 0.012467, -0.035866, -0.031780, 0.007479, -0.003388, -0.012619, -0.012099, 0.014974, -0.001908, -0.032700, 0.004703, 0.003238, -0.007498,
- 0.023241, 0.002715, -0.111739, 0.003317, 0.006475, -0.019792, -0.046558, -0.032593, -0.020762, -0.005059, 0.016934, 0.029195, 0.028744, -0.050633, 0.001907, -0.028791,
- -0.016695, 0.052143, 0.010439, 0.007204, 0.028502, 0.012607, -0.012414, -0.031238, 0.007305, 0.032309, 0.087924, -0.010530, 0.029925, 0.032666, -0.002202, 0.017539,
- -0.009091, -0.001631, 0.024906, -0.013102, 0.031772, 0.018465, 0.012035, 0.031460, 0.030193, 0.005289, 0.025859, -0.038971, -0.046577, -0.025852, 0.035235, -0.038514,
- 0.001042, 0.013012, 0.023701, -0.014630, -0.029269, -0.011981, 0.008219, -0.067347, -0.003456, 0.028198, -0.008657, -0.017773, 0.010540, 0.023964, 0.021012, -0.034465,
- -0.023748, 0.004065, -0.021598, 0.008440, -0.031533, 0.038390, -0.007680, -0.003852, 0.016136, -0.017906, -0.008927, 0.006300, -0.001251, 0.029337, -0.008632, 0.020568,
- 0.021560, -0.007222, 0.005313, -0.013089, 0.012299, 0.031303, -0.013951, 0.016547, -0.024771, -0.008753, -0.030908, -0.014421, -0.017656, 0.014044, -0.114986, 0.000956,
- -0.035588, 0.003756, 0.015383, -0.013358, 0.009385, -0.001359, 0.012623, -0.028724, 0.001607, 0.012809, 0.032668, 0.011834, -0.015587, -0.007170, -0.021344, -0.019664,
- 0.017690, -0.014538, 0.016511, 0.038037, 0.029919, 0.020907, -0.018565, 0.032964, 0.078548, -0.050386, -0.003012, -0.016965, 0.064131, 0.008077, -0.025879, -0.035820,
- 0.095075, 0.019901, -0.019114, 0.022832, 0.003741, 0.027148, 0.018231, 0.027741, 0.020328, 0.001700, -0.006939, -0.024154, 0.018523, -0.029819, 0.008050, -0.004477,
- 0.006087, 0.056878, -0.009083, -0.061537, -0.011531, -0.037551, 0.000434, -0.005843, 0.024739, 0.032020, -0.053119, 0.020704, -0.012385, -0.002726, -0.082489, 0.009072,
- 0.013341, 0.000316, 0.001899, 0.022868, 0.034407, -0.066857, 0.020589, 0.012195, 0.023211, -0.001520, 0.000897, 0.029670, -0.015930, 0.006509, -0.035172, -0.061215,
- -0.014099, -0.038584, -0.012213, 0.018613, -0.012365, -0.002777, -0.055184, 0.017146, 0.214358, -0.015750, 0.052488, 0.045205, 0.025334, -0.054615, 0.002117, -0.038122,
- 0.012402, -0.053418, -0.025405, 0.007235, 0.013208, -0.092481, -0.048700, -0.085186, 0.029039, -0.036767, -0.000777, -0.017625, -0.012556, -0.004887, -0.033660, -0.082310,
- 0.013387, -0.003256, -0.062981, -0.019886, 0.017624, -0.037421, -0.020743, -0.020894, 0.041974, -0.008502, 0.088413, -0.018697, -0.029398, -0.029389, -0.043721, 0.013872,
- 0.003944, 0.030361, 0.005355, -0.081355, 0.041843, -0.016395, 0.011954, -0.060440, -0.000966, -0.019101, 0.006803, -0.011310, -0.148581, 0.020342, 0.012795, -0.016473,
- -0.053300, -0.012340, -0.016640, -0.029834, -0.082405, 0.011859, -0.004255, -0.004396, -0.012515, 0.031962, 0.030438, 0.013792, 0.031557, -0.047200, 0.006485, 0.024815,
- -0.019376, -0.011454, -0.034184, -0.021329, 0.050115, 0.021720, 0.002874, -0.047163, 0.044031, -0.014663, 0.020534, 0.056017, 0.007017, 0.003323, 0.005734, -0.002777,
- 0.082836, 0.012048, -0.023236, -0.007401, -0.071598, 0.016760, 0.017282, 0.028306, -0.026220, -0.008016, -0.000202, -0.020271, -0.019828, -0.046986, -0.005805, -0.039647,
- 0.042879, -0.004463, 0.007753, -0.028916, -0.020612, 0.028833, -0.039839, -0.052447, -0.013275, -0.002407, -0.018937, 0.033216, 0.075535, -0.045026, -0.009901, 0.016637,
- -0.000322, 0.073925, -0.055701, 0.014912, -0.045671, -0.021189, 0.006761, -0.002015, -0.027410, 0.018250, -0.015916, 0.016254, -0.044964, 0.029261, -0.029319, -0.005222,
- ],
- dtype=torch.float32,
-)
-SYSTEM_EN_THINK_RECAPTION = torch.tensor(
- [
- 0.011004, 0.017341, -0.019959, -0.018314, 0.016520, -0.027395, 0.017946, 0.039665, 0.000645, 0.035903, 0.002499, -0.045664, 0.039472, -0.013479, 0.081302, 0.000182,
- 0.006947, 0.042845, 0.059741, -0.010796, -0.035240, 0.004176, 0.029557, -0.043467, -0.017271, 0.006896, 0.010997, 0.022498, -0.023308, -0.013046, -0.000742, 0.016209,
- -0.007152, -0.029868, 0.028747, 0.033743, -0.000227, 0.018419, -0.015023, 0.050376, -0.098475, -0.002375, 0.007897, -0.023936, 0.007843, -0.122463, -0.011680, -0.027267,
- -0.007270, 0.021869, -0.011415, 0.043770, 0.000551, 0.048573, 0.003132, 0.014233, 0.037080, -0.004818, -0.028738, 0.044468, 0.073843, 0.016947, 0.014484, 0.021931,
- 0.020110, -0.032309, -0.003811, 0.095704, -0.006950, -0.007237, -0.005529, -0.020573, -0.016259, 0.041909, -0.038748, 0.018029, 0.005066, -0.021186, -0.020102, -0.019719,
- 0.006239, -0.021284, 0.004213, -0.024963, 0.032345, -0.012557, 0.037268, -0.038075, 0.040998, -0.032766, -0.023509, -0.016426, -0.627412, 0.022675, 0.000101, 0.023162,
- -0.002081, 0.015922, -0.138671, -0.027995, 0.011579, -0.042859, 0.019935, 0.038077, 0.012640, -0.017377, -0.027456, 0.035151, -0.015756, 0.018530, 0.004646, -0.002589,
- 0.019645, -0.043736, 0.034947, -0.010166, -0.061165, -0.019195, 0.028909, -0.019415, -0.009485, 0.049566, 0.068621, -0.038644, 0.011278, -0.036133, 0.000564, 0.022611,
- -0.013612, 0.020854, 0.030614, -0.025578, 0.005673, 0.076526, -0.004887, 0.027769, -0.022605, -0.092657, -0.013218, 0.008081, -0.015227, 0.018031, -0.005145, 0.015028,
- -0.027193, 0.034767, 0.028710, 0.032007, -0.053175, 0.033528, 0.019437, 0.011517, 0.012107, -0.027679, -0.026937, 0.008612, 0.036909, -0.051484, -0.039971, -0.034372,
- -0.023825, -0.003025, 0.033648, -0.001852, 0.007309, 0.000714, -0.001075, 0.038534, 0.007586, 0.016213, -0.025223, -0.001099, 0.015852, -0.011477, 0.020635, -0.010696,
- -0.019634, 0.025613, 0.034374, 0.007169, -0.035000, -0.032268, 0.015114, -0.014217, -0.005229, -0.005495, 0.018189, -0.011360, -0.026755, 0.007036, -0.002333, -0.001174,
- 0.014729, 0.001739, -0.108591, 0.004699, 0.002048, -0.014801, -0.042855, -0.028846, -0.009609, -0.004500, 0.019466, 0.021848, 0.022140, -0.063035, -0.004272, -0.030798,
- -0.018452, 0.055169, 0.012240, -0.003555, 0.038293, 0.008503, -0.016608, -0.021309, 0.000690, 0.027093, 0.088054, -0.008881, 0.034087, 0.030647, 0.003284, 0.005038,
- -0.008359, 0.006311, 0.032462, -0.009699, 0.035283, 0.015261, 0.012827, 0.038169, 0.033959, -0.018048, 0.018122, -0.025259, -0.040084, -0.030879, 0.019853, -0.042558,
- -0.011938, 0.019602, 0.016537, -0.003378, -0.027890, -0.014909, -0.005464, -0.071862, 0.012335, 0.021899, -0.017008, -0.023228, 0.003263, 0.004571, 0.016447, -0.029446,
- -0.022645, -0.001261, -0.018573, 0.007431, -0.027587, 0.035362, -0.006785, -0.000614, 0.026044, -0.009056, -0.009843, 0.010467, -0.011929, 0.042025, -0.014068, 0.023113,
- 0.023880, 0.014948, 0.004370, -0.005262, 0.012587, 0.021608, -0.001783, 0.023697, -0.024945, -0.011533, -0.020953, -0.007205, -0.024693, 0.012961, -0.168760, 0.001767,
- -0.041265, -0.007044, 0.015021, -0.008407, 0.029642, -0.000956, 0.008607, -0.035365, -0.012187, 0.011744, 0.032612, 0.006226, -0.015891, -0.017747, -0.022565, -0.024505,
- 0.031279, 0.004188, 0.011939, 0.038032, 0.008798, 0.012314, -0.024830, 0.034484, 0.076395, -0.060108, 0.001019, -0.016138, 0.067729, 0.003899, -0.029845, -0.019960,
- 0.086663, 0.040965, -0.010458, 0.027808, -0.006394, 0.017343, 0.014788, 0.024756, 0.016446, -0.012537, -0.008406, -0.028109, 0.013369, -0.033571, 0.012170, -0.002199,
- 0.005263, 0.052280, -0.018171, -0.047898, -0.010087, -0.038632, 0.006773, -0.000838, 0.011197, 0.038187, -0.049525, 0.021689, -0.007385, -0.005987, -0.094551, 0.019019,
- 0.012760, 0.009617, -0.002262, 0.030228, 0.047823, -0.079764, 0.023391, -0.005561, 0.018866, 0.012817, 0.020878, 0.027037, -0.013905, -0.002874, -0.035522, -0.046266,
- -0.032448, -0.036010, -0.007776, 0.016512, -0.012279, -0.005665, -0.057974, 0.016967, 0.202836, -0.009066, 0.066093, 0.045689, 0.018319, -0.048465, 0.000242, -0.040874,
- 0.027824, -0.049045, -0.015616, -0.000307, 0.009163, -0.072975, -0.042979, -0.082254, 0.040549, -0.027049, 0.000725, -0.034118, -0.019604, -0.019097, -0.042483, -0.075446,
- 0.019387, -0.005218, -0.053573, -0.029975, 0.008195, -0.036608, -0.018920, -0.025610, 0.028426, -0.002688, 0.074996, -0.003423, -0.032505, -0.030565, -0.028142, 0.014437,
- 0.013359, 0.019376, 0.008356, -0.069731, 0.031824, -0.011103, 0.019327, -0.117090, -0.009352, -0.010290, -0.002129, -0.009198, -0.172915, 0.021232, 0.017274, -0.030060,
- -0.061449, -0.006598, -0.013069, -0.012857, -0.081220, 0.019058, -0.004841, 0.003066, -0.037741, 0.041806, 0.018281, 0.009458, 0.036761, -0.044987, 0.003557, 0.008890,
- -0.008011, -0.004063, -0.013474, -0.022090, 0.055398, 0.037475, 0.006991, -0.035962, 0.045503, -0.017162, 0.022391, 0.052754, -0.005924, -0.005936, 0.012673, -0.017922,
- 0.084548, 0.014695, -0.013817, 0.000421, -0.065167, 0.018269, 0.023317, 0.023523, -0.034229, -0.019588, 0.007911, -0.002426, -0.017109, -0.050870, 0.002848, -0.033077,
- 0.043451, -0.010609, -0.000375, -0.023206, -0.018155, 0.027102, -0.036006, -0.035115, -0.023922, 0.005989, -0.015372, 0.027123, 0.075210, -0.035302, -0.029799, 0.003642,
- 0.007714, 0.063498, -0.053234, 0.015699, -0.040459, -0.027354, -0.002433, 0.010923, -0.020134, 0.029292, -0.010176, 0.013508, -0.032403, 0.004323, -0.017504, -0.015237,
- ],
- dtype=torch.float32,
-)
-SYSTEM_EN_VANILLA = torch.tensor(
- [
- 0.010809, 0.021177, -0.017600, -0.016814, 0.012351, -0.024554, 0.018299, 0.039305, 0.003331, 0.030473, 0.005557, -0.040898, 0.047294, -0.016136, 0.076989, -0.002723,
- 0.017622, 0.042330, 0.058266, -0.016232, -0.029502, 0.004529, 0.033543, -0.041481, -0.017631, 0.002727, 0.018874, 0.019932, -0.030052, -0.009997, 0.004582, 0.002135,
- -0.003720, -0.030923, 0.021174, 0.034033, -0.007096, 0.011522, -0.009518, 0.055688, -0.092351, -0.003914, 0.004589, -0.032635, 0.012479, -0.140607, -0.014141, -0.031821,
- 0.001396, 0.026780, -0.007623, 0.039957, 0.006434, 0.047516, 0.014377, 0.015237, 0.034212, 0.003576, -0.027357, 0.038888, 0.087272, 0.020248, 0.015165, 0.016002,
- 0.020781, -0.040509, -0.008929, 0.080857, -0.002642, -0.009738, -0.005683, -0.000615, -0.012801, 0.046457, -0.045004, 0.024689, 0.002498, -0.017333, -0.027366, -0.023231,
- -0.006064, -0.021505, 0.007405, -0.021249, 0.026252, -0.018690, 0.020093, -0.036954, 0.037510, -0.032027, -0.030871, -0.011173, -0.618627, 0.021213, -0.004366, 0.029555,
- -0.004324, 0.020221, -0.143832, -0.021386, 0.010482, -0.042113, 0.016164, 0.040350, 0.014627, -0.011778, -0.018102, 0.035380, -0.020305, 0.010590, 0.009227, -0.011415,
- 0.018623, -0.036384, 0.031003, -0.017073, -0.056456, -0.010423, 0.033029, -0.023511, -0.008717, 0.045716, 0.068273, -0.027886, 0.009665, -0.039801, 0.001465, 0.024361,
- -0.015039, 0.022903, 0.033362, -0.022804, 0.008631, 0.076518, 0.000619, 0.022786, -0.015435, -0.095242, -0.006092, 0.015496, -0.009081, 0.015740, 0.004280, 0.013103,
- -0.031836, 0.034241, 0.031836, 0.032636, -0.053721, 0.034370, 0.019172, 0.018383, 0.006907, -0.036039, -0.027927, 0.008646, 0.040496, -0.060314, -0.039116, -0.021488,
- -0.031682, -0.005077, 0.034920, 0.002148, -0.008087, 0.002024, -0.008480, 0.041096, 0.011401, 0.020380, -0.025078, 0.005002, 0.022252, -0.014577, 0.008051, -0.014476,
- -0.007078, 0.021075, 0.036965, 0.005343, -0.038671, -0.037222, 0.014052, -0.009952, -0.003958, -0.001878, 0.017848, -0.016608, -0.030813, 0.010921, 0.001068, 0.003095,
- 0.007076, -0.001936, -0.102996, 0.006838, -0.005243, -0.009140, -0.043796, -0.027227, -0.008426, -0.013177, 0.015602, 0.021036, 0.025484, -0.064836, -0.003593, -0.038036,
- -0.023102, 0.064053, 0.007850, 0.000771, 0.039297, 0.011903, -0.015866, -0.017612, 0.006308, 0.024342, 0.086761, -0.016705, 0.039239, 0.025079, -0.006452, 0.003174,
- -0.010146, 0.010787, 0.035932, -0.015346, 0.037191, 0.010990, 0.011573, 0.044958, 0.035560, -0.017339, 0.018878, -0.025394, -0.044339, -0.029852, 0.015951, -0.032248,
- -0.012019, 0.013497, 0.012224, -0.001284, -0.034041, -0.015768, 0.000230, -0.086076, 0.024878, 0.031929, -0.016668, -0.019815, -0.001325, 0.007944, 0.017674, -0.036097,
- -0.019651, -0.001272, -0.032842, 0.002056, -0.037140, 0.043191, -0.003710, 0.011767, 0.020313, -0.018396, -0.015935, 0.010228, -0.017349, 0.049363, -0.010007, 0.019533,
- 0.018076, 0.016608, -0.005523, -0.007793, 0.016868, 0.019341, -0.008236, 0.026765, -0.025324, -0.007849, -0.023648, -0.007791, -0.018508, 0.015357, -0.166499, -0.003718,
- -0.035447, -0.005229, 0.019327, -0.014207, 0.028433, -0.002619, 0.013888, -0.033146, -0.017015, 0.004677, 0.039554, 0.003803, -0.014592, -0.018886, -0.023868, -0.022708,
- 0.033661, 0.008626, 0.015687, 0.046395, 0.014173, 0.015083, -0.025994, 0.039120, 0.076334, -0.061165, 0.001791, -0.017579, 0.067567, -0.002415, -0.032495, -0.025576,
- 0.079027, 0.036370, -0.013303, 0.030510, -0.009061, 0.019135, 0.015627, 0.024864, 0.015093, -0.017066, -0.014075, -0.021907, 0.017388, -0.033492, 0.013317, -0.000040,
- 0.003396, 0.044030, -0.009194, -0.049524, -0.005015, -0.040007, 0.009104, 0.000580, 0.005603, 0.035891, -0.038913, 0.023239, -0.017022, -0.002695, -0.095759, 0.018503,
- 0.017365, 0.011104, -0.003433, 0.024113, 0.052609, -0.085274, 0.027565, -0.005833, 0.020700, 0.015842, 0.019148, 0.020203, -0.000698, -0.005337, -0.037400, -0.060144,
- -0.031893, -0.038396, -0.001949, 0.018901, -0.014268, -0.004721, -0.055913, 0.013814, 0.215024, -0.011357, 0.057530, 0.050092, 0.016513, -0.059254, 0.001494, -0.031472,
- 0.032190, -0.047512, -0.020501, -0.002571, 0.007844, -0.063630, -0.043938, -0.079595, 0.032820, -0.021659, -0.003738, -0.035267, -0.013794, -0.021172, -0.046356, -0.077079,
- 0.021526, -0.007447, -0.050276, -0.029743, 0.022208, -0.039137, -0.021426, -0.029825, 0.029390, -0.002943, 0.073158, -0.000435, -0.032029, -0.038524, -0.029886, 0.017473,
- 0.013513, 0.022738, 0.000632, -0.073718, 0.029219, -0.018896, 0.007302, -0.116122, -0.013324, -0.012214, -0.005960, -0.003720, -0.155869, 0.019896, 0.016919, -0.021133,
- -0.066911, -0.000926, -0.020871, -0.015295, -0.086108, 0.014918, -0.009284, 0.001689, -0.038155, 0.039163, 0.015988, 0.014413, 0.034205, -0.053273, 0.001687, 0.012227,
- -0.007341, -0.006123, -0.005731, -0.026863, 0.060196, 0.028929, 0.019328, -0.033709, 0.038789, -0.015624, 0.013323, 0.053821, -0.015538, -0.001610, 0.012959, -0.013897,
- 0.082010, 0.012866, -0.017269, 0.000017, -0.059458, 0.015870, 0.028455, 0.025234, -0.051163, -0.022976, 0.011866, -0.005613, -0.008738, -0.047658, -0.002155, -0.029432,
- 0.039242, -0.013491, -0.001641, -0.024210, -0.019187, 0.026716, -0.025698, -0.027591, -0.034678, -0.002473, -0.019391, 0.017597, 0.064385, -0.029104, -0.034501, -0.004955,
- 0.015008, 0.060749, -0.051693, 0.020279, -0.027170, -0.027003, 0.000254, 0.011352, -0.028116, 0.028938, -0.007224, 0.019978, -0.025379, -0.004874, -0.019361, -0.020278,
- ],
- dtype=torch.float32,
-)
-SYSTEM_EN_UNIFIED = torch.tensor(
- [
- 0.011409, 0.014191, -0.023163, -0.020119, 0.019190, -0.029559, 0.019616, 0.035872, 0.010434, 0.028709, 0.011616, -0.039422, 0.038369, -0.004631, 0.081177, 0.007400,
- 0.008903, 0.040408, 0.055323, -0.011950, -0.026940, 0.004916, 0.028101, -0.046200, -0.016732, 0.005115, 0.012100, 0.016136, -0.026057, -0.013827, -0.004914, 0.015261,
- -0.010824, -0.028188, 0.022934, 0.026204, -0.003855, 0.013797, -0.014518, 0.050289, -0.100077, -0.002962, 0.009050, -0.028205, 0.016294, -0.128956, -0.012730, -0.023647,
- -0.009306, 0.020066, 0.000033, 0.043619, 0.003250, 0.053425, 0.005889, 0.021529, 0.036032, -0.003254, -0.029715, 0.048345, 0.077978, 0.010674, 0.019296, 0.018721,
- 0.019244, -0.040115, -0.004245, 0.085214, -0.005280, -0.010746, -0.000164, -0.023405, -0.015641, 0.040193, -0.038735, 0.018966, -0.004031, -0.017879, -0.023017, -0.030379,
- 0.006468, -0.015959, 0.000532, -0.026530, 0.042640, -0.006095, 0.037899, -0.043658, 0.040965, -0.034682, -0.023729, -0.019291, -0.630840, 0.029658, 0.005462, 0.026650,
- -0.000292, 0.013954, -0.149594, -0.019405, 0.015321, -0.045104, 0.030332, 0.031727, 0.012349, -0.009553, -0.022371, 0.034043, -0.014838, 0.015398, -0.003657, 0.000477,
- 0.021084, -0.041406, 0.029946, -0.013832, -0.057358, -0.018086, 0.031598, -0.031835, -0.006697, 0.040866, 0.068602, -0.042203, 0.007362, -0.036959, 0.003794, 0.026533,
- -0.011873, 0.017343, 0.028333, -0.021804, 0.004007, 0.075133, 0.003340, 0.025326, -0.015068, -0.092280, -0.011514, 0.006827, -0.008254, 0.021181, -0.005035, 0.022263,
- -0.022443, 0.043919, 0.026637, 0.028568, -0.056881, 0.036740, 0.024430, 0.015891, 0.012257, -0.031126, -0.030108, 0.007229, 0.026998, -0.051685, -0.033003, -0.031170,
- -0.024021, 0.004235, 0.030164, 0.002674, 0.008018, 0.005532, 0.001621, 0.044790, 0.006413, 0.027160, -0.015022, 0.000911, 0.019723, -0.016244, 0.020077, -0.006847,
- -0.014110, 0.022461, 0.031656, 0.002760, -0.039078, -0.026893, 0.006628, -0.011775, -0.000240, -0.005908, 0.014943, -0.012131, -0.021755, 0.004732, -0.005297, -0.002922,
- 0.014631, -0.002010, -0.112400, 0.000842, -0.002732, -0.014861, -0.052099, -0.034167, -0.011613, -0.006101, 0.013278, 0.018867, 0.026530, -0.068150, -0.003306, -0.032801,
- -0.018523, 0.050875, 0.005488, -0.007241, 0.045707, 0.023119, -0.021519, -0.022683, 0.004806, 0.024827, 0.091371, -0.014424, 0.043836, 0.033094, 0.002390, 0.005450,
- -0.004893, 0.013608, 0.031272, -0.002449, 0.031607, 0.014646, 0.014146, 0.043995, 0.028826, -0.012219, 0.021008, -0.020911, -0.036967, -0.036256, 0.013328, -0.038382,
- -0.012084, 0.018183, 0.018782, -0.004697, -0.024284, -0.015474, -0.001463, -0.076015, 0.013923, 0.022125, -0.018765, -0.010793, 0.008409, 0.002067, 0.017961, -0.029716,
- -0.020915, -0.001779, -0.009217, -0.001933, -0.036081, 0.042577, 0.000118, -0.013920, 0.014901, -0.016486, -0.010278, -0.000449, -0.017234, 0.042453, -0.009893, 0.021087,
- 0.017671, 0.009861, -0.004210, 0.004944, 0.015627, 0.014370, -0.001128, 0.030247, -0.019552, -0.014017, -0.020859, -0.002614, -0.024405, 0.016532, -0.173204, -0.001196,
- -0.037415, -0.010990, 0.010449, -0.006124, 0.019211, 0.003695, 0.011679, -0.031852, -0.009764, 0.005773, 0.035793, 0.003455, -0.011772, -0.020532, -0.027434, -0.024761,
- 0.027483, -0.001554, 0.010411, 0.037888, 0.015619, 0.019186, -0.021204, 0.038158, 0.074991, -0.064521, -0.002503, -0.014499, 0.068165, 0.006145, -0.032891, -0.021540,
- 0.091385, 0.047584, -0.009590, 0.028004, -0.002962, 0.021061, 0.014854, 0.025840, 0.016068, -0.014364, -0.016418, -0.033454, 0.011734, -0.036518, 0.013015, -0.003966,
- 0.000855, 0.051373, -0.010960, -0.047078, -0.011048, -0.042015, 0.006818, 0.005483, 0.010251, 0.034951, -0.046162, 0.021258, -0.013397, -0.005259, -0.093775, 0.019974,
- 0.014992, 0.004043, -0.005931, 0.035662, 0.050723, -0.083293, 0.028047, -0.008042, 0.020763, 0.016763, 0.022913, 0.027129, -0.014314, -0.009854, -0.039019, -0.044870,
- -0.028101, -0.038026, -0.006294, 0.018265, -0.015425, -0.007866, -0.052784, 0.010470, 0.200260, -0.007798, 0.064482, 0.046612, 0.025353, -0.059695, -0.001831, -0.039643,
- 0.025148, -0.042752, -0.014928, -0.010216, 0.014195, -0.069149, -0.041424, -0.078360, 0.036999, -0.021357, 0.011032, -0.026564, -0.016214, -0.023440, -0.044723, -0.064498,
- 0.018283, -0.007165, -0.051802, -0.026299, 0.005867, -0.034691, -0.020621, -0.030512, 0.024458, -0.011330, 0.066558, -0.004069, -0.031624, -0.030639, -0.037451, 0.013079,
- 0.015152, 0.008058, 0.009223, -0.069514, 0.030702, -0.009681, 0.014826, -0.115441, -0.005514, -0.011925, 0.001046, -0.007148, -0.164128, 0.018043, 0.017001, -0.026352,
- -0.049691, -0.011637, -0.013045, -0.014851, -0.079469, 0.017692, -0.006575, 0.001063, -0.028299, 0.038777, 0.019930, 0.010641, 0.036955, -0.039004, -0.006477, 0.004278,
- -0.001006, -0.002514, -0.017242, -0.023927, 0.049113, 0.038393, 0.011633, -0.031537, 0.041725, -0.012146, 0.023445, 0.049999, -0.008538, 0.001319, 0.012732, -0.021170,
- 0.082096, 0.009610, -0.025717, 0.002566, -0.060849, 0.017403, 0.032650, 0.018658, -0.030629, -0.025032, 0.005555, 0.000522, -0.009667, -0.043099, 0.005939, -0.027156,
- 0.045634, -0.011986, 0.002713, -0.032225, -0.015494, 0.028734, -0.036528, -0.033101, -0.027174, 0.009490, -0.016537, 0.029435, 0.065709, -0.037711, -0.020497, -0.005578,
- 0.011768, 0.061035, -0.044676, 0.016113, -0.042945, -0.022579, 0.002430, 0.012474, -0.018198, 0.030468, -0.016646, 0.019020, -0.035804, 0.001175, -0.018312, -0.010760,
- ],
- dtype=torch.float32,
-)
-# fmt: on
-SYSTEM_PROMPT_CASES = [
- pytest.param("none", None, SEED_1234, id="none"),
- pytest.param("dynamic", "dynamic", SYSTEM_PROMPT_DYNAMIC, id="dynamic"),
- pytest.param("en_vanilla", "en_vanilla", SYSTEM_EN_VANILLA, id="en_vanilla"),
- pytest.param("en_recaption", "en_recaption", SYSTEM_EN_RECAPTION, id="en_recaption"),
- pytest.param("en_think_recaption", "en_think_recaption", SYSTEM_EN_THINK_RECAPTION, id="en_think_recaption"),
- pytest.param("en_unified", "en_unified", SYSTEM_EN_UNIFIED, id="en_unified"),
-]
-
-
-@pytest.fixture(scope="session")
-def clip_bundle() -> tuple[CLIPModel, CLIPProcessor]:
- try:
- model = CLIPModel.from_pretrained(LOCAL_CLIP_PATH, local_files_only=True)
- processor = CLIPProcessor.from_pretrained(LOCAL_CLIP_PATH, local_files_only=True)
- except OSError as exc:
- pytest.skip(f"Could not load CLIP model from local cache ({LOCAL_CLIP_PATH}): {exc}")
-
- model.eval()
- return model, processor
-
-
-@pytest.fixture(scope="module")
-def omni() -> Generator[Omni, None, None]:
- with OmniRunner(
- MODEL_NAME,
- stage_configs_path=str(STAGE_CONFIG_PATH),
- ) as runner:
- yield runner.omni
-
-
-def _extract_generated_image(outputs: list[object]) -> Image.Image:
- if not outputs:
- raise AssertionError("No outputs were returned from Omni.generate()")
-
- first_output = outputs[0]
- if images := getattr(first_output, "images", None):
- return images[0]
-
- request_output = getattr(first_output, "request_output", None)
- if request_output is not None and (images := getattr(request_output, "images", None)):
- return images[0]
-
- raise AssertionError("No generated image found in Omni output")
-
-
-def extract_embedding(image: Image.Image, clip_model: CLIPModel, clip_processor: CLIPProcessor) -> torch.Tensor:
- inputs = clip_processor(images=image.convert("RGB"), return_tensors="pt")
- with torch.inference_mode():
- features = clip_model.get_image_features(**inputs)
- features = F.normalize(features, p=2, dim=-1)
- return features.squeeze(0)
-
-
-def compare_semantic(
- expected_embedding: torch.Tensor,
- image: Image.Image,
- clip_model: CLIPModel,
- clip_processor: CLIPProcessor,
-) -> float:
- features = extract_embedding(image, clip_model, clip_processor)
- expected = F.normalize(expected_embedding, p=2, dim=-1)
- return torch.dot(expected, features).item()
-
-
-def _generate_image(omni: Omni, use_system_prompt: str | None) -> Image.Image:
- generator_device = current_omni_platform.device_type or "cuda"
- sampling_params = OmniDiffusionSamplingParams(
- seed=1234,
- generator=torch.Generator(device=generator_device).manual_seed(1234),
- num_outputs_per_prompt=1,
- )
- if use_system_prompt is not None:
- sampling_params.extra_args = {"use_system_prompt": use_system_prompt}
-
- outputs = omni.generate({"prompt": PROMPT}, sampling_params)
- return _extract_generated_image(outputs)
-
-
-@pytest.mark.skipif(torch.cuda.device_count() < 8, reason="Need at least 8 CUDA GPUs for this test.")
-@pytest.mark.parametrize("system_prompt_name,use_system_prompt,expected_embedding", SYSTEM_PROMPT_CASES)
-def test_system_prompt_scores(
- omni: Omni,
- clip_bundle: tuple[CLIPModel, CLIPProcessor],
- system_prompt_name: str,
- use_system_prompt: str | None,
- expected_embedding: torch.Tensor,
-) -> None:
- clip_model, clip_processor = clip_bundle
- generated_image = _generate_image(omni, use_system_prompt)
- score = compare_semantic(expected_embedding, generated_image, clip_model, clip_processor)
-
- print(f"{system_prompt_name}: CLIP cosine similarity = {score:.6f}")
diff --git a/tests/e2e/offline_inference/test_ltx2_cfg_parallel_parity.py b/tests/e2e/offline_inference/test_ltx2_cfg_parallel_parity.py
deleted file mode 100644
index 07aa5a647be..00000000000
--- a/tests/e2e/offline_inference/test_ltx2_cfg_parallel_parity.py
+++ /dev/null
@@ -1,243 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import hashlib
-import os
-import subprocess
-import sys
-from pathlib import Path
-
-import numpy as np
-import pytest
-from PIL import Image
-
-from tests.helpers.mark import hardware_test
-
-REPO_ROOT = Path(__file__).resolve().parents[3]
-T2V_EXAMPLE = REPO_ROOT / "examples" / "offline_inference" / "text_to_video" / "text_to_video.py"
-I2V_EXAMPLE = REPO_ROOT / "examples" / "offline_inference" / "image_to_video" / "image_to_video.py"
-
-T2V_PROMPT = (
- "At sunrise, a glowing paper lantern boat drifts through a narrow canal between mossy stone walls, "
- "soft fog above the water, the camera slowly gliding forward as golden reflections shimmer across "
- "the ripples, cinematic, realistic, highly detailed."
-)
-T2V_NEGATIVE_PROMPT = "worst quality, blurry, jittery motion, distorted, oversaturated, artifacts"
-I2V_PROMPT = "A cinematic dolly shot of a boat drifting on calm water at sunset"
-I2V_NEGATIVE_PROMPT = "worst quality, blurry, jittery motion"
-
-os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-
-
-def _get_ltx2_model() -> str:
- return os.environ.get("VLLM_TEST_LTX2_MODEL", "Lightricks/LTX-2")
-
-
-def _md5(path: Path) -> str:
- digest = hashlib.md5(usedforsecurity=False)
- with path.open("rb") as f:
- for chunk in iter(lambda: f.read(1024 * 1024), b""):
- digest.update(chunk)
- return digest.hexdigest()
-
-
-def _make_deterministic_test_image(path: Path) -> None:
- """Create a deterministic 256x256 test image for I2V tests."""
- rng = np.random.RandomState(42)
- img = Image.fromarray(rng.randint(0, 255, (256, 256, 3), dtype=np.uint8))
- img.save(path)
-
-
-def _run_and_check(cmd: list[str], env: dict, output_path: Path, expected_md5: str) -> None:
- result = subprocess.run(cmd, cwd=REPO_ROOT, env=env, capture_output=True, text=True, check=False)
- assert result.returncode == 0, (
- f"Command failed (exit {result.returncode}).\nstdout:\n{result.stdout}\nstderr:\n{result.stderr}"
- )
- generated_md5 = _md5(output_path)
- assert generated_md5 == expected_md5, (
- f"Unexpected output md5: {generated_md5} != {expected_md5}.\nstdout:\n{result.stdout}\nstderr:\n{result.stderr}"
- )
-
-
-# ── T2V tests ──
-
-
-@pytest.mark.advanced_model
-@pytest.mark.diffusion
-@pytest.mark.parallel
-@pytest.mark.slow
-@hardware_test(res={"cuda": "L4"}, num_cards=2)
-def test_ltx2_t2v_cfg_parallel(tmp_path: Path):
- """T2V with CFG=4.0, cfg-parallel-size=2."""
- output = tmp_path / "t2v_cfg4.mp4"
- env = os.environ.copy()
- env.setdefault("CUDA_VISIBLE_DEVICES", "0,1")
- cmd = [
- sys.executable,
- str(T2V_EXAMPLE),
- "--model",
- _get_ltx2_model(),
- "--prompt",
- T2V_PROMPT,
- "--negative-prompt",
- T2V_NEGATIVE_PROMPT,
- "--height",
- "256",
- "--width",
- "256",
- "--num-frames",
- "145",
- "--num-inference-steps",
- "6",
- "--guidance-scale",
- "4.0",
- "--frame-rate",
- "24",
- "--fps",
- "24",
- "--seed",
- "42",
- "--cfg-parallel-size",
- "2",
- "--enforce-eager",
- "--output",
- str(output),
- ]
- _run_and_check(cmd, env, output, expected_md5="08e606b9c522fee4b6f30cee8b77db40")
-
-
-@pytest.mark.advanced_model
-@pytest.mark.diffusion
-@pytest.mark.slow
-@hardware_test(res={"cuda": "L4"}, num_cards=1)
-def test_ltx2_t2v_no_cfg(tmp_path: Path):
- """T2V with CFG=1.0 (no classifier-free guidance)."""
- output = tmp_path / "t2v_nocfg.mp4"
- env = os.environ.copy()
- env.setdefault("CUDA_VISIBLE_DEVICES", "0")
- cmd = [
- sys.executable,
- str(T2V_EXAMPLE),
- "--model",
- _get_ltx2_model(),
- "--prompt",
- T2V_PROMPT,
- "--height",
- "256",
- "--width",
- "256",
- "--num-frames",
- "145",
- "--num-inference-steps",
- "6",
- "--guidance-scale",
- "1.0",
- "--frame-rate",
- "24",
- "--fps",
- "24",
- "--seed",
- "42",
- "--enforce-eager",
- "--output",
- str(output),
- ]
- _run_and_check(cmd, env, output, expected_md5="a83994b94b6e67c54a524e0383c45ce8")
-
-
-# ── I2V tests ──
-
-
-@pytest.mark.advanced_model
-@pytest.mark.diffusion
-@pytest.mark.parallel
-@pytest.mark.slow
-@hardware_test(res={"cuda": "L4"}, num_cards=2)
-def test_ltx2_i2v_cfg_parallel(tmp_path: Path):
- """I2V with CFG=4.0, cfg-parallel-size=2."""
- test_image = tmp_path / "test_input.png"
- _make_deterministic_test_image(test_image)
- output = tmp_path / "i2v_cfg4.mp4"
- env = os.environ.copy()
- env.setdefault("CUDA_VISIBLE_DEVICES", "0,1")
- cmd = [
- sys.executable,
- str(I2V_EXAMPLE),
- "--model",
- _get_ltx2_model(),
- "--model-class-name",
- "LTX2ImageToVideoPipeline",
- "--image",
- str(test_image),
- "--prompt",
- I2V_PROMPT,
- "--negative-prompt",
- I2V_NEGATIVE_PROMPT,
- "--height",
- "256",
- "--width",
- "256",
- "--num-frames",
- "73",
- "--num-inference-steps",
- "6",
- "--guidance-scale",
- "4.0",
- "--frame-rate",
- "24",
- "--fps",
- "24",
- "--seed",
- "42",
- "--cfg-parallel-size",
- "2",
- "--enforce-eager",
- "--output",
- str(output),
- ]
- _run_and_check(cmd, env, output, expected_md5="aed7e56084b36373244d8f839b16d115")
-
-
-@pytest.mark.advanced_model
-@pytest.mark.diffusion
-@pytest.mark.slow
-@hardware_test(res={"cuda": "L4"}, num_cards=1)
-def test_ltx2_i2v_no_cfg(tmp_path: Path):
- """I2V with CFG=1.0 (no classifier-free guidance)."""
- test_image = tmp_path / "test_input.png"
- _make_deterministic_test_image(test_image)
- output = tmp_path / "i2v_nocfg.mp4"
- env = os.environ.copy()
- env.setdefault("CUDA_VISIBLE_DEVICES", "0")
- cmd = [
- sys.executable,
- str(I2V_EXAMPLE),
- "--model",
- _get_ltx2_model(),
- "--model-class-name",
- "LTX2ImageToVideoPipeline",
- "--image",
- str(test_image),
- "--prompt",
- I2V_PROMPT,
- "--height",
- "256",
- "--width",
- "256",
- "--num-frames",
- "73",
- "--num-inference-steps",
- "6",
- "--guidance-scale",
- "1.0",
- "--frame-rate",
- "24",
- "--fps",
- "24",
- "--seed",
- "42",
- "--enforce-eager",
- "--output",
- str(output),
- ]
- _run_and_check(cmd, env, output, expected_md5="81b21ede12753e9e14a357a6c548b666")
diff --git a/tests/e2e/offline_inference/test_magi_human.py b/tests/e2e/offline_inference/test_magi_human.py
deleted file mode 100644
index 6d46141729e..00000000000
--- a/tests/e2e/offline_inference/test_magi_human.py
+++ /dev/null
@@ -1,141 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""End-to-end tests for MagiHuman pipeline via vLLM-Omni."""
-
-import io
-
-import av
-import numpy as np
-import pytest
-
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
-from vllm_omni.diffusion.utils.media_utils import mux_video_audio_bytes
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-
-def _validate_mp4(video_bytes: bytes, min_frames: int = 10) -> None:
- """Validate that the MP4 contains meaningful video and audio tracks."""
- container = av.open(io.BytesIO(video_bytes))
-
- v_streams = [s for s in container.streams if s.type == "video"]
- assert len(v_streams) >= 1, "No video stream found in MP4"
-
- a_streams = [s for s in container.streams if s.type == "audio"]
- assert len(a_streams) >= 1, "No audio stream found in MP4"
-
- v_stream = v_streams[0]
- assert v_stream.width >= 1080, f"Unexpected video width: {v_stream.width}"
- assert v_stream.height >= 1056, f"Unexpected video height: {v_stream.height}"
-
- frame_count = 0
- for frame in container.decode(video=0):
- frame_count += 1
- if frame_count >= min_frames:
- break
- assert frame_count >= min_frames, f"Video has only {frame_count} frames (expected >= {min_frames})"
-
- container.close()
-
-
-@pytest.mark.core_model
-@pytest.mark.advanced_model
-@pytest.mark.diffusion
-@hardware_test(res={"cuda": "H100"}, num_cards=2)
-def test_magi_human_e2e(run_level):
- """End-to-end test for MagiHuman generating video and audio."""
- if run_level != "advanced_model":
- pytest.skip("MagiHuman e2e test requires advanced_model run level with real weights.")
-
- model_path = "SII-GAIR/daVinci-MagiHuman-Base-1080p"
-
- prompt = (
- "A young woman with long, wavy golden blonde hair and bright blue eyes, "
- "wearing a fitted ivory silk blouse with a delicate lace collar, sits "
- "stationary in front of a softly lit, blurred warm-toned interior. Her "
- "overall disposition is warm, composed, and gently confident. The camera "
- "holds a static medium close-up, framing her from the shoulders up, "
- "with shallow depth of field keeping her face in sharp focus. Soft "
- "directional key light falls from the upper left, casting a gentle "
- "highlight along her cheekbone and nose bridge. She draws a quiet breath, "
- "the levator labii superiors relaxing as her lips part. She speaks in "
- "clear, warm, unhurried American English: "
- "\"The most beautiful things in life aren't things at all — "
- "they're moments, feelings, and the people who make you feel truly alive.\" "
- "Her jaw descends smoothly on each stressed syllable; the orbicularis oris "
- "shapes each vowel with precision. A faint, genuine smile engages the "
- "zygomaticus major, lifting her lip corners fractionally. Her brows rest "
- "in a soft, neutral arch throughout. She maintains steady, forward-facing "
- "eye contact. Head position remains level; no torso displacement occurs.\n\n"
- "Dialogue:\n"
- ": "
- "\"The most beautiful things in life aren't things at all — "
- "they're moments, feelings, and the people who make you feel truly alive.\"\n\n"
- "Background Sound:\n"
- ""
- )
-
- sampling_params = OmniDiffusionSamplingParams(
- height=256,
- width=448,
- num_inference_steps=8,
- seed=52,
- extra_args={
- "seconds": 5,
- "sr_height": 1080,
- "sr_width": 1920,
- "sr_num_inference_steps": 5,
- },
- )
-
- with OmniRunner(
- model_path,
- init_timeout=1200,
- tensor_parallel_size=2,
- ) as runner:
- omni = runner.omni
- outputs = list(
- omni.generate(
- prompts=[prompt],
- sampling_params_list=[sampling_params],
- )
- )
-
- assert len(outputs) > 0, "No outputs returned"
- first = outputs[0]
-
- assert hasattr(first, "images") and first.images, "No video frames in output"
- video_frames = first.images[0]
- assert isinstance(video_frames, np.ndarray), f"Expected numpy array, got {type(video_frames)}"
- assert video_frames.ndim == 4, f"Expected 4D array (T,H,W,3), got shape {video_frames.shape}"
-
- mm = first.multimodal_output
- assert mm, "multimodal_output is empty or missing"
-
- audio_waveform = mm.get("audio")
- assert audio_waveform is not None, "No audio waveform in multimodal_output"
-
- audio_sample_rate = mm.get("audio_sample_rate")
- assert audio_sample_rate is not None, (
- "audio_sample_rate not found in multimodal_output; model post-process must propagate it"
- )
- assert isinstance(audio_sample_rate, (int, float)), (
- f"audio_sample_rate should be numeric, got {type(audio_sample_rate)}"
- )
- assert int(audio_sample_rate) > 0, f"audio_sample_rate must be positive, got {audio_sample_rate}"
-
- fps = mm.get("fps")
- assert fps is not None, "fps not found in multimodal_output; model post-process must propagate it"
- assert isinstance(fps, (int, float)), f"fps should be numeric, got {type(fps)}"
- assert int(fps) > 0, f"fps must be positive, got {fps}"
-
- video_bytes = mux_video_audio_bytes(
- video_frames,
- audio_waveform,
- fps=float(fps),
- audio_sample_rate=int(audio_sample_rate),
- )
- assert isinstance(video_bytes, bytes), f"Expected MP4 bytes, got {type(video_bytes)}"
- assert len(video_bytes) > 1000, f"MP4 too small ({len(video_bytes)} bytes)"
-
- _validate_mp4(video_bytes)
diff --git a/tests/e2e/offline_inference/test_mammoth_moda2.py b/tests/e2e/offline_inference/test_mammoth_moda2.py
index c3d95844c11..5293b5ed1b7 100644
--- a/tests/e2e/offline_inference/test_mammoth_moda2.py
+++ b/tests/e2e/offline_inference/test_mammoth_moda2.py
@@ -23,8 +23,7 @@
import torch
from vllm.sampling_params import SamplingParams
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
+from tests.utils import hardware_test
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
@@ -117,6 +116,8 @@ def test_mammothmoda2_t2i_e2e():
- A fixed set of pixel values matches a golden reference
(regenerate with ``UPDATE_GOLDEN=1``).
"""
+ from vllm_omni import Omni
+
if not Path(MODEL_PATH).exists():
pytest.skip(f"Model weights not found at {MODEL_PATH}")
if not Path(T2I_STAGE_CONFIG).exists():
@@ -134,8 +135,8 @@ def test_mammothmoda2_t2i_e2e():
prompt_text = "A cat sitting on a laptop keyboard"
formatted_prompt = _format_t2i_prompt(prompt_text, ar_width, ar_height)
- with OmniRunner(MODEL_PATH, stage_configs_path=T2I_STAGE_CONFIG, trust_remote_code=True) as runner:
- omni = runner.omni
+ omni = Omni(model=MODEL_PATH, stage_configs_path=T2I_STAGE_CONFIG, trust_remote_code=True)
+ try:
# Greedy / deterministic sampling so pixel values are reproducible.
ar_sampling = SamplingParams(
temperature=0.0,
@@ -210,3 +211,5 @@ def test_mammothmoda2_t2i_e2e():
found_image = True
assert found_image, "No image tensor found in pipeline output"
+ finally:
+ omni.close()
diff --git a/tests/e2e/offline_inference/test_ming_flash_omni.py b/tests/e2e/offline_inference/test_ming_flash_omni.py
deleted file mode 100644
index ca0b0fe0d96..00000000000
--- a/tests/e2e/offline_inference/test_ming_flash_omni.py
+++ /dev/null
@@ -1,195 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import os
-
-os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-
-from pathlib import Path
-
-import pytest
-
-from tests.helpers.mark import hardware_test
-from tests.helpers.media import (
- generate_synthetic_audio,
- generate_synthetic_image,
- generate_synthetic_video,
-)
-from tests.helpers.stage_config import modify_stage_config
-
-models = ["Jonathan1909/Ming-flash-omni-2.0"]
-
-# Ming-specific
-SYSTEM_PROMPT = "你是一个友好的AI助手。\n\ndetailed thinking off"
-EOS_TOKEN = "<|role_end|>"
-IMAGE_TOKEN = ""
-VIDEO_TOKEN = ""
-AUDIO_TOKEN = ""
-
-
-def build_prompt(user_text: str) -> str:
- """Build a Ming chat prompt."""
- return (
- f"SYSTEM {SYSTEM_PROMPT}{EOS_TOKEN}HUMAN {user_text}{EOS_TOKEN}ASSISTANT "
- )
-
-
-def get_eager_config():
- path = modify_stage_config(
- str(Path(__file__).parent.parent / "stage_configs" / "bailingmm_moe_v2_lite_thinker_only_ci.yaml"),
- updates={
- "stage_args": {
- 0: {
- "engine_args.enforce_eager": "true",
- },
- },
- },
- )
- return path
-
-
-def get_eager_tts_config():
- """Thinker+talker CI config with enforce_eager on the thinker stage."""
- path = modify_stage_config(
- str(Path(__file__).parent.parent / "stage_configs" / "bailingmm_moe_v2_lite_ci.yaml"),
- updates={
- "stage_args": {
- 0: {
- "engine_args.enforce_eager": "true",
- },
- },
- },
- )
- return path
-
-
-# Thinker-only config — used by text-output tests.
-stage_configs = [get_eager_config()]
-test_params = [(model, stage_config) for model in models for stage_config in stage_configs]
-
-# Thinker+talker config — used by audio-output tests.
-stage_configs_tts = [get_eager_tts_config()]
-test_params_tts = [(model, stage_config) for model in models for stage_config in stage_configs_tts]
-
-
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=4)
-@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
-def test_text_to_text(omni_runner, omni_runner_handler) -> None:
- """
- Test text-only input processing and text output generation.
- Input Modal: text
- Output Modal: text
- """
- prompt = build_prompt("请详细介绍鹦鹉的生活习性。")
- request_config = {"prompts": prompt, "modalities": ["text"]}
-
- omni_runner_handler.send_request(request_config)
-
-
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=4)
-@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
-def test_image_to_text(omni_runner, omni_runner_handler) -> None:
- """
- Test image understanding with text output.
- Input Modal: image + text
- Output Modal: text
- """
- image = generate_synthetic_image(224, 224)["np_array"]
- prompt = build_prompt(f"{IMAGE_TOKEN}Describe this image briefly.")
- request_config = {"prompts": prompt, "images": image, "modalities": ["text"]}
-
- omni_runner_handler.send_request(request_config)
-
-
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=4)
-@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
-def test_audio_to_text(omni_runner, omni_runner_handler) -> None:
- """
- Test audio understanding with text output.
- Input Modal: audio + text
- Output Modal: text
- """
- audio = generate_synthetic_audio(2, 1, 16000)["np_array"]
- if len(audio.shape) == 2:
- audio = audio.squeeze()
- prompt = build_prompt(f"{AUDIO_TOKEN}Please recognize the language of this speech and transcribe it. Format: oral.")
- request_config = {"prompts": prompt, "audios": audio, "modalities": ["text"]}
-
- omni_runner_handler.send_request(request_config)
-
-
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=4)
-@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
-def test_video_to_text(omni_runner, omni_runner_handler) -> None:
- """
- Test video understanding with text output.
- Input Modal: video + text
- Output Modal: text
- """
- video = generate_synthetic_video(224, 224, 30)["np_array"]
- prompt = build_prompt(f"{VIDEO_TOKEN}Describe what is happening in this video.")
- request_config = {"prompts": prompt, "videos": video, "modalities": ["text"]}
-
- omni_runner_handler.send_request(request_config)
-
-
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=4)
-@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
-def test_mixed_to_text(omni_runner, omni_runner_handler) -> None:
- """
- Test mixed modality input (image + audio) with text output.
- Input Modal: image + audio + text
- Output Modal: text
- """
- image = generate_synthetic_image(224, 224)["np_array"]
- audio = generate_synthetic_audio(2, 1, 16000)["np_array"]
- if len(audio.shape) == 2:
- audio = audio.squeeze()
- prompt = build_prompt(f"{IMAGE_TOKEN}{AUDIO_TOKEN}Describe the image and transcribe the audio.")
- request_config = {"prompts": prompt, "images": image, "audios": audio, "modalities": ["text"]}
-
- omni_runner_handler.send_request(request_config)
-
-
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=4)
-@pytest.mark.parametrize("omni_runner", test_params_tts, indirect=True)
-def test_text_to_audio(omni_runner, omni_runner_handler) -> None:
- """
- Test text input with audio output via the thinker+talker pipeline.
- Input Modal: text
- Output Modal: audio
- """
- prompt = build_prompt("请简单介绍一下北京。")
- request_config = {"prompts": prompt, "modalities": ["audio"]}
-
- omni_runner_handler.send_request(request_config)
-
-
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=4)
-@pytest.mark.parametrize("omni_runner", test_params_tts, indirect=True)
-def test_image_to_audio(omni_runner, omni_runner_handler) -> None:
- """
- Test image + text input with audio output via the thinker+talker pipeline.
- Input Modal: image + text
- Output Modal: audio
- """
- image = generate_synthetic_image(224, 224)["np_array"]
- prompt = build_prompt(f"{IMAGE_TOKEN}Describe this image briefly.")
- request_config = {"prompts": prompt, "images": image, "modalities": ["audio"]}
-
- omni_runner_handler.send_request(request_config)
diff --git a/tests/e2e/offline_inference/test_omni_sleep_mode.py b/tests/e2e/offline_inference/test_omni_sleep_mode.py
deleted file mode 100644
index 5a3ae9ab728..00000000000
--- a/tests/e2e/offline_inference/test_omni_sleep_mode.py
+++ /dev/null
@@ -1,159 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""
-End-to-end tests for Omni Sleep Mode across various model architectures.
-"""
-
-import pytest
-import torch
-from vllm import SamplingParams
-
-from tests.helpers.mark import hardware_test
-from vllm_omni.entrypoints.async_omni import AsyncOmni
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-MODEL = "ByteDance-Seed/BAGEL-7B-MoT"
-MODEL_DIFF = "riverclouds/qwen_image_random"
-
-
-def get_ack_info(ack, key, default=None):
- if hasattr(ack, key):
- return getattr(ack, key)
- if isinstance(ack, dict):
- return ack.get(key, default)
- return default
-
-
-def get_dynamic_devices(stage_idx, num_stages, tp_size):
- total_gpus = torch.cuda.device_count()
- gpus_per_stage = tp_size
- start_idx = stage_idx * gpus_per_stage
- if start_idx + gpus_per_stage > total_gpus:
- start_idx = start_idx % total_gpus
- device_ids = [str(start_idx + i) for i in range(gpus_per_stage)]
- return ",".join(device_ids)
-
-
-# Test 1: Diffusion Model (2-Stage BAGEL)
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@pytest.mark.parametrize("tp_size", [1])
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=1)
-@pytest.mark.asyncio
-async def test_diffusion_model_sleep_tp(tp_size):
- num_gpus = torch.cuda.device_count()
- if num_gpus < tp_size:
- pytest.skip(f"Skipping TP={tp_size}")
-
- engine_args = {
- "model": MODEL,
- "enable_sleep_mode": True,
- "tensor_parallel_size": tp_size,
- "enforce_eager": True,
- "trust_remote_code": True,
- "dtype": "bfloat16",
- "gpu_memory_utilization": 0.5,
- }
-
- engine = AsyncOmni(**engine_args, stage_init_timeout=1200)
- try:
- # BAGEL requires 2 params
- diff_sp = OmniDiffusionSamplingParams(num_inference_steps=2, height=256, width=256)
- llm_sp = SamplingParams()
-
- # Warmup
- async for _ in engine.generate("test", sampling_params=[llm_sp, diff_sp]):
- pass
-
- # Sleep all
- acks = await engine.sleep(level=2)
- statuses = [get_ack_info(ack, "status") for ack in acks]
- assert all(s == "SUCCESS" for s in statuses), f"Sleep failed. Statuses: {statuses}"
-
- # Wakeup & Verify
- await engine.wake_up()
- async for _ in engine.generate("verify", sampling_params=[llm_sp, diff_sp]):
- pass
-
- print(f"Diffusion TP={tp_size} Lifecycle OK")
- finally:
- engine.shutdown()
-
-
-# Test 2: Multi-stage Manual Config
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@pytest.mark.parametrize("tp_size", [1, 2])
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
-@pytest.mark.asyncio
-async def test_multistage_sleep_h100(tp_size):
- num_gpus = torch.cuda.device_count()
- if num_gpus < tp_size * 2:
- pytest.skip("Not enough GPUs")
-
- stages = []
- for i in range(2):
- devs = get_dynamic_devices(i, 2, tp_size)
- stages.append(
- {
- "stage_id": i,
- "stage_type": "llm" if i == 0 else "diffusion",
- "runtime": {"process": True, "devices": devs},
- "engine_args": {
- "model": MODEL,
- "model_stage": "thinker" if i == 0 else "base",
- "tensor_parallel_size": tp_size,
- "gpu_memory_utilization": 0.4,
- "dtype": "bfloat16",
- "enable_sleep_mode": True,
- "trust_remote_code": True,
- },
- }
- )
-
- connectors = [{"src_stage_id": 0, "dst_stage_id": 1, "connector_type": "queue"}]
-
- engine = AsyncOmni(
- model=MODEL, stages=stages, connectors=connectors, enable_sleep_mode=True, stage_init_timeout=1200
- )
- try:
- sp = OmniDiffusionSamplingParams(num_inference_steps=2)
- async for _ in engine.generate("warmup", sampling_params=[SamplingParams(), sp]):
- pass
-
- acks = await engine.sleep(stage_ids=[0, 1], level=2)
- assert len(acks) == 2
-
- await engine.wake_up(stage_ids=[0, 1])
- async for _ in engine.generate("verify", sampling_params=[SamplingParams(), sp]):
- pass
- finally:
- engine.shutdown()
-
-
-# Test 3: Pure Diffusion Single-Stage
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@pytest.mark.parametrize("tp_size", [1, 2])
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
-@pytest.mark.asyncio
-async def test_pure_diffusion_scenario(tp_size):
- engine_args = {
- "model": MODEL_DIFF,
- "enable_sleep_mode": True,
- "tensor_parallel_size": tp_size,
- "enforce_eager": True,
- "dtype": "bfloat16",
- "gpu_memory_utilization": 0.5,
- }
-
- engine = AsyncOmni(**engine_args, stage_init_timeout=1200)
- try:
- await engine.sleep(level=1)
- await engine.wake_up()
- async for _ in engine.generate("test", sampling_params=[SamplingParams()]):
- pass
- print("Pure Diffusion OK")
- finally:
- engine.shutdown()
diff --git a/tests/e2e/offline_inference/test_omnivoice.py b/tests/e2e/offline_inference/test_omnivoice.py
index 30a3427bee6..4b093e357d9 100644
--- a/tests/e2e/offline_inference/test_omnivoice.py
+++ b/tests/e2e/offline_inference/test_omnivoice.py
@@ -16,8 +16,7 @@
import numpy as np
import pytest
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
+from tests.utils import hardware_test
MODEL = "k2-fsa/OmniVoice"
@@ -38,42 +37,48 @@ def test_omnivoice_text_to_audio() -> None:
Input Modal: text
Output Modal: audio
"""
- from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+ from vllm_omni.entrypoints.omni import Omni
- with OmniRunner(
- MODEL,
+ omni = Omni(
+ model=MODEL,
stage_configs_path=get_stage_config(),
trust_remote_code=True,
log_stats=True,
- ) as runner:
+ )
+
+ try:
prompts = {"prompt": "Hello, this is a test for text to audio."}
+ from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+
sampling_params_list = [OmniDiffusionSamplingParams()]
- outputs = list(runner.omni.generate(prompts, sampling_params_list=sampling_params_list))
+ outputs = list(omni.generate(prompts, sampling_params_list=sampling_params_list))
- assert len(outputs) > 0, "No outputs generated"
+ assert len(outputs) > 0, "No outputs generated"
- # Check final output has audio
- final_output = outputs[-1]
- ro = final_output.request_output
- assert ro is not None, "No request_output"
+ # Check final output has audio
+ final_output = outputs[-1]
+ ro = final_output.request_output
+ assert ro is not None, "No request_output"
- mm = getattr(ro, "multimodal_output", None)
- if not mm and ro.outputs:
- mm = getattr(ro.outputs[0], "multimodal_output", None)
+ mm = getattr(ro, "multimodal_output", None)
+ if not mm and ro.outputs:
+ mm = getattr(ro.outputs[0], "multimodal_output", None)
- assert mm is not None, "No multimodal_output"
- assert "audio" in mm, f"No 'audio' key in multimodal_output: {mm.keys()}"
+ assert mm is not None, "No multimodal_output"
+ assert "audio" in mm, f"No 'audio' key in multimodal_output: {mm.keys()}"
- audio = mm["audio"]
- if isinstance(audio, np.ndarray):
- audio_np = audio
- else:
- audio_np = audio.cpu().numpy().squeeze()
+ audio = mm["audio"]
+ if isinstance(audio, np.ndarray):
+ audio_np = audio
+ else:
+ audio_np = audio.cpu().numpy().squeeze()
- assert audio_np.size > 0, "Audio output is empty"
- rms = np.sqrt(np.mean(audio_np**2))
- assert rms > 0.01, f"Audio RMS too low ({rms:.4f}), likely silence"
+ assert audio_np.size > 0, "Audio output is empty"
+ rms = np.sqrt(np.mean(audio_np**2))
+ assert rms > 0.01, f"Audio RMS too low ({rms:.4f}), likely silence"
- print(f"Generated audio: {len(audio_np) / 24000:.2f}s, rms={rms:.4f}")
+ print(f"Generated audio: {len(audio_np) / 24000:.2f}s, rms={rms:.4f}")
+ finally:
+ omni.close()
diff --git a/tests/e2e/offline_inference/test_ovis_image.py b/tests/e2e/offline_inference/test_ovis_image.py
index 70fab4fe101..41e21bca3a9 100644
--- a/tests/e2e/offline_inference/test_ovis_image.py
+++ b/tests/e2e/offline_inference/test_ovis_image.py
@@ -16,7 +16,7 @@
import torch
from pytest_mock import MockerFixture
-from tests.helpers.mark import hardware_test
+from tests.utils import hardware_test
from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig
# Mock the OvisImageTransformer2DModel to avoid complex init if needed,
diff --git a/tests/e2e/offline_inference/test_quantization_fp8.py b/tests/e2e/offline_inference/test_quantization_fp8.py
index 9801e0ae797..f71c53de74c 100644
--- a/tests/e2e/offline_inference/test_quantization_fp8.py
+++ b/tests/e2e/offline_inference/test_quantization_fp8.py
@@ -29,23 +29,22 @@
import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
+from pathlib import Path
from typing import Any
import pytest
import torch
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
-from tests.helpers.stage_config import get_deploy_config_path
+from tests.utils import hardware_test
+from vllm_omni.entrypoints.omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
pytestmark = [pytest.mark.core_model, pytest.mark.diffusion]
-os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-
# ─── helpers ──────────────────────────────────────────────────────────────────
@@ -62,15 +61,16 @@ def _generate_single_stage_image(
Returns (images, peak_memory_gib).
"""
- omni_kwargs: dict[str, Any] = dict(extra_omni_kwargs)
+ omni_kwargs: dict[str, Any] = {"model": model, **extra_omni_kwargs}
if quantization:
omni_kwargs["quantization"] = quantization
- with OmniRunner(model, **omni_kwargs) as runner:
+ omni = Omni(**omni_kwargs)
+ try:
torch.cuda.reset_peak_memory_stats()
generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(seed)
- outputs = runner.omni.generate(
+ outputs = omni.generate(
"a photo of a cat sitting on a laptop keyboard",
OmniDiffusionSamplingParams(
height=height,
@@ -94,6 +94,8 @@ def _generate_single_stage_image(
assert images[0].height == height
return images, peak_mem
+ finally:
+ omni.close()
def _generate_bagel_image(
@@ -104,7 +106,7 @@ def _generate_bagel_image(
Returns (generated_image, peak_memory_gib).
"""
- config_path = get_deploy_config_path("ci/bagel.yaml")
+ config_path = str(Path(__file__).parent / "stage_configs" / "bagel_sharedmemory_ci.yaml")
omni_kwargs: dict[str, Any] = {
"model": "ByteDance-Seed/BAGEL-7B-MoT",
"stage_configs_path": config_path,
@@ -113,9 +115,8 @@ def _generate_bagel_image(
if quantization_config:
omni_kwargs["quantization_config"] = quantization_config
- model_name = omni_kwargs.pop("model")
- with OmniRunner(model_name, **omni_kwargs) as runner:
- omni = runner.omni
+ omni = Omni(**omni_kwargs)
+ try:
torch.cuda.reset_peak_memory_stats()
params_list = omni.default_sampling_params_list
@@ -167,6 +168,8 @@ def _generate_bagel_image(
)
return generated_image, peak_mem
+ finally:
+ omni.close()
# ─── Single-stage diffusion model tests ──────────────────────────────────────
diff --git a/tests/e2e/offline_inference/test_qwen2_5_omni.py b/tests/e2e/offline_inference/test_qwen2_5_omni.py
index 8ea41b00778..4c4315aab9c 100644
--- a/tests/e2e/offline_inference/test_qwen2_5_omni.py
+++ b/tests/e2e/offline_inference/test_qwen2_5_omni.py
@@ -2,39 +2,46 @@
E2E tests for Qwen2.5-Omni model with mixed modality inputs, audio and text output.
"""
+from pathlib import Path
+
import pytest
-from tests.helpers.mark import hardware_test
-from tests.helpers.media import (
+from tests.conftest import (
generate_synthetic_audio,
generate_synthetic_image,
generate_synthetic_video,
+ modify_stage_config,
)
-from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
+from tests.utils import hardware_test
from vllm_omni.platforms import current_omni_platform
models = ["Qwen/Qwen2.5-Omni-7B"]
-# Single CI deploy YAML; rocm/xpu deltas are picked automatically via the
-# platforms: section. NPU still uses the legacy per-platform YAML until it
-# also migrates to the new schema.
-_CI_DEPLOY = get_deploy_config_path("ci/qwen2_5_omni.yaml")
-
def get_cuda_graph_config():
- return modify_stage_config(
- _CI_DEPLOY,
+ path = modify_stage_config(
+ str(Path(__file__).parent.parent / "stage_configs" / "qwen2_5_omni_ci.yaml"),
updates={
- "stages": {
- 0: {"enforce_eager": True},
- 1: {"enforce_eager": True},
+ "stage_args": {
+ 0: {
+ "engine_args.enforce_eager": "true",
+ },
+ 1: {"engine_args.enforce_eager": "true"},
},
},
)
-
-
-if current_omni_platform.is_rocm() or current_omni_platform.is_xpu() or current_omni_platform.is_npu():
- stage_config = _CI_DEPLOY
+ return path
+
+
+# CI stage config optimized for 24GB GPU (L4/RTX3090) or NPU
+if current_omni_platform.is_npu():
+ stage_config = str(Path(__file__).parent / "stage_configs" / "npu" / "qwen2_5_omni_ci.yaml")
+elif current_omni_platform.is_rocm():
+ # ROCm stage config optimized for MI325 GPU
+ stage_config = str(Path(__file__).parent.parent / "stage_configs" / "rocm" / "qwen2_5_omni_ci.yaml")
+elif current_omni_platform.is_xpu():
+ # Intel XPU stage config optimized for B60 GPU
+ stage_config = str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen2_5_omni_ci.yaml")
else:
stage_config = get_cuda_graph_config()
diff --git a/tests/e2e/offline_inference/test_qwen2_5_omni_autoround_w4a16.py b/tests/e2e/offline_inference/test_qwen2_5_omni_autoround_w4a16.py
deleted file mode 100644
index e59a8c7d709..00000000000
--- a/tests/e2e/offline_inference/test_qwen2_5_omni_autoround_w4a16.py
+++ /dev/null
@@ -1,170 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""E2E tests for Qwen2.5-Omni AutoRound W4A16 quantized inference.
-
-These tests cover text, audio, image, video, and mixed-modality inputs
-to verify multimodal understanding with quantized weights.
-
-Requirements:
- - CUDA GPUs (4x L4 / 24 GB or equivalent)
- - The quantized model checkpoint (Intel/Qwen2.5-Omni-7B-int4-AutoRound)
-"""
-
-import os
-
-import pytest
-
-from tests.helpers.mark import hardware_test
-from tests.helpers.media import (
- generate_synthetic_audio,
- generate_synthetic_image,
- generate_synthetic_video,
-)
-from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
-
-QUANTIZED_MODEL = "Intel/Qwen2.5-Omni-7B-int4-AutoRound"
-BASELINE_MODEL = "Qwen/Qwen2.5-Omni-7B"
-
-# Allow overriding via environment for local testing
-QUANTIZED_MODEL = os.environ.get("QWEN2_5_OMNI_AUTOROUND_MODEL", QUANTIZED_MODEL)
-BASELINE_MODEL = os.environ.get("QWEN2_5_OMNI_BASELINE_MODEL", BASELINE_MODEL)
-
-_CI_DEPLOY = get_deploy_config_path("ci/qwen2_5_omni.yaml")
-
-
-def _get_stage_config():
- """Build a CI-friendly stage config with eager mode."""
- return modify_stage_config(
- _CI_DEPLOY,
- updates={
- "stages": {
- 0: {"enforce_eager": True},
- 1: {"enforce_eager": True},
- },
- },
- )
-
-
-stage_config = _get_stage_config()
-
-# Parametrise: (model, stage_config)
-quant_params = [(QUANTIZED_MODEL, stage_config)]
-
-
-# ------------------------------------------------------------------
-# Test: text-only input → text output
-# ------------------------------------------------------------------
-
-
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "L4"}, num_cards=4)
-@pytest.mark.parametrize("omni_runner", quant_params, indirect=True)
-def test_text_to_text(omni_runner, omni_runner_handler):
- """Text input → text output with W4A16 quantized Qwen2.5-Omni."""
- request_config = {
- "prompts": "What is the capital of China?",
- "modalities": ["text"],
- }
- response = omni_runner_handler.send_request(request_config)
- assert response.success, f"Request failed: {response.error_message}"
- assert response.text_content and len(response.text_content.strip()) > 0
-
-
-# ------------------------------------------------------------------
-# Test: audio input → text output
-# ------------------------------------------------------------------
-
-
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "L4"}, num_cards=4)
-@pytest.mark.parametrize("omni_runner", quant_params, indirect=True)
-def test_audio_to_text(omni_runner, omni_runner_handler):
- """Audio input → text output with W4A16 quantized Qwen2.5-Omni."""
- audio = generate_synthetic_audio(1, 1, 16000)["np_array"]
- if len(audio.shape) == 2:
- audio = audio.squeeze()
-
- request_config = {
- "prompts": "What is the content of this audio?",
- "audios": (audio, 16000),
- "modalities": ["text"],
- }
- response = omni_runner_handler.send_request(request_config)
- assert response.success, f"Request failed: {response.error_message}"
- assert response.text_content and len(response.text_content.strip()) > 0
-
-
-# ------------------------------------------------------------------
-# Test: image input → text output
-# ------------------------------------------------------------------
-
-
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "L4"}, num_cards=4)
-@pytest.mark.parametrize("omni_runner", quant_params, indirect=True)
-def test_image_to_text(omni_runner, omni_runner_handler):
- """Image input → text output with W4A16 quantized Qwen2.5-Omni."""
- image = generate_synthetic_image(16, 16)["np_array"]
-
- request_config = {
- "prompts": "Describe what you see in this image.",
- "images": image,
- "modalities": ["text"],
- }
- response = omni_runner_handler.send_request(request_config)
- assert response.success, f"Request failed: {response.error_message}"
- assert response.text_content and len(response.text_content.strip()) > 0
-
-
-# ------------------------------------------------------------------
-# Test: video input → text output
-# ------------------------------------------------------------------
-
-
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "L4"}, num_cards=4)
-@pytest.mark.parametrize("omni_runner", quant_params, indirect=True)
-def test_video_to_text(omni_runner, omni_runner_handler):
- """Video input → text output with W4A16 quantized Qwen2.5-Omni."""
- video = generate_synthetic_video(16, 16, 30)["np_array"]
-
- request_config = {
- "prompts": "Describe the video briefly.",
- "videos": video,
- "modalities": ["text"],
- }
- response = omni_runner_handler.send_request(request_config)
- assert response.success, f"Request failed: {response.error_message}"
- assert response.text_content and len(response.text_content.strip()) > 0
-
-
-# ------------------------------------------------------------------
-# Test: mixed modality (audio + image + video) → audio output
-# ------------------------------------------------------------------
-
-
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "L4"}, num_cards=4)
-@pytest.mark.parametrize("omni_runner", quant_params, indirect=True)
-def test_mix_to_audio(omni_runner, omni_runner_handler):
- """Mixed-modality input → audio output with W4A16 quantized Qwen2.5-Omni."""
- video = generate_synthetic_video(16, 16, 30)["np_array"]
- image = generate_synthetic_image(16, 16)["np_array"]
- audio = generate_synthetic_audio(1, 1, 16000)["np_array"]
- if len(audio.shape) == 2:
- audio = audio.squeeze()
-
- request_config = {
- "prompts": "What is recited in the audio? What is in this image? Describe the video briefly.",
- "videos": video,
- "images": image,
- "audios": (audio, 16000),
- "modalities": ["audio"],
- }
- response = omni_runner_handler.send_request(request_config)
- assert response.success, f"Request failed: {response.error_message}"
diff --git a/tests/e2e/offline_inference/test_qwen3_omni.py b/tests/e2e/offline_inference/test_qwen3_omni.py
index c4d257b5114..cc0af437eca 100644
--- a/tests/e2e/offline_inference/test_qwen3_omni.py
+++ b/tests/e2e/offline_inference/test_qwen3_omni.py
@@ -7,35 +7,41 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
+from pathlib import Path
+
import pytest
-from tests.helpers.mark import hardware_test
-from tests.helpers.media import generate_synthetic_video
-from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
+from tests.conftest import (
+ generate_synthetic_video,
+ modify_stage_config,
+)
+from tests.utils import hardware_test
from vllm_omni.platforms import current_omni_platform
models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
-# Single CI deploy YAML; rocm/xpu deltas are picked automatically via the
-# platforms: section. Only CUDA needs an extra enforce_eager tweak.
-_CI_DEPLOY = get_deploy_config_path("ci/qwen3_omni_moe.yaml")
-
-
def get_cuda_graph_config():
- return modify_stage_config(
- _CI_DEPLOY,
+ path = modify_stage_config(
+ str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml"),
updates={
- "stages": {
- 0: {"enforce_eager": True},
- 1: {"enforce_eager": True},
+ "stage_args": {
+ 0: {
+ "engine_args.enforce_eager": "true",
+ },
+ 1: {"engine_args.enforce_eager": "true"},
},
},
)
+ return path
-if current_omni_platform.is_rocm() or current_omni_platform.is_xpu():
- stage_configs = [_CI_DEPLOY]
+# CI stage config for 2xH100-80G GPUs or AMD GPU MI325
+if current_omni_platform.is_rocm():
+ # ROCm stage config optimized for MI325 GPU
+ stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "rocm" / "qwen3_omni_ci.yaml")]
+elif current_omni_platform.is_xpu():
+ stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml")]
else:
stage_configs = [get_cuda_graph_config()]
diff --git a/tests/e2e/offline_inference/test_qwen3_omni_autoround_w4a16.py b/tests/e2e/offline_inference/test_qwen3_omni_autoround_w4a16.py
deleted file mode 100644
index 3a3c874b64b..00000000000
--- a/tests/e2e/offline_inference/test_qwen3_omni_autoround_w4a16.py
+++ /dev/null
@@ -1,205 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""E2E tests for Qwen3-Omni AutoRound W4A16 quantized inference.
-
-These tests cover text, audio, image, video, and mixed-modality inputs
-to verify multimodal understanding with quantized weights.
-
-Requirements:
- - CUDA GPUs (2x H100-80G or equivalent)
- - The quantized model checkpoint (Intel/Qwen3-Omni-30B-A3B-Instruct-int4-AutoRound)
-"""
-
-import os
-
-import pytest
-
-from tests.helpers.mark import hardware_test
-from tests.helpers.media import (
- generate_synthetic_audio,
- generate_synthetic_image,
- generate_synthetic_video,
-)
-from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
-
-QUANTIZED_MODEL = "Intel/Qwen3-Omni-30B-A3B-Instruct-int4-AutoRound"
-BASELINE_MODEL = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
-
-# Allow overriding via environment for local testing
-QUANTIZED_MODEL = os.environ.get("QWEN3_OMNI_AUTOROUND_MODEL", QUANTIZED_MODEL)
-BASELINE_MODEL = os.environ.get("QWEN3_OMNI_BASELINE_MODEL", BASELINE_MODEL)
-
-_CI_DEPLOY = get_deploy_config_path("ci/qwen3_omni_moe.yaml")
-
-
-@pytest.fixture(scope="module", autouse=True)
-def _qwen3_omni_env():
- """Set env vars required by multi-stage worker spawning.
-
- Must run before CUDA context init. Reverted after every test module
- so that values do not leak into unrelated test files.
- """
- with pytest.MonkeyPatch.context() as mp:
- mp.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
- mp.setenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0")
- yield
-
-
-def _get_stage_config():
- """Build a CI-friendly stage config with eager mode."""
- return modify_stage_config(
- _CI_DEPLOY,
- updates={
- "stages": {
- 0: {"enforce_eager": True},
- 1: {"enforce_eager": True},
- },
- },
- )
-
-
-stage_config = _get_stage_config()
-
-# Parametrise: (model, stage_config)
-quant_params = [(QUANTIZED_MODEL, stage_config)]
-
-
-# ------------------------------------------------------------------
-# Test: text-only input → text output
-# ------------------------------------------------------------------
-
-
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=2)
-@pytest.mark.parametrize("omni_runner", quant_params, indirect=True)
-def test_text_to_text(omni_runner, omni_runner_handler):
- """Text input → text output with W4A16 quantized Qwen3-Omni."""
- request_config = {
- "prompts": "What is the capital of France?",
- "modalities": ["text"],
- }
- response = omni_runner_handler.send_request(request_config)
- assert response.success, f"Request failed: {response.error_message}"
- assert response.text_content and len(response.text_content.strip()) > 0
-
-
-# ------------------------------------------------------------------
-# Test: audio input → text output
-# ------------------------------------------------------------------
-
-
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=2)
-@pytest.mark.parametrize("omni_runner", quant_params, indirect=True)
-def test_audio_to_text(omni_runner, omni_runner_handler):
- """Audio input → text output with W4A16 quantized Qwen3-Omni."""
- audio = generate_synthetic_audio(1, 1, 16000)["np_array"]
- if len(audio.shape) == 2:
- audio = audio.squeeze()
-
- request_config = {
- "prompts": "What is the content of this audio?",
- "audios": (audio, 16000),
- "modalities": ["text"],
- }
- response = omni_runner_handler.send_request(request_config)
- assert response.success, f"Request failed: {response.error_message}"
- assert response.text_content and len(response.text_content.strip()) > 0
-
-
-# ------------------------------------------------------------------
-# Test: image input → text output
-# ------------------------------------------------------------------
-
-
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=2)
-@pytest.mark.parametrize("omni_runner", quant_params, indirect=True)
-def test_image_to_text(omni_runner, omni_runner_handler):
- """Image input → text output with W4A16 quantized Qwen3-Omni."""
- image = generate_synthetic_image(16, 16)["np_array"]
-
- request_config = {
- "prompts": "Describe what you see in this image.",
- "images": image,
- "modalities": ["text"],
- }
- response = omni_runner_handler.send_request(request_config)
- assert response.success, f"Request failed: {response.error_message}"
- assert response.text_content and len(response.text_content.strip()) > 0
-
-
-# ------------------------------------------------------------------
-# Test: video input → text output
-# ------------------------------------------------------------------
-
-
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=2)
-@pytest.mark.parametrize("omni_runner", quant_params, indirect=True)
-def test_video_to_text(omni_runner, omni_runner_handler):
- """Video input → text output with W4A16 quantized Qwen3-Omni."""
- video = generate_synthetic_video(224, 224, 300)["np_array"]
-
- request_config = {
- "prompts": "Describe the video briefly.",
- "videos": video,
- "modalities": ["text"],
- }
- response = omni_runner_handler.send_request(request_config)
- assert response.success, f"Request failed: {response.error_message}"
- assert response.text_content and len(response.text_content.strip()) > 0
-
-
-# ------------------------------------------------------------------
-# Test: video input → audio output
-# ------------------------------------------------------------------
-
-
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=2)
-@pytest.mark.parametrize("omni_runner", quant_params, indirect=True)
-def test_video_to_audio(omni_runner, omni_runner_handler):
- """Video input → audio output with W4A16 quantized Qwen3-Omni."""
- video = generate_synthetic_video(224, 224, 300)["np_array"]
-
- request_config = {
- "prompts": "Describe the video briefly.",
- "videos": video,
- "modalities": ["audio"],
- }
- response = omni_runner_handler.send_request(request_config)
- assert response.success, f"Request failed: {response.error_message}"
-
-
-# ------------------------------------------------------------------
-# Test: mixed modality (audio + image + video) → audio output
-# ------------------------------------------------------------------
-
-
-@pytest.mark.advanced_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=2)
-@pytest.mark.parametrize("omni_runner", quant_params, indirect=True)
-def test_mix_to_audio(omni_runner, omni_runner_handler):
- """Mixed-modality input → audio output with W4A16 quantized Qwen3-Omni."""
- video = generate_synthetic_video(224, 224, 300)["np_array"]
- image = generate_synthetic_image(16, 16)["np_array"]
- audio = generate_synthetic_audio(1, 1, 16000)["np_array"]
- if len(audio.shape) == 2:
- audio = audio.squeeze()
-
- request_config = {
- "prompts": "What is recited in the audio? What is in this image? Describe the video briefly.",
- "videos": video,
- "images": image,
- "audios": (audio, 16000),
- "modalities": ["audio"],
- }
- response = omni_runner_handler.send_request(request_config)
- assert response.success, f"Request failed: {response.error_message}"
diff --git a/tests/e2e/offline_inference/test_qwen3_tts_base.py b/tests/e2e/offline_inference/test_qwen3_tts_base.py
index af2b5195b98..be7bd50a36a 100644
--- a/tests/e2e/offline_inference/test_qwen3_tts_base.py
+++ b/tests/e2e/offline_inference/test_qwen3_tts_base.py
@@ -13,10 +13,12 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
+from pathlib import Path
+
import pytest
-from tests.helpers.mark import hardware_test
-from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
+from tests.conftest import modify_stage_config
+from tests.utils import hardware_test
MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
REF_AUDIO_URL = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav"
@@ -24,31 +26,23 @@
def get_cuda_graph_config():
- """Build a temp deploy yaml mirroring the deleted qwen3_tts_no_async_chunk.yaml.
-
- Composes the synchronous (no-async-chunk) variant on top of the bundled
- qwen3_tts.yaml prod default, with cudagraphs disabled. Replaces the deleted
- standalone variant yaml; same effective config, no checked-in file needed.
- """
- return modify_stage_config(
- get_deploy_config_path("qwen3_tts.yaml"),
+ path = modify_stage_config(
+ get_stage_config(),
updates={
- "async_chunk": False,
- "stages": {
+ "stage_args": {
0: {
- "max_num_seqs": 1,
- "gpu_memory_utilization": 0.2,
- "enforce_eager": True,
- "async_scheduling": False,
- },
- 1: {
- "gpu_memory_utilization": 0.2,
- "enforce_eager": True,
- "async_scheduling": False,
+ "engine_args.enforce_eager": "true",
},
+ 1: {"engine_args.enforce_eager": "true"},
},
},
)
+ return path
+
+
+def get_stage_config(name: str = "qwen3_tts_no_async_chunk.yaml"):
+ """Get the no_async_chunk stage config path (async_chunk disable, cuda_graph disabled)."""
+ return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
# Same structure as test_qwen3_omni: models, stage_configs, test_params
diff --git a/tests/e2e/offline_inference/test_qwen3_tts_customvoice.py b/tests/e2e/offline_inference/test_qwen3_tts_customvoice.py
index 3214541af8f..67d72df908c 100644
--- a/tests/e2e/offline_inference/test_qwen3_tts_customvoice.py
+++ b/tests/e2e/offline_inference/test_qwen3_tts_customvoice.py
@@ -13,40 +13,34 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
+from pathlib import Path
+
import pytest
-from tests.helpers.mark import hardware_test
-from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
+from tests.conftest import modify_stage_config
+from tests.utils import hardware_test
MODEL = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
def get_cuda_graph_config():
- """Build a temp deploy yaml mirroring the deleted qwen3_tts_no_async_chunk.yaml.
-
- Composes the synchronous (no-async-chunk) variant on top of the bundled
- qwen3_tts.yaml prod default, with cudagraphs disabled. Replaces the deleted
- standalone variant yaml; same effective config, no checked-in file needed.
- """
- return modify_stage_config(
- get_deploy_config_path("qwen3_tts.yaml"),
+ path = modify_stage_config(
+ get_stage_config(),
updates={
- "async_chunk": False,
- "stages": {
+ "stage_args": {
0: {
- "max_num_seqs": 1,
- "gpu_memory_utilization": 0.2,
- "enforce_eager": True,
- "async_scheduling": False,
- },
- 1: {
- "gpu_memory_utilization": 0.2,
- "enforce_eager": True,
- "async_scheduling": False,
+ "engine_args.enforce_eager": "true",
},
+ 1: {"engine_args.enforce_eager": "true"},
},
},
)
+ return path
+
+
+def get_stage_config(name: str = "qwen3_tts_no_async_chunk.yaml"):
+ """Get the no_async_chunk stage config path (async_chunk disable, cuda_graph disabled)."""
+ return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
# Same structure as test_qwen3_omni: models, stage_configs, test_params
diff --git a/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py b/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py
index 2ce113d5bfd..d5f82f893e6 100644
--- a/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py
+++ b/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py
@@ -28,6 +28,7 @@
import argparse
import asyncio
+import os
import sys
import time
import uuid
@@ -36,8 +37,7 @@
import pytest
import torch
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
+from tests.utils import hardware_test
from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
@@ -48,6 +48,9 @@
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
+from vllm_omni import Omni
+
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
# ------------------------------------------------------------------
models = ["tiny-random/Qwen-Image"]
@@ -388,28 +391,31 @@ async def main(model: str, num_prompts: int, mode: str, batch_size: int = 1) ->
def test_diffusion_batching_sync_sequential(model_name: str):
"""Test that synchronous Omni can generate images for multiple prompts
submitted sequentially (one at a time) and each returns a valid image."""
+ m = None
try:
- with OmniRunner(model_name) as runner:
- m = runner.omni
- sp = _default_sync_sampling_params()
- prompts = TEST_PROMPTS[:4]
+ m = Omni(model=model_name)
+ sp = _default_sync_sampling_params()
+ prompts = TEST_PROMPTS[:4]
- for i, prompt in enumerate(prompts):
- outputs = m.generate(prompt, sp)
- first_output = outputs[0]
- assert first_output.final_output_type == "image", (
- f"Expected 'image', got '{first_output.final_output_type}'"
- )
+ for i, prompt in enumerate(prompts):
+ outputs = m.generate(prompt, sp)
+ first_output = outputs[0]
+ assert first_output.final_output_type == "image", (
+ f"Expected 'image', got '{first_output.final_output_type}'"
+ )
- # Images are surfaced both at top-level and inside request_output
- images = _extract_images(first_output)
- assert len(images) >= 1, f"Expected at least 1 image for prompt {i}, got {len(images)}"
- assert images[0].width == 256
- assert images[0].height == 256
- print(f" prompt {i}: OK ({len(images)} images)")
+ # Images are surfaced both at top-level and inside request_output
+ images = _extract_images(first_output)
+ assert len(images) >= 1, f"Expected at least 1 image for prompt {i}, got {len(images)}"
+ assert images[0].width == 256
+ assert images[0].height == 256
+ print(f" prompt {i}: OK ({len(images)} images)")
except Exception as e:
print(f"Test failed with error: {e}")
raise
+ finally:
+ if m is not None and hasattr(m, "close"):
+ m.close()
@pytest.mark.core_model
@@ -425,31 +431,34 @@ def test_diffusion_batching_sync_multi_prompt(model_name: str):
handling at the diffusion stage, not the explicit list-batch path
(which is only available via AsyncOmni).
"""
+ m = None
try:
- with OmniRunner(model_name) as runner:
- m = runner.omni
- sp = _default_sync_sampling_params()
- prompts = TEST_PROMPTS[:4]
-
- outputs = m.generate(prompts, sp)
- assert len(outputs) == len(prompts), f"Expected {len(prompts)} outputs, got {len(outputs)}"
+ m = Omni(model=model_name)
+ sp = _default_sync_sampling_params()
+ prompts = TEST_PROMPTS[:4]
- for i, output in enumerate(outputs):
- assert output.final_output_type == "image", (
- f"Output {i} final_output_type expected 'image', got '{output.final_output_type}'"
- )
- images = _extract_images(output)
- assert images and len(images) >= 1, f"Expected at least 1 image for prompt {i}"
- assert images[0].width == 256
- assert images[0].height == 256
- print(f" prompt {i}: OK ({len(images)} images, request_id={output.request_id})")
+ outputs = m.generate(prompts, sp)
+ assert len(outputs) == len(prompts), f"Expected {len(prompts)} outputs, got {len(outputs)}"
- # Verify all request_ids are distinct
- request_ids = [o.request_id for o in outputs]
- assert len(set(request_ids)) == len(request_ids), f"Duplicate request_ids found: {request_ids}"
+ for i, output in enumerate(outputs):
+ assert output.final_output_type == "image", (
+ f"Output {i} final_output_type expected 'image', got '{output.final_output_type}'"
+ )
+ images = _extract_images(output)
+ assert images and len(images) >= 1, f"Expected at least 1 image for prompt {i}"
+ assert images[0].width == 256
+ assert images[0].height == 256
+ print(f" prompt {i}: OK ({len(images)} images, request_id={output.request_id})")
+
+ # Verify all request_ids are distinct
+ request_ids = [o.request_id for o in outputs]
+ assert len(set(request_ids)) == len(request_ids), f"Duplicate request_ids found: {request_ids}"
except Exception as e:
print(f"Test failed with error: {e}")
raise
+ finally:
+ if m is not None and hasattr(m, "close"):
+ m.close()
@pytest.mark.core_model
@@ -543,29 +552,32 @@ async def _inner():
def test_diffusion_batching_num_outputs(model_name: str):
"""Test that the diffusion model respects num_outputs_per_prompt and
generates the correct number of images per request."""
+ m = None
try:
- with OmniRunner(model_name) as runner:
- m = runner.omni
- num_outputs = 2
- sp = _default_sync_sampling_params(num_outputs_per_prompt=num_outputs)
-
- outputs = m.generate(
- "a photo of a cat sitting on a laptop keyboard",
- sp,
- )
+ m = Omni(model=model_name)
+ num_outputs = 2
+ sp = _default_sync_sampling_params(num_outputs_per_prompt=num_outputs)
- first_output = outputs[0]
- assert first_output.final_output_type == "image"
- images = _extract_images(first_output)
- assert images is not None and len(images) == num_outputs, (
- f"Expected {num_outputs} images, got {len(images) if images else 0}"
- )
- for img in images:
- assert img.width == 256
- assert img.height == 256
+ outputs = m.generate(
+ "a photo of a cat sitting on a laptop keyboard",
+ sp,
+ )
+
+ first_output = outputs[0]
+ assert first_output.final_output_type == "image"
+ images = _extract_images(first_output)
+ assert images is not None and len(images) == num_outputs, (
+ f"Expected {num_outputs} images, got {len(images) if images else 0}"
+ )
+ for img in images:
+ assert img.width == 256
+ assert img.height == 256
except Exception as e:
print(f"Test failed with error: {e}")
raise
+ finally:
+ if m is not None and hasattr(m, "close"):
+ m.close()
@pytest.mark.core_model
@@ -575,31 +587,34 @@ def test_diffusion_batching_num_outputs(model_name: str):
def test_diffusion_batching_distinct_results(model_name: str):
"""Test that different prompts produce distinct images when batched,
ensuring the batching logic does not mix up results across requests."""
+ m = None
try:
- with OmniRunner(model_name) as runner:
- m = runner.omni
- sp = _default_sync_sampling_params()
- prompts = [
- {"prompt": "a bright red apple on a white table", "negative_prompt": "blurry"},
- {"prompt": "a blue ocean with white waves crashing", "negative_prompt": "blurry"},
- ]
-
- outputs = m.generate(prompts, sp)
- assert len(outputs) == len(prompts), f"Expected {len(prompts)} outputs, got {len(outputs)}"
-
- # Verify each output has a unique request_id
- request_ids = [o.request_id for o in outputs]
- assert len(set(request_ids)) == len(request_ids), f"Duplicate request_ids: {request_ids}"
-
- # Verify each output has images
- for i, output in enumerate(outputs):
- images = _extract_images(output)
- assert images and len(images) >= 1, f"No images for prompt {i}"
- assert images[0].width == 256
- assert images[0].height == 256
+ m = Omni(model=model_name)
+ sp = _default_sync_sampling_params()
+ prompts = [
+ {"prompt": "a bright red apple on a white table", "negative_prompt": "blurry"},
+ {"prompt": "a blue ocean with white waves crashing", "negative_prompt": "blurry"},
+ ]
+
+ outputs = m.generate(prompts, sp)
+ assert len(outputs) == len(prompts), f"Expected {len(prompts)} outputs, got {len(outputs)}"
+
+ # Verify each output has a unique request_id
+ request_ids = [o.request_id for o in outputs]
+ assert len(set(request_ids)) == len(request_ids), f"Duplicate request_ids: {request_ids}"
+
+ # Verify each output has images
+ for i, output in enumerate(outputs):
+ images = _extract_images(output)
+ assert images and len(images) >= 1, f"No images for prompt {i}"
+ assert images[0].width == 256
+ assert images[0].height == 256
except Exception as e:
print(f"Test failed with error: {e}")
raise
+ finally:
+ if m is not None and hasattr(m, "close"):
+ m.close()
# ------------------------------------------------------------------
diff --git a/tests/e2e/offline_inference/test_sequence_parallel.py b/tests/e2e/offline_inference/test_sequence_parallel.py
index 9f76b3b75c5..16239a1c52f 100644
--- a/tests/e2e/offline_inference/test_sequence_parallel.py
+++ b/tests/e2e/offline_inference/test_sequence_parallel.py
@@ -20,8 +20,8 @@
import torch.distributed as dist
from PIL import Image
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
+from tests.utils import hardware_test
+from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
@@ -92,48 +92,49 @@ def _run_inference(
warmup: If True, run one warmup iteration before the timed run.
"""
parallel_config = DiffusionParallelConfig(ulysses_degree=ulysses_degree, ring_degree=ring_degree)
+ omni = Omni(
+ model=model_name,
+ parallel_config=parallel_config,
+ dtype=dtype,
+ attention_backend=attn_backend,
+ )
+
try:
- with OmniRunner(
- model_name,
- parallel_config=parallel_config,
- dtype=dtype,
- attention_backend=attn_backend,
- ) as runner:
- omni = runner.omni
- # Warmup run (not timed)
- if warmup:
- _ = omni.generate(
- PROMPT,
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_inference_steps=DEFAULT_STEPS,
- guidance_scale=0.0,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed + 1000),
- num_outputs_per_prompt=1,
- ),
- )
-
- # Timed run
- start = time.time()
- outputs = omni.generate(
+ # Warmup run (not timed)
+ if warmup:
+ _ = omni.generate(
PROMPT,
OmniDiffusionSamplingParams(
height=height,
width=width,
num_inference_steps=DEFAULT_STEPS,
guidance_scale=0.0,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed),
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed + 1000),
num_outputs_per_prompt=1,
),
)
- elapsed_ms = (time.time() - start) * 1000
- return InferenceResult(
- images=outputs[0].request_output.images,
- elapsed_ms=elapsed_ms,
- )
+ # Timed run
+ start = time.time()
+ outputs = omni.generate(
+ PROMPT,
+ OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_inference_steps=DEFAULT_STEPS,
+ guidance_scale=0.0,
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed),
+ num_outputs_per_prompt=1,
+ ),
+ )
+ elapsed_ms = (time.time() - start) * 1000
+
+ return InferenceResult(
+ images=outputs[0].request_output.images,
+ elapsed_ms=elapsed_ms,
+ )
finally:
+ omni.close()
_cleanup_distributed()
diff --git a/tests/e2e/offline_inference/test_stable_audio_expansion.py b/tests/e2e/offline_inference/test_stable_audio_expansion.py
deleted file mode 100644
index a5d3e6d2281..00000000000
--- a/tests/e2e/offline_inference/test_stable_audio_expansion.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""Stable Audio offline e2e: real weights, FP8 + TeaCache (single job to save GPU).
-
-NOTE: This test instantiates Omni directly instead of using the omni_runner
-fixture (introduced in PR #2711) because the fixture's parametrize interface
-only accepts (model, stage_config_path) and does not support extra kwargs like
-quantization, cache_backend, or cache_config.
-"""
-
-from __future__ import annotations
-
-import numpy as np
-import pytest
-import torch
-
-from tests.helpers.assertions import assert_audio_valid
-from tests.helpers.mark import hardware_test
-from vllm_omni import Omni
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-from vllm_omni.outputs import OmniRequestOutput
-from vllm_omni.platforms import current_omni_platform
-
-pytestmark = [pytest.mark.full_model, pytest.mark.diffusion]
-
-_SAMPLE_RATE = 44100
-_CLIP_DURATION_S = 2.0
-
-
-def generate_stable_audio_short_clip(
- omni: Omni,
- *,
- audio_start_in_s: float = 0.0,
- audio_end_in_s: float = 2.0,
- num_inference_steps: int = 4,
- seed: int = 42,
-) -> np.ndarray:
- """Run a minimal Stable Audio generation and return audio as (batch, channels, samples)."""
- outputs = omni.generate(
- prompts={
- "prompt": "The sound of a dog barking",
- "negative_prompt": "Low quality.",
- },
- sampling_params_list=OmniDiffusionSamplingParams(
- num_inference_steps=num_inference_steps,
- guidance_scale=7.0,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(seed),
- num_outputs_per_prompt=1,
- extra_args={
- "audio_start_in_s": audio_start_in_s,
- "audio_end_in_s": audio_end_in_s,
- },
- ),
- )
-
- assert outputs is not None
- first_output = outputs[0]
- # Outer OmniRequestOutput.final_output_type comes from get_stage_metadata.
- # The nested request_output is the worker OmniRequestOutput
- # (e.g. final_output_type="audio") and holds the multimodal payload.
- # Follow-up: add StableAudioPipeline stage YAML, and pass model into
- # _create_default_diffusion_stage_cfg so default diffusion metadata can set
- # final_output_type to "audio" for future audio pipelines without YAML.
- assert first_output.final_output_type == "image"
- assert hasattr(first_output, "request_output") and first_output.request_output
-
- req_out = first_output.request_output
- assert isinstance(req_out, OmniRequestOutput)
- assert req_out.final_output_type == "audio"
- assert hasattr(req_out, "multimodal_output") and req_out.multimodal_output
- audio = req_out.multimodal_output.get("audio")
- assert isinstance(audio, np.ndarray)
- return audio
-
-
-@pytest.mark.cache
-@hardware_test(res={"cuda": "L4", "xpu": "B60"})
-def test_stable_audio_quantization_and_teacache() -> None:
- """Stable Audio Open on real Hub weights with FP8 + TeaCache (covers former L2 smoke + L4 features).
-
- CI should provide ``HF_TOKEN`` if the checkpoint is gated.
- """
- m = Omni(
- model="stabilityai/stable-audio-open-1.0",
- quantization="fp8",
- cache_backend="tea_cache",
- cache_config={"rel_l1_thresh": 0.2},
- )
- try:
- audio = generate_stable_audio_short_clip(m)
- assert_audio_valid(
- audio,
- sample_rate=_SAMPLE_RATE,
- channels=2,
- duration_s=_CLIP_DURATION_S,
- )
- finally:
- m.close()
diff --git a/tests/e2e/offline_inference/test_stable_audio_model.py b/tests/e2e/offline_inference/test_stable_audio_model.py
new file mode 100644
index 00000000000..ff4d9b40172
--- /dev/null
+++ b/tests/e2e/offline_inference/test_stable_audio_model.py
@@ -0,0 +1,72 @@
+import sys
+from pathlib import Path
+
+import numpy as np
+import pytest
+import torch
+
+from tests.utils import hardware_test
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+from vllm_omni.outputs import OmniRequestOutput
+from vllm_omni.platforms import current_omni_platform
+
+# ruff: noqa: E402
+REPO_ROOT = Path(__file__).resolve().parents[2]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.insert(0, str(REPO_ROOT))
+
+from vllm_omni import Omni
+
+# Use random weights model for CI testing (small, no authentication required)
+models = ["linyueqian/stable_audio_random"]
+
+
+@pytest.mark.core_model
+@pytest.mark.diffusion
+@hardware_test(res={"cuda": "L4", "xpu": "B60"})
+@pytest.mark.parametrize("model_name", models)
+def test_stable_audio_model(model_name: str):
+ m = Omni(model=model_name)
+
+ # Use minimal settings for testing
+ # Generate a short 2-second audio clip with minimal inference steps
+ audio_start_in_s = 0.0
+ audio_end_in_s = 2.0 # Short duration for fast testing
+ sample_rate = 44100 # Stable Audio uses 44100 Hz
+
+ outputs = m.generate(
+ prompts={
+ "prompt": "The sound of a dog barking",
+ "negative_prompt": "Low quality.",
+ },
+ sampling_params_list=OmniDiffusionSamplingParams(
+ num_inference_steps=4, # Minimal steps for speed
+ guidance_scale=7.0,
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
+ num_outputs_per_prompt=1,
+ extra_args={
+ "audio_start_in_s": audio_start_in_s,
+ "audio_end_in_s": audio_end_in_s,
+ },
+ ),
+ )
+
+ # Extract audio from OmniRequestOutput
+ assert outputs is not None
+ first_output = outputs[0]
+ assert first_output.final_output_type == "image"
+ assert hasattr(first_output, "request_output") and first_output.request_output
+
+ req_out = first_output.request_output
+ assert isinstance(req_out, OmniRequestOutput)
+ assert req_out.final_output_type == "audio"
+ assert hasattr(req_out, "multimodal_output") and req_out.multimodal_output
+ audio = req_out.multimodal_output.get("audio")
+ assert isinstance(audio, np.ndarray)
+ # audio shape: (batch, channels, samples)
+ # For stable-audio-open-1.0: sample_rate=44100, so 2 seconds = 88200 samples
+ assert audio.ndim == 3
+ assert audio.shape[0] == 1 # batch size
+ assert audio.shape[1] == 2 # stereo channels
+ expected_samples = int((audio_end_in_s - audio_start_in_s) * sample_rate)
+ assert audio.shape[2] == expected_samples # 88200 samples for 2 seconds
diff --git a/tests/e2e/offline_inference/test_t2i_model.py b/tests/e2e/offline_inference/test_t2i_model.py
index 702f902cdb9..77b2b3aaf20 100644
--- a/tests/e2e/offline_inference/test_t2i_model.py
+++ b/tests/e2e/offline_inference/test_t2i_model.py
@@ -1,17 +1,23 @@
+import os
+import sys
+from pathlib import Path
+
import pytest
import torch
-from tests.helpers.mark import hardware_test
+from tests.utils import hardware_test
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
-# Match unprefixed HF id even when MODEL_PREFIX is set (omni_runner resolves full path).
-_QWEN_IMAGE_RANDOM_ID = "riverclouds/qwen_image_random"
+# ruff: noqa: E402
+REPO_ROOT = Path(__file__).resolve().parents[2]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.insert(0, str(REPO_ROOT))
+from vllm_omni import Omni
-def _is_qwen_image_random(model_path: str) -> bool:
- return model_path.rstrip("/").endswith(_QWEN_IMAGE_RANDOM_ID)
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
models = ["Tongyi-MAI/Z-Image-Turbo", "riverclouds/qwen_image_random"]
@@ -20,56 +26,62 @@ def _is_qwen_image_random(model_path: str) -> bool:
# TODO: When NPU support is ready, remove this branch.
if current_omni_platform.is_npu():
models = ["Tongyi-MAI/Z-Image-Turbo", "Qwen/Qwen-Image"]
-
-# omni_runner expects (model, stage_configs_path); single-stage diffusion has no YAML.
-test_params = [(m, None) for m in models]
+elif current_omni_platform.is_rocm():
+ # TODO: When ROCm support is ready, remove this branch.
+ # Current upstream vLLM has issues running riverclouds/qwen_image_random
+ # on ROCm
+ models = ["Tongyi-MAI/Z-Image-Turbo"]
@pytest.mark.core_model
@pytest.mark.advanced_model
@pytest.mark.diffusion
-@hardware_test(res={"cuda": "L4", "rocm": "MI325", "xpu": "B60"}, num_cards={"cuda": 1, "rocm": 1, "xpu": 2})
-@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
-def test_diffusion_model(omni_runner, run_level):
- resolved = omni_runner.model_name
- if run_level == "core_model" and not _is_qwen_image_random(resolved):
+@hardware_test(res={"cuda": "L4", "rocm": "MI325", "xpu": "B60"}, num_cards={"cuda": 1, "rocm": 2, "xpu": 2})
+@pytest.mark.parametrize("model_name", models)
+def test_diffusion_model(model_name: str, run_level):
+ if run_level == "core_model" and model_name != "riverclouds/qwen_image_random":
pytest.skip()
- if run_level == "advanced_model" and _is_qwen_image_random(resolved):
+ if run_level == "advanced_model" and model_name == "riverclouds/qwen_image_random":
pytest.skip()
- # high resolution may cause OOM on L4
- height = 256
- width = 256
- sampling = OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_inference_steps=2,
- guidance_scale=0.0,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
- num_outputs_per_prompt=2,
- )
-
- # OmniRunner.generate() is typed for list[TextPrompt]; diffusion uses Omni.generate(str, ...).
- outputs = omni_runner.omni.generate(
- "a photo of a cat sitting on a laptop keyboard",
- sampling,
- )
-
- # Extract images from request_output['images']
- first_output = outputs[0]
- assert first_output.final_output_type == "image"
- if not hasattr(first_output, "request_output") or not first_output.request_output:
- raise ValueError("No request_output found in OmniRequestOutput")
-
- req_out = first_output.request_output
- if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"):
- raise ValueError("Invalid request_output structure or missing 'images' key")
-
- images = req_out.images
-
- assert len(images) == 2
- # check image size
- assert images[0].width == width
- assert images[0].height == height
- images[0].save("image_output.png")
+ m = None
+ try:
+ m = Omni(model=model_name)
+ # high resolution may cause OOM on L4
+ height = 256
+ width = 256
+ outputs = m.generate(
+ "a photo of a cat sitting on a laptop keyboard",
+ OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_inference_steps=2,
+ guidance_scale=0.0,
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
+ num_outputs_per_prompt=2,
+ ),
+ )
+ # Extract images from request_output['images']
+ first_output = outputs[0]
+ assert first_output.final_output_type == "image"
+ if not hasattr(first_output, "request_output") or not first_output.request_output:
+ raise ValueError("No request_output found in OmniRequestOutput")
+
+ req_out = first_output.request_output
+ if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"):
+ raise ValueError("Invalid request_output structure or missing 'images' key")
+
+ images = req_out.images
+
+ assert len(images) == 2
+ # check image size
+ assert images[0].width == width
+ assert images[0].height == height
+ images[0].save("image_output.png")
+ except Exception as e:
+ print(f"Test failed with error: {e}")
+ raise
+ finally:
+ if m is not None and hasattr(m, "close"):
+ m.close()
diff --git a/tests/e2e/offline_inference/test_t2v_model.py b/tests/e2e/offline_inference/test_t2v_model.py
index cedc9e59b37..94c9dedf741 100644
--- a/tests/e2e/offline_inference/test_t2v_model.py
+++ b/tests/e2e/offline_inference/test_t2v_model.py
@@ -1,13 +1,22 @@
import os
+import sys
+from pathlib import Path
import pytest
import torch
-from tests.helpers.runtime import OmniRunner
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+
+# ruff: noqa: E402
+REPO_ROOT = Path(__file__).resolve().parents[2]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.insert(0, str(REPO_ROOT))
+
+from vllm_omni import Omni
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
+# os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
models = ["Wan-AI/Wan2.2-T2V-A14B-Diffusers"]
@@ -15,28 +24,28 @@
@pytest.mark.parametrize("model_name", models)
def test_video_diffusion_model(model_name: str):
- with OmniRunner(
- model_name,
+ m = Omni(
+ model=model_name,
boundary_ratio=0.875,
flow_shift=5.0,
- ) as runner:
- # Use minimal settings for testing
- # num_frames must satisfy: num_frames % vae_scale_factor_temporal == 1
- # For Wan2.2, vae_scale_factor_temporal=4, so valid values are 5, 9, 13, 17, ...
- height = 480
- width = 640
- num_frames = 5
- outputs = runner.omni.generate(
- prompts="A cat sitting on a table",
- sampling_params_list=OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_frames=num_frames,
- num_inference_steps=2,
- guidance_scale=1.0,
- generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
- ),
- )
+ )
+ # Use minimal settings for testing
+ # num_frames must satisfy: num_frames % vae_scale_factor_temporal == 1
+ # For Wan2.2, vae_scale_factor_temporal=4, so valid values are 5, 9, 13, 17, ...
+ height = 480
+ width = 640
+ num_frames = 5
+ outputs = m.generate(
+ prompts="A cat sitting on a table",
+ sampling_params_list=OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ num_inference_steps=2,
+ guidance_scale=1.0,
+ generator=torch.Generator(current_omni_platform.device_type).manual_seed(42),
+ ),
+ )
first_output = outputs[0]
assert first_output.final_output_type == "image"
if not hasattr(first_output, "request_output") or not first_output.request_output:
diff --git a/tests/e2e/offline_inference/test_teacache.py b/tests/e2e/offline_inference/test_teacache.py
index 8152792fc01..efc0e43e86f 100644
--- a/tests/e2e/offline_inference/test_teacache.py
+++ b/tests/e2e/offline_inference/test_teacache.py
@@ -8,15 +8,27 @@
It uses minimal settings to keep test time short for CI.
"""
+import os
+import sys
+from pathlib import Path
+
import pytest
import torch
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
+from tests.utils import hardware_test
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
+# ruff: noqa: E402
+REPO_ROOT = Path(__file__).resolve().parents[2]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.insert(0, str(REPO_ROOT))
+
+from vllm_omni import Omni
+from vllm_omni.outputs import OmniRequestOutput
+
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
+
# Use random weights model for testing
models = ["riverclouds/qwen_image_random"]
@@ -32,17 +44,20 @@ def test_teacache(model_name: str):
cache_config = {
"rel_l1_thresh": 0.2, # Default threshold
}
- with OmniRunner(
- model_name,
- cache_backend="tea_cache",
- cache_config=cache_config,
- ) as runner:
+ m = None
+ try:
+ m = Omni(
+ model=model_name,
+ cache_backend="tea_cache",
+ cache_config=cache_config,
+ )
+
# Use minimal settings for fast testing
height = 256
width = 256
num_inference_steps = 4 # Minimal steps for fast test
- outputs = runner.omni.generate(
+ outputs = m.generate(
"a photo of a cat sitting on a laptop keyboard",
OmniDiffusionSamplingParams(
height=height,
@@ -71,3 +86,9 @@ def test_teacache(model_name: str):
# Check image size
assert images[0].width == width
assert images[0].height == height
+ except Exception as e:
+ print(f"Test failed with error: {e}")
+ raise
+ finally:
+ if m is not None and hasattr(m, "close"):
+ m.close()
diff --git a/tests/e2e/offline_inference/test_vae_decode_parallelism.py b/tests/e2e/offline_inference/test_vae_decode_parallelism.py
index 32902c318fa..cee76fac2e9 100644
--- a/tests/e2e/offline_inference/test_vae_decode_parallelism.py
+++ b/tests/e2e/offline_inference/test_vae_decode_parallelism.py
@@ -18,7 +18,7 @@
import time
-from tests.helpers.runtime import OmniRunner
+from vllm_omni import Omni
from vllm_omni.platforms import current_omni_platform
# os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
@@ -72,22 +72,23 @@ def is_nextstep_model(model_name: str) -> bool:
def model_run(model_configs, tp, out_height, out_width, out_frames, using_tile, vae_patch_parallel_size=1):
- parallel_config = DiffusionParallelConfig(
- tensor_parallel_size=tp,
- vae_patch_parallel_size=vae_patch_parallel_size,
- )
+ m = None
+ try:
+ parallel_config = DiffusionParallelConfig(
+ tensor_parallel_size=tp,
+ vae_patch_parallel_size=vae_patch_parallel_size,
+ )
- omni_kwargs = {
- "vae_use_tiling": using_tile,
- "parallel_config": parallel_config,
- }
- use_nextstep = is_nextstep_model(model_configs["model_name"])
- if use_nextstep:
- # NextStep-1.1 requires explicit pipeline class
- omni_kwargs["model_class_name"] = "NextStep11Pipeline"
-
- with OmniRunner(model_configs["model_name"], **omni_kwargs) as runner:
- m = runner.omni
+ omni_kwargs = {
+ "model": model_configs["model_name"],
+ "vae_use_tiling": using_tile,
+ "parallel_config": parallel_config,
+ }
+ use_nextstep = is_nextstep_model(model_configs["model_name"])
+ if use_nextstep:
+ # NextStep-1.1 requires explicit pipeline class
+ omni_kwargs["model_class_name"] = "NextStep11Pipeline"
+ m = Omni(**omni_kwargs)
image = Image.new("RGB", (out_width, out_height), (0, 0, 0))
start = time.perf_counter()
outputs = m.generate(
@@ -114,6 +115,9 @@ def model_run(model_configs, tp, out_height, out_width, out_frames, using_tile,
# frames shape: (batch, num_frames, height, width, channels)
cost = (end - start) * 1000
return frames, cost
+ finally:
+ if m is not None:
+ m.close()
cleanup_dist_env_and_memory()
diff --git a/tests/e2e/offline_inference/test_voxcpm.py b/tests/e2e/offline_inference/test_voxcpm.py
deleted file mode 100644
index bda087612de..00000000000
--- a/tests/e2e/offline_inference/test_voxcpm.py
+++ /dev/null
@@ -1,156 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""E2E test for VoxCPM offline inference."""
-
-import json
-import os
-from pathlib import Path
-from typing import Any
-
-import numpy as np
-import pytest
-import torch
-
-import tests.helpers.runtime as omni_runtime
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
-from vllm_omni.model_executor.models.voxcpm.voxcpm_runtime_utils import (
- prepare_voxcpm_hf_config_dir,
- resolve_voxcpm_model_dir,
-)
-
-VOXCPM_MODEL = os.environ.get("VOXCPM_MODEL", "OpenBMB/VoxCPM1.5")
-STAGE_CONFIG = str(
- Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / "voxcpm.yaml"
-)
-SAMPLE_RATE = 24000
-
-
-@pytest.fixture(autouse=True)
-def _patch_npu_cleanup_for_voxcpm(monkeypatch: pytest.MonkeyPatch):
- """Limit the NPU cleanup workaround to this VoxCPM test module only."""
- original_cleanup = omni_runtime.cleanup_dist_env_and_memory
-
- def _safe_cleanup() -> None:
- try:
- original_cleanup()
- except RuntimeError as exc:
- if "Allocator for npu is not a DeviceAllocator" in str(exc):
- return
- raise
-
- monkeypatch.setattr(omni_runtime, "cleanup_dist_env_and_memory", _safe_cleanup)
-
-
-def _build_prompt(text: str) -> dict[str, Any]:
- return {
- "prompt_token_ids": [1],
- "additional_information": {
- "text": [text],
- "cfg_value": [2.0],
- "inference_timesteps": [10],
- "min_len": [2],
- "max_new_tokens": [1024],
- },
- }
-
-
-def _extract_audio_tensor(multimodal_output: dict[str, Any]) -> torch.Tensor:
- audio = multimodal_output.get("audio", multimodal_output.get("model_outputs"))
- assert audio is not None, f"No audio output found, keys={list(multimodal_output.keys())}"
-
- if isinstance(audio, list):
- parts: list[torch.Tensor] = []
- for item in audio:
- if item is None:
- continue
- tensor = torch.as_tensor(item)
- if tensor.numel() == 0:
- continue
- parts.append(tensor.float().cpu().reshape(-1))
- return torch.cat(parts, dim=-1) if parts else torch.zeros((0,), dtype=torch.float32)
-
- return torch.as_tensor(audio).float().cpu().reshape(-1)
-
-
-def _extract_final_multimodal_output(outputs) -> dict[str, Any]:
- for item in reversed(outputs):
- request_output = getattr(item, "request_output", None)
- if request_output is not None:
- multimodal_output = getattr(request_output, "multimodal_output", None)
- if isinstance(multimodal_output, dict):
- return multimodal_output
- completions = getattr(request_output, "outputs", None) or []
- for completion in completions:
- multimodal_output = getattr(completion, "multimodal_output", None)
- if isinstance(multimodal_output, dict):
- return multimodal_output
-
- multimodal_output = getattr(item, "multimodal_output", None)
- if isinstance(multimodal_output, dict):
- return multimodal_output
-
- raise AssertionError("No multimodal audio output found in VoxCPM generate results")
-
-
-@pytest.fixture
-def voxcpm_model_path(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> str:
- model_dir = resolve_voxcpm_model_dir(VOXCPM_MODEL)
-
- hf_config_env = os.environ.get("VLLM_OMNI_VOXCPM_HF_CONFIG_PATH")
- if hf_config_env:
- hf_config_dir = Path(hf_config_env).expanduser()
- else:
- hf_config_dir = tmp_path / "voxcpm_hf_config"
-
- if not (hf_config_dir / "config.json").exists():
- prepare_voxcpm_hf_config_dir(model_dir, hf_config_dir)
-
- monkeypatch.setenv("VLLM_OMNI_VOXCPM_HF_CONFIG_PATH", str(hf_config_dir))
- return str(model_dir)
-
-
-def test_prepare_voxcpm_hf_config_dir(tmp_path: Path):
- model_dir = tmp_path / "model"
- model_dir.mkdir()
- (model_dir / "config.json").write_text(json.dumps({"hidden_size": 1024}), encoding="utf-8")
- (model_dir / "generation_config.json").write_text(json.dumps({"do_sample": False}), encoding="utf-8")
-
- hf_config_dir = prepare_voxcpm_hf_config_dir(model_dir, tmp_path / "voxcpm_hf_config")
-
- prepared_config = json.loads((hf_config_dir / "config.json").read_text(encoding="utf-8"))
- assert prepared_config["model_type"] == "voxcpm"
- assert prepared_config["architectures"] == ["VoxCPMForConditionalGeneration"]
- assert (hf_config_dir / "generation_config.json").exists()
-
-
-def test_resolve_voxcpm_model_dir_local_path(tmp_path: Path):
- model_dir = tmp_path / "OpenBMB" / "VoxCPM1.5"
- model_dir.mkdir(parents=True)
-
- assert resolve_voxcpm_model_dir(str(model_dir)) == model_dir
-
-
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "L4"}, num_cards=1)
-def test_voxcpm_zero_shot_001(voxcpm_model_path: str):
- with OmniRunner(voxcpm_model_path, stage_configs_path=STAGE_CONFIG) as runner:
- outputs = list(runner.omni.generate(_build_prompt("Hello, this is a VoxCPM offline inference test.")))
-
- assert outputs, "No outputs returned"
-
- multimodal_output = _extract_final_multimodal_output(outputs)
- audio = _extract_audio_tensor(multimodal_output)
- assert audio.numel() > SAMPLE_RATE // 2, f"Audio too short: {audio.numel()} samples"
-
- duration_s = audio.shape[0] / SAMPLE_RATE
- assert 0.5 < duration_s < 30.0, f"Audio duration out of range: {duration_s:.2f}s"
-
- peak = float(torch.max(torch.abs(audio)).item()) if audio.numel() > 0 else 0.0
- assert peak > 0.01, "Generated audio appears to be silence"
-
- audio_np = audio.numpy()
- rms = float(np.sqrt(np.mean(np.square(audio_np)))) if audio_np.size else 0.0
- assert rms > 1e-4, "Generated audio RMS too low"
diff --git a/tests/e2e/offline_inference/test_voxcpm2.py b/tests/e2e/offline_inference/test_voxcpm2.py
deleted file mode 100644
index 913a6969d6b..00000000000
--- a/tests/e2e/offline_inference/test_voxcpm2.py
+++ /dev/null
@@ -1,122 +0,0 @@
-"""E2E test for VoxCPM2 native AR offline inference."""
-
-import os
-
-import pytest
-import torch
-
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
-from tests.helpers.stage_config import get_deploy_config_path
-
-VOXCPM2_MODEL = "openbmb/VoxCPM2"
-DEPLOY_CONFIG = get_deploy_config_path("voxcpm2.yaml")
-SAMPLE_RATE = 48000
-
-
-@pytest.fixture(scope="module")
-def voxcpm2_engine():
- """Create VoxCPM2 engine for testing."""
- with OmniRunner(VOXCPM2_MODEL, stage_configs_path=DEPLOY_CONFIG) as runner:
- yield runner.omni
-
-
-def _extract_audio(multimodal_output: dict) -> torch.Tensor:
- """Extract the final complete audio tensor from multimodal output."""
- assert isinstance(multimodal_output, dict), f"Expected dict, got {type(multimodal_output)}"
-
- # Output processor accumulates per-step audio chunks under "audio".
- audio = multimodal_output.get("audio")
- if audio is None:
- audio = multimodal_output.get("model_outputs")
- assert audio is not None, f"No audio key, got {list(multimodal_output.keys())}"
-
- if isinstance(audio, list):
- valid = [torch.as_tensor(x).float().cpu().reshape(-1) for x in audio if x is not None]
- assert valid, "No valid audio tensors in output list"
- audio = torch.cat(valid, dim=0) if len(valid) > 1 else valid[0]
-
- assert isinstance(audio, torch.Tensor), f"Expected Tensor, got {type(audio)}"
- return audio
-
-
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "L4"}, num_cards=1)
-def test_voxcpm2_zero_shot_001(voxcpm2_engine):
- """Test zero-shot TTS produces valid audio output."""
- outputs = voxcpm2_engine.generate([{"prompt": "Hello, this is a test."}])
- assert len(outputs) == 1
-
- audio = _extract_audio(outputs[0].outputs[0].multimodal_output)
- duration_s = audio.shape[0] / SAMPLE_RATE
- assert 0.5 < duration_s < 30.0, f"Audio duration out of range: {duration_s:.2f}s"
-
-
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "L4"}, num_cards=1)
-def test_voxcpm2_voice_clone_002(voxcpm2_engine):
- """Test voice cloning with a reference audio file.
-
- Uses the example ``reference_speaker.wav`` bundled with the voxcpm
- package. Skipped if the file is not present.
- """
- # Try to locate a reference wav from the voxcpm package / env override
- candidates = []
- env_path = os.environ.get("VLLM_OMNI_VOXCPM_CODE_PATH")
- if env_path:
- candidates.append(os.path.join(env_path, "..", "examples", "reference_speaker.wav"))
- try:
- import voxcpm # noqa: F401 (only used to locate path)
-
- vox_dir = os.path.dirname(os.path.dirname(os.path.abspath(voxcpm.__file__)))
- candidates.append(os.path.join(vox_dir, "examples", "reference_speaker.wav"))
- except ImportError:
- pass
-
- ref_path = next((p for p in candidates if p and os.path.exists(p)), None)
- if ref_path is None:
- pytest.skip("No reference audio available for voice clone test")
-
- outputs = voxcpm2_engine.generate(
- [
- {
- "prompt": "Hello, this is a voice clone demo.",
- "additional_information": {"reference_audio": ref_path},
- }
- ]
- )
- assert len(outputs) == 1
-
- audio = _extract_audio(outputs[0].outputs[0].multimodal_output)
- duration_s = audio.shape[0] / SAMPLE_RATE
- assert 0.5 < duration_s < 30.0, f"Audio duration out of range: {duration_s:.2f}s"
-
-
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "L4"}, num_cards=1)
-def test_voxcpm2_prefill_decode_mixed_batch_003(voxcpm2_engine):
- """Regression: prefill+decode mixed batch must not crash (PR #2903)."""
- long_prompt = (
- "This is a deliberately long prompt that will stay in the decode "
- "phase for many steps so that subsequent shorter prompts keep "
- "entering prefill alongside it, reproducing the prefill plus "
- "decode mixed batch scheduling pattern."
- )
- short_prompts = [
- "Hello one.",
- "Hello two.",
- "Hello three.",
- "Hello four.",
- ]
- requests = [{"prompt": long_prompt}] + [{"prompt": p} for p in short_prompts]
-
- outputs = voxcpm2_engine.generate(requests)
- assert len(outputs) == len(requests)
-
- for i, out in enumerate(outputs):
- audio = _extract_audio(out.outputs[0].multimodal_output)
- duration_s = audio.shape[0] / SAMPLE_RATE
- assert 0.1 < duration_s < 30.0, f"Request {i} audio duration out of range: {duration_s:.2f}s"
diff --git a/tests/e2e/offline_inference/test_voxtral_tts.py b/tests/e2e/offline_inference/test_voxtral_tts.py
index 5386e2db4e3..b559cc252dc 100644
--- a/tests/e2e/offline_inference/test_voxtral_tts.py
+++ b/tests/e2e/offline_inference/test_voxtral_tts.py
@@ -19,6 +19,7 @@
import uuid
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
from pathlib import Path
@@ -29,13 +30,15 @@
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from vllm import SamplingParams
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
-from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
+from tests.conftest import modify_stage_config
+from tests.utils import hardware_test
from vllm_omni.entrypoints.async_omni import AsyncOmni
+from vllm_omni.entrypoints.omni import Omni
MODEL = "mistralai/Voxtral-4B-TTS-2603"
-STAGE_CONFIG = get_deploy_config_path("voxtral_tts.yaml")
+STAGE_CONFIG = str(
+ Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / "voxtral_tts.yaml"
+)
SAMPLE_RATE = 24000
# Minimum expected audio samples for a short sentence (~0.04s of 24kHz audio)
MIN_AUDIO_SAMPLES = 1000
@@ -64,9 +67,9 @@ def _resolve_stage_config(run_level: str) -> str:
return modify_stage_config(
STAGE_CONFIG,
deletes={
- "stages": {
- 0: ["load_format"],
- 1: ["load_format"],
+ "stage_args": {
+ 0: ["engine_args.load_format"],
+ 1: ["engine_args.load_format"],
}
},
)
@@ -80,12 +83,14 @@ def test_voxtral_tts_offline_basic(run_level):
"""Test basic Voxtral TTS offline inference with a voice preset."""
stage_config = _resolve_stage_config(run_level)
- with OmniRunner(
- MODEL,
+ omni = Omni(
+ model=MODEL,
stage_configs_path=stage_config,
+ stage_init_timeout=300,
enforce_eager=True,
- ) as runner:
- omni = runner.omni
+ )
+
+ try:
inputs = _compose_request(MODEL, TEST_TEXT, VOICE)
sampling_params = SamplingParams(max_tokens=2500)
@@ -122,6 +127,9 @@ def test_voxtral_tts_offline_basic(run_level):
# Verify audio isn't all zeros / silence
assert np.max(np.abs(audio_array)) > 0.01, "Audio appears to be silence"
+ finally:
+ omni.close()
+
@pytest.mark.advanced_model
@pytest.mark.omni
diff --git a/tests/e2e/offline_inference/test_zimage_parallelism.py b/tests/e2e/offline_inference/test_zimage_parallelism.py
index ab330ee9a26..9d9db16a408 100644
--- a/tests/e2e/offline_inference/test_zimage_parallelism.py
+++ b/tests/e2e/offline_inference/test_zimage_parallelism.py
@@ -12,6 +12,7 @@
"""
import os
+import sys
import time
from pathlib import Path
@@ -19,15 +20,21 @@
import pytest
import torch
from PIL import Image
+from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
-from tests.helpers.env import DeviceMemoryMonitor
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniRunner
+from tests.utils import DeviceMemoryMonitor, hardware_test
+from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
+# ruff: noqa: E402
+REPO_ROOT = Path(__file__).resolve().parents[2]
+if str(REPO_ROOT) not in sys.path:
+ sys.path.insert(0, str(REPO_ROOT))
+
+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
PROMPT = "a photo of a cat sitting on a laptop keyboard"
@@ -90,61 +97,61 @@ def _run_zimage_generate(
device_index = current_omni_platform.current_device()
monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02)
monitor.start()
+ m = Omni(
+ model=_get_zimage_model(),
+ parallel_config=DiffusionParallelConfig(
+ tensor_parallel_size=tp_size,
+ vae_patch_parallel_size=vae_patch_parallel_size,
+ ),
+ enforce_eager=enforce_eager,
+ vae_use_tiling=vae_use_tiling,
+ )
try:
- # Each run needs a distinct DiffusionParallelConfig; use OmniRunner per call (not the
- # parametrized omni_runner fixture, which is fixed per module).
- with OmniRunner(
- _get_zimage_model(),
- parallel_config=DiffusionParallelConfig(
- tensor_parallel_size=tp_size,
- vae_patch_parallel_size=vae_patch_parallel_size,
+ # NOTE: Omni closes itself when a generate() call is exhausted.
+ # To avoid measuring teardown time (process shutdown, memory cleanup),
+ # we measure the latency to produce *subsequent* outputs within a single
+ # generator run.
+ #
+ # This also serves as a warmup: the first output may include extra
+ # compilation/caching overhead, while later outputs are closer to
+ # steady-state inference.
+ gen = m.generate(
+ [PROMPT] * num_requests,
+ OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=0.0,
+ seed=seed,
+ num_outputs_per_prompt=1,
),
- enforce_eager=enforce_eager,
- vae_use_tiling=vae_use_tiling,
- ) as runner:
- # NOTE: Omni closes itself when a generate() call is exhausted.
- # To avoid measuring teardown time (process shutdown, memory cleanup),
- # we measure the latency to produce *subsequent* outputs within a single
- # generator run.
- #
- # This also serves as a warmup: the first output may include extra
- # compilation/caching overhead, while later outputs are closer to
- # steady-state inference.
- gen = runner.omni.generate(
- [PROMPT] * num_requests,
- OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_inference_steps=num_inference_steps,
- guidance_scale=0.0,
- seed=seed,
- num_outputs_per_prompt=1,
- ),
- py_generator=True,
- )
-
- warmup_output = next(gen)
-
- t_prev = time.perf_counter()
- per_request_times_s: list[float] = []
- last_output = warmup_output
- for _ in range(num_requests - 1):
- last_output = next(gen)
- t_now = time.perf_counter()
- per_request_times_s.append(t_now - t_prev)
- t_prev = t_now
-
- # Ensure the generator is fully consumed so it can clean up.
- for _ in gen:
- pass
-
- median_time_s = float(np.median(per_request_times_s))
-
- peak_memory_mb = monitor.peak_used_mb
-
- return _extract_single_image([last_output]), median_time_s, peak_memory_mb
+ py_generator=True,
+ )
+
+ warmup_output = next(gen)
+
+ t_prev = time.perf_counter()
+ per_request_times_s: list[float] = []
+ last_output = warmup_output
+ for _ in range(num_requests - 1):
+ last_output = next(gen)
+ t_now = time.perf_counter()
+ per_request_times_s.append(t_now - t_prev)
+ t_prev = t_now
+
+ # Ensure the generator is fully consumed so it can clean up.
+ for _ in gen:
+ pass
+
+ median_time_s = float(np.median(per_request_times_s))
+
+ peak_memory_mb = monitor.peak_used_mb
+
+ return _extract_single_image([last_output]), median_time_s, peak_memory_mb
finally:
monitor.stop()
+ m.close()
+ cleanup_dist_env_and_memory()
@pytest.mark.advanced_model
@@ -152,8 +159,8 @@ def _run_zimage_generate(
@pytest.mark.parallel
@hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards={"cuda": 4, "rocm": 2})
def test_zimage_tensor_parallel_tp2(tmp_path: Path):
- if current_omni_platform.is_npu():
- pytest.skip("Z-Image TP e2e test is only supported on CUDA and ROCm for now.")
+ if current_omni_platform.is_npu() or current_omni_platform.is_rocm():
+ pytest.skip("Z-Image TP e2e test is only supported on CUDA for now.")
if not current_omni_platform.is_available() or current_omni_platform.device_count() < 2:
pytest.skip("Z-Image TP=2 requires >= 2 devices.")
@@ -204,9 +211,7 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path):
)
print(f"Z-Image TP perf (lower is better): tp1_time_s={tp1_time_s:.6f}, tp2_time_s={tp2_time_s:.6f}")
- # ROCm is not optimized TP2 can be slower than TP1
- if not current_omni_platform.is_rocm():
- assert tp2_time_s < tp1_time_s, f"Expected TP=2 to be faster than TP=1 (tp1={tp1_time_s}, tp2={tp2_time_s})"
+ assert tp2_time_s < tp1_time_s, f"Expected TP=2 to be faster than TP=1 (tp1={tp1_time_s}, tp2={tp2_time_s})"
print(f"Z-Image TP peak memory (MB): tp1_peak_mem={tp1_peak_mem:.2f}, tp2_peak_mem={tp2_peak_mem:.2f}")
assert tp2_peak_mem < tp1_peak_mem, (
@@ -214,13 +219,10 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path):
)
-@pytest.mark.advanced_model
-@pytest.mark.diffusion
-@pytest.mark.parallel
-@hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards={"cuda": 4, "rocm": 2})
+@pytest.mark.integration
def test_zimage_vae_patch_parallel_tp2(tmp_path: Path):
- if current_omni_platform.is_npu():
- pytest.skip("Z-Image VAE patch parallel e2e test is only supported on CUDA and ROCm for now.")
+ if current_omni_platform.is_npu() or current_omni_platform.is_rocm():
+ pytest.skip("Z-Image VAE patch parallel e2e test is only supported on CUDA for now.")
if not current_omni_platform.is_available() or current_omni_platform.device_count() < 2:
pytest.skip("Z-Image VAE patch parallel TP=2 requires >= 2 devices.")
diff --git a/tests/helpers/process.py b/tests/e2e/offline_inference/utils.py
similarity index 58%
rename from tests/helpers/process.py
rename to tests/e2e/offline_inference/utils.py
index 094de965239..3113599a305 100644
--- a/tests/helpers/process.py
+++ b/tests/e2e/offline_inference/utils.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import contextlib
import functools
import os
import signal
@@ -9,48 +10,73 @@
import tempfile
from collections.abc import Callable
from contextlib import ExitStack, suppress
+from pathlib import Path
from typing import Any, Literal
import cloudpickle
from typing_extensions import ParamSpec
from vllm.platforms import current_platform
+VLLM_PATH = Path(__file__).parent.parent.parent
+"""Path to root of the vLLM repository."""
+
+
_P = ParamSpec("_P")
def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]:
- """Decorator to fork a new process for each test function."""
+ """Decorator to fork a new process for each test function.
+ See https://github.com/vllm-project/vllm/issues/7053 for more details.
+ """
@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
+ # Make the process the leader of its own process group
+ # to avoid sending SIGTERM to the parent process
os.setpgrp()
from _pytest.outcomes import Skipped
+ # Create a unique temporary file to store exception info from child
+ # process. Use test function name and process ID to avoid collisions.
with (
tempfile.NamedTemporaryFile(
- delete=False, mode="w+b", prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", suffix=".exc"
+ delete=False,
+ mode="w+b",
+ prefix=f"vllm_test_{func.__name__}_{os.getpid()}_",
+ suffix=".exc",
) as exc_file,
ExitStack() as delete_after,
):
exc_file_path = exc_file.name
delete_after.callback(os.remove, exc_file_path)
+
pid = os.fork()
+ print(f"Fork a new process to run a test {pid}")
if pid == 0:
+ # Parent process responsible for deleting, don't delete
+ # in child.
delete_after.pop_all()
try:
func(*args, **kwargs)
except Skipped as e:
+ # convert Skipped to exit code 0
print(str(e))
os._exit(0)
except Exception as e:
import traceback
tb_string = traceback.format_exc()
+
+ # Try to serialize the exception object first
exc_to_serialize: dict[str, Any]
try:
+ # First, try to pickle the actual exception with
+ # its traceback.
exc_to_serialize = {"pickled_exception": e}
+ # Test if it can be pickled
cloudpickle.dumps(exc_to_serialize)
except (Exception, KeyboardInterrupt):
+ # Fall back to string-based approach.
exc_to_serialize = {
"exception_type": type(e).__name__,
"exception_msg": str(e),
@@ -60,6 +86,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
with open(exc_file_path, "wb") as f:
cloudpickle.dump(exc_to_serialize, f)
except Exception:
+ # Fallback: just print the traceback.
print(tb_string)
os._exit(1)
else:
@@ -67,24 +94,40 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
else:
pgid = os.getpgid(pid)
_pid, _exitcode = os.waitpid(pid, 0)
+ # ignore SIGTERM signal itself
old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
+ # kill all child processes
os.killpg(pgid, signal.SIGTERM)
+ # restore the signal handler
signal.signal(signal.SIGTERM, old_signal_handler)
if _exitcode != 0:
+ # Try to read the exception from the child process
exc_info = {}
if os.path.exists(exc_file_path):
- with suppress(Exception), open(exc_file_path, "rb") as f:
+ with (
+ contextlib.suppress(Exception),
+ open(exc_file_path, "rb") as f,
+ ):
exc_info = cloudpickle.load(f)
- if (original_exception := exc_info.get("pickled_exception")) is not None:
- assert isinstance(original_exception, Exception)
+
+ original_exception = exc_info.get("pickled_exception")
+ if original_exception is not None and isinstance(original_exception, Exception):
+ # Re-raise the actual exception object if it was
+ # successfully pickled.
raise original_exception
+
if (original_tb := exc_info.get("traceback")) is not None:
+ # Use string-based traceback for fallback case
raise AssertionError(
- f"Test {func.__name__} failed when called with args {args} and kwargs {kwargs}"
+ f"Test {func.__name__} failed when called with"
+ f" args {args} and kwargs {kwargs}"
f" (exit code: {_exitcode}):\n{original_tb}"
) from None
+
+ # Fallback to the original generic error
raise AssertionError(
- f"function {func.__name__} failed when called with args {args} and kwargs {kwargs}"
+ f"function {func.__name__} failed when called with"
+ f" args {args} and kwargs {kwargs}"
f" (exit code: {_exitcode})"
) from None
@@ -96,7 +139,9 @@ def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]
@functools.wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
+ # Check if we're already in a subprocess
if os.environ.get("RUNNING_IN_SUBPROCESS") == "1":
+ # If we are, just run the function directly
return f(*args, **kwargs)
import torch.multiprocessing as mp
@@ -104,18 +149,33 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
with suppress(RuntimeError):
mp.set_start_method("spawn")
+ # Get the module
module_name = f.__module__
+
+ # Create a process with environment variable set
env = os.environ.copy()
env["RUNNING_IN_SUBPROCESS"] = "1"
with tempfile.TemporaryDirectory() as tempdir:
output_filepath = os.path.join(tempdir, "new_process.tmp")
+
+ # `cloudpickle` allows pickling complex functions directly
input_bytes = cloudpickle.dumps((f, output_filepath))
+
+ repo_root = str(VLLM_PATH.resolve())
+
+ env = dict(env or os.environ)
+ env["PYTHONPATH"] = repo_root + os.pathsep + env.get("PYTHONPATH", "")
+
cmd = [sys.executable, "-m", f"{module_name}"]
+
returned = subprocess.run(cmd, input=input_bytes, capture_output=True, env=env)
+
+ # check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
+ # wrap raised exception to provide more information
raise RuntimeError(f"Error raised in subprocess:\n{returned.stderr.decode()}") from e
return wrapper
@@ -124,11 +184,27 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
def create_new_process_for_each_test(
method: Literal["spawn", "fork"] | None = None,
) -> Callable[[Callable[_P, None]], Callable[_P, None]]:
- """Creates a decorator that runs each test function in a new process."""
+ """Creates a decorator that runs each test function in a new process.
+
+ Args:
+ method: The process creation method. Can be either "spawn" or "fork".
+ If not specified, it defaults to "spawn" on ROCm and XPU
+ platforms and "fork" otherwise.
+
+ Returns:
+ A decorator to run test functions in separate processes.
+ """
if method is None:
+ # TODO: Find out why spawn is not working correctly on ROCm
+ # The test content will not run and tests passed immediately.
+ # For now, using `fork` for ROCm as it can run with `fork`
+ # and tests are running correctly.
use_spawn = current_platform.is_xpu()
method = "spawn" if use_spawn else "fork"
+
assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'"
+
if method == "fork":
return fork_new_process_for_each_test
+
return spawn_new_process_for_each_test
diff --git a/tests/e2e/online_serving/test_bagel_expansion.py b/tests/e2e/online_serving/test_bagel_expansion.py
index 21fdc314c96..e2d75e0d199 100644
--- a/tests/e2e/online_serving/test_bagel_expansion.py
+++ b/tests/e2e/online_serving/test_bagel_expansion.py
@@ -9,7 +9,6 @@
- Tensor-Parallel
- Ulysses-SP
- Ring-Attention
-- Layerwise Offloading
assert_diffusion_response validates successful generation and the expected
512x512 resolution.
@@ -17,10 +16,13 @@
import pytest
-from tests.helpers.mark import hardware_marks
-from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.conftest import (
+ OmniServer,
+ OmniServerParams,
+ OpenAIClientHandler,
+ dummy_messages_from_mix_data,
+)
+from tests.utils import hardware_marks
PROMPT = "A futuristic city skyline at twilight, cyberpunk style, ultra-detailed, high resolution."
NEGATIVE_PROMPT = "low quality, blurry, distorted, deformed, watermark"
@@ -32,7 +34,7 @@
def _get_diffusion_feature_cases(model: str):
"""Return L4 diffusion feature cases for Bagel.
TeaCache, Cache-DiT, CFG-Parallel, Tensor-Parallel,
- Ulysses-SP, Ring-Attention, Layerwise Offloading.
+ Ulysses-SP, Ring-Attention.
"""
return [
@@ -86,7 +88,7 @@ def _get_diffusion_feature_cases(model: str):
],
),
id="parallel_tp_2",
- marks=[*PARALLEL_FEATURE_MARKS, pytest.mark.skip(reason="issue: #2862")],
+ marks=PARALLEL_FEATURE_MARKS,
),
# Ulysses-SP degree=2 (2 GPUs)
pytest.param(
@@ -112,18 +114,11 @@ def _get_diffusion_feature_cases(model: str):
id="sp_ring_2",
marks=PARALLEL_FEATURE_MARKS,
),
- # Layerwise Offloading (single-card)
- pytest.param(
- OmniServerParams(
- model=model,
- server_args=["--enable-layerwise-offload"],
- ),
- id="single_card_layerwise_offload",
- marks=SINGLE_CARD_FEATURE_MARKS,
- ),
]
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_diffusion_feature_cases("ByteDance-Seed/BAGEL-7B-MoT"),
@@ -142,9 +137,8 @@ def test_bagel(
- Tensor-Parallel (size=2)
- Ulysses-SP (degree=2)
- Ring-Attention (degree=2)
- - Layerwise Offloading
- Validation is delegated to assert_diffusion_response in tests/helpers/assertions.py,
+ Validation is delegated to assert_diffusion_response in tests.conftest,
which checks output dimensions and basic correctness.
"""
diff --git a/tests/e2e/online_serving/test_bagel_online.py b/tests/e2e/online_serving/test_bagel_online.py
index a8ec6548937..ca24f5f81f7 100644
--- a/tests/e2e/online_serving/test_bagel_online.py
+++ b/tests/e2e/online_serving/test_bagel_online.py
@@ -23,19 +23,21 @@
import base64
import os
from io import BytesIO
+from pathlib import Path
import pytest
from vllm.assets.image import ImageAsset
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServerParams
-from tests.helpers.stage_config import get_deploy_config_path
+from tests.conftest import OmniServerParams
+from tests.utils import hardware_test
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
MODEL = "ByteDance-Seed/BAGEL-7B-MoT"
-STAGE_CONFIGS_PATH = get_deploy_config_path("ci/bagel.yaml")
+STAGE_CONFIGS_PATH = str(
+ Path(__file__).parent.parent / "offline_inference" / "stage_configs" / "bagel_sharedmemory_ci.yaml"
+)
TEXT2IMG_PROMPT = "A cute cat"
IMG2IMG_PROMPT = "Change the grass color to red"
@@ -45,7 +47,7 @@
OmniServerParams(
model=MODEL,
stage_config_path=STAGE_CONFIGS_PATH,
- stage_init_timeout=300,
+ server_args=["--stage-init-timeout", "300"],
),
]
diff --git a/tests/e2e/online_serving/test_cosyvoice3_tts.py b/tests/e2e/online_serving/test_cosyvoice3_tts.py
index 76e5e4f49c5..976be805c27 100644
--- a/tests/e2e/online_serving/test_cosyvoice3_tts.py
+++ b/tests/e2e/online_serving/test_cosyvoice3_tts.py
@@ -12,11 +12,12 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
+from pathlib import Path
+
import pytest
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServerParams
-from tests.helpers.stage_config import get_deploy_config_path
+from tests.conftest import OmniServerParams
+from tests.utils import hardware_test
MODEL = "FunAudioLLM/Fun-CosyVoice3-0.5B-2512"
@@ -26,8 +27,8 @@
def get_stage_config(name: str = "cosyvoice3.yaml"):
- """Get the deploy config path for CosyVoice3."""
- return get_deploy_config_path(name)
+ """Get the stage config path from vllm_omni model_executor stage_configs."""
+ return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
def get_prompt(prompt_type="zh"):
@@ -39,28 +40,18 @@ def get_prompt(prompt_type="zh"):
tts_server_params = [
- pytest.param(
- OmniServerParams(
- model=MODEL,
- stage_config_path=get_stage_config(),
- server_args=["--trust-remote-code", "--disable-log-stats", "--no-async-chunk"],
- ),
- id="cosyvoice3",
- )
-]
-
-tts_async_chunk_server_params = [
pytest.param(
OmniServerParams(
model=MODEL,
stage_config_path=get_stage_config(),
server_args=["--trust-remote-code", "--disable-log-stats"],
),
- id="cosyvoice3_async_chunk",
+ id="cosyvoice3",
)
]
+@pytest.mark.advanced_model
@pytest.mark.core_model
@pytest.mark.omni
@hardware_test(res={"cuda": "H100"}, num_cards=1)
@@ -85,16 +76,16 @@ def test_voice_clone_zh_001(omni_server, openai_client) -> None:
openai_client.send_audio_speech_request(request_config)
-@pytest.mark.core_model
+@pytest.mark.advanced_model
@pytest.mark.omni
@hardware_test(res={"cuda": "H100"}, num_cards=1)
-@pytest.mark.parametrize("omni_server", tts_async_chunk_server_params, indirect=True)
+@pytest.mark.parametrize("omni_server", tts_server_params, indirect=True)
def test_voice_clone_zh_002(omni_server, openai_client) -> None:
"""
- Test voice cloning TTS with Chinese text via async_chunk streaming.
- Deploy Setting: cosyvoice3.yaml with default ``async_chunk: true``
+ Test voice cloning TTS with Chinese text via OpenAI API.
+ Deploy Setting: default yaml
Input Modal: text + ref_audio + ref_text
- Output Modal: audio (streamed)
+ Output Modal: audio
Input Setting: stream=True
Datasets: single request
"""
@@ -109,7 +100,7 @@ def test_voice_clone_zh_002(omni_server, openai_client) -> None:
openai_client.send_audio_speech_request(request_config)
-@pytest.mark.core_model
+@pytest.mark.advanced_model
@pytest.mark.omni
@hardware_test(res={"cuda": "H100"}, num_cards=1)
@pytest.mark.parametrize("omni_server", tts_server_params, indirect=True)
diff --git a/tests/e2e/online_serving/test_diffusers_adapter.py b/tests/e2e/online_serving/test_diffusers_adapter.py
deleted file mode 100644
index 8b41db13a53..00000000000
--- a/tests/e2e/online_serving/test_diffusers_adapter.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""End-to-end tests for DiffusersAdapterPipeline.
-
-It tests the full user flow of launching a diffusers-backed model and running inference.
-"""
-
-import pytest
-from PIL import Image
-
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.core_model]
-
-
-@pytest.mark.parametrize(
- "omni_server",
- [
- OmniServerParams(
- model="tiny-random/Qwen-Image",
- server_args=[
- "--diffusion-load-format",
- "diffusers",
- "--diffusers-call-kwargs",
- '{"height": 512, "width": 0}', # deliberately weird width to be overridden
- ],
- ),
- ],
- indirect=True,
-)
-@hardware_test(res={"cuda": "L4"}, num_cards=1)
-def test_t2i_with_diffusers_adapter(
- omni_server: OmniServer,
- openai_client: OpenAIClientHandler,
-):
- messages = dummy_messages_from_mix_data(content_text="a photo of an astronaut riding a horse on mars")
-
- request_config = {
- "model": omni_server.model,
- "messages": messages,
- "extra_body": {
- "width": 512,
- "num_inference_steps": 2,
- "negative_prompt": "blurry",
- "true_cfg_scale": 4.0,
- "seed": 42,
- },
- }
-
- response = openai_client.send_diffusion_request(request_config)
- image: Image.Image = response[0].images[0] # pyright: ignore[reportOptionalSubscript]
-
- # Request config has incomplete width/height, so internal assertion in `send_diffusion_request` is incomplete.
- assert image.size == (512, 512)
diff --git a/tests/e2e/online_serving/test_dynin_omni_expansion.py b/tests/e2e/online_serving/test_dynin_omni_expansion.py
deleted file mode 100644
index da179dbc802..00000000000
--- a/tests/e2e/online_serving/test_dynin_omni_expansion.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""
-Example online tests for Dynin-Omni model.
-"""
-
-import base64
-import os
-from io import BytesIO
-from pathlib import Path
-
-import pytest
-from vllm.assets.image import ImageAsset
-
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServerParams
-
-pytestmark = [pytest.mark.full_model, pytest.mark.omni]
-
-os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-
-MODEL = "snu-aidas/Dynin-Omni"
-STAGE_CONFIG = str(Path(__file__).parent.parent / "stage_configs" / "dynin_omni_ci.yaml")
-
-T2I_PROMPT = "A high quality detailed living room interior photo."
-T2S_PROMPT = "Please read this sentence naturally: Hello from online serving."
-I2I_PROMPT = "Transform this outdoor nature boardwalk scene into a painting style with vivid colors."
-
-TEST_PARAMS = [OmniServerParams(model=MODEL, stage_config_path=STAGE_CONFIG, stage_init_timeout=600)]
-_STAGE_COUNT = 3
-_I2I_STAGE_SAMPLING = {"max_tokens": 1, "temperature": 0.0, "top_p": 1.0, "detokenize": False}
-
-
-def _build_t2i_messages(prompt: str) -> list[dict]:
- return [{"role": "user", "content": [{"type": "text", "text": f"<|t2i|> {prompt}"}]}]
-
-
-def _build_t2s_messages(prompt: str) -> list[dict]:
- return [{"role": "user", "content": [{"type": "text", "text": f"<|t2s|> {prompt}"}]}]
-
-
-def _build_i2i_messages(prompt: str) -> list[dict]:
- input_image = ImageAsset("2560px-Gfp-wisconsin-madison-the-nature-boardwalk").pil_image.convert("RGB")
- buffer = BytesIO()
- input_image.save(buffer, format="JPEG")
- image_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
- return [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": f"<|i2i|> {prompt}"},
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_b64}"}},
- ],
- }
- ]
-
-
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"})
-@pytest.mark.parametrize("omni_server", TEST_PARAMS, indirect=True)
-def test_send_i2i_request_001(omni_server, openai_client) -> None:
- request_config = {
- "model": omni_server.model,
- "messages": _build_i2i_messages(I2I_PROMPT),
- "modalities": ["image"],
- "extra_body": {
- "sampling_params_list": [dict(_I2I_STAGE_SAMPLING) for _ in range(_STAGE_COUNT)],
- },
- }
- openai_client.send_diffusion_request(request_config)
-
-
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"})
-@pytest.mark.parametrize("omni_server", TEST_PARAMS, indirect=True)
-def test_send_t2i_request_001(omni_server, openai_client) -> None:
- request_config = {
- "model": omni_server.model,
- "messages": _build_t2i_messages(T2I_PROMPT),
- "modalities": ["image"],
- }
- openai_client.send_diffusion_request(request_config)
-
-
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"})
-@pytest.mark.parametrize("omni_server", TEST_PARAMS, indirect=True)
-def test_send_t2s_request_001(omni_server, openai_client) -> None:
- request_config = {
- "model": omni_server.model,
- "messages": _build_t2s_messages(T2S_PROMPT),
- "modalities": ["audio"],
- "audio_ref_text": T2S_PROMPT,
- }
- openai_client.send_omni_request(request_config)
diff --git a/tests/e2e/online_serving/test_flux2_expansion.py b/tests/e2e/online_serving/test_flux2_expansion.py
index 9a2b164b357..0e9e8c89a6a 100644
--- a/tests/e2e/online_serving/test_flux2_expansion.py
+++ b/tests/e2e/online_serving/test_flux2_expansion.py
@@ -1,20 +1,16 @@
"""
Tests for Flux2 Klein; currently Dev is implemented separately,
but ideally these models will fold together in the future.
-
-Coverage:
-- FP8 + CacheDiT + Ulysses=2 + TP=2
-- Layerwise CPU offload + Ulysses=2 + Ring=2
-- Layerwise CPU offload + TP=2
-- Layerwise CPU offload + HSDP
"""
import pytest
-from tests.helpers.mark import hardware_marks
-from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.conftest import (
+ OmniServer,
+ OmniServerParams,
+ OpenAIClientHandler,
+)
+from tests.utils import hardware_marks
FOUR_CARD_FEATURE_MARKS = hardware_marks(res={"cuda": "L4"}, num_cards=4)
POSITIVE_PROMPT = "A cat sitting on a windowsill"
@@ -46,48 +42,11 @@ def _get_diffusion_feature_cases(model: str):
),
marks=FOUR_CARD_FEATURE_MARKS,
),
- pytest.param(
- OmniServerParams(
- model=model,
- server_args=[
- "--enable-layerwise-offload",
- "--ulysses-degree",
- "2",
- "--ring",
- "2",
- ],
- ),
- id="layerwise_ulysses2_ring2",
- marks=FOUR_CARD_FEATURE_MARKS,
- ),
- pytest.param(
- OmniServerParams(
- model=model,
- server_args=[
- "--enable-layerwise-offload",
- "--tensor-parallel-size",
- "2",
- ],
- ),
- id="layerwise_tp2",
- marks=FOUR_CARD_FEATURE_MARKS,
- ),
- pytest.param(
- OmniServerParams(
- model=model,
- server_args=[
- "--enable-layerwise-offload",
- "--use-hsdp",
- "--hsdp-shard-size",
- "2",
- ],
- ),
- id="layerwise_hsdp",
- marks=FOUR_CARD_FEATURE_MARKS,
- ),
]
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_diffusion_feature_cases(
diff --git a/tests/e2e/online_serving/test_flux2_klein_inpaint_expansion.py b/tests/e2e/online_serving/test_flux2_klein_inpaint_expansion.py
deleted file mode 100644
index a3d4f004e8a..00000000000
--- a/tests/e2e/online_serving/test_flux2_klein_inpaint_expansion.py
+++ /dev/null
@@ -1,156 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""
-End-to-end tests for Flux2 Klein inpainting in online serving mode.
-
-Uses /v1/images/edits endpoint which is the correct API for image inpainting.
-"""
-
-import base64
-from io import BytesIO
-
-import httpx
-import pytest
-from PIL import Image, ImageDraw
-
-from tests.helpers.runtime import OmniServer, OmniServerParams
-
-pytestmark = [pytest.mark.full_model, pytest.mark.diffusion]
-
-MODEL = "black-forest-labs/FLUX.2-klein-4B"
-
-_HEIGHT = 512
-_WIDTH = 512
-_NUM_INFERENCE_STEPS = 4
-
-
-def _get_diffusion_feature_cases(model: str):
- return [
- pytest.param(
- OmniServerParams(
- model=model,
- server_args=["--tensor-parallel-size", "2"],
- ),
- id="tp2_basic",
- ),
- ]
-
-
-def _image_to_base64_jpeg(image: Image.Image) -> str:
- buffer = BytesIO()
- image.save(buffer, format="JPEG")
- buffer.seek(0)
- return base64.b64encode(buffer.read()).decode("utf-8")
-
-
-def _create_test_mask_base64(width: int = _WIDTH, height: int = _HEIGHT) -> str:
- mask = Image.new("L", (width, height), 0)
- draw = ImageDraw.Draw(mask)
- draw.rectangle([width // 4, height // 4, width * 3 // 4, height * 3 // 4], fill=255)
- return _image_to_base64_jpeg(mask)
-
-
-def _compare_images(img1: Image.Image, img2: Image.Image) -> bool:
- return list(img1.getdata()) == list(img2.getdata())
-
-
-def _send_edit_request(host: str, port: int, model: str, image_b64: str, mask_b64: str, prompt: str, **kwargs):
- url = f"http://{host}:{port}/v1/images/edits"
- files = {
- "image": ("image.jpg", base64.b64decode(image_b64), "image/jpeg"),
- "mask_image": ("mask.jpg", base64.b64decode(mask_b64), "image/jpeg"),
- }
- data = {"prompt": prompt, "model": model, **kwargs}
- with httpx.Client(timeout=60.0) as client:
- response = client.post(url, files=files, data=data)
- response.raise_for_status()
- return response.json()
-
-
-@pytest.mark.parametrize("omni_server", _get_diffusion_feature_cases(MODEL), indirect=True)
-def test_flux2_klein_inpaint_basic(omni_server: OmniServer):
- input_image_b64 = _image_to_base64_jpeg(Image.new("RGB", (_WIDTH, _HEIGHT), (128, 128, 128)))
- mask_b64 = _create_test_mask_base64()
-
- result = _send_edit_request(
- host=omni_server.host,
- port=omni_server.port,
- model=MODEL,
- image_b64=input_image_b64,
- mask_b64=mask_b64,
- prompt="Fill in the masked area with a beautiful garden",
- guidance_scale=1.0,
- num_inference_steps=_NUM_INFERENCE_STEPS,
- n=1,
- seed=42,
- )
-
- assert "data" in result and len(result["data"]) == 1
- img_data = result["data"][0].get("b64_json") or result["data"][0].get("url", "").split(",")[-1]
- img = Image.open(BytesIO(base64.b64decode(img_data)))
- assert img.size == (_WIDTH, _HEIGHT)
-
-
-@pytest.mark.parametrize("omni_server", _get_diffusion_feature_cases(MODEL), indirect=True)
-def test_flux2_klein_inpaint_deterministic(omni_server: OmniServer):
- input_image_b64 = _image_to_base64_jpeg(Image.new("RGB", (_WIDTH, _HEIGHT), (128, 128, 128)))
- mask_b64 = _create_test_mask_base64()
- prompt = "A red flower in a field"
-
- result1 = _send_edit_request(
- host=omni_server.host,
- port=omni_server.port,
- model=MODEL,
- image_b64=input_image_b64,
- mask_b64=mask_b64,
- prompt=prompt,
- guidance_scale=1.0,
- num_inference_steps=_NUM_INFERENCE_STEPS,
- n=1,
- seed=12345,
- )
-
- result2 = _send_edit_request(
- host=omni_server.host,
- port=omni_server.port,
- model=MODEL,
- image_b64=input_image_b64,
- mask_b64=mask_b64,
- prompt=prompt,
- guidance_scale=1.0,
- num_inference_steps=_NUM_INFERENCE_STEPS,
- n=1,
- seed=12345,
- )
-
- img1_data = result1["data"][0].get("b64_json") or result1["data"][0].get("url", "").split(",")[-1]
- img2_data = result2["data"][0].get("b64_json") or result2["data"][0].get("url", "").split(",")[-1]
-
- img1 = Image.open(BytesIO(base64.b64decode(img1_data)))
- img2 = Image.open(BytesIO(base64.b64decode(img2_data)))
-
- assert _compare_images(img1, img2), (
- "Same input with same seed should produce identical output. This is critical for offline/online consistency."
- )
-
-
-@pytest.mark.parametrize("omni_server", _get_diffusion_feature_cases(MODEL), indirect=True)
-def test_flux2_klein_inpaint_multiple_outputs(omni_server: OmniServer):
- input_image_b64 = _image_to_base64_jpeg(Image.new("RGB", (_WIDTH, _HEIGHT), (128, 128, 128)))
- mask_b64 = _create_test_mask_base64()
-
- result = _send_edit_request(
- host=omni_server.host,
- port=omni_server.port,
- model=MODEL,
- image_b64=input_image_b64,
- mask_b64=mask_b64,
- prompt="A beautiful landscape",
- guidance_scale=1.0,
- num_inference_steps=_NUM_INFERENCE_STEPS,
- n=2,
- seed=42,
- )
-
- assert "data" in result and len(result["data"]) == 2
diff --git a/tests/e2e/online_serving/test_flux_2_dev_expansion.py b/tests/e2e/online_serving/test_flux_2_dev_expansion.py
index 953cb448a30..eba0fbda225 100644
--- a/tests/e2e/online_serving/test_flux_2_dev_expansion.py
+++ b/tests/e2e/online_serving/test_flux_2_dev_expansion.py
@@ -2,11 +2,13 @@
End-to-end diffusion coverage for FLUX.2-dev in online serving mode.
Coverage:
+- Cache-DiT cache acceleration backend
- CPU offload
-This test verifies that FLUX.2-dev can be launched with CPU offload enabled,
-accepts text-to-image requests through the OpenAI-compatible API, and returns
-valid generated images with the requested resolution.
+This test verifies that FLUX.2-dev can be launched with the Cache-DiT backend
+and CPU offload enabled, accepts text-to-image requests through the
+OpenAI-compatible API, and returns valid generated images with the requested
+resolution.
assert_diffusion_response validates successful generation and the expected
image resolution.
@@ -14,48 +16,42 @@
import pytest
-from tests.helpers.mark import hardware_marks
-from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.conftest import (
+ OmniServer,
+ OmniServerParams,
+ OpenAIClientHandler,
+ dummy_messages_from_mix_data,
+)
+from tests.utils import hardware_marks
MODEL = "black-forest-labs/FLUX.2-dev"
PROMPT = "A cinematic mountain landscape at sunrise, dramatic clouds, ultra-detailed, realistic photography."
NEGATIVE_PROMPT = "low quality, blurry, distorted, deformed, watermark"
SINGLE_CARD_FEATURE_MARKS = hardware_marks(res={"cuda": "H100"})
-PARALLEL_FEATURE_MARKS = hardware_marks(res={"cuda": "H100"}, num_cards=2)
def _get_flux_2_dev_feature_cases(model: str):
- """Return FLUX.2-dev diffusion feature cases for CPU offload."""
+ """Return FLUX.2-dev diffusion feature cases for Cache-DiT + CPU offload."""
return [
pytest.param(
OmniServerParams(
model=model,
server_args=[
+ "--cache-backend",
+ "cache_dit",
"--enable-cpu-offload",
],
),
- id="cpu_offload",
+ id="cache_dit_cpu_offload",
marks=SINGLE_CARD_FEATURE_MARKS,
),
- pytest.param(
- OmniServerParams(
- model=model,
- server_args=[
- "--enable-cpu-offload",
- "--cfg-parallel-size",
- "2",
- ],
- ),
- id="parallel_cfg_2",
- marks=PARALLEL_FEATURE_MARKS,
- ),
]
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_flux_2_dev_feature_cases(MODEL),
@@ -65,7 +61,7 @@ def test_flux_2_dev(
omni_server: OmniServer,
openai_client: OpenAIClientHandler,
):
- """Validate FLUX.2-dev online serving with CPU offload."""
+ """Validate FLUX.2-dev online serving with Cache-DiT and CPU offload."""
messages = dummy_messages_from_mix_data(content_text=PROMPT)
diff --git a/tests/e2e/online_serving/test_flux_kontext_expansion.py b/tests/e2e/online_serving/test_flux_kontext_expansion.py
index fd7d7d1b484..c13e1e8189d 100644
--- a/tests/e2e/online_serving/test_flux_kontext_expansion.py
+++ b/tests/e2e/online_serving/test_flux_kontext_expansion.py
@@ -5,10 +5,13 @@
import pytest
-from tests.helpers.media import generate_synthetic_image
-from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.conftest import (
+ OmniServer,
+ OmniServerParams,
+ OpenAIClientHandler,
+ dummy_messages_from_mix_data,
+ generate_synthetic_image,
+)
EDIT_PROMPT = "Transform this modern, geometrist image into a Vincent van Gogh style impressionist painting."
NEGATIVE_PROMPT = "blurry, low quality, modern, geometrist"
@@ -31,6 +34,8 @@ def _get_diffusion_feature_cases(model: str):
]
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_diffusion_feature_cases(MODEL),
@@ -54,6 +59,8 @@ def test_flux_kontext_text_to_image(omni_server: OmniServer, openai_client: Open
openai_client.send_diffusion_request(request_config)
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_diffusion_feature_cases(MODEL),
@@ -81,6 +88,8 @@ def test_flux_kontext_image_edit(omni_server: OmniServer, openai_client: OpenAIC
openai_client.send_diffusion_request(request_config)
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_diffusion_feature_cases(MODEL),
@@ -106,6 +115,8 @@ def test_flux_kontext_image_edit_no_negative(omni_server: OmniServer, openai_cli
openai_client.send_diffusion_request(request_config)
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_diffusion_feature_cases(MODEL),
@@ -129,6 +140,8 @@ def test_flux_kontext_high_resolution(omni_server: OmniServer, openai_client: Op
openai_client.send_diffusion_request(request_config)
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_diffusion_feature_cases(MODEL),
diff --git a/tests/e2e/online_serving/test_hunyuan_video_15_expansion.py b/tests/e2e/online_serving/test_hunyuan_video_15_expansion.py
index 681d8ff23b1..de950edb900 100644
--- a/tests/e2e/online_serving/test_hunyuan_video_15_expansion.py
+++ b/tests/e2e/online_serving/test_hunyuan_video_15_expansion.py
@@ -11,10 +11,12 @@
import pytest
-from tests.helpers.mark import hardware_marks
-from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.conftest import (
+ OmniServer,
+ OmniServerParams,
+ OpenAIClientHandler,
+)
+from tests.utils import hardware_marks
PROMPT = "A cat walking across a sunlit garden, cinematic lighting, slow motion."
NEGATIVE_PROMPT = "low quality, blurry, distorted"
@@ -64,6 +66,8 @@ def _get_diffusion_feature_cases(model: str):
]
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_diffusion_feature_cases(MODEL),
diff --git a/tests/e2e/online_serving/test_image_gen_edit.py b/tests/e2e/online_serving/test_image_gen_edit.py
index 56747abd16e..7db740f2037 100644
--- a/tests/e2e/online_serving/test_image_gen_edit.py
+++ b/tests/e2e/online_serving/test_image_gen_edit.py
@@ -22,7 +22,7 @@
from vllm.assets.image import ImageAsset
from vllm.utils.network_utils import get_open_port
-from tests.helpers.mark import hardware_test
+from tests.utils import hardware_test
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
# Increase timeout for downloading assets from S3 (default 5s is too short for CI)
diff --git a/tests/e2e/online_serving/test_images_generations_lora.py b/tests/e2e/online_serving/test_images_generations_lora.py
index 931a572878e..8c826591a56 100644
--- a/tests/e2e/online_serving/test_images_generations_lora.py
+++ b/tests/e2e/online_serving/test_images_generations_lora.py
@@ -22,13 +22,13 @@
from PIL import Image
from safetensors.torch import save_file
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServer
+from tests.conftest import OmniServer
+from tests.utils import hardware_test
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
MODEL = "Tongyi-MAI/Z-Image-Turbo"
-DIFFUSION_INIT_TIMEOUT_S = 900
+DIFFUSION_INIT_TIMEOUT_S = 700
PROMPT = "a photo of a cat sitting on a laptop keyboard"
diff --git a/tests/e2e/online_serving/test_longcat_image_edit_expansion.py b/tests/e2e/online_serving/test_longcat_image_edit_expansion.py
index 4be5fe42d62..8a2cfbcc145 100644
--- a/tests/e2e/online_serving/test_longcat_image_edit_expansion.py
+++ b/tests/e2e/online_serving/test_longcat_image_edit_expansion.py
@@ -13,11 +13,14 @@
import pytest
-from tests.helpers.mark import hardware_marks
-from tests.helpers.media import generate_synthetic_image
-from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.conftest import (
+ OmniServer,
+ OmniServerParams,
+ OpenAIClientHandler,
+ dummy_messages_from_mix_data,
+ generate_synthetic_image,
+)
+from tests.utils import hardware_marks
EDIT_PROMPT = "Transform this modern image into a cinematic animation style with vibrant colors and soft lighting."
NEGATIVE_PROMPT = "blurry, low quality, distorted, oversaturated"
@@ -52,6 +55,8 @@ def _get_diffusion_feature_cases(model: str):
]
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_diffusion_feature_cases("meituan-longcat/LongCat-Image-Edit"),
diff --git a/tests/e2e/online_serving/test_longcat_image_expansion.py b/tests/e2e/online_serving/test_longcat_image_expansion.py
index b9fd858b052..161e7cd2e65 100644
--- a/tests/e2e/online_serving/test_longcat_image_expansion.py
+++ b/tests/e2e/online_serving/test_longcat_image_expansion.py
@@ -13,10 +13,13 @@
import pytest
-from tests.helpers.mark import hardware_marks
-from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.conftest import (
+ OmniServer,
+ OmniServerParams,
+ OpenAIClientHandler,
+ dummy_messages_from_mix_data,
+)
+from tests.utils import hardware_marks
TEXT_TO_IMAGE_PROMPT = (
"A cinematic illustration of a cat typing on a silver laptop, soft window light, highly detailed."
@@ -53,6 +56,8 @@ def _get_diffusion_feature_cases(model: str):
]
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_diffusion_feature_cases("meituan-longcat/LongCat-Image"),
diff --git a/tests/e2e/online_serving/test_mimo_audio.py b/tests/e2e/online_serving/test_mimo_audio.py
index df00c64161e..43eeb773355 100644
--- a/tests/e2e/online_serving/test_mimo_audio.py
+++ b/tests/e2e/online_serving/test_mimo_audio.py
@@ -9,10 +9,13 @@
import pytest
-from tests.helpers.mark import hardware_test
-from tests.helpers.media import generate_synthetic_audio
-from tests.helpers.runtime import OmniServerParams, dummy_messages_from_mix_data
-from tests.helpers.stage_config import get_deploy_config_path
+from tests.conftest import (
+ OmniServerParams,
+ dummy_messages_from_mix_data,
+ generate_synthetic_audio,
+ modify_stage_config,
+)
+from tests.utils import hardware_test
from vllm_omni.model_executor.model_loader.weight_utils import (
download_weights_from_hf_specific,
)
@@ -26,6 +29,26 @@
models = ["XiaomiMiMo/MiMo-Audio-7B-Instruct"]
+def get_chunk_config():
+ path = modify_stage_config(
+ str(Path(__file__).parent.parent / "stage_configs" / "mimo_audio_ci.yaml"),
+ updates={
+ "async_chunk": True,
+ "stage_args": {
+ 0: {
+ "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.mimo_audio.llm2code2wav_async_chunk"
+ },
+ 1: {
+ "engine_args.max_model_len": 8192,
+ "engine_args.max_num_batched_tokens": 8192,
+ },
+ },
+ },
+ deletes={"stage_args": {1: ["custom_process_input_func"]}},
+ )
+ return path
+
+
def download_tokenizer():
tokenizer_path = os.environ.get("MIMO_AUDIO_TOKENIZER_PATH", MIMO_AUDIO_TOKENIZER_REPO)
if os.path.exists(tokenizer_path):
@@ -39,21 +62,19 @@ def download_tokenizer():
return local_path
+# CI stage config for H100 / MI325
# Guard module-level setup so test collection doesn't fail in environments
# where the model cache is read-only or models aren't available.
try:
- stage_configs = [get_deploy_config_path("mimo_audio.yaml")]
+ stage_configs = [get_chunk_config()]
tokenizer_path = download_tokenizer()
os.environ["MIMO_AUDIO_TOKENIZER_PATH"] = tokenizer_path
- # --load-format dummy applies to every stage pipeline-wide, avoiding a
- # per-stage yaml rewrite (the old approach wrote a tempfile + atexit-unlink
- # which raced with CI's process lifecycle).
test_params = [
OmniServerParams(
model=model,
stage_config_path=stage_config,
- server_args=["--chat-template", CHAT_TEMPLATE_PATH, "--load-format", "dummy"],
+ server_args=["--chat-template", CHAT_TEMPLATE_PATH],
)
for model in models
for stage_config in stage_configs
diff --git a/tests/e2e/online_serving/test_ming_flash_omni.py b/tests/e2e/online_serving/test_ming_flash_omni.py
deleted file mode 100644
index 8161c438929..00000000000
--- a/tests/e2e/online_serving/test_ming_flash_omni.py
+++ /dev/null
@@ -1,246 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""
-E2E online serving tests for Ming-flash-omni-2.0 model (Thinker stage).
-Tests multimodal understanding via OpenAI-compatible API.
-"""
-
-import os
-from pathlib import Path
-
-import pytest
-
-from tests.helpers.mark import hardware_test
-from tests.helpers.media import (
- generate_synthetic_audio,
- generate_synthetic_image,
- generate_synthetic_video,
-)
-from tests.helpers.runtime import OmniServerParams, dummy_messages_from_mix_data
-from tests.helpers.stage_config import modify_stage_config
-
-os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-
-models = ["Jonathan1909/Ming-flash-omni-2.0"]
-
-
-def get_eager_config():
- path = modify_stage_config(
- str(Path(__file__).parent.parent / "stage_configs" / "bailingmm_moe_v2_lite_ci.yaml"),
- updates={
- "stage_args": {
- 0: {
- "engine_args.enforce_eager": "true",
- },
- },
- },
- )
- return path
-
-
-stage_configs = [get_eager_config()]
-
-# Create parameter combinations for model and stage config
-test_params = [
- OmniServerParams(model=model, stage_config_path=stage_config) for model in models for stage_config in stage_configs
-]
-
-
-def get_system_prompt():
- return {
- "role": "system",
- "content": [
- {
- "type": "text",
- "text": "你是一个友好的AI助手。\n\ndetailed thinking off",
- }
- ],
- }
-
-
-def get_prompt(prompt_type="text_only"):
- prompts = {
- "text_only": "What is the capital of China? Answer in 20 words.",
- "text_image": "What is in this image?",
- "text_audio": "What is in this audio?",
- "text_video": "What is in this video?",
- "mix": "What is recited in the audio? What is in this image? What is in this video?",
- }
- return prompts.get(prompt_type, prompts["text_only"])
-
-
-def get_max_batch_size(size_type="few"):
- batch_sizes = {"few": 5, "medium": 100, "large": 256}
- return batch_sizes.get(size_type, 5)
-
-
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=4)
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-def test_text_to_text_001(omni_server, openai_client) -> None:
- """
- Input Modal: text
- Output Modal: text
- Input Setting: stream=False
- Datasets: single request
- """
- messages = dummy_messages_from_mix_data(
- system_prompt=get_system_prompt(),
- content_text=get_prompt("text_only"),
- )
-
- request_config = {
- "model": omni_server.model,
- "messages": messages,
- "stream": False,
- "modalities": ["text"],
- "key_words": {"text": ["beijing"]},
- }
-
- openai_client.send_omni_request(request_config)
-
-
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=4)
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-def test_text_to_text_stream_001(omni_server, openai_client) -> None:
- """
- Input Modal: text
- Output Modal: text
- Input Setting: stream=True
- Datasets: few requests
- """
- messages = dummy_messages_from_mix_data(
- system_prompt=get_system_prompt(),
- content_text=get_prompt("text_only"),
- )
-
- request_config = {
- "model": omni_server.model,
- "messages": messages,
- "stream": True,
- "modalities": ["text"],
- "key_words": {"text": ["beijing"]},
- }
-
- openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
-
-
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=4)
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-def test_image_to_text_001(omni_server, openai_client) -> None:
- """
- Input Modal: image + text
- Output Modal: text
- Input Setting: stream=True
- Datasets: single request
- """
- image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
- messages = dummy_messages_from_mix_data(
- system_prompt=get_system_prompt(),
- image_data_url=image_data_url,
- content_text=get_prompt("text_image"),
- )
-
- request_config = {
- "model": omni_server.model,
- "messages": messages,
- "stream": True,
- "modalities": ["text"],
- }
-
- openai_client.send_omni_request(request_config)
-
-
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=4)
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-def test_audio_to_text_001(omni_server, openai_client) -> None:
- """
- Input Modal: audio + text
- Output Modal: text
- Input Setting: stream=True
- Datasets: single request
- """
- audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(2, 1)['base64']}"
- messages = dummy_messages_from_mix_data(
- system_prompt=get_system_prompt(),
- audio_data_url=audio_data_url,
- content_text=get_prompt("text_audio"),
- )
-
- request_config = {
- "model": omni_server.model,
- "messages": messages,
- "stream": True,
- "modalities": ["text"],
- }
-
- openai_client.send_omni_request(request_config)
-
-
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=4)
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-def test_video_to_text_001(omni_server, openai_client) -> None:
- """
- Input Modal: video + text
- Output Modal: text
- Input Setting: stream=False
- Datasets: single request
- """
- video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}"
- messages = dummy_messages_from_mix_data(
- system_prompt=get_system_prompt(),
- video_data_url=video_data_url,
- content_text=get_prompt("text_video"),
- )
-
- request_config = {
- "model": omni_server.model,
- "messages": messages,
- "stream": False,
- "modalities": ["text"],
- }
-
- openai_client.send_omni_request(request_config)
-
-
-@pytest.mark.advanced_model
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100"}, num_cards=4)
-@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-def test_mix_to_text_001(omni_server, openai_client) -> None:
- """
- Input Modal: text + audio + image + video
- Output Modal: text
- Input Setting: stream=True
- Datasets: single request
- """
- video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}"
- image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
- audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(2, 1)['base64']}"
- messages = dummy_messages_from_mix_data(
- system_prompt=get_system_prompt(),
- video_data_url=video_data_url,
- image_data_url=image_data_url,
- audio_data_url=audio_data_url,
- content_text=get_prompt("mix"),
- )
-
- request_config = {
- "model": omni_server.model,
- "messages": messages,
- "stream": True,
- "modalities": ["text"],
- }
-
- openai_client.send_omni_request(request_config)
diff --git a/tests/e2e/online_serving/test_nextstep_expansion.py b/tests/e2e/online_serving/test_nextstep_expansion.py
deleted file mode 100644
index 75290db5d29..00000000000
--- a/tests/e2e/online_serving/test_nextstep_expansion.py
+++ /dev/null
@@ -1,71 +0,0 @@
-"""
-Online serving E2E for NextStep-1.1 text-to-image (tensor parallel).
-"""
-
-import os
-
-import pytest
-
-from tests.helpers.mark import hardware_marks
-from tests.helpers.runtime import (
- OmniServer,
- OmniServerParams,
- OpenAIClientHandler,
- dummy_messages_from_mix_data,
-)
-
-pytestmark = [pytest.mark.full_model, pytest.mark.diffusion]
-
-# L4: 4 GPUs + TP=4; XPU B60: 2 cards (use num_cards={"cuda": 4, "xpu": 4} if needed)
-FOUR_CARD_MARKS = hardware_marks(
- res={"cuda": "L4", "xpu": "B60"},
- num_cards={"cuda": 2, "xpu": 2},
-)
-
-POSITIVE_PROMPT = "A small red barn in a snowy field, simple illustration."
-NEGATIVE_PROMPT = "blurry, low quality"
-
-_DEFAULT_MODEL = "stepfun-ai/NextStep-1.1"
-
-
-def _get_diffusion_feature_cases(model: str):
- """Single online config: TP=4, explicit pipeline class."""
- return [
- pytest.param(
- OmniServerParams(
- model=model,
- server_args=[
- "--tensor-parallel-size",
- "2",
- "--model-class-name",
- "NextStep11Pipeline",
- ],
- ),
- id="nextstep_tp4_pipeline",
- marks=FOUR_CARD_MARKS,
- ),
- ]
-
-
-@pytest.mark.parametrize(
- "omni_server",
- _get_diffusion_feature_cases(model=os.environ.get("VLLM_TEST_NEXTSTEP_MODEL", _DEFAULT_MODEL)),
- indirect=True,
-)
-def test_nextstep_11(omni_server: OmniServer, openai_client: OpenAIClientHandler):
- messages = dummy_messages_from_mix_data(content_text=POSITIVE_PROMPT)
- request_config = {
- "model": omni_server.model,
- "messages": messages,
- "extra_body": {
- "height": 512,
- "width": 512,
- "num_inference_steps": 2,
- "guidance_scale": 5.0,
- "guidance_scale_2": 1.0,
- "negative_prompt": NEGATIVE_PROMPT,
- "seed": 42,
- },
- }
-
- openai_client.send_diffusion_request(request_config)
diff --git a/tests/e2e/online_serving/test_omnivoice.py b/tests/e2e/online_serving/test_omnivoice.py
index 892896e05c7..ec1981aab22 100644
--- a/tests/e2e/online_serving/test_omnivoice.py
+++ b/tests/e2e/online_serving/test_omnivoice.py
@@ -17,16 +17,8 @@
import httpx
import pytest
-from tests.helpers.mark import hardware_test
-from tests.helpers.media import generate_synthetic_audio
-from tests.helpers.runtime import OmniServerParams
-
-try:
- from transformers import HiggsAudioV2TokenizerModel # noqa: F401
-
- _HAS_VOICE_CLONE = True
-except ImportError:
- _HAS_VOICE_CLONE = False
+from tests.conftest import OmniServerParams
+from tests.utils import hardware_test
MODEL = "k2-fsa/OmniVoice"
@@ -48,16 +40,6 @@
MIN_AUDIO_BYTES = 5000
-def _get_ref_audio_b64() -> str:
- """Generate synthetic speech for reference audio.
-
- Returns:
- Base64 data URL string (data:audio/wav;base64,...)
- """
- audio_data = generate_synthetic_audio(duration=2, num_channels=1, sample_rate=24000)
- return f"data:audio/wav;base64,{audio_data['base64']}"
-
-
def make_speech_request(
host: str,
port: int,
@@ -100,102 +82,3 @@ def test_speech_auto_voice(self, omni_server) -> None:
assert len(response.content) > MIN_AUDIO_BYTES, (
f"Audio too small ({len(response.content)} bytes), expected > {MIN_AUDIO_BYTES}"
)
-
-
-def make_voice_clone_request(
- host: str,
- port: int,
- text: str,
- ref_audio_b64: str,
- ref_text: str | None = None,
- timeout: float = 180.0,
-) -> httpx.Response:
- """Make a voice cloning request to the /v1/audio/speech endpoint.
-
- Args:
- host: Server host
- port: Server port
- text: Text to synthesize
- ref_audio_b64: Base64-encoded reference audio data URL
- ref_text: Optional transcript of reference audio
- timeout: Request timeout in seconds
-
- Returns:
- httpx.Response object
- """
- url = f"http://{host}:{port}/v1/audio/speech"
- payload = {
- "input": text,
- "ref_audio": ref_audio_b64,
- }
- if ref_text:
- payload["ref_text"] = ref_text
-
- with httpx.Client(timeout=timeout) as client:
- return client.post(url, json=payload)
-
-
-@pytest.mark.skipif(not _HAS_VOICE_CLONE, reason="Voice cloning requires transformers>=5.3.0")
-@pytest.mark.parametrize("omni_server", TEST_PARAMS, indirect=True)
-class TestOmniVoiceVoiceCloning:
- """E2E tests for OmniVoice voice cloning functionality."""
-
- @pytest.mark.core_model
- @pytest.mark.omni
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
- def test_voice_clone_ref_audio_only(self, omni_server) -> None:
- """Test voice cloning with ref_audio only (x_vector mode)."""
- ref_audio_b64 = _get_ref_audio_b64()
-
- response = make_voice_clone_request(
- host=omni_server.host,
- port=omni_server.port,
- text="Hello, this is a voice cloning test.",
- ref_audio_b64=ref_audio_b64,
- )
-
- assert response.status_code == 200, f"Request failed: {response.text}"
- assert response.headers.get("content-type") == "audio/wav"
- assert verify_wav_audio(response.content), "Response is not valid WAV audio"
- assert len(response.content) > MIN_AUDIO_BYTES, (
- f"Audio too small ({len(response.content)} bytes), expected > {MIN_AUDIO_BYTES}"
- )
-
- @pytest.mark.core_model
- @pytest.mark.omni
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
- def test_voice_clone_ref_audio_and_text(self, omni_server) -> None:
- """Test voice cloning with ref_audio and ref_text (in-context mode)."""
- ref_audio_b64 = _get_ref_audio_b64()
- ref_text = "This is the reference transcript."
-
- response = make_voice_clone_request(
- host=omni_server.host,
- port=omni_server.port,
- text="Hello, this is a voice cloning test with in-context learning.",
- ref_audio_b64=ref_audio_b64,
- ref_text=ref_text,
- )
-
- assert response.status_code == 200, f"Request failed: {response.text}"
- assert response.headers.get("content-type") == "audio/wav"
- assert verify_wav_audio(response.content), "Response is not valid WAV audio"
- assert len(response.content) > MIN_AUDIO_BYTES, (
- f"Audio too small ({len(response.content)} bytes), expected > {MIN_AUDIO_BYTES}"
- )
-
- @pytest.mark.core_model
- @pytest.mark.omni
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
- def test_voice_clone_invalid_ref_audio_format(self, omni_server) -> None:
- """Test that invalid ref_audio format returns a clear error."""
- response = make_voice_clone_request(
- host=omni_server.host,
- port=omni_server.port,
- text="This should fail with invalid ref_audio.",
- ref_audio_b64="not_a_valid_uri",
- )
-
- assert response.status_code in (400, 422), (
- f"Expected 400/422 for invalid ref_audio format, got {response.status_code}"
- )
diff --git a/tests/e2e/online_serving/test_qwen2_5_omni.py b/tests/e2e/online_serving/test_qwen2_5_omni.py
index 8a1a8eb9950..e2913ce0215 100644
--- a/tests/e2e/online_serving/test_qwen2_5_omni.py
+++ b/tests/e2e/online_serving/test_qwen2_5_omni.py
@@ -3,13 +3,20 @@
"""
import os
+from pathlib import Path
import pytest
-from tests.helpers.mark import hardware_test
-from tests.helpers.media import generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video
-from tests.helpers.runtime import OmniServerParams, dummy_messages_from_mix_data
-from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
+from tests.conftest import (
+ OmniServerParams,
+ dummy_messages_from_mix_data,
+ generate_synthetic_audio,
+ generate_synthetic_image,
+ generate_synthetic_video,
+ modify_stage_config,
+)
+from tests.utils import hardware_test
+from vllm_omni.platforms import current_omni_platform
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
@@ -17,9 +24,20 @@
models = ["Qwen/Qwen2.5-Omni-7B"]
-# Single CI deploy YAML; rocm/xpu deltas are picked automatically via the
-# platforms: section in vllm_omni/deploy/ci/qwen2_5_omni.yaml.
-stage_configs = [modify_stage_config(get_deploy_config_path("ci/qwen2_5_omni.yaml"))]
+
+def get_config():
+ path = modify_stage_config(
+ str(Path(__file__).parent.parent / "stage_configs" / "qwen2_5_omni_ci.yaml"),
+ )
+ return path
+
+
+# CI stage config for 2xH100-80G GPUs or AMD GPU MI325
+if current_omni_platform.is_rocm():
+ # ROCm stage config optimized for MI325 GPU
+ stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "rocm" / "qwen2_5_omni_ci.yaml")]
+else:
+ stage_configs = [get_config()]
# Create parameter combinations for model and stage config
test_params = [
diff --git a/tests/e2e/online_serving/test_qwen3_omni.py b/tests/e2e/online_serving/test_qwen3_omni.py
index c9494d9fe74..fcda20ba388 100644
--- a/tests/e2e/online_serving/test_qwen3_omni.py
+++ b/tests/e2e/online_serving/test_qwen3_omni.py
@@ -3,13 +3,19 @@
"""
import os
+from pathlib import Path
import pytest
-from tests.helpers.mark import hardware_test
-from tests.helpers.media import generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video
-from tests.helpers.runtime import OmniServerParams, dummy_messages_from_mix_data
-from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
+from tests.conftest import (
+ OmniServerParams,
+ dummy_messages_from_mix_data,
+ generate_synthetic_audio,
+ generate_synthetic_image,
+ generate_synthetic_video,
+ modify_stage_config,
+)
+from tests.utils import hardware_test
from vllm_omni.platforms import current_omni_platform
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@@ -18,64 +24,35 @@
models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
-# Set VLLM_TEST_PD_MODE=1 to test PD disaggregation (follow-up — deploy overlay not yet migrated).
-_USE_PD = os.environ.get("VLLM_TEST_PD_MODE", "0") == "1"
-_CI_DEPLOY = get_deploy_config_path("ci/qwen3_omni_moe.yaml")
-
-
-def get_chunk_config(config_path: str | None = None):
- """Load the qwen3_omni CI deploy yaml with async_chunk modifications for streaming mode."""
- if config_path is None:
- config_path = _CI_DEPLOY
- # TODO: remove this workaround once legacy `stage_args` path is deleted.
- # The pipeline (qwen3_omni/pipeline.py) already wires
- # thinker2talker_async_chunk / talker2code2wav_async_chunk on stage 0/1,
- # so only async_chunk needs flipping. Writing nested `engine_args:` into
- # the new-schema overlay trips _parse_stage_deploy's legacy branch and
- # drops flat fields (load_format, max_num_seqs, ...).
- return modify_stage_config(config_path, updates={"async_chunk": True})
-
-
-def get_prefix_caching_config(config_path: str):
- """Create a stage config with prefix caching enabled on the thinker (stage 0)."""
+def get_chunk_config():
path = modify_stage_config(
- config_path,
+ str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml"),
updates={
+ "async_chunk": True,
"stage_args": {
- 0: {"engine_args.enable_prefix_caching": True},
+ 0: {
+ "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk"
+ },
+ 1: {
+ "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk"
+ },
},
},
+ deletes={"stage_args": {2: ["custom_process_input_func"]}},
)
return path
-# Platform-specific overrides live inside the new deploy yaml's ``platforms:``
-# section, so a single ``_CI_DEPLOY`` path serves CUDA, ROCm, and XPU.
-# TODO: re-add VLLM_TEST_PD_MODE branch once the PD-disaggregation deploy
-# overlay has been migrated to the new schema (previously used the deleted
-# ``qwen3_omni_moe_pd_ci.yaml`` stage-configs file).
if current_omni_platform.is_xpu():
- stage_configs = [_CI_DEPLOY]
-else: # CUDA + ROCm MI325 share the same deploy config
+ stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml")]
+else: # MI325 GPU should share the same config as H100
stage_configs = [get_chunk_config()]
-prefix_caching_stage_configs = [get_prefix_caching_config(_CI_DEPLOY)]
# Create parameter combinations for model and stage config
test_params = [
OmniServerParams(model=model, stage_config_path=stage_config) for model in models for stage_config in stage_configs
]
-# For prefix caching, we need to enable prompt token details so that we
-# can determine if any tokens were cached.
-prefix_test_params = [
- OmniServerParams(
- model=model,
- stage_config_path=stage_config,
- server_args=["--enable-prompt-tokens-details"], # Enable prompt tokens details to get cached_tokens
- )
- for model in models
- for stage_config in prefix_caching_stage_configs
-]
def get_system_prompt():
@@ -98,7 +75,6 @@ def get_prompt(prompt_type="text_only"):
prompts = {
"text_only": "What is the capital of China? Answer in 20 words.",
"mix": "What is recited in the audio? What is in this image? Describe the video briefly.",
- "text_image": "What color are the squares in this image?",
}
return prompts.get(prompt_type, prompts["text_only"])
@@ -111,8 +87,7 @@ def get_max_batch_size(size_type="few"):
@pytest.mark.advanced_model
@pytest.mark.core_model
@pytest.mark.omni
-@pytest.mark.skipif(_USE_PD, reason="Temporarily skip PD mode in this test module.")
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=3 if _USE_PD else 2)
+@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_mix_to_text_audio_001(omni_server, openai_client) -> None:
"""
@@ -145,14 +120,13 @@ def test_mix_to_text_audio_001(omni_server, openai_client) -> None:
}
# Test single completion
- openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
+ openai_client.send_omni_request(request_config)
@pytest.mark.advanced_model
@pytest.mark.core_model
@pytest.mark.omni
-@pytest.mark.skipif(_USE_PD, reason="Temporarily skip PD mode in this test module.")
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=3 if _USE_PD else 2)
+@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_text_to_text_001(omni_server, openai_client) -> None:
"""
@@ -173,45 +147,3 @@ def test_text_to_text_001(omni_server, openai_client) -> None:
}
openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
-
-
-@pytest.mark.advanced_model
-@pytest.mark.core_model
-@pytest.mark.omni
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
-@pytest.mark.parametrize("omni_server", prefix_test_params, indirect=True)
-def test_thinker_prefix_caching(omni_server, openai_client) -> None:
- """
- Test thinker prefix caching by sending identical requests with an image (i.e.,
- a large shared prefix) and verifying that the second request uses cached tokens
- & produces the same output with greedy decoding.
-
- NOTE: The seed for this test is used as a regression test for the issue linked below;
- https://github.com/vllm-project/vllm-omni/issues/2833; without passing the sampling
- params, this test will fail with the current default stage configs.
- """
- seed = 10
- img_res = generate_synthetic_image(224, 224, seed=seed)
- image_data_url = f"data:image/jpeg;base64,{img_res['base64']}"
- messages = dummy_messages_from_mix_data(
- system_prompt=get_system_prompt(),
- image_data_url=image_data_url,
- content_text=get_prompt("text_image"),
- )
-
- request_config = {
- "model": omni_server.model,
- "messages": messages,
- "stream": False,
- "modalities": ["text"],
- "sampling_params_list": [{"seed": seed, "temperature": 0, "max_tokens": 16}] * 3,
- }
-
- response_1 = openai_client.send_omni_request(request_config, request_num=1)[0]
- response_2 = openai_client.send_omni_request(request_config, request_num=1)[0]
-
- # We should cache the vast majority of the prompt (image + up to last full block),
- # and set seed + temperature, so the second request should give an identical
- # response for the generated input image, even if we use dummy weights
- assert response_2.cached_tokens is not None and response_2.cached_tokens > 0
- assert response_1.text_content == response_2.text_content
diff --git a/tests/e2e/online_serving/test_qwen3_omni_expansion.py b/tests/e2e/online_serving/test_qwen3_omni_expansion.py
index 2ebf5c7e364..0bcc86840ba 100644
--- a/tests/e2e/online_serving/test_qwen3_omni_expansion.py
+++ b/tests/e2e/online_serving/test_qwen3_omni_expansion.py
@@ -6,16 +6,22 @@
import os
-import pytest
+from vllm_omni.platforms import current_omni_platform
-from tests.helpers.mark import hardware_test
-from tests.helpers.media import generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video
-from tests.helpers.runtime import OmniServerParams, dummy_messages_from_mix_data
-from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
+os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+from pathlib import Path
-pytestmark = [pytest.mark.full_model, pytest.mark.omni]
+import pytest
-os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+from tests.conftest import (
+ OmniServerParams,
+ dummy_messages_from_mix_data,
+ generate_synthetic_audio,
+ generate_synthetic_image,
+ generate_synthetic_video,
+ modify_stage_config,
+)
+from tests.utils import hardware_test
model = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
@@ -23,77 +29,51 @@
IMAGE_KEY = ["square", "quadrate", "rectangle"]
VIDEO_KEY = ["sphere", "globe", "circle", "round", "ball"]
-# Heavier synthetic inputs than the default expansion cases (longer timeline / more pixels).
-# Long video: 120s @ 30fps => 3600 frames (generate_synthetic_video in tests/conftest.py).
-# Use 224² spatial size to bound RAM (~W*H*num_frames*3) vs. 288² at this frame count.
-LONG_VIDEO_WIDTH = 224
-LONG_VIDEO_HEIGHT = 224
-LONG_VIDEO_FRAMES = 3600
-LARGE_IMAGE_WIDTH = 1920
-LARGE_IMAGE_HEIGHT = 1080
-LONG_AUDIO_DURATION_SEC = 120
-
-def get_batch_token_config(default_path):
- """Override stage 1's max_num_batched_tokens to exercise small-batch paths.
-
- Uses the new flat-stage schema (``stages..``); the legacy
- ``stage_args..engine_args.`` path no longer applies because
- the deploy YAML doesn't nest engine fields under ``engine_args:``.
- """
- return modify_stage_config(
+def get_chunk_config(default_path):
+ path = modify_stage_config(
default_path,
updates={
- "stages": {1: {"max_num_batched_tokens": 64}},
+ "async_chunk": True,
+ "stage_args": {
+ 0: {
+ "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk"
+ },
+ 1: {
+ "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk"
+ },
+ },
},
+ deletes={"stage_args": {2: ["custom_process_input_func"]}},
)
+ return path
-def get_async_chunk_config(default_path):
- """Flip async_chunk on and bump stage 0 thinker output to 2048 tokens.
-
- Pipeline registry (qwen3_omni/pipeline.py) already wires
- thinker2talker_async_chunk / talker2code2wav_async_chunk on stages 0/1,
- so no per-stage processor override is needed. Using only flat-schema
- writes so _parse_stage_deploy stays in its flat branch (nested
- ``engine_args:`` would drop other overlay fields).
- """
- return modify_stage_config(
+def get_batch_token_config(default_path):
+ path = modify_stage_config(
default_path,
updates={
- "stages": {0: {"default_sampling_params.max_tokens": 2048}},
+ "stage_args": {1: {"engine_args.max_num_batched_tokens": 64}},
},
)
+ return path
+
+# CI stage config for 2*H100-80G GPUs
+default_path = str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml")
-# CI deploy YAML (single file; xpu deltas applied via ``platforms:`` section).
-# The overlay explicitly sets ``async_chunk: False``, so ``default`` tests the
-# sync path and ``async_chunk`` tests the streaming path with a longer thinker
-# output — two distinct scenarios, kept as separate parametrizations.
-default_path = get_deploy_config_path("ci/qwen3_omni_moe.yaml")
+if current_omni_platform.is_xpu():
+ default_path = str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml")
+# Create parameter combinations for model and stage config
test_params = [
- pytest.param(
- OmniServerParams(
- model=model, stage_config_path=default_path, use_stage_cli=True, server_args=["--no-async-chunk"]
- ),
- id="default",
- ),
- pytest.param(
- OmniServerParams(
- model=model,
- stage_config_path=get_async_chunk_config(default_path),
- use_stage_cli=True,
- server_args=["--async-chunk"],
- ),
- id="async_chunk",
- ),
+ pytest.param(OmniServerParams(model=model, stage_config_path=default_path), id="default"),
+ pytest.param(OmniServerParams(model=model, stage_config_path=get_chunk_config(default_path)), id="async_chunk"),
]
test_token_params = [
pytest.param(
- OmniServerParams(model=model, stage_config_path=get_batch_token_config(default_path), use_stage_cli=True),
- id="batch_token_64",
+ OmniServerParams(model=model, stage_config_path=get_batch_token_config(default_path)), id="batch_token_64"
)
]
@@ -133,6 +113,8 @@ def get_max_batch_size(size_type="few"):
return batch_sizes.get(size_type, 5)
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_text_to_audio_001(omni_server, openai_client) -> None:
@@ -155,6 +137,8 @@ def test_text_to_audio_001(omni_server, openai_client) -> None:
openai_client.send_omni_request(request_config)
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params + test_token_params, indirect=True)
def test_text_to_text_audio_001(omni_server, openai_client) -> None:
@@ -175,19 +159,92 @@ def test_text_to_text_audio_001(omni_server, openai_client) -> None:
openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-def test_text_video_to_text_001(omni_server, openai_client) -> None:
+def test_image_to_text_001(omni_server, openai_client) -> None:
"""
- Input Modal: long synthetic video (120s @ 30fps, LONG_VIDEO_FRAMES frames)
+ Input Modal: image
Output Modal: text
+ Input Setting: stream=True
+ Datasets: single request
+ """
+ image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
+ messages = dummy_messages_from_mix_data(image_data_url=image_data_url)
+
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "modalities": ["text"],
+ "stream": True,
+ "key_words": {"image": IMAGE_KEY},
+ }
+
+ openai_client.send_omni_request(request_config)
+
+
+@pytest.mark.advanced_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_image_to_audio_001(omni_server, openai_client) -> None:
+ """
+ Input Modal: image
+ Output Modal: audio
Input Setting: stream=False
Datasets: single request
"""
- video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(LONG_VIDEO_WIDTH, LONG_VIDEO_HEIGHT, LONG_VIDEO_FRAMES)['base64']}"
- messages = dummy_messages_from_mix_data(
- video_data_url=video_data_url, system_prompt=get_system_prompt(), content_text=get_prompt("text_video")
- )
+ image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
+ messages = dummy_messages_from_mix_data(image_data_url=image_data_url)
+
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "modalities": ["audio"],
+ "key_words": {"image": IMAGE_KEY},
+ }
+
+ openai_client.send_omni_request(request_config)
+
+
+@pytest.mark.advanced_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_image_to_text_audio_001(omni_server, openai_client) -> None:
+ """
+ Input Modal: image
+ Output Modal: text, audio
+ Input Setting: stream=False
+ Datasets: few requests
+ """
+ image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(1280, 720)['base64']}"
+
+ messages = dummy_messages_from_mix_data(image_data_url=image_data_url)
+
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "key_words": {"image": IMAGE_KEY},
+ }
+
+ openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
+
+
+@pytest.mark.advanced_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_video_to_text_001(omni_server, openai_client) -> None:
+ """
+ Input Modal: video
+ Output Modal: text
+ Input Setting: stream=False
+ Datasets: single request
+ """
+ video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}"
+ messages = dummy_messages_from_mix_data(video_data_url=video_data_url)
request_config = {
"model": omni_server.model,
@@ -196,98 +253,97 @@ def test_text_video_to_text_001(omni_server, openai_client) -> None:
"key_words": {"video": VIDEO_KEY},
}
- openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
+ openai_client.send_omni_request(request_config)
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
-@pytest.mark.parametrize("omni_server", test_params + test_token_params, indirect=True)
-def test_text_audio_to_text_audio_001(omni_server, openai_client) -> None:
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_video_to_audio_001(omni_server, openai_client) -> None:
"""
- Input Modal: text, audio
- Output Modal: text, audio
+ Input Modal: video
+ Output Modal: audio
Input Setting: stream=False
Datasets: single request
"""
- audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(5, 1)['base64']}"
- messages = dummy_messages_from_mix_data(
- audio_data_url=audio_data_url, system_prompt=get_system_prompt(), content_text=get_prompt("text_audio")
- )
+ video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}"
+ messages = dummy_messages_from_mix_data(video_data_url=video_data_url)
request_config = {
"model": omni_server.model,
"messages": messages,
- "key_words": {"audio": AUDIO_KEY},
+ "modalities": ["audio"],
+ "key_words": {"video": VIDEO_KEY},
}
openai_client.send_omni_request(request_config)
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
-@pytest.mark.parametrize("omni_server", test_params + test_token_params, indirect=True)
-def test_text_audio_to_text_audio_002(omni_server, openai_client) -> None:
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_video_to_text_audio_001(omni_server, openai_client) -> None:
"""
- Input Modal: text, long-duration audio (~LONG_AUDIO_DURATION_SEC s WAV)
+ Input Modal: video
Output Modal: text, audio
Input Setting: stream=False
- Datasets: single request
+ Datasets: few requests
"""
- audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(LONG_AUDIO_DURATION_SEC, 1)['base64']}"
- messages = dummy_messages_from_mix_data(
- audio_data_url=audio_data_url,
- system_prompt=get_system_prompt(),
- content_text=get_prompt("text_audio"),
- )
+ video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}"
+
+ messages = dummy_messages_from_mix_data(video_data_url=video_data_url)
request_config = {
"model": omni_server.model,
"messages": messages,
- "key_words": {"audio": AUDIO_KEY},
+ "key_words": {"video": VIDEO_KEY},
}
openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params + test_token_params, indirect=True)
-def test_text_image_to_text_audio_001(omni_server, openai_client) -> None:
+def test_text_audio_to_text_audio_001(omni_server, openai_client) -> None:
"""
- Input Modal: text, image
+ Input Modal: text, audio
Output Modal: text, audio
Input Setting: stream=False
Datasets: single request
"""
- image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
-
+ audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(5, 1)['base64']}"
messages = dummy_messages_from_mix_data(
- image_data_url=image_data_url, system_prompt=get_system_prompt(), content_text=get_prompt("text_image")
+ audio_data_url=audio_data_url, system_prompt=get_system_prompt(), content_text=get_prompt("text_audio")
)
request_config = {
"model": omni_server.model,
"messages": messages,
- "key_words": {"image": IMAGE_KEY},
+ "key_words": {"audio": AUDIO_KEY},
}
openai_client.send_omni_request(request_config)
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params + test_token_params, indirect=True)
-def test_large_image_to_text_audio_001(omni_server, openai_client) -> None:
+def test_text_image_to_text_audio_001(omni_server, openai_client) -> None:
"""
- Input Modal: text, high-resolution image (1080p-class JPEG)
+ Input Modal: text, image
Output Modal: text, audio
Input Setting: stream=False
Datasets: single request
"""
- image_data_url = (
- f"data:image/jpeg;base64,{generate_synthetic_image(LARGE_IMAGE_WIDTH, LARGE_IMAGE_HEIGHT)['base64']}"
- )
+ image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
messages = dummy_messages_from_mix_data(
- image_data_url=image_data_url,
- system_prompt=get_system_prompt(),
- content_text=get_prompt("text_image"),
+ image_data_url=image_data_url, system_prompt=get_system_prompt(), content_text=get_prompt("text_image")
)
request_config = {
@@ -296,9 +352,11 @@ def test_large_image_to_text_audio_001(omni_server, openai_client) -> None:
"key_words": {"image": IMAGE_KEY},
}
- openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
+ openai_client.send_omni_request(request_config)
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params + test_token_params, indirect=True)
def test_text_video_to_text_audio_001(omni_server, openai_client) -> None:
@@ -325,6 +383,8 @@ def test_text_video_to_text_audio_001(omni_server, openai_client) -> None:
@pytest.mark.skip(reason="There is a known issue with shape mismatch error.")
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params + test_token_params, indirect=True)
def test_mix_to_text_audio_001(omni_server, openai_client) -> None:
@@ -354,9 +414,10 @@ def test_mix_to_text_audio_001(omni_server, openai_client) -> None:
openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-@pytest.mark.skip(reason="issue: #2827")
def test_audio_in_video_001(omni_server, openai_client) -> None:
"""
Input Modal: text + video (synthetic MP4 with embedded audio; ``use_audio_in_video`` uses audio from the video).
@@ -381,6 +442,8 @@ def test_audio_in_video_001(omni_server, openai_client) -> None:
openai_client.send_omni_request(request_config)
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_audio_in_video_002(omni_server, openai_client) -> None:
@@ -408,6 +471,8 @@ def test_audio_in_video_002(omni_server, openai_client) -> None:
openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_one_word_prompt_001(omni_server, openai_client) -> None:
@@ -429,7 +494,7 @@ def test_one_word_prompt_001(omni_server, openai_client) -> None:
"key_words": {"text": ["london"]},
}
- # Retry only when assert_omni_response fails on text/audio cosine similarity (see tests/helpers/assertions.py).
+ # Retry only when assert_omni_response fails on text/audio cosine similarity (see tests/conftest.py).
_similarity_assert_msg = "The audio content is not same as the text"
_max_retries = 3
for attempt in range(_max_retries):
@@ -442,6 +507,8 @@ def test_one_word_prompt_001(omni_server, openai_client) -> None:
print(f"Similarity assertion failed, retrying {attempt + 2}/{_max_retries}: {e!r}")
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_speaker_001(omni_server, openai_client) -> None:
@@ -467,9 +534,10 @@ def test_speaker_001(omni_server, openai_client) -> None:
openai_client.send_omni_request(request_config)
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
-@pytest.mark.skip(reason="Known issue: occasional inaccuracy in voice recognition.")
def test_speaker_002(omni_server, openai_client) -> None:
"""
Input Modal: text only (one-word answer constraint).
@@ -490,7 +558,7 @@ def test_speaker_002(omni_server, openai_client) -> None:
"key_words": {"text": ["beijing"]},
}
- # Retry only when assert_omni_response fails on preset voice gender (see tests/helpers/assertions.py).
+ # Retry only when assert_omni_response fails on preset voice gender (see tests/conftest.py).
_gender_assert_substr = "estimated gender"
_max_retries = 3
for attempt in range(_max_retries):
@@ -503,6 +571,8 @@ def test_speaker_002(omni_server, openai_client) -> None:
print(f"Gender assertion failed, retrying {attempt + 2}/{_max_retries}: {e!r}")
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_speaker_003(omni_server, openai_client) -> None:
@@ -528,6 +598,8 @@ def test_speaker_003(omni_server, openai_client) -> None:
openai_client.send_omni_request(request_config)
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_language_001(omni_server, openai_client) -> None:
diff --git a/tests/e2e/online_serving/test_qwen3_tts_base.py b/tests/e2e/online_serving/test_qwen3_tts_base.py
index fd7bc43b55d..002f9d99724 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_base.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_base.py
@@ -12,11 +12,12 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
+from pathlib import Path
+
import pytest
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServerParams
-from tests.helpers.stage_config import get_deploy_config_path
+from tests.conftest import OmniServerParams
+from tests.utils import hardware_test
MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
@@ -24,6 +25,11 @@
REF_TEXT = "Okay. Yeah. I resent you. I love you. I respect you. But you know what? You blew it! And thanks to you."
+def get_stage_config(name: str = "qwen3_tts.yaml"):
+ """Get the stage config path from vllm_omni model_executor stage_configs."""
+ return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
+
+
def get_prompt(prompt_type="text"):
"""Text prompt for text-to-audio tests (same as test_qwen3_omni - beijing test case)."""
prompts = {
@@ -42,7 +48,7 @@ def get_max_batch_size(size_type="few"):
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
+ stage_config_path=get_stage_config("qwen3_tts.yaml"),
server_args=["--trust-remote-code", "--disable-log-stats"],
),
id="async_chunk",
diff --git a/tests/e2e/online_serving/test_qwen3_tts_base_expansion.py b/tests/e2e/online_serving/test_qwen3_tts_base_expansion.py
index d86f96af099..3c33485e4f4 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_base_expansion.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_base_expansion.py
@@ -9,16 +9,15 @@
import os
-import pytest
+os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServerParams
-from tests.helpers.stage_config import get_deploy_config_path
+from pathlib import Path
-pytestmark = [pytest.mark.full_model, pytest.mark.omni]
+import pytest
-os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
+from tests.conftest import OmniServerParams
+from tests.utils import hardware_test
MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
@@ -26,6 +25,11 @@
REF_TEXT = "Okay. Yeah. I resent you. I love you. I respect you. But you know what? You blew it! And thanks to you."
+def get_stage_config(name: str = "qwen3_tts.yaml"):
+ """Get the stage config path from vllm_omni model_executor stage_configs."""
+ return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
+
+
def get_prompt(prompt_type="text"):
"""Text prompt for text-to-audio tests (same as test_qwen3_omni - beijing test case)."""
prompts = {
@@ -44,25 +48,25 @@ def get_max_batch_size(size_type="few"):
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
+ stage_config_path=get_stage_config("qwen3_tts.yaml"),
server_args=["--trust-remote-code", "--disable-log-stats"],
),
id="async_chunk",
),
- # Synchronous (no async-chunk) variant — ``--no-async-chunk`` alone
- # flips the deploy yaml's bool and the pipeline dispatches to the
- # end-to-end codec processor. No variant yaml / pipeline needed.
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
- server_args=["--trust-remote-code", "--disable-log-stats", "--no-async-chunk"],
+ stage_config_path=get_stage_config("qwen3_tts_no_async_chunk.yaml"),
+ server_args=["--trust-remote-code", "--disable-log-stats"],
),
id="no_async_chunk",
),
]
+@pytest.mark.advanced_model
+@pytest.mark.core_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "L4"}, num_cards=1)
@pytest.mark.parametrize("omni_server", tts_server_params, indirect=True)
def test_voice_clone_streaming_001(omni_server, openai_client) -> None:
@@ -88,6 +92,9 @@ def test_voice_clone_streaming_001(omni_server, openai_client) -> None:
openai_client.send_audio_speech_request(request_config, request_num=get_max_batch_size("few"))
+@pytest.mark.advanced_model
+@pytest.mark.core_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "L4"}, num_cards=1)
@pytest.mark.parametrize("omni_server", tts_server_params, indirect=True)
def test_response_format_001(omni_server, openai_client) -> None:
diff --git a/tests/e2e/online_serving/test_qwen3_tts_batch.py b/tests/e2e/online_serving/test_qwen3_tts_batch.py
index 3ca0688195b..d0d6336618e 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_batch.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_batch.py
@@ -22,18 +22,18 @@
import pytest
import yaml
-from tests.helpers.mark import hardware_test
-from tests.helpers.media import convert_audio_file_to_text, cosine_similarity_text
-from tests.helpers.runtime import OmniServer
-from tests.helpers.stage_config import get_deploy_config_path
+from tests.conftest import (
+ OmniServer,
+ convert_audio_file_to_text,
+ cosine_similarity_text,
+)
+from tests.utils import hardware_test
MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice"
-STAGE_INIT_TIMEOUT_S = 120
-def get_stage_config(name: str = "qwen3_tts.yaml") -> str:
- """Resolve a deploy config path under vllm_omni/deploy/."""
- return get_deploy_config_path(name)
+def get_stage_config(name: str = "qwen3_tts.yaml"):
+ return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
@pytest.fixture(scope="module")
@@ -47,7 +47,7 @@ def omni_server():
"--stage-configs-path",
stage_config_path,
"--stage-init-timeout",
- str(STAGE_INIT_TIMEOUT_S),
+ "120",
"--trust-remote-code",
"--enforce-eager",
"--disable-log-stats",
@@ -337,7 +337,7 @@ def omni_server_batch2():
"--stage-configs-path",
config_path,
"--stage-init-timeout",
- str(STAGE_INIT_TIMEOUT_S),
+ "120",
"--trust-remote-code",
"--enforce-eager",
"--disable-log-stats",
diff --git a/tests/e2e/online_serving/test_qwen3_tts_customvoice.py b/tests/e2e/online_serving/test_qwen3_tts_customvoice.py
index 2577361a0c8..fb60df725ba 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_customvoice.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_customvoice.py
@@ -12,15 +12,21 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
+from pathlib import Path
+
import pytest
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServerParams
-from tests.helpers.stage_config import get_deploy_config_path
+from tests.conftest import OmniServerParams
+from tests.utils import hardware_test
MODEL = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
+def get_stage_config(name: str = "qwen3_tts.yaml"):
+ """Get the stage config path from vllm_omni model_executor stage_configs."""
+ return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
+
+
def get_prompt(prompt_type="text"):
"""Text prompt for text-to-audio tests (same as test_qwen3_omni - beijing test case)."""
prompts = {
@@ -39,7 +45,7 @@ def get_max_batch_size(size_type="few"):
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
+ stage_config_path=get_stage_config("qwen3_tts.yaml"),
server_args=["--trust-remote-code", "--disable-log-stats"],
),
id="async_chunk",
diff --git a/tests/e2e/online_serving/test_qwen3_tts_customvoice_expansion.py b/tests/e2e/online_serving/test_qwen3_tts_customvoice_expansion.py
index d97e70e41dd..03a985896e4 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_customvoice_expansion.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_customvoice_expansion.py
@@ -9,20 +9,24 @@
import os
-import pytest
+os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServerParams
-from tests.helpers.stage_config import get_deploy_config_path
+from pathlib import Path
-pytestmark = [pytest.mark.full_model, pytest.mark.omni]
+import pytest
-os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
+from tests.conftest import OmniServerParams
+from tests.utils import hardware_test
MODEL = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
+def get_stage_config(name: str = "qwen3_tts.yaml"):
+ """Get the stage config path from vllm_omni model_executor stage_configs."""
+ return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
+
+
def get_prompt(prompt_type="english"):
"""Text prompt for text-to-audio tests (same as test_qwen3_omni - beijing test case)."""
prompts = {
@@ -42,25 +46,24 @@ def get_max_batch_size(size_type="few"):
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
+ stage_config_path=get_stage_config("qwen3_tts.yaml"),
server_args=["--trust-remote-code", "--disable-log-stats"],
),
id="async_chunk",
),
- # Synchronous (no async-chunk) variant — ``--no-async-chunk`` alone
- # flips the deploy yaml's bool and the pipeline dispatches to the
- # end-to-end codec processor. No variant yaml / pipeline needed.
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
- server_args=["--trust-remote-code", "--disable-log-stats", "--no-async-chunk"],
+ stage_config_path=get_stage_config("qwen3_tts_no_async_chunk.yaml"),
+ server_args=["--trust-remote-code", "--disable-log-stats"],
),
id="no_async_chunk",
),
]
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "L4"}, num_cards=1)
@pytest.mark.parametrize("omni_server", tts_server_params, indirect=True)
def test_voice_001(omni_server, openai_client) -> None:
@@ -92,6 +95,8 @@ def test_voice_001(omni_server, openai_client) -> None:
raise
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "L4"}, num_cards=1)
@pytest.mark.parametrize("omni_server", tts_server_params, indirect=True)
def test_voice_002(omni_server, openai_client) -> None:
@@ -115,6 +120,8 @@ def test_voice_002(omni_server, openai_client) -> None:
openai_client.send_audio_speech_request(request_config)
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "L4"}, num_cards=1)
@pytest.mark.parametrize("omni_server", tts_server_params, indirect=True)
def test_voice_003(omni_server, openai_client) -> None:
@@ -138,6 +145,8 @@ def test_voice_003(omni_server, openai_client) -> None:
openai_client.send_audio_speech_request(request_config)
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "L4"}, num_cards=1)
@pytest.mark.parametrize("omni_server", tts_server_params, indirect=True)
def test_language_001(omni_server, openai_client) -> None:
diff --git a/tests/e2e/online_serving/test_qwen3_tts_speaker_embedding.py b/tests/e2e/online_serving/test_qwen3_tts_speaker_embedding.py
index b91548c5a66..64e13e1557d 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_speaker_embedding.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_speaker_embedding.py
@@ -13,17 +13,16 @@
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
import struct
+from pathlib import Path
import httpx
import pytest
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServer
-from tests.helpers.stage_config import get_deploy_config_path
+from tests.conftest import OmniServer
+from tests.utils import hardware_test
MODEL_BASE = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
MODEL_BASE_1_7B = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
-STAGE_INIT_TIMEOUT_S = 120
# A synthetic 1024-dim speaker embedding (all 0.1 — not a real voice, but
# exercises the full code path through the talker's _build_prompt_embeds).
@@ -37,8 +36,10 @@
MAX_NEW_TOKENS = 256
-def get_stage_config() -> str:
- return get_deploy_config_path("qwen3_tts.yaml")
+def get_stage_config():
+ return str(
+ Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / "qwen3_tts.yaml"
+ )
def _server_args():
@@ -46,7 +47,7 @@ def _server_args():
"--stage-configs-path",
get_stage_config(),
"--stage-init-timeout",
- str(STAGE_INIT_TIMEOUT_S),
+ "120",
"--trust-remote-code",
"--enforce-eager",
"--disable-log-stats",
diff --git a/tests/e2e/online_serving/test_qwen3_tts_websocket.py b/tests/e2e/online_serving/test_qwen3_tts_websocket.py
index 5ac021cf88b..df051460119 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_websocket.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_websocket.py
@@ -7,23 +7,24 @@
import asyncio
import json
import os
+from pathlib import Path
import pytest
import websockets
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServer
-from tests.helpers.stage_config import get_deploy_config_path
+from tests.conftest import OmniServer
+from tests.utils import hardware_test
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice"
-STAGE_INIT_TIMEOUT_S = 120
def get_stage_config() -> str:
- return get_deploy_config_path("qwen3_tts.yaml")
+ return str(
+ Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / "qwen3_tts.yaml"
+ )
@pytest.fixture(scope="module")
@@ -36,7 +37,7 @@ def omni_server():
"--stage-configs-path",
stage_config_path,
"--stage-init-timeout",
- str(STAGE_INIT_TIMEOUT_S),
+ "120",
"--trust-remote-code",
"--enforce-eager",
"--disable-log-stats",
diff --git a/tests/e2e/online_serving/test_qwen_image_edit_expansion.py b/tests/e2e/online_serving/test_qwen_image_edit_expansion.py
index bb8bf7dca49..14e4c915b6b 100644
--- a/tests/e2e/online_serving/test_qwen_image_edit_expansion.py
+++ b/tests/e2e/online_serving/test_qwen_image_edit_expansion.py
@@ -7,11 +7,14 @@
import pytest
-from tests.helpers.mark import hardware_marks
-from tests.helpers.media import generate_synthetic_image
-from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.conftest import (
+ OmniServer,
+ OmniServerParams,
+ OpenAIClientHandler,
+ dummy_messages_from_mix_data,
+ generate_synthetic_image,
+)
+from tests.utils import hardware_marks
EDIT_PROMPT = "Transform this modern, geometrist image into a Vincent van Gogh style impressionist painting."
MULTI_EDIT_PROMPT = (
@@ -110,6 +113,8 @@ def _get_diffusion_feature_cases(model: str):
]
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_diffusion_feature_cases("Qwen/Qwen-Image-Edit"),
@@ -138,6 +143,8 @@ def test_qwen_image_edit(omni_server: OmniServer, openai_client: OpenAIClientHan
openai_client.send_diffusion_request(request_config)
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_diffusion_feature_cases("Qwen/Qwen-Image-Edit-2509"),
diff --git a/tests/e2e/online_serving/test_qwen_image_expansion.py b/tests/e2e/online_serving/test_qwen_image_expansion.py
index 0830e494495..88e56cc3e10 100644
--- a/tests/e2e/online_serving/test_qwen_image_expansion.py
+++ b/tests/e2e/online_serving/test_qwen_image_expansion.py
@@ -12,10 +12,13 @@
import pytest
-from tests.helpers.mark import hardware_marks
-from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.conftest import (
+ OmniServer,
+ OmniServerParams,
+ OpenAIClientHandler,
+ dummy_messages_from_mix_data,
+)
+from tests.utils import hardware_marks
T2I_PROMPT = "A photo of a cat sitting on a laptop keyboard, digital art style."
NEGATIVE_PROMPT = "blurry, low quality"
@@ -119,6 +122,8 @@ def _get_diffusion_feature_cases(model: str):
]
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_diffusion_feature_cases("Qwen/Qwen-Image"),
@@ -142,6 +147,8 @@ def test_qwen_image(omni_server: OmniServer, openai_client: OpenAIClientHandler)
openai_client.send_diffusion_request(request_config)
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_diffusion_feature_cases("Qwen/Qwen-Image-2512"),
diff --git a/tests/e2e/online_serving/test_qwen_image_layered_expansion.py b/tests/e2e/online_serving/test_qwen_image_layered_expansion.py
index 4e79beab7a0..fc73801c0e0 100644
--- a/tests/e2e/online_serving/test_qwen_image_layered_expansion.py
+++ b/tests/e2e/online_serving/test_qwen_image_layered_expansion.py
@@ -14,11 +14,15 @@
import pytest
-from tests.helpers.mark import hardware_marks
-from tests.helpers.media import decode_b64_image, generate_synthetic_image
-from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.conftest import (
+ OmniServer,
+ OmniServerParams,
+ OpenAIClientHandler,
+ decode_b64_image,
+ dummy_messages_from_mix_data,
+ generate_synthetic_image,
+)
+from tests.utils import hardware_marks
MODEL = "Qwen/Qwen-Image-Layered"
EDIT_PROMPT = "Decompose this image into layers."
@@ -73,6 +77,8 @@
]
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize("omni_server", FEATURE_CASES, indirect=True)
def test_feature(omni_server: OmniServer, openai_client: OpenAIClientHandler):
"""Test feature combinations with Qwen-Image-Layered."""
@@ -149,6 +155,8 @@ def _collect_image_url_items(openai_client: OpenAIClientHandler, request_config:
return image_items
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server, expected_layers",
LAYERS_GUARD_CASES,
@@ -222,6 +230,8 @@ def test_layered_output_image_count(
]
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize("omni_server", PROMPT_CASES, indirect=True)
def test_empty_prompt(omni_server: OmniServer, openai_client: OpenAIClientHandler):
"""Test feature combinations with Qwen-Image-Layered."""
diff --git a/tests/e2e/online_serving/test_sd3_expansion.py b/tests/e2e/online_serving/test_sd3_expansion.py
index 09b50d2e501..3ed5cc5f308 100644
--- a/tests/e2e/online_serving/test_sd3_expansion.py
+++ b/tests/e2e/online_serving/test_sd3_expansion.py
@@ -4,10 +4,12 @@
import pytest
-from tests.helpers.mark import hardware_marks
-from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.conftest import (
+ OmniServer,
+ OmniServerParams,
+ OpenAIClientHandler,
+)
+from tests.utils import hardware_marks
FOUR_CARD_FEATURE_MARKS = hardware_marks(res={"cuda": "L4"}, num_cards=4)
POSITIVE_PROMPT = "A serene mountain landscape at sunset"
@@ -37,6 +39,8 @@ def _get_diffusion_feature_cases(model: str):
]
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_diffusion_feature_cases(
diff --git a/tests/e2e/online_serving/test_video_generation_api.py b/tests/e2e/online_serving/test_video_generation_api.py
index 6a8fe45875a..0711a1048e3 100644
--- a/tests/e2e/online_serving/test_video_generation_api.py
+++ b/tests/e2e/online_serving/test_video_generation_api.py
@@ -16,8 +16,8 @@
import pytest
import requests
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServer
+from tests.conftest import OmniServer
+from tests.utils import hardware_test
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
diff --git a/tests/e2e/online_serving/test_voxtral_tts.py b/tests/e2e/online_serving/test_voxtral_tts.py
index 91f62bb71ed..f795288f375 100644
--- a/tests/e2e/online_serving/test_voxtral_tts.py
+++ b/tests/e2e/online_serving/test_voxtral_tts.py
@@ -12,16 +12,19 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
+from pathlib import Path
+
import httpx
import pytest
-from tests.helpers.mark import hardware_test
-from tests.helpers.runtime import OmniServerParams
-from tests.helpers.stage_config import get_deploy_config_path
+from tests.conftest import OmniServerParams
+from tests.utils import hardware_test
MODEL = "mistralai/Voxtral-4B-TTS-2603"
-STAGE_CONFIG = get_deploy_config_path("voxtral_tts.yaml")
+STAGE_CONFIG = str(
+ Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / "voxtral_tts.yaml"
+)
EXTRA_ARGS = ["--trust-remote-code", "--enforce-eager", "--disable-log-stats"]
TEST_PARAMS = [OmniServerParams(model=MODEL, stage_config_path=STAGE_CONFIG, server_args=EXTRA_ARGS)]
@@ -82,30 +85,6 @@ def test_speech_english_basic(self, omni_server) -> None:
f"Audio content too small ({len(response.content)} bytes), expected at least {MIN_AUDIO_BYTES} bytes"
)
- @pytest.mark.core_model
- @pytest.mark.omni
- @hardware_test(res={"cuda": "H100"}, num_cards=1)
- def test_speech_english_streaming(self, omni_server) -> None:
- """Test basic streaming English TTS generation."""
- url = f"http://{omni_server.host}:{omni_server.port}/v1/audio/speech"
- payload = {
- "input": "Hello, how are you?",
- "voice": "casual_female",
- "language": "English",
- "stream": True,
- "response_format": "pcm",
- }
-
- with httpx.Client(timeout=120.0) as client:
- with client.stream("POST", url, json=payload) as response:
- assert response.status_code == 200
- assert response.headers.get("content-type") == "audio/pcm"
- total = sum(len(c) for c in response.iter_bytes())
-
- assert total > MIN_AUDIO_BYTES, (
- f"Streamed audio too small ({total} bytes), expected at least {MIN_AUDIO_BYTES} bytes"
- )
-
@pytest.mark.advanced_model
@pytest.mark.omni
@hardware_test(res={"cuda": "H100"}, num_cards=1)
diff --git a/tests/e2e/online_serving/test_wan22_expansion.py b/tests/e2e/online_serving/test_wan22_expansion.py
index 36035f649ea..e5e2d748d58 100644
--- a/tests/e2e/online_serving/test_wan22_expansion.py
+++ b/tests/e2e/online_serving/test_wan22_expansion.py
@@ -19,11 +19,13 @@
import pytest
-from tests.helpers.mark import hardware_marks
-from tests.helpers.media import generate_synthetic_image
-from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.conftest import (
+ OmniServer,
+ OmniServerParams,
+ OpenAIClientHandler,
+ generate_synthetic_image,
+)
+from tests.utils import hardware_marks
PROMPT = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
NEGATIVE_PROMPT = "low quality, blurry, distorted face, extra limbs, bad anatomy, watermark, logo, text, ugly, deformed, mutated, jpeg artifacts"
@@ -81,6 +83,8 @@ def _get_wan22_feature_cases():
return cases
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_wan22_feature_cases(),
diff --git a/tests/e2e/online_serving/test_wan_2_1_vace_expansion.py b/tests/e2e/online_serving/test_wan_2_1_vace_expansion.py
index 4f0e9644721..0de70afe862 100644
--- a/tests/e2e/online_serving/test_wan_2_1_vace_expansion.py
+++ b/tests/e2e/online_serving/test_wan_2_1_vace_expansion.py
@@ -23,10 +23,12 @@
import pytest
-from tests.helpers.mark import hardware_marks
-from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.conftest import (
+ OmniServer,
+ OmniServerParams,
+ OpenAIClientHandler,
+)
+from tests.utils import hardware_marks
MODEL = "Wan-AI/Wan2.1-VACE-1.3B-diffusers"
PROMPT = "A cat walking slowly across a sunlit garden path"
@@ -133,6 +135,8 @@ def _get_vace_feature_cases():
]
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_vace_feature_cases(),
diff --git a/tests/e2e/online_serving/test_zimage_expansion.py b/tests/e2e/online_serving/test_zimage_expansion.py
index d730db0cabc..bef12e55d1a 100644
--- a/tests/e2e/online_serving/test_zimage_expansion.py
+++ b/tests/e2e/online_serving/test_zimage_expansion.py
@@ -3,20 +3,19 @@
for Z-Image.
Coverage is intentionally limited to the minimal 4xL4 cases that
-exercise Z-Image's supported feature combinations:
+exercise Z-Image's supported parallel feature combinations:
- CacheDiT + FP8 + Ring=2 + TP=2
- TeaCache + FP8 + Ulysses=2 + Ring=2
-- Layerwise CPU offload + Ulysses=2 + Ring=2
-- Layerwise CPU offload + TP=2
-- Layerwise CPU offload + HSDP
"""
import pytest
-from tests.helpers.mark import hardware_marks
-from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler
-
-pytestmark = [pytest.mark.diffusion, pytest.mark.full_model]
+from tests.conftest import (
+ OmniServer,
+ OmniServerParams,
+ OpenAIClientHandler,
+)
+from tests.utils import hardware_marks
MODEL = "Tongyi-MAI/Z-Image-Turbo"
PROMPT = "A high-detail studio photo of an orange tabby cat sitting on a laptop keyboard."
@@ -65,44 +64,19 @@ def _get_diffusion_feature_cases():
OmniServerParams(
model=MODEL,
server_args=[
- "--enable-layerwise-offload",
- "--ulysses-degree",
- "2",
- "--ring",
- "2",
- ],
- ),
- id="layerwise_ulysses2_ring2",
- marks=FOUR_CARD_MARKS,
- ),
- pytest.param(
- OmniServerParams(
- model=MODEL,
- server_args=[
- "--enable-layerwise-offload",
- "--tensor-parallel-size",
- "2",
- ],
- ),
- id="layerwise_tp2",
- marks=FOUR_CARD_MARKS,
- ),
- pytest.param(
- OmniServerParams(
- model=MODEL,
- server_args=[
- "--enable-layerwise-offload",
"--use-hsdp",
"--hsdp-shard-size",
"2",
],
),
- id="layerwise_hsdp",
+ id="parallel_hsdp",
marks=[*FOUR_CARD_MARKS, pytest.mark.skip(reason="issue #2435")],
),
]
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
@pytest.mark.parametrize(
"omni_server",
_get_diffusion_feature_cases(),
diff --git a/tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml b/tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml
deleted file mode 100644
index 53cc73ce09c..00000000000
--- a/tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml
+++ /dev/null
@@ -1,75 +0,0 @@
-# CI stage config for Ming-flash-omni-2.0 thinker+talker pipeline.
-# Stage 0: Thinker (multimodal understanding -> text generation)
-# Stage 1: Talker (text -> audio waveform via CFM + AudioVAE)
-#
-# The following config has been verified on 4x H100-80G GPUs
-stage_args:
- - stage_id: 0
- stage_type: llm
- runtime:
- devices: "0,1,2,3"
- max_batch_size: 1
- engine_args:
- model_stage: thinker
- model_arch: MingFlashOmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.74
- enforce_eager: false
- trust_remote_code: true
- # Ming Thinker -> talker bridge reads detokenised text from
- # source_output.outputs[0].text (not hidden states).
- engine_output_type: text
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- max_model_len: 32768
- tensor_parallel_size: 4
- hf_config_name: llm_config
- load_format: dummy
- mm_processor_cache_gb: 0
- compilation_config:
- pass_config:
- # Disable fused all-reduce to avoid a vllm/flashinfer version mismatch.
- fuse_allreduce_rms: false
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.0
- top_p: 0.9
- max_tokens: 100
- repetition_penalty: 1.05
- seed: 42
- detokenize: true
- ignore_eos: false
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "3"
- max_batch_size: 1
- engine_args:
- model_stage: ming_tts
- model_arch: MingFlashOmniTalkerForConditionalGeneration
- worker_cls: vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.18
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: audio
- enable_prefix_caching: false
- max_num_batched_tokens: 1000000
- tokenizer_subdir: talker/llm
- # The HF repo ships BailingMM2Config (thinker-only) at root;
- # OmniModelConfig treats that as "stage does not share outer mrope".
- hf_config_name: talker_config
- load_format: dummy
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.ming_flash_omni.thinker2talker
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- max_tokens: 1
- seed: 42
diff --git a/tests/e2e/stage_configs/bailingmm_moe_v2_lite_thinker_only_ci.yaml b/tests/e2e/stage_configs/bailingmm_moe_v2_lite_thinker_only_ci.yaml
deleted file mode 100644
index fb0c72cc513..00000000000
--- a/tests/e2e/stage_configs/bailingmm_moe_v2_lite_thinker_only_ci.yaml
+++ /dev/null
@@ -1,35 +0,0 @@
-# Thinker stage only
-stage_args:
- - stage_id: 0
- stage_type: llm
- runtime:
- devices: "0,1,2,3"
- max_batch_size: 1
- engine_args:
- model_stage: thinker
- model_arch: MingFlashOmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- max_model_len: 32768
- tensor_parallel_size: 4
- hf_config_name: llm_config
- load_format: dummy
- mm_processor_cache_gb: 0
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- max_tokens: 100
- repetition_penalty: 1.05
- seed: 42
- detokenize: true
- ignore_eos: false
diff --git a/tests/e2e/stage_configs/dynin_omni_ci.yaml b/tests/e2e/stage_configs/dynin_omni_ci.yaml
deleted file mode 100644
index 525b7d888c2..00000000000
--- a/tests/e2e/stage_configs/dynin_omni_ci.yaml
+++ /dev/null
@@ -1,79 +0,0 @@
-# stage config for running dynin_omni with a 3-stage architecture.
-# this config is intended for e2e smoke tests.
-
-stage_args:
- - stage_id: 0
- stage_type: llm
- runtime:
- process: true
- devices: "0"
- max_batch_size: 1
- engine_args:
- model_stage: token2text
- model_arch: DyninOmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- engine_output_type: latent
- trust_remote_code: true
- gpu_memory_utilization: 0.5
- enforce_eager: true
- enable_prefix_caching: false
- async_scheduling: false
- max_num_batched_tokens: 4096
- is_comprehension: true
- final_output: true
- final_output_type: text
-
- - stage_id: 1
- stage_type: llm
- runtime:
- process: true
- devices: "0"
- max_batch_size: 1
- engine_args:
- model_stage: token2image
- model_arch: DyninOmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- engine_output_type: latent
- trust_remote_code: true
- gpu_memory_utilization: 0.2
- enforce_eager: true
- enable_prefix_caching: false
- async_scheduling: false
- max_num_batched_tokens: 4096
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image
- final_output: true
- final_output_type: image
-
- - stage_id: 2
- stage_type: llm
- runtime:
- process: true
- devices: "0"
- max_batch_size: 1
- engine_args:
- model_stage: token2audio
- model_arch: DyninOmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- engine_output_type: latent
- trust_remote_code: true
- gpu_memory_utilization: 0.2
- enforce_eager: true
- enable_prefix_caching: false
- async_scheduling: false
- max_num_batched_tokens: 4096
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2image_to_token2audio
- final_output: true
- final_output_type: audio
-
-runtime:
- enabled: true
- edges:
- - from: 0
- to: 1
- - from: 1
- to: 2
diff --git a/tests/e2e/stage_configs/mimo_audio_ci.yaml b/tests/e2e/stage_configs/mimo_audio_ci.yaml
new file mode 100644
index 00000000000..7127a71b499
--- /dev/null
+++ b/tests/e2e/stage_configs/mimo_audio_ci.yaml
@@ -0,0 +1,71 @@
+# CI stage config for running MiMo-Audio in multi-stage omni runtime.
+# Based on mimo_audio.yaml with load_format: dummy for CI testing.
+
+async_chunk: false
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ runtime:
+ process: true
+ devices: "0"
+ engine_args:
+ dtype: bfloat16
+ max_num_seqs: 1
+ model_stage: fused_thinker_talker
+ model_arch: MiMoAudioForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ tensor_parallel_size: 1
+ gpu_memory_utilization: 0.5
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: latent
+ max_model_len: 8192
+ max_num_batched_tokens: 8192
+ load_format: dummy
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.6
+ top_p: 0.95
+ top_k: 50
+ max_tokens: 18192
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+
+ - stage_id: 1
+ stage_type: llm
+ runtime:
+ process: true
+ devices: "0"
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: MiMoAudioForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ tensor_parallel_size: 1
+ gpu_memory_utilization: 0.2
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: audio
+ max_model_len: 18192
+ max_num_batched_tokens: 18192
+ async_scheduling: false
+ load_format: dummy
+ engine_input_source: [ 0 ]
+ is_comprehension: false
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.mimo_audio.llm2code2wav
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 18192
+ seed: 42
+ detokenize: false
diff --git a/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml b/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml
new file mode 100644
index 00000000000..a7c637d486a
--- /dev/null
+++ b/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml
@@ -0,0 +1,109 @@
+# stage config for running qwen2.5-omni for multi-stage omni runtime.
+
+# The following config has been verified on 2x 24GB GPU (L4/RTX3090/RTX4090).
+# This config is optimized for CI e2e tests.
+stage_args:
+ - stage_id: 0
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
+ engine_args:
+ model_stage: thinker
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ max_model_len: 16384
+ max_num_batched_tokens: 16384
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.9
+ skip_mm_profiling: true
+ enforce_eager: true # Now we only support eager mode
+ trust_remote_code: true
+ engine_output_type: latent
+ enable_prefix_caching: false
+ mm_processor_cache_gb: 0
+ load_format: dummy
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 128
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+ - stage_id: 1
+ runtime:
+ process: true
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ max_model_len: 16384
+ max_num_batched_tokens: 16384
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.4
+ skip_mm_profiling: true
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: latent
+ load_format: dummy
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
+ default_sampling_params:
+ temperature: 0.9
+ top_p: 0.8
+ top_k: 40
+ max_tokens: 4096
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+ stop_token_ids: [8294]
+ - stage_id: 2
+ runtime:
+ process: true
+ devices: "2" # Example: use a different GPU than the previous stage; use "0" if single GPU
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ gpu_memory_utilization: 0.5 #increase the gpu memory utilization to enable the test on H800
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: audio
+ max_num_batched_tokens: 8192
+ max_model_len: 8192
+ load_format: dummy
+ engine_input_source: [1]
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 8192
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+
+# Top-level runtime config (concise): default windows and stage edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1 # Simplified: trigger downstream only after full upstream completion
+ max_inflight: 1 # Simplified: process serially within each stage
+ edges:
+ - from: 0 # thinker → talker: trigger only after receiving full input (-1)
+ to: 1
+ window_size: -1
+ - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
+ to: 2
+ window_size: -1
diff --git a/tests/e2e/stage_configs/qwen2_5_omni_thinker_ci.yaml b/tests/e2e/stage_configs/qwen2_5_omni_thinker_ci.yaml
new file mode 100644
index 00000000000..94013828478
--- /dev/null
+++ b/tests/e2e/stage_configs/qwen2_5_omni_thinker_ci.yaml
@@ -0,0 +1,31 @@
+stage_args:
+ - stage_id: 0
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
+ engine_args:
+ model_stage: thinker
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ max_model_len: 16384
+ max_num_batched_tokens: 16384
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.9
+ skip_mm_profiling: true
+ enforce_eager: true # Now we only support eager mode
+ trust_remote_code: true
+ engine_output_type: latent
+ enable_prefix_caching: false
+ mm_processor_cache_gb: 0
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 128
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
diff --git a/tests/e2e/stage_configs/qwen3_omni_ci.yaml b/tests/e2e/stage_configs/qwen3_omni_ci.yaml
new file mode 100644
index 00000000000..08dd49de953
--- /dev/null
+++ b/tests/e2e/stage_configs/qwen3_omni_ci.yaml
@@ -0,0 +1,102 @@
+# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
+# Stage 0: Thinker (multimodal understanding + text generation)
+# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
+# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
+
+# The following config has been verified on 2x H100-80G GPUs.
+stage_args:
+- stage_id: 0
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 5
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent # Output hidden states for talker
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 32768
+ max_model_len: 32768
+ enable_prefix_caching: false
+ mm_processor_cache_gb: 0
+ hf_config_name: thinker_config
+ tensor_parallel_size: 1
+ load_format: dummy
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 150
+ seed: 42
+ ignore_eos: False
+ detokenize: True
+ repetition_penalty: 1.05
+
+- stage_id: 1
+ runtime:
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ max_num_seqs: 5
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.5
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent # Output codec codes for code2wav
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ max_model_len: 32768
+ distributed_executor_backend: "mp"
+ hf_config_name: talker_config
+ load_format: dummy
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 1000
+ seed: 42
+ detokenize: False
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+
+- stage_id: 2
+ runtime:
+ devices: "1"
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 5
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: audio # Final output: audio waveform
+ gpu_memory_utilization: 0.1
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 100000
+ hf_config_name: thinker_config
+ async_scheduling: false
+ load_format: dummy
+ engine_input_source: [1]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2000
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
diff --git a/tests/e2e/stage_configs/rocm/qwen2_5_omni_ci.yaml b/tests/e2e/stage_configs/rocm/qwen2_5_omni_ci.yaml
new file mode 100644
index 00000000000..0c756ce56b1
--- /dev/null
+++ b/tests/e2e/stage_configs/rocm/qwen2_5_omni_ci.yaml
@@ -0,0 +1,106 @@
+# stage config for running qwen2.5-omni for multi-stage omni runtime.
+
+# The following config has been verified on 2x 24GB GPU (L4/RTX3090/RTX4090).
+# This config is optimized for CI e2e tests.
+stage_args:
+ - stage_id: 0
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
+ engine_args:
+ model_stage: thinker
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ max_model_len: 16384
+ max_num_batched_tokens: 16384
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.8
+ skip_mm_profiling: true
+ enforce_eager: true # Now we only support eager mode
+ trust_remote_code: true
+ engine_output_type: latent
+ enable_prefix_caching: false
+ mm_processor_cache_gb: 0
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 128
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+ - stage_id: 1
+ runtime:
+ process: true
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ max_model_len: 16384
+ max_num_batched_tokens: 16384
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.8
+ skip_mm_profiling: true
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: latent
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
+ default_sampling_params:
+ temperature: 0.9
+ top_p: 0.8
+ top_k: 40
+ max_tokens: 4096
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+ stop_token_ids: [8294]
+ - stage_id: 2
+ runtime:
+ process: true
+ devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ gpu_memory_utilization: 0.15
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: audio
+ max_num_batched_tokens: 4096
+ max_model_len: 4096
+ engine_input_source: [1]
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 4096
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+
+# Top-level runtime config (concise): default windows and stage edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1 # Simplified: trigger downstream only after full upstream completion
+ max_inflight: 1 # Simplified: process serially within each stage
+ edges:
+ - from: 0 # thinker → talker: trigger only after receiving full input (-1)
+ to: 1
+ window_size: -1
+ - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
+ to: 2
+ window_size: -1
diff --git a/tests/e2e/stage_configs/rocm/qwen3_omni_ci.yaml b/tests/e2e/stage_configs/rocm/qwen3_omni_ci.yaml
new file mode 100644
index 00000000000..ac2b1fbd713
--- /dev/null
+++ b/tests/e2e/stage_configs/rocm/qwen3_omni_ci.yaml
@@ -0,0 +1,100 @@
+# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
+# Stage 0: Thinker (multimodal understanding + text generation)
+# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
+# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
+
+# The following config has been verified on 2x H100-80G GPUs.
+stage_args:
+ - stage_id: 0
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 1
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent # Output hidden states for talker
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ mm_processor_cache_gb: 0
+ hf_config_name: thinker_config
+ tensor_parallel_size: 1
+ load_format: dummy
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 100
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ runtime:
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ max_num_seqs: 1
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.6
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: latent # Output codec codes for code2wav
+ # tensor_parallel_size: 2
+ enable_prefix_caching: false
+ distributed_executor_backend: "mp"
+ hf_config_name: talker_config
+ load_format: dummy
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
+ # final_output: true
+ # final_output_type: text
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 100
+ seed: 42
+ detokenize: False
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+
+ - stage_id: 2
+ runtime:
+ devices: "1"
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: audio # Final output: audio waveform
+ gpu_memory_utilization: 0.1
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 1000000
+ hf_config_name: thinker_config
+ load_format: dummy
+ async_scheduling: false
+ engine_input_source: [1]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 200
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
diff --git a/tests/e2e/stage_configs/xpu/qwen2_5_omni_ci.yaml b/tests/e2e/stage_configs/xpu/qwen2_5_omni_ci.yaml
new file mode 100644
index 00000000000..14ef3c34385
--- /dev/null
+++ b/tests/e2e/stage_configs/xpu/qwen2_5_omni_ci.yaml
@@ -0,0 +1,108 @@
+# stage config for running qwen2.5-omni for multi-stage omni runtime.
+
+# The following config is verified with 2 * Intel Arc Pro B60 XPU.
+stage_args:
+ - stage_id: 0
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "0" # Visible devices for this stage
+ engine_args:
+ model_stage: thinker
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ max_model_len: 16384
+ max_num_batched_tokens: 16384
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.9 # thinker weight is around 16.74GB for Qwen2.5-Omni-7B
+ skip_mm_profiling: true
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: latent
+ enable_prefix_caching: false
+ mm_processor_cache_gb: 0
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 128
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+ - stage_id: 1
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ process: true
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ max_model_len: 16384
+ max_num_batched_tokens: 16384
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.5 # talker weight is 6.03GB for Qwen2.5-Omni-7B
+ skip_mm_profiling: true
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: latent
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
+ default_sampling_params:
+ temperature: 0.9
+ top_p: 0.8
+ top_k: 40
+ max_tokens: 4096
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+ stop_token_ids: [8294]
+
+ - stage_id: 2
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ process: true
+ devices: "2"
+ engine_args:
+ max_num_seqs: 1
+ model_stage: code2wav
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ gpu_memory_utilization: 0.3 # code2wav weight is around 1.46GB for Qwen2.5-Omni-7B
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: audio
+ engine_input_source: [1]
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+
+# Top-level runtime config (concise): default windows and stage edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1 # Simplified: trigger downstream only after full upstream completion
+ max_inflight: 1 # Simplified: process serially within each stage
+
+ edges:
+ - from: 0 # thinker → talker: trigger only after receiving full input (-1)
+ to: 1
+ window_size: -1
+ - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
+ to: 2
+ window_size: -1
diff --git a/tests/e2e/stage_configs/xpu/qwen3_omni_ci.yaml b/tests/e2e/stage_configs/xpu/qwen3_omni_ci.yaml
new file mode 100644
index 00000000000..c4586e06649
--- /dev/null
+++ b/tests/e2e/stage_configs/xpu/qwen3_omni_ci.yaml
@@ -0,0 +1,109 @@
+# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
+# Stage 0: Thinker (multimodal understanding + text generation)
+# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
+# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
+
+# The following config is verified with 8 * Intel Arc Pro B60 XPU.
+stage_args:
+- stage_id: 0
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "0,1,2,3"
+ engine_args:
+ max_num_seqs: 1
+ model_stage: thinker
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.85 # thinker weight is around 61.08GB for Qwen3-Omni-30B-A3B-Instruct
+ skip_mm_profiling: true
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: latent # Output hidden states for talker
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 4096
+ max_model_len: 4096
+ enable_prefix_caching: false
+ hf_config_name: thinker_config
+ tensor_parallel_size: 4
+ max_cudagraph_capture_size: 0
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 100
+ seed: 42
+ ignore_eos: False
+ detokenize: True
+ repetition_penalty: 1.05
+
+- stage_id: 1
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "4"
+ engine_args:
+ max_num_seqs: 1
+ model_stage: talker
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.6 # talker weight is around 8.5GB for Qwen3-Omni-30B-A3B-Instruct
+ skip_mm_profiling: true
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: latent # Output codec codes for code2wav
+ enable_prefix_caching: false
+ max_num_batched_tokens: 4096
+ max_model_len: 4096
+ distributed_executor_backend: "mp"
+ hf_config_name: talker_config
+ max_cudagraph_capture_size: 0
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
+ # final_output: true
+ # final_output_type: text
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ detokenize: False
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+
+- stage_id: 2
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "5"
+ engine_args:
+ max_num_seqs: 1
+ model_stage: code2wav
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: audio # Final output: audio waveform
+ gpu_memory_utilization: 0.3 # code2wav weight is around 0.4GB for Qwen3-Omni-30B-A3B-Instruct
+ skip_mm_profiling: true
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 100000
+ hf_config_name: thinker_config
+ async_scheduling: false
+ max_cudagraph_capture_size: 0
+ engine_input_source: [1]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2000
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py
index 0d61f6a675b..7ba1cebece0 100644
--- a/tests/engine/test_arg_utils.py
+++ b/tests/engine/test_arg_utils.py
@@ -4,9 +4,7 @@
explicitly patch values that differ from vLLM.
"""
-import argparse
import inspect
-from types import SimpleNamespace
from unittest.mock import Mock
import pytest
@@ -16,7 +14,6 @@
from vllm_omni.config.model import OmniModelConfig
from vllm_omni.engine.arg_utils import OmniEngineArgs
-from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
@@ -27,40 +24,21 @@ def test_sync_config_is_omni():
assert isinstance(cfg, OmniModelConfig)
-def test_default_stage_id_is_concrete_int():
- """Ensure `stage_id` stays safe for downstream arithmetic/indexing."""
- engine_args = OmniEngineArgs()
-
- assert engine_args.stage_id == 0
- assert isinstance(engine_args.stage_id, int)
- assert engine_args.log_stats is False
-
- cfg = engine_args.create_model_config()
- assert cfg.stage_id == 0
-
-
-def test_multimodal_kwarg_overrides(mocker):
+def test_multimodal_kwarg_overrides():
"""Ensure that overrides in the multimodal config are preserved."""
+ # Get a different value than the default for a multimodal field
sig = inspect.signature(OmniEngineArgs)
default_mm_cache = sig.parameters["mm_processor_cache_gb"].default
override_val = default_mm_cache + 1
- fake_model_config = SimpleNamespace(
- multimodal_config=SimpleNamespace(mm_processor_cache_gb=override_val),
- )
-
- def _fake_parent_create_model_config(self):
- assert self.mm_processor_cache_gb == override_val
- return fake_model_config
-
- mocker.patch.object(EngineArgs, "create_model_config", _fake_parent_create_model_config)
- mocker.patch.object(OmniModelConfig, "from_vllm_model_config", side_effect=lambda model_config, **_: model_config)
-
+ # NOTE: This needs to be a model that resolves to supports_multimodal=True
+ # in vLLM, otherwise we won't have an MM config
cfg = OmniEngineArgs(
model="Qwen/Qwen2-VL-2B-Instruct",
mm_processor_cache_gb=override_val,
).create_model_config()
+ # Ensure that the override was applied correctly
assert cfg.multimodal_config is not None
assert cfg.multimodal_config.mm_processor_cache_gb == override_val
@@ -110,7 +88,7 @@ def test_qwen3_tts_codec_frame_rate_patching():
vllm_config = EngineArgs().create_model_config()
# Create a mock talking config with a dummy value for position_id_per_seconds
- mock_talker_config = SimpleNamespace()
+ mock_talker_config = Mock()
mock_talker_config.position_id_per_seconds = 12.3
vllm_config.hf_config.talker_config = mock_talker_config
@@ -126,54 +104,6 @@ def test_qwen3_tts_codec_frame_rate_patching():
assert omni_config.codec_frame_rate_hz == 12.3
-def test_from_cli_args_picks_up_stage_configs_path():
- """from_cli_args should pick up stage_configs_path from namespace."""
- ns = argparse.Namespace(
- model="facebook/opt-125m",
- stage_configs_path="/some/path.yaml",
- custom_pipeline_args=None,
- )
-
- args = OmniEngineArgs.from_cli_args(ns)
- assert args.stage_configs_path == "/some/path.yaml"
- assert args.custom_pipeline_args is None
-
-
-def test_qwen3_tts_code2wav_injects_max_position_embeddings(monkeypatch):
- """Ensure Code2Wav mirrors stage max_model_len into nested HF overrides.
-
- Qwen3-TTS Code2Wav is a pure decoder stage whose runtime max_model_len can
- legitimately exceed the base checkpoint's default text max length. Recent
- vLLM validates these values during ModelConfig creation, so we inject
- ``talker_config.max_position_embeddings`` before delegating to vLLM.
- """
- captured: dict[str, object] = {}
- baseline_config = Mock()
-
- def fake_create_model_config(self):
- captured["hf_overrides"] = self.hf_overrides
- return baseline_config
-
- monkeypatch.setattr(EngineArgs, "create_model_config", fake_create_model_config)
- monkeypatch.setattr(
- OmniModelConfig,
- "from_vllm_model_config",
- classmethod(lambda cls, model_config, **omni_kwargs: model_config),
- )
-
- OmniEngineArgs(
- model_arch="Qwen3TTSCode2Wav",
- max_model_len=65536,
- ).create_model_config()
-
- assert captured["hf_overrides"] == {
- "architectures": ["Qwen3TTSCode2Wav"],
- "talker_config": {
- "max_position_embeddings": 65536,
- },
- }
-
-
def test_stage_specific_text_config_override():
"""Ensure dependent attributes are updated when using stage-specific config."""
vllm_config = EngineArgs().create_model_config()
@@ -182,12 +112,13 @@ def test_stage_specific_text_config_override():
# Switch the created hf text config with a mock whose
# values we want to pull through the text config helper
stage_text_config = vllm_config.hf_text_config
- vllm_config.hf_text_config = SimpleNamespace()
+ vllm_config.hf_text_config = Mock()
stage_text_config.sliding_window = 4096
stage_text_config.attention_chunk_size = 2048
# Move the stage config's text config getter & thinker config
- mock_stage_config = SimpleNamespace(get_text_config=lambda: stage_text_config)
+ mock_stage_config = Mock()
+ mock_stage_config.get_text_config.return_value = stage_text_config
vllm_config.hf_config.thinker_config = mock_stage_config
# Ensure that create from a vLLM config correctly pulls the
@@ -201,92 +132,3 @@ def test_stage_specific_text_config_override():
assert omni_config.attention_chunk_size == 2048
assert omni_config.max_model_len == 4096
assert omni_config.hf_text_config.sliding_window is None
-
-
-def test_stage_configs_path_field():
- """OmniEngineArgs with stage_configs_path should construct without error."""
- args = OmniEngineArgs(stage_configs_path="/some/path.yaml")
- assert args.stage_configs_path == "/some/path.yaml"
-
-
-def test_voxcpm_model_arch_injects_model_type_override(mocker):
- """Ensure VoxCPM model_arch injects hf_overrides for config resolution."""
- mocker.patch.object(OmniEngineArgs, "_ensure_omni_models_registered", return_value=True)
- mocker.patch.object(OmniEngineArgs, "_patch_empty_hf_config")
- mocker.patch.object(EngineArgs, "create_model_config", return_value=Mock())
- mocker.patch.object(OmniModelConfig, "from_vllm_model_config", return_value=Mock())
-
- args = OmniEngineArgs(
- model="OpenBMB/VoxCPM1.5",
- model_arch="VoxCPMForConditionalGeneration",
- )
- args.create_model_config()
-
- assert args.hf_overrides["architectures"] == ["VoxCPMForConditionalGeneration"]
- assert args.hf_overrides["model_type"] == "voxcpm"
- args._patch_empty_hf_config.assert_called_once_with("voxcpm")
-
-
-def test_strip_single_engine_args():
- """_strip_single_engine_args should remove EngineArgs fields but keep omni fields."""
- kwargs = {
- # Parent EngineArgs fields — should be stripped
- "compilation_config": '{"cudagraph_mode": "FULL_AND_PIECEWISE"}',
- "tensor_parallel_size": 4,
- "gpu_memory_utilization": 0.9,
- "model": "some/model",
- # Parent field that should be kept (allowlisted)
- "worker_extension_cls": "some.Extension",
- # OmniEngineArgs-only / non-engine fields — should pass through
- "stage_configs_path": "/path/to/yaml",
- "custom_pipeline_args": {"pipeline_class": "my.Pipeline"},
- "mode": "text-to-image",
- "lora_path": "/some/lora",
- }
-
- filtered = AsyncOmniEngine._strip_single_engine_args(kwargs)
-
- # Stripped — parent EngineArgs fields
- assert "compilation_config" not in filtered
- assert "tensor_parallel_size" not in filtered
- assert "gpu_memory_utilization" not in filtered
- assert "model" not in filtered
-
- # Stripped — orchestrator-level OmniEngineArgs field
- assert "stage_configs_path" not in filtered
-
- # Kept
- assert filtered["worker_extension_cls"] == "some.Extension"
- assert filtered["custom_pipeline_args"] == {"pipeline_class": "my.Pipeline"}
- assert filtered["mode"] == "text-to-image"
- assert filtered["lora_path"] == "/some/lora"
-
-
-def test_strip_single_engine_args_model_does_not_trigger_warning(mocker):
- """model is always in kwargs (callers set it via from_cli_args/asdict),
- so it should not cause the override warning by itself or appear in it."""
- mock_warn = mocker.patch("vllm_omni.engine.async_omni_engine.logger.warning")
-
- # Typical caller kwargs: model is always present, no other parent
- # EngineArgs fields are explicitly overridden.
- AsyncOmniEngine._strip_single_engine_args(
- {
- "model": "some/model",
- "custom_pipeline_args": {"pipeline_class": "my.Pipeline"},
- }
- )
- mock_warn.assert_not_called()
-
- # When there *are* genuinely surprising overrides alongside model,
- # the warning should mention them but not model.
- AsyncOmniEngine._strip_single_engine_args(
- {
- "model": "some/model",
- "tensor_parallel_size": 4,
- "custom_pipeline_args": {"pipeline_class": "my.Pipeline"},
- }
- )
- mock_warn.assert_called_once()
- warned_args = mock_warn.call_args[0][-1] # the formatted arg list
- assert "tensor_parallel_size" in warned_args
- assert "model" not in warned_args
diff --git a/tests/engine/test_async_omni_engine_abort.py b/tests/engine/test_async_omni_engine_abort.py
index eda7a7a788e..34fdf45ea25 100644
--- a/tests/engine/test_async_omni_engine_abort.py
+++ b/tests/engine/test_async_omni_engine_abort.py
@@ -2,21 +2,20 @@
import os
import sys
from contextlib import ExitStack
+from pathlib import Path
import pytest
from vllm import SamplingParams
from vllm.inputs import PromptType
-from tests.helpers.mark import hardware_test
-from tests.helpers.stage_config import get_deploy_config_path
+from tests.utils import hardware_test
from vllm_omni.entrypoints.async_omni import AsyncOmni
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
SEED = 42
-# Single-stage thinker-only deploy, materialized from tests.helpers.stage_config._CI_OVERLAYS.
-stage_config = get_deploy_config_path("ci/qwen2_5_omni_thinker_only.yaml")
+stage_config = str(Path(__file__).parent.parent / "e2e" / "stage_configs" / "qwen2_5_omni_thinker_ci.yaml")
model = "Qwen/Qwen2.5-Omni-7B"
@@ -61,6 +60,7 @@ async def generate(
@pytest.mark.core_model
@pytest.mark.omni
+@pytest.mark.real_hf_config
@hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=1)
@pytest.mark.asyncio
async def test_abort():
diff --git a/tests/engine/test_async_omni_engine_input.py b/tests/engine/test_async_omni_engine_input.py
index 3700e426d42..ed6a7277b46 100644
--- a/tests/engine/test_async_omni_engine_input.py
+++ b/tests/engine/test_async_omni_engine_input.py
@@ -1,5 +1,6 @@
+from unittest.mock import Mock
+
import pytest
-from pytest_mock import MockerFixture
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import EngineCoreRequest
@@ -23,18 +24,18 @@ def _make_engine_core_request() -> EngineCoreRequest:
)
-def test_build_add_request_message_preserves_additional_information(mocker: MockerFixture):
+def test_build_add_request_message_preserves_additional_information():
engine = object.__new__(AsyncOmniEngine)
params = SamplingParams(max_tokens=8)
engine.default_sampling_params_list = [params]
engine.stage_metadata = [{"stage_type": "llm"}]
engine.supported_tasks = ("speech",)
- input_processor = mocker.Mock()
+ input_processor = Mock()
input_processor.process_inputs.return_value = _make_engine_core_request()
engine.input_processor = input_processor
- output_processor = mocker.Mock()
+ output_processor = Mock()
engine.output_processors = [output_processor]
prompt = {
@@ -62,18 +63,18 @@ def test_build_add_request_message_preserves_additional_information(mocker: Mock
output_processor.add_request.assert_called_once()
-def test_build_add_request_message_with_resumable_streaming(mocker: MockerFixture):
+def test_build_add_request_message_with_resumable_streaming():
engine = object.__new__(AsyncOmniEngine)
params = SamplingParams(max_tokens=8)
engine.default_sampling_params_list = [params]
engine.stage_metadata = [{"stage_type": "llm"}]
engine.supported_tasks = ("generate",)
- input_processor = mocker.Mock()
+ input_processor = Mock()
input_processor.process_inputs.return_value = _make_engine_core_request()
engine.input_processor = input_processor
- output_processor = mocker.Mock()
+ output_processor = Mock()
engine.output_processors = [output_processor]
msg = engine._build_add_request_message(
diff --git a/tests/engine/test_async_omni_engine_outputs.py b/tests/engine/test_async_omni_engine_outputs.py
index 47b3d5e9f14..ccf9e8cb6b6 100644
--- a/tests/engine/test_async_omni_engine_outputs.py
+++ b/tests/engine/test_async_omni_engine_outputs.py
@@ -5,36 +5,36 @@
"""
import queue
+from unittest.mock import MagicMock
import pytest
-from pytest_mock import MockerFixture
from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-def _make_engine(output_queue, mocker: MockerFixture, *, thread_alive: bool = True) -> AsyncOmniEngine:
+def _make_engine(output_queue, *, thread_alive: bool = True) -> AsyncOmniEngine:
"""Create an AsyncOmniEngine bypassing __init__."""
engine = object.__new__(AsyncOmniEngine)
engine.output_queue = output_queue
- engine.orchestrator_thread = mocker.MagicMock(
- is_alive=mocker.MagicMock(return_value=thread_alive),
+ engine.orchestrator_thread = MagicMock(
+ is_alive=MagicMock(return_value=thread_alive),
)
return engine
-def test_try_get_output_raises_after_orchestrator_dies(mocker: MockerFixture):
+def test_try_get_output_raises_after_orchestrator_dies():
"""Draining remaining results then hitting an empty queue with a dead
orchestrator must raise RuntimeError so callers know the pipeline is gone."""
- mock_queue = mocker.MagicMock()
+ mock_queue = MagicMock()
# First call succeeds; second call finds the queue empty.
mock_queue.sync_q.get.side_effect = [
{"type": "output", "request_id": "r1"},
queue.Empty,
]
- engine = _make_engine(mock_queue, mocker, thread_alive=True)
+ engine = _make_engine(mock_queue, thread_alive=True)
# Collect the one buffered result.
assert engine.try_get_output()["request_id"] == "r1"
@@ -47,15 +47,15 @@ def test_try_get_output_raises_after_orchestrator_dies(mocker: MockerFixture):
@pytest.mark.asyncio
-async def test_try_get_output_async_raises_after_orchestrator_dies(mocker: MockerFixture):
+async def test_try_get_output_async_raises_after_orchestrator_dies():
"""Same scenario as above but for the async variant."""
- mock_queue = mocker.MagicMock()
+ mock_queue = MagicMock()
mock_queue.sync_q.get_nowait.side_effect = [
{"type": "output", "request_id": "r1"},
queue.Empty,
]
- engine = _make_engine(mock_queue, mocker, thread_alive=True)
+ engine = _make_engine(mock_queue, thread_alive=True)
assert (await engine.try_get_output_async())["request_id"] == "r1"
@@ -63,39 +63,3 @@ async def test_try_get_output_async_raises_after_orchestrator_dies(mocker: Mocke
with pytest.raises(RuntimeError, match="Orchestrator died unexpectedly"):
await engine.try_get_output_async()
-
-
-def test_fatal_error_message_surfaces_through_try_get_output(mocker: MockerFixture):
- """When the orchestrator thread crashes, it enqueues a fatal error message.
-
- ``try_get_output`` must return this message so the caller
- (``OmniBase._handle_output_message``) can detect the fatal flag.
- """
- fatal_msg = {"type": "error", "error": "Orchestrator thread crashed", "fatal": True}
-
- mock_queue = mocker.MagicMock()
- mock_queue.sync_q.get.return_value = fatal_msg
-
- engine = _make_engine(mock_queue, mocker, thread_alive=False)
-
- msg = engine.try_get_output()
- assert msg is not None
- assert msg["type"] == "error"
- assert msg["fatal"] is True
- assert "crashed" in msg["error"]
-
-
-@pytest.mark.asyncio
-async def test_fatal_error_message_surfaces_through_try_get_output_async(mocker: MockerFixture):
- """Async variant of the fatal error message test."""
- fatal_msg = {"type": "error", "error": "Orchestrator thread crashed", "fatal": True}
-
- mock_queue = mocker.MagicMock()
- mock_queue.sync_q.get_nowait.return_value = fatal_msg
-
- engine = _make_engine(mock_queue, mocker, thread_alive=False)
-
- msg = await engine.try_get_output_async()
- assert msg is not None
- assert msg["type"] == "error"
- assert msg["fatal"] is True
diff --git a/tests/engine/test_async_omni_engine_stage_init.py b/tests/engine/test_async_omni_engine_stage_init.py
index 5c2a9edb771..9f47fd449d7 100644
--- a/tests/engine/test_async_omni_engine_stage_init.py
+++ b/tests/engine/test_async_omni_engine_stage_init.py
@@ -1,6 +1,4 @@
-import importlib
import os
-import threading
import types
import pytest
@@ -10,17 +8,6 @@
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-def test_stage_engine_core_client_module_reload_keeps_forward_refs_deferred():
- """Regression test for forward references in make_async_mp_client."""
- import vllm_omni.engine.stage_engine_core_client as client_mod
-
- importlib.reload(client_mod)
-
- assert client_mod.StageEngineCoreClientBase.make_async_mp_client.__annotations__["return"] == (
- "StageEngineCoreClient | DPLBStageEngineCoreClient"
- )
-
-
def test_initialize_stages_restores_device_visibility_after_diffusion_init(monkeypatch):
"""Regression test for stage device env leakage across stage init.
@@ -36,9 +23,6 @@ def test_initialize_stages_restores_device_visibility_after_diffusion_init(monke
engine.num_stages = 1
engine.async_chunk = False
engine.diffusion_batch_size = 1
- engine.single_stage_mode = False
- engine._single_stage_id_filter = None
- engine._omni_master_server = None
engine.stage_configs = [types.SimpleNamespace(stage_id=0, stage_type="diffusion")]
env_var = current_omni_platform.device_control_env_var
@@ -65,7 +49,7 @@ def _fake_setup_stage_devices(_stage_id, _runtime_cfg):
current_omni_platform.set_device_control_env_var("1")
monkeypatch.setattr(engine_mod, "setup_stage_devices", _fake_setup_stage_devices)
- monkeypatch.setattr(engine_mod, "inject_kv_stage_info", lambda *_: None)
+ monkeypatch.setattr(engine_mod, "_inject_kv_stage_info", lambda *_: None)
monkeypatch.setattr(engine_mod, "initialize_diffusion_stage", lambda *_, **__: diffusion_client)
monkeypatch.setattr(
engine_mod,
@@ -87,238 +71,6 @@ def _fake_setup_stage_devices(_stage_id, _runtime_cfg):
os.environ[env_var] = old_env
-def test_initialize_stages_passes_stage_init_timeout_to_diffusion_handshake(monkeypatch):
- """Regression test for stage_init_timeout passing to complete_diffusion_handshake
- in the diffusion stage path.
- """
- import vllm_omni.diffusion.data as diffusion_data_mod
- import vllm_omni.diffusion.stage_diffusion_client as client_mod
- import vllm_omni.engine.async_omni_engine as engine_mod
- from vllm_omni.platforms import current_omni_platform
-
- engine = object.__new__(AsyncOmniEngine)
- engine.log_stats = False
- engine.model = "dummy-model"
- engine.config_path = "dummy-config"
- engine.num_stages = 2
- engine.async_chunk = False
- engine.diffusion_batch_size = 1
- engine.single_stage_mode = False
- engine._omni_master_server = None
- engine.stage_configs = [types.SimpleNamespace(stage_id=0, stage_type="diffusion", engine_args={})]
-
- metadata = types.SimpleNamespace(
- stage_id=0,
- stage_type="diffusion",
- runtime_cfg={"devices": "0"},
- prompt_expand_func=None,
- final_output=True,
- final_output_type="image",
- default_sampling_params=None,
- custom_process_input_func=None,
- engine_input_source=None,
- cfg_kv_collect_func=None,
- )
-
- captured_timeout = None
- device_env_var = current_omni_platform.device_control_env_var
- prev_device_env = os.environ.get(device_env_var)
- os.environ[device_env_var] = "0"
-
- monkeypatch.setattr(engine_mod, "prepare_engine_environment", lambda: None)
- monkeypatch.setattr(engine_mod, "load_omni_transfer_config_for_model", lambda *_: None)
- monkeypatch.setattr(engine_mod, "extract_stage_metadata", lambda _cfg: metadata)
- monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None)
- monkeypatch.setattr(
- engine_mod,
- "finalize_initialized_stages",
- lambda stage_clients, _input_processor: (
- stage_clients,
- [types.SimpleNamespace()],
- [{"final_output_type": "image"}],
- ),
- )
- monkeypatch.setattr(
- diffusion_data_mod.OmniDiffusionConfig,
- "from_kwargs",
- classmethod(lambda cls, **kwargs: types.SimpleNamespace(parallel_config=types.SimpleNamespace(world_size=1))),
- )
- monkeypatch.setattr(
- client_mod,
- "spawn_diffusion_proc",
- lambda model, od_cfg: (object(), "ipc://handshake", "ipc://request", "ipc://response"),
- )
-
- def _capture_handshake_timeout(proc, handshake_address, handshake_timeout):
- nonlocal captured_timeout
- captured_timeout = handshake_timeout
-
- monkeypatch.setattr(client_mod, "complete_diffusion_handshake", _capture_handshake_timeout)
- monkeypatch.setattr(
- client_mod.zmq,
- "Context",
- lambda: types.SimpleNamespace(socket=lambda _: types.SimpleNamespace(connect=lambda _: None)),
- )
-
- try:
- engine._initialize_stages(stage_init_timeout=302)
- finally:
- if prev_device_env is None:
- os.environ.pop(device_env_var, None)
- else:
- os.environ[device_env_var] = prev_device_env
-
- assert captured_timeout == 302
-
-
-def test_launch_llm_stage_passes_stage_init_timeout_to_complete_stage_handshake(monkeypatch):
- """Regression test for stage_init_timeout reaching complete_stage_handshake
- in the LLM stage path.
- """
- import vllm_omni.engine.async_omni_engine as engine_mod
- from vllm_omni.platforms import current_omni_platform
-
- engine = object.__new__(AsyncOmniEngine)
- engine.log_stats = False
- engine.model = "dummy-model"
- engine.single_stage_mode = False
- engine._omni_master_server = None
- engine.stage_configs = []
-
- metadata = types.SimpleNamespace(stage_id=0, runtime_cfg={"devices": "0"})
- fake_vllm_config = types.SimpleNamespace()
- fake_addresses = types.SimpleNamespace()
- fake_proc = types.SimpleNamespace()
-
- captured_timeout = None
-
- device_env_var = current_omni_platform.device_control_env_var
- prev_device_env = os.environ.get(device_env_var)
- os.environ[device_env_var] = "0"
-
- monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None)
- monkeypatch.setattr(engine_mod, "build_engine_args_dict", lambda *_, **__: {})
- monkeypatch.setattr(engine_mod, "build_vllm_config", lambda *_, **__: (fake_vllm_config, object))
- monkeypatch.setattr(engine_mod, "acquire_device_locks", lambda *_: [])
- monkeypatch.setattr(
- engine_mod,
- "spawn_stage_core",
- lambda **_: (fake_addresses, fake_proc, "ipc://handshake"),
- )
-
- def _capture_stage_timeout(_proc, _handshake_addr, _addresses, _vllm_cfg, handshake_timeout):
- nonlocal captured_timeout
- captured_timeout = handshake_timeout
-
- monkeypatch.setattr(engine_mod, "complete_stage_handshake", _capture_stage_timeout)
-
- try:
- engine._launch_llm_stage(
- stage_cfg=types.SimpleNamespace(engine_args={}),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=302,
- llm_stage_launch_lock=threading.Lock(),
- )
- finally:
- if prev_device_env is None:
- os.environ.pop(device_env_var, None)
- else:
- os.environ[device_env_var] = prev_device_env
-
- assert captured_timeout == 302
-
-
-def test_launch_llm_stage_releases_launch_lock_before_complete_stage_handshake(monkeypatch):
- """Regression test for parallel LLM stage startup during handshake wait."""
- import vllm_omni.engine.async_omni_engine as engine_mod
- from vllm_omni.platforms import current_omni_platform
-
- engine = object.__new__(AsyncOmniEngine)
- engine.log_stats = False
- engine.model = "dummy-model"
- engine.single_stage_mode = False
- engine._omni_master_server = None
- engine.stage_configs = []
-
- fake_vllm_config = types.SimpleNamespace()
- fake_addresses = types.SimpleNamespace()
- shared_launch_lock = threading.Lock()
- counter_lock = threading.Lock()
- first_handshake_started = threading.Event()
- second_stage_spawned = threading.Event()
- allow_first_handshake_to_finish = threading.Event()
- launch_errors: list[BaseException] = []
- spawn_count = 0
-
- device_env_var = current_omni_platform.device_control_env_var
- prev_device_env = os.environ.get(device_env_var)
- os.environ[device_env_var] = "0"
-
- monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None)
- monkeypatch.setattr(engine_mod, "build_engine_args_dict", lambda *_, **__: {})
- monkeypatch.setattr(engine_mod, "build_vllm_config", lambda *_, **__: (fake_vllm_config, object))
- monkeypatch.setattr(engine_mod, "acquire_device_locks", lambda *_: [])
-
- def _spawn_stage_core(**_):
- nonlocal spawn_count
- with counter_lock:
- spawn_count += 1
- call_idx = spawn_count
- if call_idx == 2:
- second_stage_spawned.set()
- return fake_addresses, types.SimpleNamespace(), f"ipc://handshake-{call_idx}"
-
- def _complete_stage_handshake(_proc, handshake_address, _addresses, _vllm_cfg, _timeout):
- if handshake_address == "ipc://handshake-1":
- first_handshake_started.set()
- assert second_stage_spawned.wait(timeout=1), (
- "second stage did not reach spawn_stage_core while first stage waited in handshake"
- )
- assert allow_first_handshake_to_finish.wait(timeout=1), (
- "second stage did not enter handshake while first stage was still waiting"
- )
- else:
- allow_first_handshake_to_finish.set()
-
- monkeypatch.setattr(engine_mod, "spawn_stage_core", _spawn_stage_core)
- monkeypatch.setattr(engine_mod, "complete_stage_handshake", _complete_stage_handshake)
-
- def _launch_stage(stage_id: int) -> None:
- metadata = types.SimpleNamespace(stage_id=stage_id, runtime_cfg={"devices": str(stage_id)})
- try:
- engine._launch_llm_stage(
- stage_cfg=types.SimpleNamespace(engine_args={}),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=302,
- llm_stage_launch_lock=shared_launch_lock,
- )
- except BaseException as exc: # pragma: no cover - surfaced through assertion below
- launch_errors.append(exc)
-
- try:
- first_thread = threading.Thread(target=_launch_stage, args=(0,))
- first_thread.start()
- assert first_handshake_started.wait(timeout=1), "first stage never entered handshake"
-
- second_thread = threading.Thread(target=_launch_stage, args=(1,))
- second_thread.start()
-
- first_thread.join(timeout=3)
- second_thread.join(timeout=3)
- finally:
- if prev_device_env is None:
- os.environ.pop(device_env_var, None)
- else:
- os.environ[device_env_var] = prev_device_env
-
- assert not first_thread.is_alive()
- assert not second_thread.is_alive()
- assert second_stage_spawned.is_set()
- assert not launch_errors
-
-
def test_attach_llm_stage_uses_omni_input_preprocessor(monkeypatch):
"""Regression test for GLM-Image t2i preprocessing path.
@@ -349,11 +101,7 @@ def __init__(self, vllm_config, renderer=None):
self.vllm_config = vllm_config
self.renderer = renderer
- monkeypatch.setattr(
- engine_mod.StageEngineCoreClientBase,
- "make_async_mp_client",
- staticmethod(lambda **kwargs: DummyStageEngineCoreClient(**kwargs)),
- )
+ monkeypatch.setattr(engine_mod, "StageEngineCoreClient", DummyStageEngineCoreClient)
monkeypatch.setattr(engine_mod, "MultimodalOutputProcessor", DummyOutputProcessor)
monkeypatch.setattr(engine_mod, "InputProcessor", DummyInputProcessor)
monkeypatch.setattr(engine_mod, "OmniInputPreprocessor", DummyOmniInputPreprocessor)
@@ -380,70 +128,3 @@ def __init__(self, vllm_config, renderer=None):
assert input_processor is not None
assert isinstance(input_processor.input_preprocessor, DummyOmniInputPreprocessor)
assert input_processor.input_preprocessor.renderer is input_processor.renderer
-
-
-def test_inject_kv_stage_info_infers_sender_tp_topology():
- from vllm_omni.engine.stage_init_utils import inject_kv_stage_info
-
- stage0 = types.SimpleNamespace(
- stage_id=0,
- engine_args={
- "tensor_parallel_size": 4,
- "omni_kv_config": {
- "need_send_cache": True,
- "omni_from_stage": "0",
- "omni_to_stage": "1",
- },
- },
- engine_input_source=[],
- )
- stage1 = types.SimpleNamespace(
- stage_id=1,
- engine_args={
- "parallel_config": {
- "tensor_parallel_size": 2,
- "cfg_parallel_size": 1,
- },
- "omni_kv_config": {"need_recv_cache": True},
- },
- engine_input_source=[0],
- )
-
- inject_kv_stage_info(stage0, 0, [stage0, stage1])
-
- assert stage0.engine_args["omni_kv_config"]["stage_id"] == 0
- assert stage0.engine_args["omni_kv_config"]["rank_mapping"] == {"from_tp": 4, "to_tp": 2}
-
-
-def test_inject_kv_stage_info_infers_receiver_tp_topology():
- from vllm_omni.engine.stage_init_utils import inject_kv_stage_info
-
- stage0 = types.SimpleNamespace(
- stage_id=0,
- engine_args={
- "tensor_parallel_size": 4,
- "omni_kv_config": {"need_send_cache": True},
- },
- engine_input_source=[],
- )
- stage1 = types.SimpleNamespace(
- stage_id=1,
- engine_args={
- "parallel_config": {
- "tensor_parallel_size": 2,
- "cfg_parallel_size": 1,
- },
- "omni_kv_config": {
- "need_recv_cache": True,
- "omni_from_stage": "0",
- "omni_to_stage": "1",
- },
- },
- engine_input_source=[0],
- )
-
- inject_kv_stage_info(stage1, 1, [stage0, stage1])
-
- assert stage1.engine_args["omni_kv_config"]["stage_id"] == 1
- assert stage1.engine_args["omni_kv_config"]["engine_input_source"] == [0]
- assert stage1.engine_args["omni_kv_config"]["rank_mapping"] == {"from_tp": 4, "to_tp": 2}
diff --git a/tests/engine/test_cfg_companion_tracker.py b/tests/engine/test_cfg_companion_tracker.py
deleted file mode 100644
index f856a38c3e3..00000000000
--- a/tests/engine/test_cfg_companion_tracker.py
+++ /dev/null
@@ -1,82 +0,0 @@
-import pytest
-
-from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def test_register_companion_and_cleanup():
- tracker = CfgCompanionTracker()
-
- tracker.register_companion("req1", "cfg_text", "req1__cfg_text")
- tracker.register_companion("req1", "cfg_img", "req1__cfg_img")
-
- assert tracker.is_companion("req1__cfg_text")
- assert tracker.get_companion_request_ids("req1") == {
- "cfg_text": "req1__cfg_text",
- "cfg_img": "req1__cfg_img",
- }
-
- removed = tracker.cleanup_parent("req1")
-
- assert sorted(removed) == ["req1__cfg_img", "req1__cfg_text"]
- assert not tracker.is_companion("req1__cfg_text")
- assert tracker.get_companion_request_ids("req1") == {}
-
-
-def test_attach_cfg_request_ids_clones_diffusion_params():
- tracker = CfgCompanionTracker()
- tracker.register_companion("req1", "cfg_text", "req1__cfg_text")
-
- params = OmniDiffusionSamplingParams()
- updated = tracker.attach_cfg_request_ids("req1", params)
-
- assert updated is not params
- assert params.cfg_kv_request_ids is None
- assert updated.cfg_kv_request_ids == {"cfg_text": "req1__cfg_text"}
-
-
-def test_abort_parent_expands_to_companions_and_cleans_up_deferred_parent():
- tracker = CfgCompanionTracker()
- tracker.register_companion("req1", "cfg_text", "req1__cfg_text")
- tracker.defer_parent("req1", {"out": 1}, stage_id=0)
-
- aborted = tracker.abort_parents(["req1"])
-
- assert sorted(aborted) == ["req1", "req1__cfg_text"]
- assert not tracker.is_companion("req1__cfg_text")
- assert tracker.pop_pending_parent("req1") is None
-
-
-def test_abort_companion_does_not_expand_to_parent():
- tracker = CfgCompanionTracker()
- tracker.register_companion("req1", "cfg_text", "req1__cfg_text")
-
- aborted = tracker.abort_parents(["req1__cfg_text"])
-
- assert aborted == ["req1__cfg_text"]
-
-
-def test_companion_completion_flushes_deferred_parent():
- tracker = CfgCompanionTracker()
- tracker.register_companion("req1", "cfg_text", "req1__cfg_text")
- tracker.defer_parent("req1", {"out": 123}, stage_id=0)
-
- assert not tracker.all_companions_done("req1")
- assert tracker.on_companion_completed("req1__cfg_text") == "req1"
- assert tracker.all_companions_done("req1")
-
- popped = tracker.pop_pending_parent("req1")
- assert popped is not None
- assert popped["engine_outputs"] == {"out": 123}
- assert popped["stage_id"] == 0
-
-
-def test_companion_completion_without_registered_parent_asserts():
- tracker = CfgCompanionTracker()
- tracker._companion_ids.add("req1__cfg_text")
- tracker._companion_to_parent["req1__cfg_text"] = "req1"
-
- with pytest.raises(AssertionError, match="completed before parent req1 was registered"):
- tracker.on_companion_completed("req1__cfg_text")
diff --git a/tests/engine/test_cross_stage_lora.py b/tests/engine/test_cross_stage_lora.py
deleted file mode 100644
index 1eccc5526c6..00000000000
--- a/tests/engine/test_cross_stage_lora.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for cross-stage LoRA routing in the orchestrator."""
-
-from __future__ import annotations
-
-import pytest
-from vllm.lora.request import LoRARequest
-from vllm.sampling_params import SamplingParams
-
-from vllm_omni.engine.orchestrator import build_engine_core_request_from_tokens
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-class TestBuildEngineCoreRequestLoRA:
- """Verify build_engine_core_request_from_tokens passes LoRA from params."""
-
- def test_lora_extracted_from_diffusion_params(self):
- lr = LoRARequest(lora_name="test", lora_int_id=1, lora_path="/tmp/fake")
- params = OmniDiffusionSamplingParams(lora_request=lr)
-
- # OmniDiffusionSamplingParams is not a SamplingParams, so
- # build_engine_core_request_from_tokens takes the pooling path.
- # We only care that lora_request is extracted via getattr.
- request = build_engine_core_request_from_tokens(
- request_id="req-1",
- prompt={"prompt_token_ids": [1, 2, 3]},
- params=params,
- model_config=None,
- )
- assert request.lora_request is lr
-
- def test_no_lora_on_sampling_params(self):
- params = SamplingParams(max_tokens=10)
-
- request = build_engine_core_request_from_tokens(
- request_id="req-2",
- prompt={"prompt_token_ids": [1, 2, 3]},
- params=params,
- model_config=None,
- )
- assert request.lora_request is None
diff --git a/tests/engine/test_orchestrator.py b/tests/engine/test_orchestrator.py
deleted file mode 100644
index c07762ad56a..00000000000
--- a/tests/engine/test_orchestrator.py
+++ /dev/null
@@ -1,604 +0,0 @@
-from __future__ import annotations
-
-import asyncio
-import concurrent.futures
-import queue
-import threading
-import time
-from dataclasses import dataclass
-from types import SimpleNamespace
-from typing import Any
-
-import janus
-import psutil
-import pytest
-from vllm.outputs import CompletionOutput, RequestOutput
-from vllm.sampling_params import SamplingParams
-from vllm.v1.engine.core_client import AsyncMPClient
-
-from vllm_omni.engine.orchestrator import Orchestrator
-from vllm_omni.engine.stage_engine_core_client import StageEngineCoreClient
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-from vllm_omni.outputs import OmniRequestOutput
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-@dataclass
-class OrchestratorFixture:
- orchestrator: Orchestrator
- request_sync_q: Any
- output_sync_q: Any
- queues: tuple[janus.Queue, ...]
- thread: threading.Thread
- result_future: concurrent.futures.Future[None]
-
-
-class FakeStageClient:
- def __init__(
- self,
- *,
- stage_type: str = "llm",
- final_output: bool = False,
- final_output_type: str = "text",
- next_inputs: list[dict] | None = None,
- ) -> None:
- self.stage_type = stage_type
- self.final_output = final_output
- self.final_output_type = final_output_type
- self.next_inputs = list(next_inputs or [])
- self.custom_process_input_func = None
- self.add_request_calls: list[tuple] = []
- self.abort_calls: list[list[str]] = []
- self.shutdown_calls = 0
- self._engine_core_outputs = queue.Queue()
- self._diffusion_outputs = queue.Queue()
-
- # Orchestrator-facing interface.
- async def add_request_async(self, *args, **_kwargs) -> None:
- self.add_request_calls.append(args)
-
- async def get_output_async(self):
- try:
- return self._engine_core_outputs.get_nowait()
- except queue.Empty:
- return SimpleNamespace(outputs=[])
-
- def get_diffusion_output_nowait(self):
- try:
- return self._diffusion_outputs.get_nowait()
- except queue.Empty:
- return None
-
- def set_engine_outputs(self, outputs) -> None:
- return None
-
- def process_engine_inputs(self, stage_list, prompt=None, streaming_context=None):
- return list(self.next_inputs)
-
- async def abort_requests_async(self, request_ids: list[str]) -> None:
- self.abort_calls.append(list(request_ids))
-
- def shutdown(self) -> None:
- self.shutdown_calls += 1
-
- # Test helpers for seeding fake stage outputs.
- def push_engine_core_outputs(self, outputs) -> None:
- self._engine_core_outputs.put_nowait(outputs)
-
- def push_diffusion_output(self, output) -> None:
- self._diffusion_outputs.put_nowait(output)
-
-
-class FakeOutputProcessor:
- def __init__(self, *, request_outputs: list[object] | None = None) -> None:
- self.request_outputs = list(request_outputs or [])
-
- def add_request(self, *_args, **_kwargs) -> None:
- return None
-
- def process_outputs(self, *_args, **_kwargs):
- return SimpleNamespace(
- request_outputs=list(self.request_outputs),
- reqs_to_abort=[],
- )
-
- def update_scheduler_stats(self, _scheduler_stats) -> None:
- return None
-
-
-class _FakeProc:
- pid = 1234
-
- def __init__(self):
- self.terminated = False
- self.killed = False
- self.join_calls = []
-
- def is_alive(self):
- return not self.terminated and not self.killed
-
- def terminate(self):
- self.terminated = True
-
- def kill(self):
- self.killed = True
-
- def join(self, timeout=None):
- self.join_calls.append(timeout)
-
-
-class _FakeChildProc:
- def __init__(self):
- self.terminated = False
- self.killed = False
-
- def is_running(self):
- return not self.terminated and not self.killed
-
- def terminate(self):
- self.terminated = True
-
- def kill(self):
- self.killed = True
-
-
-def _sampling_params(max_tokens: int = 4) -> SamplingParams:
- return SamplingParams(max_tokens=max_tokens)
-
-
-def _engine_core_outputs(tag: str, timestamp: float) -> SimpleNamespace:
- return SimpleNamespace(outputs=[tag], timestamp=timestamp, scheduler_stats=None)
-
-
-def _build_request_output(
- request_id: str,
- *,
- token_ids: list[int] | None = None,
- prompt_token_ids: list[int] | None = None,
- finished: bool = True,
- text: str = "test",
-) -> RequestOutput:
- completion = CompletionOutput(
- index=0,
- text=text,
- token_ids=list(token_ids or [1, 2]),
- cumulative_logprob=0.0,
- logprobs=None,
- finish_reason="stop" if finished else None,
- stop_reason=None,
- )
- return RequestOutput(
- request_id=request_id,
- prompt="prompt",
- prompt_token_ids=list(prompt_token_ids or [10, 11]),
- prompt_logprobs=None,
- outputs=[completion],
- finished=finished,
- metrics=None,
- lora_request=None,
- )
-
-
-def _build_harness(
- stage_clients: list[object],
- *,
- output_processors: list[object] | None = None,
- stage_vllm_configs: list[object] | None = None,
- async_chunk: bool = False,
-) -> OrchestratorFixture:
- if output_processors is None:
- output_processors = [FakeOutputProcessor() for _ in stage_clients]
- if stage_vllm_configs is None:
- stage_vllm_configs = [SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)) for _ in stage_clients]
-
- ready_future: concurrent.futures.Future[tuple[Orchestrator, janus.Queue, janus.Queue, janus.Queue]] = (
- concurrent.futures.Future()
- )
- result_future: concurrent.futures.Future[None] = concurrent.futures.Future()
-
- def _runner() -> None:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- async def _run() -> None:
- request_queue = janus.Queue()
- output_queue = janus.Queue()
- rpc_queue = janus.Queue()
- orchestrator = Orchestrator(
- request_async_queue=request_queue.async_q,
- output_async_queue=output_queue.async_q,
- rpc_async_queue=rpc_queue.async_q,
- stage_clients=stage_clients,
- output_processors=output_processors,
- stage_vllm_configs=stage_vllm_configs,
- async_chunk=async_chunk,
- )
- ready_future.set_result((orchestrator, request_queue, output_queue, rpc_queue))
- await orchestrator.run()
-
- try:
- loop.run_until_complete(_run())
- result_future.set_result(None)
- except Exception as exc:
- result_future.set_exception(exc)
- finally:
- try:
- pending = [task for task in asyncio.all_tasks(loop) if not task.done()]
- for task in pending:
- task.cancel()
- if pending:
- loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
- loop.run_until_complete(loop.shutdown_asyncgens())
- finally:
- asyncio.set_event_loop(None)
- loop.close()
-
- thread = threading.Thread(target=_runner, daemon=True, name="test-orchestrator")
- thread.start()
-
- orchestrator, request_queue, output_queue, rpc_queue = ready_future.result(timeout=5)
- return OrchestratorFixture(
- orchestrator=orchestrator,
- request_sync_q=request_queue.sync_q,
- output_sync_q=output_queue.sync_q,
- queues=(request_queue, output_queue, rpc_queue),
- thread=thread,
- result_future=result_future,
- )
-
-
-async def _shutdown_orchestrator(orchestrator_fixture: OrchestratorFixture) -> None:
- orchestrator_fixture.request_sync_q.put_nowait({"type": "shutdown"})
- await asyncio.to_thread(orchestrator_fixture.thread.join, 5)
- if orchestrator_fixture.thread.is_alive():
- raise AssertionError("Timed out waiting for orchestrator thread shutdown")
- orchestrator_fixture.result_future.result(timeout=0)
-
-
-async def _wait_for(predicate, *, timeout: float = 2.0) -> None:
- deadline = time.monotonic() + timeout
- while not predicate():
- if time.monotonic() >= deadline:
- raise AssertionError("Timed out waiting for predicate")
- await asyncio.sleep(0.01)
-
-
-async def _get_output_message(orchestrator_fixture: OrchestratorFixture, *, timeout: float = 2.0) -> dict:
- deadline = time.monotonic() + timeout
- while True:
- if time.monotonic() >= deadline:
- raise AssertionError("Timed out waiting for orchestrator output")
- try:
- msg = orchestrator_fixture.output_sync_q.get_nowait()
- except queue.Empty:
- await asyncio.sleep(0.01)
- continue
- if msg.get("type") == "output":
- return msg
-
-
-async def _enqueue_add_request(
- orchestrator_fixture: OrchestratorFixture,
- *,
- request_id: str,
- prompt,
- original_prompt,
- sampling_params_list,
- final_stage_id: int,
-) -> None:
- orchestrator_fixture.request_sync_q.put_nowait(
- {
- "type": "add_request",
- "request_id": request_id,
- "prompt": prompt,
- "original_prompt": original_prompt,
- "sampling_params_list": sampling_params_list,
- "final_stage_id": final_stage_id,
- }
- )
-
-
-async def _enqueue_abort_request(orchestrator_fixture: OrchestratorFixture, request_ids: list[str]) -> None:
- orchestrator_fixture.request_sync_q.put_nowait(
- {
- "type": "abort",
- "request_ids": request_ids,
- }
- )
-
-
-def test_stage_engine_core_client_shutdown_cleans_children_if_base_shutdown_fails(monkeypatch):
- fake_proc = _FakeProc()
- fake_child = _FakeChildProc()
-
- class FakePsutilProcess:
- def __init__(self, pid):
- assert pid == fake_proc.pid
-
- def children(self, recursive=True):
- assert recursive
- return [fake_child]
-
- def fail_base_shutdown(self):
- raise RuntimeError("base shutdown failed")
-
- monkeypatch.setattr(psutil, "Process", FakePsutilProcess)
- monkeypatch.setattr(psutil, "wait_procs", lambda procs, timeout: (list(procs), []))
- monkeypatch.setattr(AsyncMPClient, "shutdown", fail_base_shutdown)
-
- client = object.__new__(StageEngineCoreClient)
- client._proc = fake_proc
-
- with pytest.raises(RuntimeError, match="base shutdown failed"):
- client.shutdown()
-
- assert fake_proc.terminated
- assert fake_proc.join_calls == [5]
- assert fake_child.terminated
-
-
-def test_stage_engine_core_client_shutdown_kills_stubborn_children(monkeypatch):
- fake_proc = _FakeProc()
- fake_child = _FakeChildProc()
-
- class FakePsutilProcess:
- def __init__(self, pid):
- assert pid == fake_proc.pid
-
- def children(self, recursive=True):
- assert recursive
- return [fake_child]
-
- monkeypatch.setattr(psutil, "Process", FakePsutilProcess)
- monkeypatch.setattr(psutil, "wait_procs", lambda procs, timeout: ([], list(procs)))
- monkeypatch.setattr(AsyncMPClient, "shutdown", lambda self: None)
-
- client = object.__new__(StageEngineCoreClient)
- client._proc = fake_proc
-
- client.shutdown()
-
- assert fake_child.terminated
- assert fake_child.killed
-
-
-@pytest.fixture
-def orchestrator_factory():
- fixtures: list[OrchestratorFixture] = []
-
- def _factory(*args, **kwargs) -> OrchestratorFixture:
- fixture = _build_harness(*args, **kwargs)
- fixtures.append(fixture)
- return fixture
-
- yield _factory
-
- for fixture in fixtures:
- if fixture.thread.is_alive():
- fixture.request_sync_q.put_nowait({"type": "shutdown"})
- fixture.thread.join(timeout=5)
- for q in fixture.queues:
- q.close()
-
-
-@pytest.mark.asyncio
-async def test_run_two_stage_llm(orchestrator_factory) -> None:
- stage0 = FakeStageClient(stage_type="llm", final_output=False)
- stage1 = FakeStageClient(
- stage_type="llm",
- final_output=True,
- next_inputs=[{"prompt_token_ids": [7, 8, 9]}],
- )
- processors = [
- FakeOutputProcessor(request_outputs=[_build_request_output("req-llm", token_ids=[3, 4], finished=True)]),
- FakeOutputProcessor(request_outputs=[_build_request_output("req-llm", token_ids=[10, 11], finished=True)]),
- ]
- orchestrator_fixture = orchestrator_factory([stage0, stage1], output_processors=processors)
- request = SimpleNamespace(request_id="req-llm", prompt_token_ids=[1, 2, 3])
-
- try:
- await _enqueue_add_request(
- orchestrator_fixture,
- request_id="req-llm",
- prompt=request,
- original_prompt={"prompt": "hello"},
- sampling_params_list=[_sampling_params(), _sampling_params()],
- final_stage_id=1,
- )
-
- await _wait_for(lambda: len(stage0.add_request_calls) == 1)
- stage0.push_engine_core_outputs(_engine_core_outputs("stage0-raw", 1.0))
-
- await _wait_for(lambda: len(stage1.add_request_calls) == 1)
- stage1_request = stage1.add_request_calls[0][0]
- assert stage1_request.request_id == "req-llm"
- assert stage1_request.prompt_token_ids == [7, 8, 9]
-
- stage1.push_engine_core_outputs(_engine_core_outputs("stage1-raw", 2.0))
-
- output_msg = await _get_output_message(orchestrator_fixture)
-
- assert output_msg["request_id"] == "req-llm"
- assert output_msg["stage_id"] == 1
- assert output_msg["finished"] is True
- assert output_msg["engine_outputs"].request_id == "req-llm"
- assert "req-llm" not in orchestrator_fixture.orchestrator.request_states
- finally:
- await _shutdown_orchestrator(orchestrator_fixture)
-
-
-@pytest.mark.asyncio
-async def test_run_single_stage_diffusion(orchestrator_factory) -> None:
- stage0 = FakeStageClient(stage_type="diffusion", final_output=True, final_output_type="image")
- orchestrator_fixture = orchestrator_factory([stage0])
- params = OmniDiffusionSamplingParams()
-
- try:
- await _enqueue_add_request(
- orchestrator_fixture,
- request_id="req-diff",
- prompt={"prompt": "draw a cat"},
- original_prompt={"prompt": "draw a cat"},
- sampling_params_list=[params],
- final_stage_id=0,
- )
-
- await _wait_for(lambda: len(stage0.add_request_calls) == 1)
- stage0.push_diffusion_output(
- OmniRequestOutput.from_diffusion(
- request_id="req-diff",
- images=[],
- final_output_type="image",
- )
- )
-
- output_msg = await _get_output_message(orchestrator_fixture)
-
- assert output_msg["request_id"] == "req-diff"
- assert output_msg["stage_id"] == 0
- assert output_msg["finished"] is True
- assert output_msg["engine_outputs"].request_id == "req-diff"
- assert "req-diff" not in orchestrator_fixture.orchestrator.request_states
- finally:
- await _shutdown_orchestrator(orchestrator_fixture)
-
-
-@pytest.mark.asyncio
-async def test_run_llm_to_diffusion(orchestrator_factory) -> None:
- stage0 = FakeStageClient(stage_type="llm", final_output=False)
- stage1 = FakeStageClient(stage_type="diffusion", final_output=True, final_output_type="image")
- processors = [
- FakeOutputProcessor(request_outputs=[_build_request_output("req-img", token_ids=[3, 4], finished=True)]),
- FakeOutputProcessor(),
- ]
- orchestrator_fixture = orchestrator_factory([stage0, stage1], output_processors=processors)
- request = SimpleNamespace(request_id="req-img", prompt_token_ids=[1, 2, 3])
- params = OmniDiffusionSamplingParams()
- original_prompt = {"prompt": "draw a fox"}
-
- try:
- await _enqueue_add_request(
- orchestrator_fixture,
- request_id="req-img",
- prompt=request,
- original_prompt=original_prompt,
- sampling_params_list=[_sampling_params(), params],
- final_stage_id=1,
- )
-
- await _wait_for(lambda: len(stage0.add_request_calls) == 1)
- stage0.push_engine_core_outputs(_engine_core_outputs("stage0-raw", 1.0))
-
- await _wait_for(lambda: len(stage1.add_request_calls) == 1)
- assert stage1.add_request_calls[0] == ("req-img", original_prompt, params)
-
- stage1.push_diffusion_output(
- OmniRequestOutput.from_diffusion(
- request_id="req-img",
- images=[],
- final_output_type="image",
- )
- )
-
- output_msg = await _get_output_message(orchestrator_fixture)
-
- assert output_msg["request_id"] == "req-img"
- assert output_msg["stage_id"] == 1
- assert output_msg["finished"] is True
- assert output_msg["engine_outputs"].request_id == "req-img"
- assert "req-img" not in orchestrator_fixture.orchestrator.request_states
- finally:
- await _shutdown_orchestrator(orchestrator_fixture)
-
-
-@pytest.mark.asyncio
-async def test_run_async_chunk(orchestrator_factory) -> None:
- stage0 = FakeStageClient(stage_type="llm", final_output=False)
- stage1 = FakeStageClient(stage_type="llm", final_output=True)
- processors = [
- FakeOutputProcessor(request_outputs=[_build_request_output("req-async", token_ids=[1], finished=True)]),
- FakeOutputProcessor(request_outputs=[_build_request_output("req-async", token_ids=[20, 21], finished=True)]),
- ]
- orchestrator_fixture = orchestrator_factory(
- [stage0, stage1],
- output_processors=processors,
- async_chunk=True,
- )
- request = SimpleNamespace(request_id="req-async", prompt_token_ids=[1, 2, 3, 4])
-
- try:
- await _enqueue_add_request(
- orchestrator_fixture,
- request_id="req-async",
- prompt=request,
- original_prompt={"prompt": "hello async"},
- sampling_params_list=[_sampling_params(), _sampling_params()],
- final_stage_id=1,
- )
-
- await _wait_for(lambda: len(stage1.add_request_calls) == 1)
- prewarmed_request = stage1.add_request_calls[0][0]
- assert prewarmed_request.request_id == "req-async"
- assert prewarmed_request.prompt_token_ids
- assert all(token_id == 0 for token_id in prewarmed_request.prompt_token_ids)
-
- stage1.push_engine_core_outputs(_engine_core_outputs("stage1-final", 3.0))
-
- output_msg = await _get_output_message(orchestrator_fixture)
-
- assert output_msg["request_id"] == "req-async"
- assert output_msg["stage_id"] == 1
- assert output_msg["finished"] is True
- assert "req-async" not in orchestrator_fixture.orchestrator.request_states
- finally:
- await _shutdown_orchestrator(orchestrator_fixture)
-
-
-@pytest.mark.asyncio
-async def test_run_shutdown(orchestrator_factory) -> None:
- stages = [
- FakeStageClient(stage_type="llm", final_output=False),
- FakeStageClient(stage_type="diffusion", final_output=True, final_output_type="image"),
- ]
- orchestrator_fixture = orchestrator_factory(stages)
-
- await _shutdown_orchestrator(orchestrator_fixture)
-
- assert not orchestrator_fixture.thread.is_alive()
- for stage in stages:
- assert stage.shutdown_calls == 1
-
-
-@pytest.mark.asyncio
-async def test_run_abort(orchestrator_factory) -> None:
- stages = [
- FakeStageClient(stage_type="llm", final_output=False),
- FakeStageClient(stage_type="llm", final_output=True),
- ]
- processors = [
- FakeOutputProcessor(request_outputs=[_build_request_output("req-abort", token_ids=[1], finished=True)]),
- FakeOutputProcessor(request_outputs=[_build_request_output("req-abort", token_ids=[2], finished=True)]),
- ]
- orchestrator_fixture = orchestrator_factory(stages, output_processors=processors)
- request = SimpleNamespace(request_id="req-abort", prompt_token_ids=[1, 2, 3])
-
- try:
- await _enqueue_add_request(
- orchestrator_fixture,
- request_id="req-abort",
- prompt=request,
- original_prompt={"prompt": "cancel me"},
- sampling_params_list=[_sampling_params(), _sampling_params()],
- final_stage_id=1,
- )
- await _wait_for(lambda: len(stages[0].add_request_calls) == 1)
-
- await _enqueue_abort_request(orchestrator_fixture, ["req-abort"])
- await _wait_for(lambda: all(stage.abort_calls for stage in stages))
-
- for stage in stages:
- assert stage.abort_calls == [["req-abort"]]
- assert "req-abort" not in orchestrator_fixture.orchestrator.request_states
- finally:
- await _shutdown_orchestrator(orchestrator_fixture)
diff --git a/tests/engine/test_orchestrator_error_handling.py b/tests/engine/test_orchestrator_error_handling.py
deleted file mode 100644
index 6131c5e9bc4..00000000000
--- a/tests/engine/test_orchestrator_error_handling.py
+++ /dev/null
@@ -1,160 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Tests for error propagation paths within the Orchestrator.
-
-Covers:
-- EngineDeadError from an LLM stage poll → fatal error broadcast + shutdown
-- Diffusion stage error output (OmniRequestOutput.from_error) → routed correctly
-"""
-
-from __future__ import annotations
-
-import asyncio
-import queue
-import time
-from types import SimpleNamespace
-
-import pytest
-from vllm.v1.engine.exceptions import EngineDeadError
-
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-from vllm_omni.outputs import OmniRequestOutput
-
-from .test_orchestrator import (
- FakeStageClient,
- OrchestratorFixture,
- _build_harness,
- _enqueue_add_request,
- _wait_for,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def _sampling_params(max_tokens: int = 4):
- from vllm.sampling_params import SamplingParams
-
- return SamplingParams(max_tokens=max_tokens)
-
-
-async def _get_any_output_message(fixture: OrchestratorFixture, *, timeout: float = 2.0) -> dict:
- """Like _get_output_message but returns any message type (including errors)."""
- deadline = time.monotonic() + timeout
- while True:
- if time.monotonic() >= deadline:
- raise AssertionError("Timed out waiting for orchestrator output")
- try:
- return fixture.output_sync_q.get_nowait()
- except queue.Empty:
- await asyncio.sleep(0.01)
-
-
-@pytest.fixture
-def orchestrator_factory():
- fixtures: list[OrchestratorFixture] = []
-
- def _factory(*args, **kwargs) -> OrchestratorFixture:
- fixture = _build_harness(*args, **kwargs)
- fixtures.append(fixture)
- return fixture
-
- yield _factory
-
- for fixture in fixtures:
- if fixture.thread.is_alive():
- fixture.request_sync_q.put_nowait({"type": "shutdown"})
- fixture.thread.join(timeout=5)
- for q in fixture.queues:
- q.close()
-
-
-# ───────── EngineDeadError from LLM stage poll ─────────
-
-
-class FakeDeadLLMStageClient(FakeStageClient):
- """LLM stage client that raises EngineDeadError on get_output_async."""
-
- async def get_output_async(self):
- raise EngineDeadError("Stage-0 engine core is dead")
-
-
-@pytest.mark.asyncio
-async def test_engine_dead_error_broadcasts_fatal_and_shuts_down(orchestrator_factory) -> None:
- """When a stage raises EngineDeadError during poll, the orchestrator must:
- 1. Enqueue a fatal error message for each affected request
- 2. Shut itself down (thread exits)
- """
- stage0 = FakeDeadLLMStageClient(stage_type="llm", final_output=True)
- orchestrator_fixture = orchestrator_factory([stage0])
- request = SimpleNamespace(request_id="req-dead", prompt_token_ids=[1, 2])
-
- try:
- await _enqueue_add_request(
- orchestrator_fixture,
- request_id="req-dead",
- prompt=request,
- original_prompt={"prompt": "hello"},
- sampling_params_list=[_sampling_params()],
- final_stage_id=0,
- )
-
- # Collect the fatal error message.
- msg = await _get_any_output_message(orchestrator_fixture)
-
- assert msg["type"] == "error"
- assert msg["fatal"] is True
- assert msg["request_id"] == "req-dead"
- assert "Stage-0 engine core is dead" in msg["error"]
-
- # The orchestrator thread should exit after the fatal error.
- orchestrator_fixture.thread.join(timeout=5)
- assert not orchestrator_fixture.thread.is_alive()
-
- # Request state should be cleaned up.
- assert "req-dead" not in orchestrator_fixture.orchestrator.request_states
- finally:
- if orchestrator_fixture.thread.is_alive():
- orchestrator_fixture.request_sync_q.put_nowait({"type": "shutdown"})
- orchestrator_fixture.thread.join(timeout=5)
-
-
-# ───────── Diffusion stage error output routing ─────────
-
-
-@pytest.mark.asyncio
-async def test_diffusion_error_output_routed_as_finished(orchestrator_factory) -> None:
- """When a diffusion stage returns an OmniRequestOutput with a non-None
- error, the orchestrator must route it as a finished output message and
- clean up the request state.
- """
- stage0 = FakeStageClient(stage_type="diffusion", final_output=True, final_output_type="image")
- orchestrator_fixture = orchestrator_factory([stage0])
- params = OmniDiffusionSamplingParams()
-
- try:
- await _enqueue_add_request(
- orchestrator_fixture,
- request_id="req-err",
- prompt={"prompt": "draw a cat"},
- original_prompt={"prompt": "draw a cat"},
- sampling_params_list=[params],
- final_stage_id=0,
- )
-
- await _wait_for(lambda: len(stage0.add_request_calls) == 1)
-
- # Push an error output from the diffusion stage.
- stage0.push_diffusion_output(OmniRequestOutput.from_error("req-err", "gpu fault"))
-
- msg = await _get_any_output_message(orchestrator_fixture)
-
- assert msg["type"] == "output"
- assert msg["request_id"] == "req-err"
- assert msg["finished"] is True
- assert msg["engine_outputs"].error == "gpu fault"
-
- # Request state should be cleaned up.
- await _wait_for(lambda: "req-err" not in orchestrator_fixture.orchestrator.request_states)
- finally:
- orchestrator_fixture.request_sync_q.put_nowait({"type": "shutdown"})
- orchestrator_fixture.thread.join(timeout=5)
diff --git a/tests/engine/test_orchestrator_kv_sender_info.py b/tests/engine/test_orchestrator_kv_sender_info.py
deleted file mode 100644
index ec3a42e3546..00000000000
--- a/tests/engine/test_orchestrator_kv_sender_info.py
+++ /dev/null
@@ -1,254 +0,0 @@
-import asyncio
-from types import SimpleNamespace
-
-import pytest
-from vllm import SamplingParams
-
-from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker
-from vllm_omni.engine.orchestrator import Orchestrator, OrchestratorRequestState
-from vllm_omni.engine.stage_engine_core_client import StageEngineCoreClient
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-class _DummySenderStage:
- def __init__(self, sender_info):
- self._sender_info = sender_info
- self.engine_outputs = None
-
- def set_engine_outputs(self, outputs):
- self.engine_outputs = outputs
-
- def get_kv_sender_info(self):
- return self._sender_info
-
-
-class _DummyDiffusionStage:
- stage_type = "diffusion"
- custom_process_input_func = None
-
- def __init__(self, engine_input_source=None):
- self.engine_input_source = engine_input_source or [0]
- self.calls = []
-
- async def add_request_async(self, request_id, prompt, sampling_params, kv_sender_info=None):
- self.calls.append(
- {
- "request_id": request_id,
- "prompt": prompt,
- "sampling_params": sampling_params,
- "kv_sender_info": kv_sender_info,
- }
- )
-
-
-def test_stage_engine_core_client_builds_kv_sender_info_from_tcp_address():
- client = object.__new__(StageEngineCoreClient)
- client.stage_id = 0
- client.client_addresses = {"input_address": "tcp://10.20.30.40:1234"}
- client._omni_kv_config = None
- client._kv_sender_info = None
- client._kv_sender_initialized = False
- client._kv_sender_host = client._resolve_contact_host()
- client._initialize_kv_sender_endpoint()
-
- assert client.get_kv_sender_info() == {
- "host": "10.20.30.40",
- "zmq_port": 50151,
- }
-
-
-def test_stage_engine_core_client_falls_back_to_detected_ip_for_loopback(monkeypatch):
- client = object.__new__(StageEngineCoreClient)
- client.stage_id = 1
- client.client_addresses = {"input_address": "tcp://127.0.0.1:1234"}
- client._omni_kv_config = None
- client._kv_sender_info = None
- client._kv_sender_initialized = False
- monkeypatch.setattr(client, "_detect_local_ip", lambda: "192.168.0.12")
- client._kv_sender_host = client._resolve_contact_host()
- client._initialize_kv_sender_endpoint()
-
- assert client.get_kv_sender_info() == {
- "host": "192.168.0.12",
- "zmq_port": 50152,
- }
-
-
-def test_stage_engine_core_client_uses_connector_config_for_sender_port():
- client = object.__new__(StageEngineCoreClient)
- client.stage_id = 3
- client.client_addresses = {"input_address": "tcp://10.20.30.40:1234"}
- client._kv_sender_info = None
- client._kv_sender_initialized = False
- client._omni_kv_config = {
- "omni_from_stage": "3",
- "connector_config": {
- "type": "MooncakeTransferEngineConnector",
- "role": "sender",
- "host": "10.20.30.99",
- "zmq_port": 51000,
- },
- }
- client._kv_sender_host = client._resolve_contact_host()
- client._initialize_kv_sender_endpoint()
-
- assert client.get_kv_sender_info() == {
- "host": "10.20.30.99",
- "zmq_port": 51103,
- }
-
-
-def test_stage_engine_core_client_preserves_explicit_loopback_sender_host():
- client = object.__new__(StageEngineCoreClient)
- client.stage_id = 2
- client.client_addresses = {"input_address": "tcp://10.20.30.40:1234"}
- client._kv_sender_info = None
- client._kv_sender_initialized = False
- client._omni_kv_config = {
- "omni_from_stage": "2",
- "connector_config": {
- "type": "MooncakeTransferEngineConnector",
- "role": "sender",
- "host": "127.0.0.1",
- "zmq_port": 51000,
- },
- }
- client._kv_sender_host = client._resolve_contact_host()
- client._initialize_kv_sender_endpoint()
-
- assert client.get_kv_sender_info() == {
- "host": "127.0.0.1",
- "zmq_port": 51102,
- }
-
-
-def test_forward_to_diffusion_attaches_kv_sender_info():
- orchestrator = object.__new__(Orchestrator)
- sender_stage = _DummySenderStage({"host": "10.0.0.2", "zmq_port": 50151})
- diffusion_stage = _DummyDiffusionStage(engine_input_source=[0])
-
- orchestrator.num_stages = 2
- orchestrator.stage_clients = [sender_stage, diffusion_stage]
- orchestrator._cfg_tracker = CfgCompanionTracker()
- orchestrator.stage_vllm_configs = [None, None]
- orchestrator.output_processors = [None, None]
-
- params = OmniDiffusionSamplingParams()
- req_state = OrchestratorRequestState(
- request_id="req-1",
- prompt={"prompt": "hello"},
- sampling_params_list=[SamplingParams(max_tokens=4), params],
- final_stage_id=1,
- )
-
- output = SimpleNamespace(request_id="req-1", finished=True)
- asyncio.run(Orchestrator._forward_to_next_stage(orchestrator, "req-1", 0, output, req_state))
-
- assert sender_stage.engine_outputs == [output]
- assert diffusion_stage.calls[0]["request_id"] == "req-1"
- assert diffusion_stage.calls[0]["kv_sender_info"] == {
- 0: {"host": "10.0.0.2", "zmq_port": 50151},
- }
- assert req_state.stage_submit_ts[1] > 0
-
-
-def test_forward_to_diffusion_uses_engine_input_source_for_kv_sender_info():
- orchestrator = object.__new__(Orchestrator)
- source_stage = _DummySenderStage({"host": "10.0.0.2", "zmq_port": 50151})
- previous_stage = _DummySenderStage({"host": "10.0.0.9", "zmq_port": 59999})
- diffusion_stage = _DummyDiffusionStage(engine_input_source=[0])
-
- orchestrator.num_stages = 3
- orchestrator.stage_clients = [source_stage, previous_stage, diffusion_stage]
- orchestrator._cfg_tracker = CfgCompanionTracker()
- orchestrator.stage_vllm_configs = [None, None, None]
- orchestrator.output_processors = [None, None, None]
-
- params = OmniDiffusionSamplingParams()
- req_state = OrchestratorRequestState(
- request_id="req-3",
- prompt={"prompt": "hello"},
- sampling_params_list=[SamplingParams(max_tokens=4), SamplingParams(max_tokens=4), params],
- final_stage_id=2,
- )
-
- output = SimpleNamespace(request_id="req-3", finished=True)
- asyncio.run(Orchestrator._forward_to_next_stage(orchestrator, "req-3", 1, output, req_state))
-
- assert previous_stage.engine_outputs == [output]
- assert diffusion_stage.calls[0]["kv_sender_info"] == {
- 0: {"host": "10.0.0.2", "zmq_port": 50151},
- }
-
-
-def test_forward_to_diffusion_returns_terminal_error_for_empty_custom_inputs():
- orchestrator = object.__new__(Orchestrator)
- sender_stage = _DummySenderStage({"host": "10.0.0.2", "zmq_port": 50151})
- diffusion_stage = _DummyDiffusionStage(engine_input_source=[0])
- diffusion_stage.custom_process_input_func = lambda *_args, **_kwargs: []
-
- class _AsyncQueue:
- def __init__(self):
- self.items = []
-
- async def put(self, item):
- self.items.append(item)
-
- orchestrator.num_stages = 2
- orchestrator.stage_clients = [sender_stage, diffusion_stage]
- orchestrator._cfg_tracker = CfgCompanionTracker()
- orchestrator.stage_vllm_configs = [None, None]
- orchestrator.output_processors = [None, None]
- orchestrator.output_async_queue = _AsyncQueue()
- orchestrator.request_states = {}
- orchestrator._pd_kv_params = {}
-
- params = OmniDiffusionSamplingParams()
- req_state = OrchestratorRequestState(
- request_id="req-empty",
- prompt={"prompt": "hello"},
- sampling_params_list=[SamplingParams(max_tokens=4), params],
- final_stage_id=1,
- )
- orchestrator.request_states["req-empty"] = req_state
-
- output = SimpleNamespace(request_id="req-empty", finished=True)
- asyncio.run(Orchestrator._forward_to_next_stage(orchestrator, "req-empty", 0, output, req_state))
-
- assert sender_stage.engine_outputs == [output]
- assert diffusion_stage.calls == []
- assert len(orchestrator.output_async_queue.items) == 1
- terminal_msg = orchestrator.output_async_queue.items[0]
- assert terminal_msg["type"] == "output"
- assert terminal_msg["request_id"] == "req-empty"
- assert terminal_msg["stage_id"] == 1
- assert terminal_msg["finished"] is True
- assert "produced no valid inputs" in terminal_msg["engine_outputs"].error
- assert "req-empty" not in orchestrator.request_states
-
-
-def test_prewarm_diffusion_attaches_kv_sender_info():
- orchestrator = object.__new__(Orchestrator)
- sender_stage = _DummySenderStage({"host": "10.0.0.3", "zmq_port": 50151})
- diffusion_stage = _DummyDiffusionStage(engine_input_source=[0])
-
- orchestrator.stage_clients = [sender_stage, diffusion_stage]
- orchestrator.num_stages = 2
-
- req_state = OrchestratorRequestState(
- request_id="req-2",
- prompt={"prompt": "hello"},
- sampling_params_list=[SamplingParams(max_tokens=4), OmniDiffusionSamplingParams()],
- final_stage_id=1,
- )
-
- stage0_request = SimpleNamespace(prompt_token_ids=[1, 2, 3])
- asyncio.run(Orchestrator._prewarm_async_chunk_stages(orchestrator, "req-2", stage0_request, req_state))
-
- assert diffusion_stage.calls[0]["request_id"] == "req-2"
- assert diffusion_stage.calls[0]["kv_sender_info"] == {
- 0: {"host": "10.0.0.3", "zmq_port": 50151},
- }
- assert req_state.stage_submit_ts[1] > 0
diff --git a/tests/engine/test_output_modality.py b/tests/engine/test_output_modality.py
index 7a9c765028f..5a2a5dfc575 100644
--- a/tests/engine/test_output_modality.py
+++ b/tests/engine/test_output_modality.py
@@ -12,7 +12,6 @@
import torch
# ── Load modules without triggering vllm_omni.__init__ ─────────────
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
_ENGINE_DIR = Path(__file__).resolve().parents[2] / "vllm_omni" / "engine"
diff --git a/tests/engine/test_output_processor.py b/tests/engine/test_output_processor.py
deleted file mode 100644
index 4576da5a12f..00000000000
--- a/tests/engine/test_output_processor.py
+++ /dev/null
@@ -1,211 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Regression tests for OmniRequestState multimodal DELTA drain and consolidation guard."""
-
-from unittest.mock import MagicMock
-
-import pytest
-import torch
-from vllm.outputs import PoolingRequestOutput
-from vllm.sampling_params import RequestOutputKind
-from vllm.v1.engine import FinishReason
-
-from vllm_omni.engine.output_modality import OutputModalityNames
-from vllm_omni.engine.output_processor import OmniRequestState
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-# Audio is explicitly listed as a drainable modality
-AUDIO = OutputModalityNames.AUDIO
-
-# Latent is explicitly not drainable, but the choice doesn't matter here as
-# long as isn't listed as drainable. I.e., could also be arbitrary keys for
-# the purposes of these tests
-LATENT = OutputModalityNames.LATENT
-
-# NOTE: detokenizer and logprobs aren't really used here, but we mock them since
-# some of the utils called in vLLM superclass assert require them to be None.
-_DETOK = MagicMock(
- output_token_ids=[0],
- get_next_output_text=MagicMock(return_value=""),
- num_output_tokens=MagicMock(return_value=1),
-)
-_LOGPROBS = MagicMock(logprobs=None, cumulative_logprob=None, prompt_logprobs=None)
-
-_DEFAULT_STATE_KWARGS = dict(
- request_id="r",
- external_req_id="r",
- parent_req=None,
- request_index=0,
- lora_request=None,
- prompt=None,
- prompt_token_ids=[0],
- prompt_embeds=None,
- logprobs_processor=_LOGPROBS,
- detokenizer=_DETOK,
- max_tokens_param=None,
- arrival_time=0.0,
- queue=None,
- log_stats=False,
- stream_interval=1,
-)
-
-
-def _make_state(output_kind: RequestOutputKind):
- return OmniRequestState(**_DEFAULT_STATE_KWARGS, output_kind=output_kind)
-
-
-def test_init_empty_dict():
- """Ensure mm_accumulated is initially empty."""
- assert _make_state(RequestOutputKind.CUMULATIVE).mm_accumulated == {}
- assert _make_state(RequestOutputKind.DELTA).mm_accumulated == {}
-
-
-def test_delta_drains_output_modality_per_step():
- """DELTA drains the mm_type key (output modality) but preserves hidden-state keys."""
- s = _make_state(RequestOutputKind.DELTA)
- audio1, audio2, hs1, hs2 = [torch.ones(num_elem) for num_elem in range(1, 5)]
-
- # Add audio and hidden state tensors
- s.add_multimodal_tensor(audio1, mm_type=AUDIO) # should be drained
- s.add_multimodal_tensor(hs1, mm_type=LATENT) # shouldn't be drained
-
- out1 = s._new_completion_output([1], None, None)
- out1_audio = out1.multimodal_output[AUDIO]
- out1_hidden = out1.multimodal_output[LATENT]
- assert isinstance(out1_audio, torch.Tensor)
- assert torch.equal(out1.multimodal_output[AUDIO], audio1)
- assert isinstance(out1_hidden, torch.Tensor)
- assert torch.equal(out1.multimodal_output[LATENT], hs1)
-
- # After emission, hidden states should remain, but audio is drained
- assert set(s.mm_accumulated.keys()) == {LATENT}
-
- s.add_multimodal_tensor(audio2, AUDIO)
- s.add_multimodal_tensor(hs2, mm_type=LATENT)
- out2 = s._new_completion_output([2], None, None)
- out2_audio = out2.multimodal_output[AUDIO]
- out2_hidden = out2.multimodal_output[LATENT]
- assert isinstance(out2_audio, torch.Tensor)
- assert torch.equal(out2_audio, audio2)
- # Since hidden isn't drained, it's grown to a list
- assert isinstance(out2_hidden, list) and len(out2_hidden) == 2
- assert torch.equal(hs1, out2_hidden[0])
- assert torch.equal(hs2, out2_hidden[1])
-
-
-def test_cumulative_emits_consolidated_audio_each_step():
- """Ensure cumulative accumulates and consolidates modality keys every step."""
- s = _make_state(RequestOutputKind.CUMULATIVE)
- # NOTE: audio is usually emitted as (1, size) chunks; we need to be sure
- # to not change the tensor dimension when we consolidate
- audio1 = torch.ones(1, 500)
- s.add_multimodal_tensor(audio1, mm_type=AUDIO)
- req_out = s.make_request_output([1], None, None, None)
- assert req_out is not None
- cons_audio = req_out.outputs[0].multimodal_output[AUDIO]
- # Single chunk keeps original shape [1, 500]
- assert isinstance(cons_audio, torch.Tensor) and cons_audio.shape == audio1.shape
-
- audio2 = torch.ones(1, 300)
- s.add_multimodal_tensor(audio2, mm_type=AUDIO)
- req_out = s.make_request_output([2], None, None, None)
- assert req_out is not None
- cons_audio = req_out.outputs[0].multimodal_output[AUDIO]
- # After consolidation, audio chunks are concatenated on last axis,
- # preserving the [1, N] channel dimension
- total_audio_len = audio1.shape[-1] + audio2.shape[-1]
- assert isinstance(audio2, torch.Tensor) and cons_audio.shape == (1, total_audio_len)
-
- assert "audio" in s.mm_accumulated
-
-
-def test_finish_consolidates_hidden_states():
- """Ensure consolidation merges hidden-state tensor lists on finish."""
- s = _make_state(RequestOutputKind.CUMULATIVE)
- s.add_multimodal_tensor(torch.ones(5, 4), mm_type=LATENT)
- s.add_multimodal_tensor(torch.ones(3, 4), mm_type=LATENT)
-
- result = s.make_request_output([1], None, FinishReason.STOP, None)
- assert result is not None and not isinstance(result, PoolingRequestOutput)
-
- hs = result.outputs[0].multimodal_output[LATENT]
- assert isinstance(hs, torch.Tensor) and hs.shape[0] == 8
-
-
-def test_finish_consolidation_for_hs_delta():
- """Ensure finish doesn't drop the accumulated hidden states."""
- s = _make_state(RequestOutputKind.DELTA)
- # hidden state accumulation (nothing drained)
- s.add_multimodal_tensor({"foo": torch.ones(5, 4)}, mm_type=LATENT)
- result = s.make_request_output([0], None, FinishReason.STOP, None)
- assert result is not None and not isinstance(result, PoolingRequestOutput)
- hs = result.outputs[0].multimodal_output["foo"]
- assert isinstance(hs, torch.Tensor) and hs.shape[0] == 5
-
- # Since we don't drain the hidden states, if we add 3 elements, we should get 8
- s.add_multimodal_tensor({"foo": torch.ones(3, 4)}, mm_type=LATENT)
- result = s.make_request_output([0], None, FinishReason.STOP, None)
- assert result is not None and not isinstance(result, PoolingRequestOutput)
- hs = result.outputs[0].multimodal_output["foo"]
- assert isinstance(hs, torch.Tensor) and hs.shape[0] == 8
- assert "foo" in s.mm_accumulated
-
-
-def test_finish_consolidation_drains_mm_delta():
- """Ensure making the request output drains modality deltas (e.g., audio)."""
- s = _make_state(RequestOutputKind.DELTA)
- # multimodal data accumulation (drained)
- s.add_multimodal_tensor({AUDIO: torch.ones(5, 4)}, mm_type=AUDIO)
- result = s.make_request_output([0], None, FinishReason.STOP, None)
- assert result is not None and not isinstance(result, PoolingRequestOutput)
- hs = result.outputs[0].multimodal_output[AUDIO]
- assert isinstance(hs, torch.Tensor) and hs.shape[0] == 5
-
- # Since we did drain the hidden states, we no longer get the 5 back
- s.add_multimodal_tensor({AUDIO: torch.ones(3, 4)}, mm_type=AUDIO)
- result = s.make_request_output([0], None, FinishReason.STOP, None)
- assert result is not None and not isinstance(result, PoolingRequestOutput)
- hs = result.outputs[0].multimodal_output[AUDIO]
- assert isinstance(hs, torch.Tensor) and hs.shape[0] == 3
- assert AUDIO not in s.mm_accumulated # drained
-
-
-@pytest.mark.parametrize("mm_type", [AUDIO, "hidden"])
-def test_final_only_consolidates_drainable_keys(mm_type):
- """FINAL_ONLY never drains per-step, so modality keys and hidden state
- keys both accumulate and are consolidated on finish."""
- s = _make_state(RequestOutputKind.FINAL_ONLY)
-
- # NOTE: Currently there is brittlness in the tensor stacking, so we just
- # test a 1D tensor here. The intention is just to ensure audio /hidden
- # behave the same.
- s.add_multimodal_tensor(torch.ones(500), mm_type=mm_type)
- # Non-finish step returns None without calling _new_completion_output
- assert s.make_request_output([1], None, None, None) is None
- assert mm_type in s.mm_accumulated
-
- s.add_multimodal_tensor(torch.ones(300), mm_type=mm_type)
- result = s.make_request_output([2], None, FinishReason.STOP, None)
- assert result is not None and not isinstance(result, PoolingRequestOutput)
-
- audio = result.outputs[0].multimodal_output[mm_type]
- assert isinstance(audio, torch.Tensor)
- assert audio.shape == (800,)
-
-
-def test_cumulative_token_ids_always_set():
- """cumulative_token_ids is set for all output kinds."""
- for kind in (RequestOutputKind.DELTA, RequestOutputKind.CUMULATIVE, RequestOutputKind.FINAL_ONLY):
- s = _make_state(kind)
- out = s._new_completion_output([42], None, None)
- assert hasattr(out, "cumulative_token_ids")
- # The mock detokenizer has output_token_ids=[0]
- assert list(out.cumulative_token_ids) == [0]
-
-
-def test_cumulative_token_ids_is_a_copy():
- """cumulative_token_ids must be a snapshot, not a live reference."""
- s = _make_state(RequestOutputKind.DELTA)
- out = s._new_completion_output([42], None, None)
- assert out.cumulative_token_ids is not _DETOK.output_token_ids
diff --git a/tests/engine/test_single_stage_mode.py b/tests/engine/test_single_stage_mode.py
deleted file mode 100644
index 8169f586a66..00000000000
--- a/tests/engine/test_single_stage_mode.py
+++ /dev/null
@@ -1,1879 +0,0 @@
-"""Unit tests for AsyncOmniEngine single-stage mode and OmniMasterServer.
-
-These tests cover:
-- OmniMasterServer address pre-allocation & ZMQ registration handshake
-- AsyncOmniEngine single_stage_mode detection / _single_stage_id_filter setup
-- _initialize_stages stage routing (local launch vs. remote-wait) in
- single_stage_mode
-- _create_remote_llm_stage delegation to connect_remote_engine_cores
-- _launch_llm_stage delegation to launch_omni_core_engines in
- single_stage_mode
-
-All tests run without real hardware by mocking ZMQ, vllm_config, and the
-heavy initialization helpers.
-"""
-
-from __future__ import annotations
-
-import threading
-from contextlib import contextmanager
-from types import SimpleNamespace
-from typing import Any
-
-import pytest
-from pytest_mock import MockerFixture
-from vllm.v1.engine.utils import EngineZmqAddresses
-
-from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
-from vllm_omni.engine.stage_engine_startup import (
- OmniMasterServer,
- StageAllocation,
- StageCoordinatorAddresses,
- connect_remote_engine_cores,
- launch_omni_core_engines,
-)
-from vllm_omni.engine.stage_init_utils import StartedLlmStage
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-# ---------------------------------------------------------------------------
-# Helpers
-# ---------------------------------------------------------------------------
-
-
-def _make_stage_cfg(stage_id: int, stage_type: str = "llm"):
- """Return a lightweight stage config mock."""
- return SimpleNamespace(
- stage_id=stage_id,
- stage_type=stage_type,
- engine_args=SimpleNamespace(
- async_chunk=False,
- model_stage=None,
- engine_output_type=None,
- ),
- )
-
-
-def _make_started_llm_stage(stage_id: int) -> StartedLlmStage:
- """Return a minimal StartedLlmStage for mocking."""
- addresses = SimpleNamespace(
- inputs=["tcp://127.0.0.1:5000"],
- outputs=["tcp://127.0.0.1:5001"],
- frontend_stats_publish_address=None,
- )
- return StartedLlmStage(
- stage_id=stage_id,
- metadata=SimpleNamespace(stage_id=stage_id),
- vllm_config=SimpleNamespace(),
- executor_class=SimpleNamespace(),
- engine_manager=SimpleNamespace(),
- coordinator=SimpleNamespace(),
- addresses=addresses,
- )
-
-
-# ---------------------------------------------------------------------------
-# OmniMasterServer – address pre-allocation
-# ---------------------------------------------------------------------------
-
-
-class TestOmniMasterServerAllocation:
- """Test address pre-allocation in OmniMasterServer.__init__."""
-
- def test_public_address_and_port_properties_expose_registration_endpoint(self):
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=15000,
- stage_ids=[0],
- )
- assert server.address == "127.0.0.1"
- assert server.port == 15000
-
- def test_allocations_created_for_each_stage_id(self):
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=15000,
- stage_ids=[0, 1, 2],
- )
- assert set(server._allocations.keys()) == {0, 1, 2}
-
- def test_each_allocation_is_stage_allocation(self):
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=15000,
- stage_ids=[0, 1],
- )
- for sid in (0, 1):
- alloc = server._allocations[sid]
- assert isinstance(alloc, StageAllocation)
-
- def test_allocation_addresses_reference_master_address(self):
- server = OmniMasterServer(
- master_address="192.168.1.10",
- master_port=20000,
- stage_ids=[0],
- )
- alloc = server._allocations[0]
- for addr in (
- alloc.handshake_bind_address,
- alloc.handshake_connect_address,
- alloc.input_bind_address,
- alloc.input_connect_address,
- alloc.output_bind_address,
- alloc.output_connect_address,
- ):
- assert "192.168.1.10" in addr, f"Expected master address in {addr}"
-
- def test_port_uniqueness_within_single_allocation(self):
- """Each allocation uses three distinct ports."""
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=15001,
- stage_ids=[0],
- )
- alloc = server._allocations[0]
- hs_port = int(alloc.handshake_bind_address.split(":")[-1])
- inp_port = int(alloc.input_bind_address.split(":")[-1])
- out_port = int(alloc.output_bind_address.split(":")[-1])
- assert len({hs_port, inp_port, out_port}) == 3, "Expected three distinct ports per stage allocation"
-
- def test_get_zmq_addresses_returns_bind_addresses(self):
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=15002,
- stage_ids=[0],
- )
- alloc = server._allocations[0]
- zmq_addrs = server.get_zmq_addresses(0)
- assert zmq_addrs.inputs == [alloc.input_bind_address]
- assert zmq_addrs.outputs == [alloc.output_bind_address]
-
- def test_get_engine_zmq_addresses_returns_connect_addresses(self):
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=15003,
- stage_ids=[0],
- )
- alloc = server._allocations[0]
- engine_addrs = server.get_engine_zmq_addresses(0)
- assert engine_addrs.inputs == [alloc.input_connect_address]
- assert engine_addrs.outputs == [alloc.output_connect_address]
-
- def test_get_allocation_returns_correct_object(self):
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=15004,
- stage_ids=[3],
- )
- assert server.get_allocation(3) is server._allocations[3]
-
-
-# ---------------------------------------------------------------------------
-# OmniMasterServer – ZMQ registration flow
-# ---------------------------------------------------------------------------
-
-
-class TestOmniMasterServerRegistration:
- """Test that the server correctly handles a stage registration."""
-
- def test_registration_reply_contains_handshake_address(self):
- """A DEALER client that sends a registration msg gets the handshake
- address back from the ROUTER registration socket."""
- import msgspec
- import zmq
- from vllm.utils.network_utils import get_open_port
-
- master_port = get_open_port()
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=master_port,
- stage_ids=[0],
- )
- server.start()
- expected_hs = server._allocations[0].handshake_connect_address
-
- ctx = zmq.Context()
- try:
- sock = ctx.socket(zmq.DEALER)
- sock.connect(f"tcp://127.0.0.1:{master_port}")
- sock.send(msgspec.msgpack.encode({"stage_id": 0}))
- if not sock.poll(timeout=5_000):
- pytest.fail("No reply received from OmniMasterServer within 5 s")
- reply = msgspec.msgpack.decode(sock.recv())
- assert reply["handshake_address"] == expected_hs
- finally:
- sock.close(linger=0)
- ctx.term()
- server.stop()
-
- def test_server_handles_unknown_stage_id_gracefully(self):
- """A registration for an unrecognised stage_id must not crash the server."""
- import msgspec
- import zmq
- from vllm.utils.network_utils import get_open_port
-
- master_port = get_open_port()
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=master_port,
- stage_ids=[0],
- )
- server.start()
-
- ctx = zmq.Context()
- try:
- bad_sock = ctx.socket(zmq.DEALER)
- bad_sock.connect(f"tcp://127.0.0.1:{master_port}")
- # Send unknown stage_id=99
- bad_sock.send(msgspec.msgpack.encode({"stage_id": 99}))
- # Server should NOT reply for an unknown id; wait briefly
- has_reply = bad_sock.poll(timeout=500)
- assert not has_reply, "Server should not reply to unknown stage_id"
- # Then register the valid stage so the server thread can exit
- good_sock = ctx.socket(zmq.DEALER)
- good_sock.connect(f"tcp://127.0.0.1:{master_port}")
- good_sock.send(msgspec.msgpack.encode({"stage_id": 0}))
- good_sock.poll(timeout=2_000)
- finally:
- for s in (bad_sock, good_sock):
- try:
- s.close(linger=0)
- except Exception:
- pass
- ctx.term()
- server.stop()
-
- def test_registration_stores_stage_config(self):
- """Stage registration should persist the sender's stage config."""
- import msgspec
- import zmq
- from vllm.utils.network_utils import get_open_port
-
- master_port = get_open_port()
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=master_port,
- stage_ids=[0],
- )
- server.start()
-
- payload = {
- "stage_id": 0,
- "stage_config": {
- "stage_id": 0,
- "stage_type": "llm",
- "engine_args": {"model": "fake-model"},
- },
- }
-
- ctx = zmq.Context()
- try:
- sock = ctx.socket(zmq.DEALER)
- sock.connect(f"tcp://127.0.0.1:{master_port}")
- sock.send(msgspec.msgpack.encode(payload))
- assert sock.poll(timeout=5_000)
- sock.recv()
-
- stored = server.get_stage_config(0, timeout_s=0.1)
- assert stored == payload["stage_config"]
- finally:
- sock.close(linger=0)
- ctx.term()
- server.stop()
-
- def test_registration_stores_coordinator_addresses(self):
- """Stage registration should persist optional coordinator addresses."""
- import msgspec
- import zmq
- from vllm.utils.network_utils import get_open_port
-
- master_port = get_open_port()
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=master_port,
- stage_ids=[0],
- )
- server.start()
-
- payload = {
- "stage_id": 0,
- "stage_config": {"stage_id": 0},
- "coordinator_input": "tcp://127.0.0.1:31001",
- "coordinator_output": "tcp://127.0.0.1:31002",
- "frontend_stats_publish_address": "tcp://127.0.0.1:31003",
- }
-
- ctx = zmq.Context()
- try:
- sock = ctx.socket(zmq.DEALER)
- sock.connect(f"tcp://127.0.0.1:{master_port}")
- sock.send(msgspec.msgpack.encode(payload))
- assert sock.poll(timeout=5_000)
- sock.recv()
-
- stored = server.get_stage_coordinator_addresses(0, timeout_s=0.1)
- assert stored == StageCoordinatorAddresses(
- coordinator_input=payload["coordinator_input"],
- coordinator_output=payload["coordinator_output"],
- frontend_stats_publish_address=payload["frontend_stats_publish_address"],
- )
- finally:
- sock.close(linger=0)
- ctx.term()
- server.stop()
-
- def test_stop_joins_server_thread(self):
- from vllm.utils.network_utils import get_open_port
-
- master_port = get_open_port()
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=master_port,
- stage_ids=[], # no stages → thread exits immediately
- )
- server.start()
- assert server._thread is not None
- server.stop()
- # Thread should have exited (joined with timeout=10 inside stop())
- assert not server._thread.is_alive()
-
-
-# ---------------------------------------------------------------------------
-# AsyncOmniEngine – single_stage_mode detection in __init__
-# ---------------------------------------------------------------------------
-
-
-class TestSingleStageModeDetection:
- """Test __init__ single_stage_mode / _single_stage_id_filter setup.
-
- We bypass the real __init__ by patching _resolve_stage_configs and
- the orchestrator thread, so no actual engines are started.
- """
-
- def _make_engine_no_thread(self, mocker: MockerFixture, **kwargs: Any) -> AsyncOmniEngine:
- """Create an AsyncOmniEngine without starting the orchestrator thread."""
- stage_cfg = _make_stage_cfg(0)
- mock_stage_configs = [stage_cfg]
-
- mocker.patch.object(
- AsyncOmniEngine,
- "_resolve_stage_configs",
- return_value=("/fake/path", mock_stage_configs),
- )
- mocker.patch.object(
- AsyncOmniEngine,
- "_bootstrap_orchestrator",
- )
- mock_thread_cls = mocker.patch("threading.Thread")
- mock_future_cls = mocker.patch("concurrent.futures.Future")
-
- mock_future = mocker.Mock()
- mock_future.result.return_value = mocker.Mock() # simulates a loop
- mock_future_cls.return_value = mock_future
-
- mock_thread = mocker.Mock()
- mock_thread.is_alive.return_value = False
- mock_thread_cls.return_value = mock_thread
-
- engine = AsyncOmniEngine(model="fake-model", **kwargs)
- return engine
-
- def test_explicit_single_stage_mode_true(self, mocker: MockerFixture):
- engine = self._make_engine_no_thread(
- mocker,
- single_stage_mode=True,
- omni_master_address="127.0.0.1",
- omni_master_port=20000,
- )
- assert engine.single_stage_mode is True
-
- def test_stage_id_kwarg_promotes_to_single_stage_mode(self, mocker: MockerFixture):
- engine = self._make_engine_no_thread(
- mocker,
- stage_id=0,
- omni_master_address="127.0.0.1",
- omni_master_port=20001,
- )
- assert engine.single_stage_mode is True
-
- def test_stage_id_kwarg_sets_filter(self, mocker: MockerFixture):
- engine = self._make_engine_no_thread(
- mocker,
- stage_id=1,
- omni_master_address="127.0.0.1",
- omni_master_port=20002,
- )
- assert engine._single_stage_id_filter == 1
-
- def test_no_stage_id_no_single_stage_mode(self, mocker: MockerFixture):
- engine = self._make_engine_no_thread(
- mocker,
- )
- assert engine.single_stage_mode is False
- assert engine._single_stage_id_filter is None
-
- def test_single_stage_mode_without_stage_id_has_no_filter(self, mocker: MockerFixture):
- engine = self._make_engine_no_thread(
- mocker,
- single_stage_mode=True,
- omni_master_address="127.0.0.1",
- omni_master_port=20003,
- )
- assert engine._single_stage_id_filter is None
-
- def test_engine_args_create_only_forwards_explicit_fields(self, mocker: MockerFixture):
- from vllm_omni.engine.arg_utils import OmniEngineArgs
-
- captured: dict[str, Any] = {}
-
- def fake_resolve(self, model: str, kwargs: dict[str, Any]):
- captured.update(kwargs)
- return "/fake/path", [_make_stage_cfg(0)]
-
- mocker.patch.object(AsyncOmniEngine, "_resolve_stage_configs", fake_resolve)
- mocker.patch.object(AsyncOmniEngine, "_bootstrap_orchestrator")
- mock_thread_cls = mocker.patch("threading.Thread")
- mock_future_cls = mocker.patch("concurrent.futures.Future")
- mock_future = mocker.Mock()
- mock_future.result.return_value = mocker.Mock()
- mock_future_cls.return_value = mock_future
- mock_thread = mocker.Mock()
- mock_thread.is_alive.return_value = False
- mock_thread_cls.return_value = mock_thread
-
- ea = OmniEngineArgs.create(model="ignored", gpu_memory_utilization=0.5)
- AsyncOmniEngine(model="fake-model", engine_args=ea)
-
- assert captured["gpu_memory_utilization"] == 0.5
- assert "model" not in captured
- assert "max_num_seqs" not in captured
-
- def test_bare_engine_args_rejected(self, mocker: MockerFixture):
- from vllm_omni.engine.arg_utils import OmniEngineArgs
-
- with pytest.raises(TypeError, match="OmniEngineArgs.create"):
- self._make_engine_no_thread(mocker, engine_args=OmniEngineArgs(model="fake-model"))
-
- def test_master_address_and_port_stored(self, mocker: MockerFixture):
- engine = self._make_engine_no_thread(
- mocker,
- stage_id=0,
- omni_master_address="10.0.0.1",
- omni_master_port=12345,
- )
- assert engine._omni_master_address == "10.0.0.1"
- assert engine._omni_master_port == 12345
-
- def test_omni_master_server_starts_as_none(self, mocker: MockerFixture):
- engine = self._make_engine_no_thread(
- mocker,
- )
- assert engine._omni_master_server is None
-
-
-# ---------------------------------------------------------------------------
-# AsyncOmniEngine – _initialize_stages stage routing
-# ---------------------------------------------------------------------------
-
-
-class TestInitializeStagesRouting:
- """Verify that _initialize_stages routes each stage to the correct launch
- function depending on single_stage_mode and _single_stage_id_filter."""
-
- _COMMON_PATCHES = [
- "vllm_omni.engine.async_omni_engine.prepare_engine_environment",
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- ]
-
- def _build_engine_skeleton(
- self,
- stage_cfgs: list[Any],
- single_stage_mode: bool,
- stage_id_filter: int | None,
- omni_master_address: str = "127.0.0.1",
- omni_master_port: int = 25000,
- ) -> AsyncOmniEngine:
- """Build a bare AsyncOmniEngine without launching any threads."""
- engine = object.__new__(AsyncOmniEngine)
- engine.model = "fake-model"
- engine.config_path = "/fake"
- engine.stage_configs = stage_cfgs
- engine.num_stages = len(stage_cfgs)
- engine.async_chunk = False
- engine.single_stage_mode = single_stage_mode
- engine._single_stage_id_filter = stage_id_filter
- engine._omni_master_address = omni_master_address
- engine._omni_master_port = omni_master_port
- engine._omni_master_server = None
- engine._llm_stage_launch_lock = __import__("threading").Lock()
- engine.diffusion_batch_size = 1
- engine.stage_clients = []
- engine.stage_vllm_configs = []
- engine.output_processors = []
- engine.input_processor = None
- engine.supported_tasks = ("generate",)
- engine.default_sampling_params_list = []
- engine.stage_metadata = []
- engine.prompt_expand_func = None
- return engine
-
- def _fake_metadata(self, mocker: MockerFixture, stage_id: int, stage_type: str = "llm") -> Any:
- meta = mocker.Mock()
- meta.stage_id = stage_id
- meta.stage_type = stage_type
- meta.runtime_cfg = {}
- meta.prompt_expand_func = None
- meta.engine_output_type = None
- meta.is_comprehension = False
- meta.final_output = True if stage_id == 0 else False
- meta.final_output_type = None
- return meta
-
- def _run_initialize_stages_mocked(
- self,
- mocker: MockerFixture,
- engine: AsyncOmniEngine,
- stage_cfgs: list[Any],
- *,
- launch_side_effect: Any = None,
- remote_side_effect: Any = None,
- attach_result: Any = None,
- ) -> tuple[Any, Any]:
- """Execute _initialize_stages with all heavy helpers mocked.
-
- Returns (mock_launch_llm_stage, mock_create_remote_llm_stage).
- """
- started_by_stage: dict[int, StartedLlmStage] = {
- cfg.stage_id: _make_started_llm_stage(cfg.stage_id)
- for cfg in stage_cfgs
- if getattr(cfg, "stage_type", "llm") != "diffusion"
- }
-
- default_attach = (mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock())
-
- mock_launch = mocker.Mock(
- side_effect=launch_side_effect
- or (lambda cfg, meta, spec, timeout, llm_stage_launch_lock, kv: started_by_stage[meta.stage_id])
- )
- mock_remote = mocker.Mock(
- side_effect=remote_side_effect or (lambda cfg, meta, spec, timeout, srv: started_by_stage[meta.stage_id])
- )
- mock_attach = mocker.Mock(return_value=attach_result or default_attach)
-
- mock_oms = mocker.Mock(spec=OmniMasterServer)
- mock_oms.get_zmq_addresses.side_effect = lambda sid: mocker.Mock()
-
- finalized = (
- [mocker.Mock() for _ in stage_cfgs],
- [mocker.Mock() for _ in stage_cfgs],
- [{"final_output": True, "final_output_type": None, "stage_type": "llm"} for _ in stage_cfgs],
- )
-
- mocker.patch.object(engine, "_launch_llm_stage", mock_launch)
- mocker.patch.object(engine, "_create_remote_llm_stage", mock_remote)
- mocker.patch.object(engine, "_attach_llm_stage", mock_attach)
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.OmniMasterServer",
- return_value=mock_oms,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.prepare_engine_environment",
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(
- mocker,
- cfg.stage_id,
- getattr(cfg, "stage_type", "llm"),
- ),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
- )
-
- engine._initialize_stages(stage_init_timeout=60)
-
- return mock_launch, mock_remote
-
- # -- single-stage mode: stage matches filter → local launch ---------------
-
- def test_matching_stage_uses_launch_llm_stage(self, mocker: MockerFixture):
- """stage_id == _single_stage_id_filter → _launch_llm_stage is called."""
- stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
-
- launched_ids = [c.args[1].stage_id for c in mock_launch.call_args_list]
- assert 0 in launched_ids, "_launch_llm_stage should be called for stage 0"
-
- def test_non_matching_stage_uses_create_remote_llm_stage(self, mocker: MockerFixture):
- """stage_id != _single_stage_id_filter → _create_remote_llm_stage is called."""
- stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
-
- remote_ids = [c.args[1].stage_id for c in mock_remote.call_args_list]
- assert 1 in remote_ids, "_create_remote_llm_stage should be called for stage 1"
-
- def test_filter_1_routes_correctly(self, mocker: MockerFixture):
- """With filter=1, stage 0 is remote and stage 1 is local."""
- stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=1)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
-
- launched_ids = [c.args[1].stage_id for c in mock_launch.call_args_list]
- remote_ids = [c.args[1].stage_id for c in mock_remote.call_args_list]
- assert 1 in launched_ids, "stage 1 should be launched locally with filter=1"
- assert 0 in remote_ids, "stage 0 should use remote path with filter=1"
-
- def test_no_filter_all_stages_use_launch_path(self, mocker: MockerFixture):
- """single_stage_mode=True but no filter → all stages use _launch_llm_stage."""
- stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=None)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
-
- assert mock_remote.call_count == 0, "No remote launches without a filter"
- launched_ids = [c.args[1].stage_id for c in mock_launch.call_args_list]
- assert set(launched_ids) == {0, 1}
-
- def test_non_single_stage_mode_never_calls_create_remote(self, mocker: MockerFixture):
- """Outside single_stage_mode, _create_remote_llm_stage must not be called."""
- stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=False, stage_id_filter=None)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
-
- assert mock_remote.call_count == 0
-
- def test_omni_master_server_started_in_single_stage_mode(self, mocker: MockerFixture):
- """OmniMasterServer.start() must be called when single_stage_mode=True."""
- stage_cfgs = [_make_stage_cfg(0)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_oms = mocker.Mock(spec=OmniMasterServer)
- mock_oms.get_zmq_addresses.return_value = mocker.Mock()
- finalized = (
- [mocker.Mock()],
- [mocker.Mock()],
- [{"final_output": True, "final_output_type": None, "stage_type": "llm"}],
- )
-
- mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0))
- mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(0))
- mocker.patch.object(
- engine,
- "_attach_llm_stage",
- return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.OmniMasterServer",
- return_value=mock_oms,
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
- )
-
- engine._initialize_stages(stage_init_timeout=60)
-
- mock_oms.start.assert_called_once()
-
- def test_omni_master_server_uses_configured_stage_ids(self, mocker: MockerFixture):
- """Configured stage IDs, not list indexes, should drive pre-allocation."""
- stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=7)
- mock_oms = mocker.Mock(spec=OmniMasterServer)
- mock_oms.get_zmq_addresses.return_value = mocker.Mock()
- finalized = (
- [mocker.Mock(), mocker.Mock()],
- [mocker.Mock(), mocker.Mock()],
- [{"final_output": False, "final_output_type": None, "stage_type": "llm"} for _ in stage_cfgs],
- )
-
- mocker.patch.object(
- engine,
- "_launch_llm_stage",
- side_effect=[_make_started_llm_stage(7), _make_started_llm_stage(11)],
- )
- mocker.patch.object(
- engine,
- "_create_remote_llm_stage",
- return_value=_make_started_llm_stage(11),
- )
- mocker.patch.object(
- engine,
- "_attach_llm_stage",
- return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
- )
- mock_oms_cls = mocker.patch(
- "vllm_omni.engine.async_omni_engine.OmniMasterServer",
- return_value=mock_oms,
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
- )
-
- engine._initialize_stages(stage_init_timeout=60)
-
- mock_oms_cls.assert_called_once_with(
- master_address=engine._omni_master_address,
- master_port=engine._omni_master_port,
- stage_ids=[7, 11],
- )
-
- def test_single_stage_filter_uses_configured_stage_ids(self, mocker: MockerFixture):
- """Local/remote dispatch should compare against configured stage IDs."""
- stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=7)
- mock_oms = mocker.Mock(spec=OmniMasterServer)
- finalized = (
- [mocker.Mock(), mocker.Mock()],
- [mocker.Mock(), mocker.Mock()],
- [{"final_output": False, "final_output_type": None, "stage_type": "llm"} for _ in stage_cfgs],
- )
-
- mock_launch = mocker.patch.object(
- engine,
- "_launch_llm_stage",
- side_effect=[_make_started_llm_stage(7)],
- )
- mock_remote = mocker.patch.object(
- engine,
- "_create_remote_llm_stage",
- return_value=_make_started_llm_stage(11),
- )
- mocker.patch.object(
- engine,
- "_attach_llm_stage",
- return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.OmniMasterServer",
- return_value=mock_oms,
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
- )
-
- engine._initialize_stages(stage_init_timeout=60)
-
- assert [call.args[1].stage_id for call in mock_launch.call_args_list] == [7]
- assert [call.args[1].stage_id for call in mock_remote.call_args_list] == [11]
-
- def test_omni_master_server_preallocates_diffusion_stage_ids(self, mocker: MockerFixture):
- """Diffusion stages should also receive OmniMasterServer allocations."""
- stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11, stage_type="diffusion")]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=7)
- mock_oms = mocker.Mock(spec=OmniMasterServer)
- finalized = (
- [mocker.Mock(), mocker.Mock()],
- [mocker.Mock(), mocker.Mock()],
- [
- {"final_output": False, "final_output_type": None, "stage_type": "llm"},
- {"final_output": False, "final_output_type": None, "stage_type": "diffusion"},
- ],
- )
-
- mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(7))
- mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(7))
- mocker.patch.object(engine, "_launch_diffusion_stage", return_value=mocker.Mock())
- mocker.patch.object(
- engine,
- "_create_remote_diffusion_stage",
- return_value=mocker.Mock(),
- )
- mocker.patch.object(
- engine,
- "_attach_llm_stage",
- return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
- )
- mock_oms_cls = mocker.patch(
- "vllm_omni.engine.async_omni_engine.OmniMasterServer",
- return_value=mock_oms,
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(
- mocker,
- cfg.stage_id,
- getattr(cfg, "stage_type", "llm"),
- ),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
- )
-
- engine._initialize_stages(stage_init_timeout=60)
-
- mock_oms_cls.assert_called_once_with(
- master_address=engine._omni_master_address,
- master_port=engine._omni_master_port,
- stage_ids=[7, 11],
- )
-
- def test_duplicate_llm_stage_ids_raise(self, mocker: MockerFixture):
- """Duplicate configured LLM stage IDs should fail fast."""
- stage_cfgs = [_make_stage_cfg(3), _make_stage_cfg(3)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=3)
-
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- with pytest.raises(ValueError, match="Duplicate stage_id"):
- engine._initialize_stages(stage_init_timeout=60)
-
- def test_omni_master_server_not_started_in_normal_mode(self, mocker: MockerFixture):
- """OmniMasterServer must NOT be instantiated outside single_stage_mode."""
- stage_cfgs = [_make_stage_cfg(0)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=False, stage_id_filter=None)
- finalized = (
- [mocker.Mock()],
- [mocker.Mock()],
- [{"final_output": True, "final_output_type": None, "stage_type": "llm"}],
- )
-
- mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0))
- mocker.patch.object(
- engine,
- "_attach_llm_stage",
- return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
- )
- mock_oms_cls = mocker.patch("vllm_omni.engine.async_omni_engine.OmniMasterServer")
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
- )
-
- engine._initialize_stages(stage_init_timeout=60)
-
- mock_oms_cls.assert_not_called()
-
- def test_single_stage_mode_missing_master_address_raises(self, mocker: MockerFixture):
- """single_stage_mode without master address/port raises ValueError."""
- stage_cfgs = [_make_stage_cfg(0)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- engine._omni_master_address = None # missing
- engine._omni_master_port = None
-
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- with pytest.raises(ValueError, match="omni_master_address"):
- engine._initialize_stages(stage_init_timeout=60)
-
- def test_matching_diffusion_stage_uses_local_registered_launch(self, mocker: MockerFixture):
- """A local diffusion stage should use the registered single-stage launch path."""
- stage_cfgs = [_make_stage_cfg(0, stage_type="diffusion"), _make_stage_cfg(1)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_oms = mocker.Mock(spec=OmniMasterServer)
- diffusion_client = mocker.Mock(stage_type="diffusion")
- finalized = (
- [diffusion_client, mocker.Mock()],
- [mocker.Mock(), mocker.Mock()],
- [
- {"final_output": False, "final_output_type": None, "stage_type": "diffusion"},
- {"final_output": False, "final_output_type": None, "stage_type": "llm"},
- ],
- )
-
- mock_local_diff = mocker.patch.object(
- engine,
- "_launch_diffusion_stage",
- return_value=diffusion_client,
- )
- mock_remote_diff = mocker.patch.object(engine, "_create_remote_diffusion_stage")
- mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(1))
- mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(1))
- mocker.patch.object(
- engine,
- "_attach_llm_stage",
- return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.OmniMasterServer",
- return_value=mock_oms,
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(
- mocker,
- cfg.stage_id,
- getattr(cfg, "stage_type", "llm"),
- ),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
- )
-
- engine._initialize_stages(stage_init_timeout=60)
-
- assert mock_local_diff.call_count == 1
- assert mock_local_diff.call_args.args[1].stage_id == 0
- mock_remote_diff.assert_not_called()
-
- def test_non_matching_diffusion_stage_uses_remote_diffusion_client(self, mocker: MockerFixture):
- """A non-local diffusion stage should attach via the remote diffusion path."""
- stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1, stage_type="diffusion")]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_oms = mocker.Mock(spec=OmniMasterServer)
- remote_diffusion_client = mocker.Mock(stage_type="diffusion")
- finalized = (
- [mocker.Mock(), remote_diffusion_client],
- [mocker.Mock(), mocker.Mock()],
- [
- {"final_output": False, "final_output_type": None, "stage_type": "llm"},
- {"final_output": False, "final_output_type": None, "stage_type": "diffusion"},
- ],
- )
-
- mock_local_diff = mocker.patch.object(engine, "_launch_diffusion_stage")
- mock_remote_diff = mocker.patch.object(
- engine,
- "_create_remote_diffusion_stage",
- return_value=remote_diffusion_client,
- )
- mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0))
- mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(0))
- mocker.patch.object(
- engine,
- "_attach_llm_stage",
- return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.OmniMasterServer",
- return_value=mock_oms,
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(
- mocker,
- cfg.stage_id,
- getattr(cfg, "stage_type", "llm"),
- ),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
- )
-
- engine._initialize_stages(stage_init_timeout=60)
-
- mock_local_diff.assert_not_called()
- assert mock_remote_diff.call_count == 1
- assert mock_remote_diff.call_args.args[0].stage_id == 1
-
-
-# ---------------------------------------------------------------------------
-# AsyncOmniEngine – _launch_diffusion_stage
-# ---------------------------------------------------------------------------
-
-
-class TestLaunchDiffusionStage:
- """Test local diffusion stage launch wiring."""
-
- def test_registers_stage_with_public_master_properties(self, mocker: MockerFixture):
- engine = object.__new__(AsyncOmniEngine)
- engine.model = "fake-model"
- engine.diffusion_batch_size = 4
-
- stage_cfg = _make_stage_cfg(5, stage_type="diffusion")
- metadata = mocker.Mock(stage_id=5)
- omni_master_server = mocker.Mock(spec=OmniMasterServer)
- omni_master_server.address = "127.0.0.1"
- omni_master_server.port = 25000
-
- proc = mocker.Mock()
- diffusion_client = mocker.Mock()
-
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_diffusion_config",
- return_value="diffusion-config",
- )
- mock_register = mocker.patch(
- "vllm_omni.engine.async_omni_engine.register_stage_with_omni_master",
- return_value=(
- "tcp://127.0.0.1:25001",
- "tcp://127.0.0.1:25002",
- "tcp://127.0.0.1:25003",
- ),
- )
- mock_spawn = mocker.patch(
- "vllm_omni.engine.async_omni_engine.spawn_diffusion_proc",
- return_value=(proc, None, None, None),
- )
- mock_handshake = mocker.patch("vllm_omni.engine.async_omni_engine.complete_diffusion_handshake")
- mock_from_addresses = mocker.patch(
- "vllm_omni.engine.async_omni_engine.StageDiffusionClient.from_addresses",
- return_value=diffusion_client,
- )
-
- result = engine._launch_diffusion_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- omni_master_server=omni_master_server,
- )
-
- mock_register.assert_called_once_with(
- omni_master_address="127.0.0.1",
- omni_master_port=25000,
- omni_stage_id=5,
- omni_stage_config=stage_cfg,
- return_addresses=True,
- )
- mock_spawn.assert_called_once_with(
- "fake-model",
- "diffusion-config",
- handshake_address="tcp://127.0.0.1:25001",
- request_address="tcp://127.0.0.1:25002",
- response_address="tcp://127.0.0.1:25003",
- )
- mock_handshake.assert_called_once_with(proc, "tcp://127.0.0.1:25001")
- mock_from_addresses.assert_called_once_with(
- metadata,
- request_address="tcp://127.0.0.1:25002",
- response_address="tcp://127.0.0.1:25003",
- proc=proc,
- batch_size=4,
- )
- assert result is diffusion_client
-
-
-# ---------------------------------------------------------------------------
-# AsyncOmniEngine – _create_remote_llm_stage
-# ---------------------------------------------------------------------------
-
-
-class TestCreateRemoteLlmStage:
- """Test _create_remote_llm_stage delegates correctly."""
-
- def _engine(self, mocker: MockerFixture) -> AsyncOmniEngine:
- engine = object.__new__(AsyncOmniEngine)
- engine.model = "fake-model"
- engine.single_stage_mode = True
- engine._single_stage_id_filter = 0
- engine._omni_master_server = mocker.Mock(spec=OmniMasterServer)
- engine._omni_master_server.get_zmq_addresses.return_value = mocker.Mock()
- engine._omni_master_server.get_allocation.return_value = mocker.Mock()
- engine._omni_master_server.get_stage_config.return_value = {
- "stage_id": 0,
- "stage_type": "llm",
- "engine_args": {},
- }
- return engine
-
- def _mock_build_and_connect(self, mocker: MockerFixture, stage_id: int):
- fake_vllm_config = mocker.Mock()
- fake_executor_cls = mocker.Mock()
- fake_addresses = mocker.Mock()
- fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
- fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
- fake_addresses.frontend_stats_publish_address = None
-
- eng_mgr = mocker.Mock()
- coordinator = mocker.Mock()
-
- @contextmanager
- def fake_connect_cm(*args, **kwargs):
- yield eng_mgr, coordinator, fake_addresses
-
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": stage_id},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(fake_vllm_config, fake_executor_cls),
- )
- mock_connect = mocker.patch(
- "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores",
- return_value=fake_connect_cm(),
- )
-
- return mock_connect, fake_vllm_config, fake_executor_cls, fake_addresses
-
- def test_returns_started_llm_stage_with_correct_stage_id(self, mocker: MockerFixture):
- engine = self._engine(mocker)
- stage_cfg = _make_stage_cfg(1)
- metadata = mocker.Mock(stage_id=1)
- omni_ms = engine._omni_master_server
- omni_ms.get_stage_config.return_value = {
- "stage_id": 1,
- "stage_type": "llm",
- "engine_args": {},
- }
-
- self._mock_build_and_connect(mocker, 1)
- result = engine._create_remote_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- omni_master_server=omni_ms,
- )
- assert isinstance(result, StartedLlmStage)
- assert result.stage_id == 1
-
- def test_connect_remote_engine_cores_called_with_stage_id(self, mocker: MockerFixture):
- engine = self._engine(mocker)
- stage_cfg = _make_stage_cfg(2)
- metadata = mocker.Mock(stage_id=2)
- omni_ms = engine._omni_master_server
- omni_ms.get_zmq_addresses.return_value = mocker.Mock(inputs=["x"], outputs=["y"])
- omni_ms.get_stage_config.return_value = {
- "stage_id": 2,
- "stage_type": "llm",
- "engine_args": {},
- }
-
- fake_vllm_config = mocker.Mock()
- fake_executor_cls = mocker.Mock()
- fake_addresses = mocker.Mock()
- fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
- fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
- fake_addresses.frontend_stats_publish_address = None
-
- @contextmanager
- def fake_connect_cm(*args, **kwargs):
- yield mocker.Mock(), mocker.Mock(), fake_addresses
-
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 2},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(fake_vllm_config, fake_executor_cls),
- )
- mock_connect = mocker.patch(
- "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores",
- return_value=fake_connect_cm(),
- )
-
- engine._create_remote_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- omni_master_server=omni_ms,
- )
-
- mock_connect.assert_called_once()
- _, kwargs = mock_connect.call_args
- assert kwargs.get("stage_id") == 2 or mock_connect.call_args.args[-1] == 2
- omni_ms.get_stage_config.assert_called_once_with(2, timeout_s=60)
-
- def test_missing_registered_stage_config_raises_value_error(self, mocker: MockerFixture):
- engine = self._engine(mocker)
- stage_cfg = _make_stage_cfg(3)
- metadata = mocker.Mock(stage_id=3)
- omni_ms = engine._omni_master_server
- omni_ms.get_stage_config.return_value = None
-
- mock_build_args = mocker.patch("vllm_omni.engine.async_omni_engine.build_engine_args_dict")
- with pytest.raises(
- ValueError,
- match="Remote stage 3 registered without stage config",
- ):
- engine._create_remote_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- omni_master_server=omni_ms,
- )
-
- mock_build_args.assert_not_called()
-
- def test_exception_during_connect_closes_started_stage(self, mocker: MockerFixture):
- """If an error occurs after StartedLlmStage creation, close_started_llm_stage is called."""
- engine = self._engine(mocker)
- stage_cfg = _make_stage_cfg(1)
- metadata = mocker.Mock(stage_id=1)
- omni_ms = engine._omni_master_server
- omni_ms.get_stage_config.return_value = {
- "stage_id": 1,
- "stage_type": "llm",
- "engine_args": {},
- }
-
- @contextmanager
- def boom(*args, **kwargs):
- yield mocker.Mock(), mocker.Mock(), mocker.Mock()
- raise RuntimeError("handshake failed")
-
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 1},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(mocker.Mock(), mocker.Mock()),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores",
- return_value=boom(),
- )
- mock_close = mocker.patch("vllm_omni.engine.async_omni_engine.close_started_llm_stage")
- with pytest.raises(RuntimeError, match="handshake failed"):
- engine._create_remote_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- omni_master_server=omni_ms,
- )
- mock_close.assert_called_once()
-
-
-class TestConnectRemoteEngineCoresCoordinator:
- """Test coordinator launch parity with launch_core_engines."""
-
- @staticmethod
- def _build_vllm_config(
- mocker: MockerFixture, *, dp_rank: int = 0, offline_mode: bool = False, needs_dp_coordinator: bool = True
- ) -> Any:
- parallel_config = mocker.Mock()
- parallel_config.data_parallel_size_local = 1
- parallel_config.data_parallel_size = 2
- parallel_config.data_parallel_rank = dp_rank
- parallel_config.data_parallel_rank_local = 0 if offline_mode else None
-
- vllm_config = mocker.Mock()
- vllm_config.parallel_config = parallel_config
- vllm_config.needs_dp_coordinator = needs_dp_coordinator
- vllm_config.model_config = mocker.Mock(is_moe=False)
- return vllm_config
-
- def test_uses_registered_coordinator_addresses(self, mocker: MockerFixture):
- vllm_config = self._build_vllm_config(mocker, dp_rank=0, offline_mode=False, needs_dp_coordinator=True)
-
- omni_master_server = mocker.Mock(spec=OmniMasterServer)
- omni_master_server.get_zmq_addresses.return_value = EngineZmqAddresses(
- inputs=["tcp://client-in"], outputs=["tcp://client-out"]
- )
- omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001")
- omni_master_server.get_stage_coordinator_addresses.return_value = StageCoordinatorAddresses(
- coordinator_input="tcp://coord-in",
- coordinator_output="tcp://coord-out",
- frontend_stats_publish_address="tcp://stats",
- )
-
- @contextmanager
- def fake_socket_ctx(*args, **kwargs):
- yield mocker.Mock()
-
- mocker.patch(
- "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx",
- return_value=fake_socket_ctx(),
- )
- mock_wait = mocker.patch("vllm_omni.engine.stage_engine_startup._wait_for_omni_engine_startup")
- with connect_remote_engine_cores(
- vllm_config=vllm_config,
- omni_master_server=omni_master_server,
- stage_id=7,
- ) as (_, yielded_coordinator, yielded_addresses):
- assert yielded_coordinator is None
- assert yielded_addresses.coordinator_input == "tcp://coord-in"
- assert yielded_addresses.coordinator_output == "tcp://coord-out"
- assert yielded_addresses.frontend_stats_publish_address == "tcp://stats"
-
- omni_master_server.get_stage_coordinator_addresses.assert_called_once_with(7)
- mock_wait.assert_called_once()
-
- def test_defaults_to_no_coordinator_addresses_when_none_registered(self, mocker: MockerFixture):
- vllm_config = self._build_vllm_config(
- mocker,
- dp_rank=0,
- offline_mode=False,
- needs_dp_coordinator=True,
- )
-
- omni_master_server = mocker.Mock(spec=OmniMasterServer)
- omni_master_server.get_zmq_addresses.return_value = EngineZmqAddresses(
- inputs=["tcp://client-in"], outputs=["tcp://client-out"]
- )
- omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001")
- omni_master_server.get_stage_coordinator_addresses.return_value = StageCoordinatorAddresses()
-
- @contextmanager
- def fake_socket_ctx(*args, **kwargs):
- yield mocker.Mock()
-
- mocker.patch(
- "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx",
- return_value=fake_socket_ctx(),
- )
- mocker.patch("vllm_omni.engine.stage_engine_startup._wait_for_omni_engine_startup")
- with connect_remote_engine_cores(
- vllm_config=vllm_config,
- omni_master_server=omni_master_server,
- stage_id=7,
- ) as (_, yielded_coordinator, yielded_addresses):
- assert yielded_coordinator is None
- assert yielded_addresses.coordinator_input is None
- assert yielded_addresses.coordinator_output is None
- assert yielded_addresses.frontend_stats_publish_address is None
-
-
-class TestLaunchOmniCoreEngines:
- """Tests for local omni engine launch wiring."""
-
- def test_registers_stage_once_and_reuses_handshake_for_all_local_engines(self, mocker: MockerFixture):
- parallel_config = mocker.Mock(
- data_parallel_size_local=2,
- data_parallel_size=4,
- data_parallel_rank=3,
- )
- vllm_config = mocker.Mock(parallel_config=parallel_config)
-
- omni_master_server = mocker.Mock(spec=OmniMasterServer)
- omni_master_server.address = "127.0.0.1"
- omni_master_server.port = 26000
- omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001")
-
- stage_config = {"stage_id": 7, "stage_type": "llm"}
- local_engine_manager = mocker.Mock()
-
- @contextmanager
- def fake_socket_ctx(*args, **kwargs):
- yield mocker.Mock()
-
- mock_register = mocker.patch(
- "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
- return_value="tcp://127.0.0.1:26001",
- )
- mocker.patch(
- "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx",
- return_value=fake_socket_ctx(),
- )
- mock_manager_cls = mocker.patch(
- "vllm_omni.engine.stage_engine_startup.CoreEngineProcManager",
- return_value=local_engine_manager,
- )
- mocker.patch("vllm_omni.engine.stage_engine_startup.wait_for_engine_startup")
- with launch_omni_core_engines(
- vllm_config=vllm_config,
- executor_class=mocker.Mock(),
- log_stats=False,
- omni_master_server=omni_master_server,
- stage_id=7,
- stage_config=stage_config,
- ) as (yielded_manager, yielded_coordinator, yielded_addresses):
- assert yielded_manager is local_engine_manager
- assert yielded_coordinator is None
-
- mock_register.assert_called_once_with(
- omni_master_address="127.0.0.1",
- omni_master_port=26000,
- omni_stage_id=7,
- omni_stage_config=stage_config,
- coordinator=None,
- )
- mock_manager_cls.assert_called_once()
- manager_kwargs = mock_manager_cls.call_args.kwargs
- assert manager_kwargs["local_engine_count"] == 2
- assert manager_kwargs["start_index"] == 3
- assert manager_kwargs["local_start_index"] == 0
- assert manager_kwargs["vllm_config"] is vllm_config
- assert manager_kwargs["local_client"] is True
- assert manager_kwargs["handshake_address"] == "tcp://127.0.0.1:26001"
- assert manager_kwargs["executor_class"] is not None
-
- def test_registers_stage_with_coordinator_when_started(self, mocker: MockerFixture):
- parallel_config = mocker.Mock(
- data_parallel_size_local=1,
- data_parallel_size=2,
- data_parallel_rank=0,
- )
- vllm_config = mocker.Mock(parallel_config=parallel_config)
- vllm_config.needs_dp_coordinator = True
- vllm_config.model_config = mocker.Mock(is_moe=False)
-
- omni_master_server = mocker.Mock(spec=OmniMasterServer)
- omni_master_server.address = "127.0.0.1"
- omni_master_server.port = 26000
- omni_master_server.get_zmq_addresses.return_value = EngineZmqAddresses(
- inputs=["tcp://client-in"], outputs=["tcp://client-out"]
- )
- omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001")
-
- coordinator = mocker.Mock()
- coordinator.proc.pid = 1234
- coordinator.get_engine_socket_addresses.return_value = ("tcp://coord-in", "tcp://coord-out")
- coordinator.get_stats_publish_address.return_value = "tcp://stats"
-
- @contextmanager
- def fake_socket_ctx(*args, **kwargs):
- yield mocker.Mock()
-
- mocker.patch("vllm_omni.engine.stage_engine_startup.DPCoordinator", return_value=coordinator)
- mock_register = mocker.patch(
- "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
- return_value="tcp://127.0.0.1:26001",
- )
- mocker.patch(
- "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx",
- return_value=fake_socket_ctx(),
- )
- mock_manager_cls = mocker.patch(
- "vllm_omni.engine.stage_engine_startup.CoreEngineProcManager",
- return_value=mocker.Mock(),
- )
- mock_wait = mocker.patch("vllm_omni.engine.stage_engine_startup.wait_for_engine_startup")
- with launch_omni_core_engines(
- vllm_config=vllm_config,
- executor_class=mocker.Mock(),
- log_stats=False,
- omni_master_server=omni_master_server,
- stage_id=7,
- stage_config={"stage_id": 7},
- ):
- pass
-
- mock_register.assert_called_once_with(
- omni_master_address="127.0.0.1",
- omni_master_port=26000,
- omni_stage_id=7,
- omni_stage_config={"stage_id": 7},
- coordinator=coordinator,
- )
- manager_kwargs = mock_manager_cls.call_args.kwargs
- assert manager_kwargs["log_stats"] is False
- mock_wait.assert_called_once()
-
-
-# ---------------------------------------------------------------------------
-# AsyncOmniEngine – _launch_llm_stage single_stage_mode codepath
-# ---------------------------------------------------------------------------
-
-
-class TestLaunchLlmStageSingleStageMode:
- """Test that _launch_llm_stage selects launch_omni_core_engines when
- single_stage_mode=True and _omni_master_server is set."""
-
- def _build_engine_with_oms(self, mocker: MockerFixture) -> AsyncOmniEngine:
- engine = object.__new__(AsyncOmniEngine)
- engine.model = "fake-model"
- engine.single_stage_mode = True
- engine._single_stage_id_filter = 0
- engine._llm_stage_launch_lock = threading.Lock()
- engine.stage_configs = []
- mock_oms = mocker.Mock(spec=OmniMasterServer)
- mock_oms.address = "127.0.0.1"
- mock_oms.port = 25000
- alloc = mocker.Mock()
- alloc.handshake_bind_address = "tcp://127.0.0.1:25001"
- mock_oms.get_allocation.return_value = alloc
- fake_addresses = mocker.Mock()
- fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
- fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
- fake_addresses.frontend_stats_publish_address = None
- mock_oms.get_zmq_addresses.return_value = fake_addresses
- engine._omni_master_server = mock_oms
- return engine
-
- def _mock_launch_omni(self, mocker: MockerFixture, stage_id: int):
- fake_vllm_config = mocker.Mock()
- fake_executor_cls = mocker.Mock()
- fake_addresses = mocker.Mock()
- fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
- fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
- fake_addresses.frontend_stats_publish_address = None
-
- eng_mgr = mocker.Mock()
-
- @contextmanager
- def fake_launch_omni(*args, **kwargs):
- yield eng_mgr, None, fake_addresses
-
- mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": stage_id},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(fake_vllm_config, fake_executor_cls),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.acquire_device_locks",
- return_value=[],
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
- return mocker.patch(
- "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
- return_value=fake_launch_omni(),
- )
-
- def test_launch_omni_core_engines_used_in_single_stage_mode(self, mocker: MockerFixture):
- """single_stage_mode + _omni_master_server → launch_omni_core_engines."""
- engine = self._build_engine_with_oms(mocker)
- metadata = mocker.Mock(stage_id=0, runtime_cfg={})
- stage_cfg = _make_stage_cfg(0)
-
- mock_launch_omni = self._mock_launch_omni(mocker, 0)
- result = engine._launch_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
-
- mock_launch_omni.assert_called_once()
- assert mock_launch_omni.call_args.kwargs["stage_config"] is stage_cfg
- assert isinstance(result, StartedLlmStage)
- assert result.stage_id == 0
-
- def test_spawn_stage_core_used_in_normal_mode(self, mocker: MockerFixture):
- """~single_stage_mode → spawn_stage_core + complete_stage_handshake."""
- engine = object.__new__(AsyncOmniEngine)
- engine.model = "fake-model"
- engine.single_stage_mode = False
- engine._omni_master_server = None
- engine._llm_stage_launch_lock = threading.Lock()
- engine.stage_configs = []
-
- fake_vllm_config = mocker.Mock()
- fake_executor_cls = mocker.Mock()
- fake_addresses = mocker.Mock()
- fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
- fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
- fake_addresses.frontend_stats_publish_address = None
-
- fake_proc = mocker.Mock()
- fake_handshake_address = "ipc:///tmp/fake-handshake"
- stage_init_timeout = 60
-
- mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 0},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(fake_vllm_config, fake_executor_cls),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.acquire_device_locks",
- return_value=[],
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
- mock_spawn = mocker.patch(
- "vllm_omni.engine.async_omni_engine.spawn_stage_core",
- return_value=(fake_addresses, fake_proc, fake_handshake_address),
- )
- mock_handshake = mocker.patch("vllm_omni.engine.async_omni_engine.complete_stage_handshake")
- mock_omni = mocker.patch("vllm_omni.engine.async_omni_engine.launch_omni_core_engines")
- metadata = mocker.Mock(stage_id=0, runtime_cfg={})
- result = engine._launch_llm_stage(
- stage_cfg=_make_stage_cfg(0),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=stage_init_timeout,
- llm_stage_launch_lock=threading.Lock(),
- )
-
- mock_spawn.assert_called_once_with(
- vllm_config=fake_vllm_config,
- executor_class=fake_executor_cls,
- log_stats=False,
- )
- mock_handshake.assert_called_once_with(
- fake_proc,
- fake_handshake_address,
- fake_addresses,
- fake_vllm_config,
- stage_init_timeout,
- )
- mock_omni.assert_not_called()
- assert isinstance(result, StartedLlmStage)
- assert result.proc is fake_proc
-
- def test_launch_omni_passes_stage_id_and_master_server(self, mocker: MockerFixture):
- """launch_omni_core_engines receives the correct stage_id and omni_master_server."""
- engine = self._build_engine_with_oms(mocker)
- metadata = mocker.Mock(stage_id=0, runtime_cfg={})
-
- captured_kwargs: dict[str, Any] = {}
-
- @contextmanager
- def capturing_launch(*args, **kwargs):
- captured_kwargs.update(kwargs)
- fake_addresses = mocker.Mock()
- fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
- fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
- fake_addresses.frontend_stats_publish_address = None
- yield mocker.Mock(), None, fake_addresses
-
- mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 0},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(mocker.Mock(), mocker.Mock()),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.acquire_device_locks",
- return_value=[],
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
- side_effect=capturing_launch,
- )
-
- engine._launch_llm_stage(
- stage_cfg=_make_stage_cfg(0),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
-
- assert captured_kwargs.get("stage_id") == 0
- assert captured_kwargs.get("omni_master_server") is engine._omni_master_server
-
- def test_launch_omni_context_exits_before_stage_cleanup_on_error(self, mocker: MockerFixture):
- """Errors after entering the omni launch context still unwind it first."""
- engine = self._build_engine_with_oms(mocker)
- metadata = mocker.Mock(stage_id=0, runtime_cfg={})
-
- fake_addresses = mocker.Mock()
- fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
- fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
- fake_addresses.frontend_stats_publish_address = None
-
- events: list[str] = []
-
- @contextmanager
- def fake_launch_omni(*args, **kwargs):
- try:
- yield mocker.Mock(), None, fake_addresses
- finally:
- events.append("launch_exit")
-
- mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 0},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(mocker.Mock(), mocker.Mock()),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.acquire_device_locks",
- return_value=[],
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
- return_value=fake_launch_omni(),
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.logger.info", side_effect=RuntimeError("boom"))
- mock_close_stage = mocker.patch(
- "vllm_omni.engine.async_omni_engine.close_started_llm_stage",
- side_effect=lambda _started: events.append("stage_close"),
- )
- with pytest.raises(RuntimeError, match="boom"):
- engine._launch_llm_stage(
- stage_cfg=_make_stage_cfg(0),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
-
- mock_close_stage.assert_called_once()
- assert events == ["launch_exit", "stage_close"]
-
- def test_base_exception_propagates_without_started_stage_cleanup(self, mocker: MockerFixture):
- """BaseException subclasses should bypass the Exception cleanup path."""
- engine = self._build_engine_with_oms(mocker)
- metadata = mocker.Mock(stage_id=0, runtime_cfg={})
-
- fake_addresses = mocker.Mock()
- fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
- fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
- fake_addresses.frontend_stats_publish_address = None
-
- events: list[str] = []
-
- class FatalLaunchInterrupt(BaseException):
- pass
-
- @contextmanager
- def fake_launch_omni(*args, **kwargs):
- try:
- yield mocker.Mock(), None, fake_addresses
- finally:
- events.append("launch_exit")
-
- mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 0},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(mocker.Mock(), mocker.Mock()),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.acquire_device_locks",
- return_value=[],
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
- return_value=fake_launch_omni(),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.logger.info",
- side_effect=FatalLaunchInterrupt("stop"),
- )
- mock_close_stage = mocker.patch("vllm_omni.engine.async_omni_engine.close_started_llm_stage")
- with pytest.raises(FatalLaunchInterrupt, match="stop"):
- engine._launch_llm_stage(
- stage_cfg=_make_stage_cfg(0),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
-
- mock_close_stage.assert_not_called()
- assert events == ["launch_exit"]
diff --git a/tests/engine/test_stage_engine_core_client.py b/tests/engine/test_stage_engine_core_client.py
deleted file mode 100644
index dde0927af2d..00000000000
--- a/tests/engine/test_stage_engine_core_client.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Tests for StageEngineCoreClient.check_health().
-
-Uses object.__new__ to construct a minimal client — check_health only
-touches self.resources, self.stage_id, and self._proc.
-"""
-
-from __future__ import annotations
-
-from types import SimpleNamespace
-from unittest.mock import MagicMock
-
-import pytest
-from vllm.v1.engine.exceptions import EngineDeadError
-
-from vllm_omni.engine.stage_engine_core_client import StageEngineCoreClient
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def _make_client(*, engine_dead=False, proc_alive=True):
- client = object.__new__(StageEngineCoreClient)
- client.stage_id = 0
- client.resources = SimpleNamespace(engine_dead=engine_dead)
- client._proc = MagicMock(is_alive=MagicMock(return_value=proc_alive), exitcode=1)
- return client
-
-
-def test_check_health_passes_when_alive():
- client = _make_client(engine_dead=False, proc_alive=True)
- client.check_health() # no exception
-
-
-def test_check_health_raises_when_resources_engine_dead():
- client = _make_client(engine_dead=True, proc_alive=True)
- with pytest.raises(EngineDeadError, match="engine core is dead"):
- client.check_health()
-
-
-def test_check_health_raises_when_proc_not_alive():
- client = _make_client(engine_dead=False, proc_alive=False)
- with pytest.raises(EngineDeadError, match="not alive"):
- client.check_health()
- # Verify it set resources.engine_dead as a side effect
- assert client.resources.engine_dead is True
diff --git a/tests/entrypoints/openai_api/conftest_video.py b/tests/entrypoints/openai_api/conftest_video.py
deleted file mode 100644
index 3d5e0d85510..00000000000
--- a/tests/entrypoints/openai_api/conftest_video.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Shared test helpers for video streaming tests."""
-
-from __future__ import annotations
-
-import io
-
-import pytest
-
-np = pytest.importorskip("numpy", reason="numpy required for video stream tests")
-PIL = pytest.importorskip("PIL", reason="Pillow required for video stream tests")
-from PIL import Image # noqa: E402
-
-
-def make_jpeg(r: int = 128, g: int = 128, b: int = 128, size: int = 64) -> bytes:
- """Create a solid-colour JPEG image."""
- img = Image.new("RGB", (size, size), (r, g, b))
- buf = io.BytesIO()
- img.save(buf, format="JPEG", quality=95)
- return buf.getvalue()
-
-
-def make_gradient_jpeg(seed: int, size: int = 64) -> bytes:
- """Create a random-gradient JPEG that varies based on *seed*."""
- rng = np.random.RandomState(seed)
- arr = rng.randint(0, 256, (size, size, 3), dtype=np.uint8)
- img = Image.fromarray(arr, "RGB")
- buf = io.BytesIO()
- img.save(buf, format="JPEG", quality=95)
- return buf.getvalue()
diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py
index b5ff891f8f6..d68143dae8f 100644
--- a/tests/entrypoints/openai_api/test_image_server.py
+++ b/tests/entrypoints/openai_api/test_image_server.py
@@ -13,20 +13,15 @@
from types import SimpleNamespace
import pytest
-from fastapi import FastAPI
from fastapi.testclient import TestClient
from PIL import Image
from pytest_mock import MockerFixture
from vllm import SamplingParams
-from vllm.entrypoints.openai.models.protocol import BaseModelPath
-from vllm_omni.entrypoints.async_omni import AsyncOmni
-from vllm_omni.entrypoints.openai.api_server import _DiffusionServingModels, router
from vllm_omni.entrypoints.openai.image_api_utils import (
encode_image_base64,
parse_size,
)
-from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
@@ -111,13 +106,10 @@ def test_encode_image_base64():
class MockGenerationResult:
- """Mock result object compatible with current diffusion output shape."""
+ """Mock result object from AsyncOmni.generate()"""
def __init__(self, images):
self.images = images
- self.request_output = SimpleNamespace(images=images)
- self.stage_durations = {}
- self.peak_memory_mb = 0.0
class FakeAsyncOmni:
@@ -125,26 +117,20 @@ class FakeAsyncOmni:
def __init__(self, images=None):
self.stage_configs = [
- SimpleNamespace(stage_type="llm", is_comprehension=True),
- SimpleNamespace(stage_type="diffusion", is_comprehension=False),
+ SimpleNamespace(stage_type="llm"),
+ SimpleNamespace(stage_type="diffusion"),
]
self.default_sampling_params_list = [SamplingParams(temperature=0.1), OmniDiffusionSamplingParams()]
self.captured_sampling_params_list = None
self.captured_prompt = None
self._images = images or [Image.new("RGB", (64, 64), color="green")]
- async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
- if sampling_params_list is not None:
- self.captured_sampling_params_list = sampling_params_list
- else:
- self.captured_sampling_params_list = [sampling_params]
+ async def generate(self, prompt, request_id, sampling_params_list):
+ self.captured_sampling_params_list = sampling_params_list
self.captured_prompt = prompt
images = [img.copy() for img in self._images]
yield MockGenerationResult(images)
- def __class_getitem__(cls, item):
- return cls
-
@pytest.fixture
def mock_async_diffusion(mocker: MockerFixture):
@@ -191,8 +177,8 @@ def test_client(mock_async_diffusion):
[BaseModelPath(name="Qwen/Qwen-Image", model_path="Qwen/Qwen-Image")]
)
app.state.args = Namespace(
- default_sampling_params='{"0": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}',
- max_generated_image_size=1024 * 1792,
+ default_sampling_params='{"0": {"num_inference_steps":4, "guidance_scale":7.5}}',
+ max_generated_image_size=4096, # 64*64
)
return TestClient(app)
@@ -203,64 +189,18 @@ def async_omni_test_client():
"""Create test client with mocked AsyncOmni engine."""
from fastapi import FastAPI
- from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.openai.api_server import router
- from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
-
- class FakeAsyncOmniClass(AsyncOmni):
- def __init__(self):
- stage_configs = [
- SimpleNamespace(stage_type="llm", is_comprehension=True),
- SimpleNamespace(stage_type="diffusion", is_comprehension=False),
- ]
- default_sampling_params_list = [
- SamplingParams(temperature=0.1),
- OmniDiffusionSamplingParams(
- num_inference_steps=4,
- guidance_scale=7.5,
- generator_device="cpu",
- ),
- ]
- self.engine = SimpleNamespace(
- stage_configs=stage_configs,
- default_sampling_params_list=default_sampling_params_list,
- )
- self.default_sampling_params_list = default_sampling_params_list
- self.captured_sampling_params_list = None
- self.captured_prompt = None
- self._images = [Image.new("RGB", (64, 64), color="green")]
- self.od_config = SimpleNamespace(supports_multimodal_inputs=True)
-
- async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
- if sampling_params_list is not None:
- self.captured_sampling_params_list = sampling_params_list
- else:
- self.captured_sampling_params_list = [sampling_params]
- self.captured_prompt = prompt
- images = [img.copy() for img in self._images]
- yield MockGenerationResult(images)
-
- def __class_getitem__(cls, item):
- return cls
-
- def get_diffusion_od_config(self):
- return self.od_config
app = FastAPI()
app.include_router(router)
- engine = FakeAsyncOmniClass()
- chat_handler = object.__new__(OmniOpenAIServingChat)
- chat_handler.engine_client = engine
- chat_handler._diffusion_engine = None
- app.state.openai_serving_chat = chat_handler
- app.state.engine_client = engine
+ app.state.engine_client = FakeAsyncOmni()
app.state.stage_configs = [
SimpleNamespace(stage_type="llm"),
SimpleNamespace(stage_type="diffusion"),
]
app.state.args = Namespace(
- default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}',
+ default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}',
max_generated_image_size=1048576, # 1024*1024 to support resolution tests
)
return TestClient(app)
@@ -271,60 +211,18 @@ def async_omni_rgba_test_client():
"""Create test client with mocked AsyncOmni engine returning RGBA output."""
from fastapi import FastAPI
- from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.openai.api_server import router
- from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
-
- class FakeAsyncOmniClass(AsyncOmni):
- def __init__(self):
- stage_configs = [
- SimpleNamespace(stage_type="llm", is_comprehension=True),
- SimpleNamespace(stage_type="diffusion", is_comprehension=False),
- ]
- default_sampling_params_list = [
- SamplingParams(temperature=0.1),
- OmniDiffusionSamplingParams(),
- ]
- self.engine = SimpleNamespace(
- stage_configs=stage_configs,
- default_sampling_params_list=default_sampling_params_list,
- )
- self.default_sampling_params_list = default_sampling_params_list
- self.captured_sampling_params_list = None
- self.captured_prompt = None
- self._images = [Image.new("RGBA", (64, 64), color=(0, 255, 0, 128))]
- self.od_config = SimpleNamespace(supports_multimodal_inputs=True)
-
- async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
- if sampling_params_list is not None:
- self.captured_sampling_params_list = sampling_params_list
- else:
- self.captured_sampling_params_list = [sampling_params]
- self.captured_prompt = prompt
- images = [img.copy() for img in self._images]
- yield MockGenerationResult(images)
-
- def __class_getitem__(cls, item):
- return cls
-
- def get_diffusion_od_config(self):
- return self.od_config
app = FastAPI()
app.include_router(router)
- engine = FakeAsyncOmniClass()
- chat_handler = object.__new__(OmniOpenAIServingChat)
- chat_handler.engine_client = engine
- chat_handler._diffusion_engine = None
- app.state.openai_serving_chat = chat_handler
- app.state.engine_client = engine
+ app.state.engine_client = FakeAsyncOmni(images=[Image.new("RGBA", (64, 64), color=(0, 255, 0, 128))])
app.state.stage_configs = [
SimpleNamespace(stage_type="llm"),
SimpleNamespace(stage_type="diffusion"),
]
app.state.args = Namespace(
- default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}',
+ default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}',
max_generated_image_size=1048576,
)
return TestClient(app)
@@ -335,58 +233,19 @@ def async_omni_stage_configs_only_client():
"""Create test client with refactored AsyncOmni compatibility surface only."""
from fastapi import FastAPI
- from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.openai.api_server import router
- from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
-
- class FakeAsyncOmniClass(AsyncOmni):
- def __init__(self):
- stage_configs = [
- SimpleNamespace(stage_type="llm", is_comprehension=True),
- SimpleNamespace(stage_type="diffusion", is_comprehension=False),
- ]
- default_sampling_params_list = [
- SamplingParams(temperature=0.1),
- OmniDiffusionSamplingParams(),
- ]
- self.engine = SimpleNamespace(
- stage_configs=stage_configs,
- default_sampling_params_list=default_sampling_params_list,
- )
- self.default_sampling_params_list = default_sampling_params_list
- self.captured_sampling_params_list = None
- self.captured_prompt = None
- self._images = [Image.new("RGB", (64, 64), color="green")]
- self.od_config = SimpleNamespace(supports_multimodal_inputs=True)
-
- async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
- if sampling_params_list is not None:
- self.captured_sampling_params_list = sampling_params_list
- else:
- self.captured_sampling_params_list = [sampling_params]
- self.captured_prompt = prompt
- images = [img.copy() for img in self._images]
- yield MockGenerationResult(images)
-
- def __class_getitem__(cls, item):
- return cls
-
- def get_diffusion_od_config(self):
- return self.od_config
app = FastAPI()
app.include_router(router)
- engine = FakeAsyncOmniClass()
+ engine = FakeAsyncOmni()
assert not hasattr(engine, "stage_list")
app.state.engine_client = engine
- chat_handler = object.__new__(OmniOpenAIServingChat)
- chat_handler.engine_client = engine
- chat_handler._diffusion_engine = None
- app.state.openai_serving_chat = chat_handler
+ # Intentionally do not populate app.state.stage_configs. Refactored
+ # AsyncOmni exposes stage_configs on the engine instance.
app.state.args = Namespace(
- default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}',
- max_generated_image_size=1024 * 1792,
+ default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5}}',
+ max_generated_image_size=4096, # 64*64
)
return TestClient(app)
@@ -416,29 +275,6 @@ def test_health_endpoint_no_engine():
assert data["status"] == "unhealthy"
-def test_health_endpoint_dead_engine():
- """Health returns 503 when the engine raises EngineDeadError."""
- from unittest.mock import AsyncMock
-
- from fastapi import FastAPI
- from vllm.v1.engine.exceptions import EngineDeadError
-
- from vllm_omni.entrypoints.openai.api_server import router
-
- app = FastAPI()
- app.include_router(router)
-
- dead_engine = AsyncMock()
- dead_engine.check_health = AsyncMock(side_effect=EngineDeadError())
- app.state.engine_client = dead_engine
-
- client = TestClient(app)
- response = client.get("/health")
- assert response.status_code == 503
- data = response.json()
- assert data["status"] == "unhealthy"
-
-
def test_models_endpoint(test_client):
"""Test /v1/models endpoint for diffusion mode"""
response = test_client.get("/v1/models")
@@ -470,9 +306,6 @@ def test_models_endpoint_no_engine():
def test_generate_single_image(test_client):
"""Test generating a single image"""
- # Single-stage path should not require openai_serving_chat.
- assert not hasattr(test_client.app.state, "openai_serving_chat")
-
response = test_client.post(
"/v1/images/generations",
json={
@@ -541,127 +374,6 @@ def test_generate_images_async_omni_stage_configs_only(async_omni_stage_configs_
assert captured[1].seed == 11
-def test_multistage_images_async_omni_construction(async_omni_test_client):
- """Regression: multistage image generation builds the expected chat-style payload."""
- response = async_omni_test_client.post(
- "/v1/images/generations",
- json={
- "prompt": "a cat",
- "n": 2,
- "size": "128x256",
- "seed": 7,
- "num_inference_steps": 12,
- "guidance_scale": 6.5,
- },
- )
- assert response.status_code == 200
-
- engine = async_omni_test_client.app.state.engine_client
- captured_prompt = engine.captured_prompt
- assert captured_prompt["prompt"] == "a cat"
- assert captured_prompt["modalities"] == ["image"]
- assert captured_prompt["mm_processor_kwargs"] == {
- "target_h": 256,
- "target_w": 128,
- }
-
- captured = engine.captured_sampling_params_list
- assert captured is not None
- assert len(captured) == 2
- assert captured[0].temperature == 0.1
- assert captured[0].seed == 7
- assert captured[1].num_outputs_per_prompt == 2
- assert captured[1].width == 128
- assert captured[1].height == 256
- assert captured[1].seed == 7
- assert captured[1].num_inference_steps == 12
- assert captured[1].guidance_scale == 6.5
-
-
-def test_generate_images_async_omni_glm_image_sets_stage0_max_tokens():
- """GLM-Image multistage: stage-0 gets target_h/w from requested size.
-
- max_tokens comes from the deploy YAML default (upper-bound ceiling),
- NOT computed dynamically from height/width.
- """
-
- class FakeAsyncOmniClass(AsyncOmni):
- def __init__(self):
- stage_configs = [
- SimpleNamespace(stage_type="llm", is_comprehension=True, model_arch="GlmImageForConditionalGeneration"),
- SimpleNamespace(stage_type="diffusion", is_comprehension=False, model_arch="GlmImagePipeline"),
- ]
- # YAML default max_tokens for GLM-Image AR stage (upper bound for 2048x2048 t2i)
- default_sampling_params_list = [
- SamplingParams(temperature=0.1, seed=42, max_tokens=4353),
- OmniDiffusionSamplingParams(height=1024, width=1024),
- ]
- self.engine = SimpleNamespace(
- stage_configs=stage_configs,
- default_sampling_params_list=default_sampling_params_list,
- )
- self.default_sampling_params_list = default_sampling_params_list
- self.captured_sampling_params_list = None
- self.captured_prompt = None
- self._images = [Image.new("RGB", (64, 64), color="green")]
- self.od_config = SimpleNamespace(supports_multimodal_inputs=True)
-
- async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
- self.captured_sampling_params_list = (
- sampling_params_list if sampling_params_list is not None else [sampling_params]
- )
- self.captured_prompt = prompt
- yield MockGenerationResult([img.copy() for img in self._images])
-
- def __class_getitem__(cls, item):
- return cls
-
- def get_diffusion_od_config(self):
- return self.od_config
-
- app = FastAPI()
- app.include_router(router)
- engine = FakeAsyncOmniClass()
- chat_handler = object.__new__(OmniOpenAIServingChat)
- chat_handler.engine_client = engine
- chat_handler._diffusion_engine = None
- app.state.openai_serving_chat = chat_handler
- app.state.engine_client = engine
- app.state.stage_configs = [
- SimpleNamespace(stage_type="llm", model_arch="GlmImageForConditionalGeneration"),
- SimpleNamespace(stage_type="diffusion", model_arch="GlmImagePipeline"),
- ]
- app.state.openai_serving_models = _DiffusionServingModels(
- [BaseModelPath(name="THUDM/GLM-4.5V", model_path="THUDM/GLM-4.5V")]
- )
- app.state.args = Namespace(
- default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}',
- max_generated_image_size=1048576,
- )
- client = TestClient(app)
-
- response = client.post(
- "/v1/images/generations",
- json={
- "prompt": "a coral reef",
- "n": 1,
- "size": "1024x1024",
- "seed": 7,
- },
- )
- assert response.status_code == 200
-
- captured = engine.captured_sampling_params_list
- assert captured is not None
- assert len(captured) == 2
- # max_tokens comes from YAML default, not computed dynamically
- assert captured[0].max_tokens == 4353
- assert captured[0].extra_args["target_h"] == 1024
- assert captured[0].extra_args["target_w"] == 1024
- assert captured[1].height == 1024
- assert captured[1].width == 1024
-
-
def test_image_edits_async_omni_stage_configs_only(async_omni_stage_configs_only_client):
"""Regression: image edits accepts refactored AsyncOmni without stage_list."""
img_bytes = make_test_image_bytes((16, 16))
@@ -680,18 +392,6 @@ def test_image_edits_async_omni_stage_configs_only(async_omni_stage_configs_only
assert len(captured) == 2
-def test_generate_images_max_size_rejected(async_omni_test_client):
- """Test that a size exceeding max_generated_image_size returns 400."""
- response = async_omni_test_client.post(
- "/v1/images/generations",
- json={
- "prompt": "a cat",
- "size": "2048x2048", # 4,194,304 pixels > max_generated_image_size (1,048,576)
- },
- )
- assert response.status_code == 400
-
-
def test_generate_multiple_images(test_client):
"""Test generating multiple images"""
response = test_client.post(
@@ -967,19 +667,6 @@ def test_model_field_omitted_works(test_client):
assert response.status_code == 200
-def test_generate_images_rejects_model_mismatch(test_client):
- response = test_client.post(
- "/v1/images/generations",
- json={
- "prompt": "test",
- "model": "Qwen/Qwen-Image-2512",
- "size": "1024x1024",
- },
- )
- assert response.status_code == 400
- assert "model mismatch" in response.json()["detail"].lower()
-
-
def make_test_image_bytes(size=(64, 64)) -> bytes:
img = Image.new(
"RGB",
@@ -1083,99 +770,6 @@ def test_image_edit_rejects_multiple_images_when_model_does_not_support_them(asy
assert engine.captured_prompt is None
-def test_image_edit_rejects_model_mismatch(test_client):
- img_bytes = make_test_image_bytes((16, 16))
- response = test_client.post(
- "/v1/images/edits",
- files=[("image", img_bytes)],
- data={
- "prompt": "edit me",
- "model": "Qwen/Qwen-Image-Edit",
- },
- )
- assert response.status_code == 400
- assert "model mismatch" in response.json()["detail"].lower()
-
-
-def test_image_edit_rejects_too_many_images_for_qwen_image_edit_2511(async_omni_test_client):
- engine = async_omni_test_client.app.state.engine_client
- engine.get_diffusion_od_config = lambda: SimpleNamespace(
- supports_multimodal_inputs=True,
- max_multimodal_image_inputs=4,
- )
-
- response = async_omni_test_client.post(
- "/v1/images/edits",
- files=[
- ("image", make_test_image_bytes((16, 16))),
- ("image", make_test_image_bytes((16, 16))),
- ("image", make_test_image_bytes((16, 16))),
- ("image", make_test_image_bytes((16, 16))),
- ("image", make_test_image_bytes((16, 16))),
- ],
- data={"prompt": "hello world."},
- )
-
- assert response.status_code == 400
- assert response.json()["detail"] == "Received 5 input images. At most 4 images are supported by this model."
- assert engine.captured_prompt is None
-
-
-def test_image_edit_rejects_too_many_images_for_qwen_image_edit_2511_before_loading(
- async_omni_test_client, monkeypatch: pytest.MonkeyPatch
-):
- import vllm_omni.entrypoints.openai.api_server as api_server_module
-
- engine = async_omni_test_client.app.state.engine_client
- engine.get_diffusion_od_config = lambda: SimpleNamespace(
- supports_multimodal_inputs=True,
- max_multimodal_image_inputs=4,
- )
-
- def _fail_load(*args, **kwargs):
- raise AssertionError("_load_input_images should not run for over-limit requests")
-
- monkeypatch.setattr(api_server_module, "_load_input_images", _fail_load)
-
- response = async_omni_test_client.post(
- "/v1/images/edits",
- files=[
- ("image", make_test_image_bytes((16, 16))),
- ("image", make_test_image_bytes((16, 16))),
- ("image", make_test_image_bytes((16, 16))),
- ("image", make_test_image_bytes((16, 16))),
- ("image", make_test_image_bytes((16, 16))),
- ],
- data={"prompt": "hello world."},
- )
-
- assert response.status_code == 400
- assert response.json()["detail"] == "Received 5 input images. At most 4 images are supported by this model."
- assert engine.captured_prompt is None
-
-
-def test_image_edit_ignores_mock_like_multimodal_limit(async_omni_test_client):
- engine = async_omni_test_client.app.state.engine_client
- engine.get_diffusion_od_config = lambda: SimpleNamespace(
- supports_multimodal_inputs=SimpleNamespace(),
- max_multimodal_image_inputs=SimpleNamespace(),
- )
-
- response = async_omni_test_client.post(
- "/v1/images/edits",
- files=[("image", make_test_image_bytes((16, 16)))],
- data={"prompt": "hello world."},
- )
-
- assert response.status_code == 200
- captured_prompt = engine.captured_prompt
- assert captured_prompt is not None
- # Multi-stage path uses "img2img" key for single reference image
- processed_images = captured_prompt["multi_modal_data"]["img2img"]
- assert isinstance(processed_images, Image.Image)
- assert processed_images.size == (16, 16)
-
-
def test_image_edit_parameter_pass(async_omni_test_client):
img_bytes_1 = make_test_image_bytes((16, 16))
@@ -1354,7 +948,6 @@ def test_image_edit_parameter_default(async_omni_test_client):
assert captured_sampling_params.num_outputs_per_prompt == 1
assert captured_sampling_params.num_inference_steps == 4
assert captured_sampling_params.guidance_scale == 7.5
- assert captured_sampling_params.generator_device == "cpu"
# Test that a size exceeding max_generated_image_size returns 400
response = async_omni_test_client.post(
@@ -1388,15 +981,13 @@ def test_image_edit_parameter_default_single_stage(test_client):
assert captured_sampling_params.num_outputs_per_prompt == 1
assert captured_sampling_params.num_inference_steps == 4
assert captured_sampling_params.guidance_scale == 7.5
- assert captured_sampling_params.generator_device == "cpu"
- # Size exceeding max_generated_image_size (1024*1792) returns 400
response = test_client.post(
"/v1/images/edits",
files=[("image", img_bytes_1)],
data={
"prompt": "hello world.",
- "size": "2048x2048",
+ "size": "96x96",
},
)
assert response.status_code == 400
@@ -1561,91 +1152,3 @@ def test_image_edit_with_seed_zero_single_stage(test_client):
f"Expected seed=0, but got seed={captured_sampling_params.seed}. "
"This indicates the bug where seed=0 is treated as falsy."
)
-
-
-def test_normalize_image():
- """Test _normalize_image with various input types"""
- import numpy as np
-
- from vllm_omni.entrypoints.openai.api_server import _normalize_image
-
- # Test PIL Image input
- img = Image.new("RGB", (64, 64), color="red")
- result = _normalize_image(img)
- assert isinstance(result, Image.Image)
- assert result.size == (64, 64)
-
- # Test uint8 numpy array
- arr = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
- result = _normalize_image(arr)
- assert isinstance(result, Image.Image)
- assert result.size == (64, 64)
-
- # Test float [0, 1] numpy array
- arr = np.random.rand(64, 64, 3).astype(np.float32)
- result = _normalize_image(arr)
- assert isinstance(result, Image.Image)
- assert result.size == (64, 64)
-
- # Test float [-1, 1] numpy array
- arr = np.random.rand(64, 64, 3).astype(np.float32) * 2 - 1
- result = _normalize_image(arr)
- assert isinstance(result, Image.Image)
- assert result.size == (64, 64)
-
- # Test batch dimensions (1, 1, H, W, C)
- arr = np.random.randint(0, 255, (1, 1, 64, 64, 3), dtype=np.uint8)
- result = _normalize_image(arr)
- assert isinstance(result, Image.Image)
- assert result.size == (64, 64)
-
-
-def test_extract_images_from_result():
- """Test _extract_images_from_result with various result formats"""
- import numpy as np
-
- from vllm_omni.entrypoints.openai.api_server import _extract_images_from_result
-
- # Test empty result
- class EmptyResult:
- pass
-
- result = EmptyResult()
- images = _extract_images_from_result(result)
- assert images == []
-
- # Test nested batch: [np.array(shape=(3, 64, 64, 3))]
- batch = np.random.randint(0, 255, (3, 1, 64, 64, 3), dtype=np.uint8)
-
- class BatchResult:
- def __init__(self):
- self.images = [batch]
-
- result = BatchResult()
- images = _extract_images_from_result(result)
- assert len(images) == 3
- assert all(isinstance(img, Image.Image) for img in images)
- assert all(img.size == (64, 64) for img in images)
-
- # Test dict path: result.request_output["images"]
- class DictRequestOutput:
- def __init__(self):
- self.request_output = {"images": [np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)]}
-
- result = DictRequestOutput()
- images = _extract_images_from_result(result)
- assert len(images) == 1
- assert isinstance(images[0], Image.Image)
-
- # Test attribute path: result.request_output.images
- class AttrRequestOutput:
- def __init__(self):
- self.request_output = type(
- "obj", (), {"images": [np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)]}
- )()
-
- result = AttrRequestOutput()
- images = _extract_images_from_result(result)
- assert len(images) == 1
- assert isinstance(images[0], Image.Image)
- assert images[0].size == (32, 32)
diff --git a/tests/entrypoints/openai_api/test_qwen3_omni_realtime_websocket.py b/tests/entrypoints/openai_api/test_qwen3_omni_realtime_websocket.py
deleted file mode 100644
index 90f8897c58f..00000000000
--- a/tests/entrypoints/openai_api/test_qwen3_omni_realtime_websocket.py
+++ /dev/null
@@ -1,207 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""
-E2E online tests for Qwen3-Omni /v1/realtime WebSocket (streaming PCM in, audio out).
-"""
-
-from __future__ import annotations
-
-import asyncio
-import base64
-import io
-import json
-import os
-import wave
-
-import pytest
-import websockets
-
-from tests.helpers.mark import hardware_test
-from tests.helpers.media import (
- convert_audio_bytes_to_text,
- cosine_similarity_text,
- generate_synthetic_audio,
-)
-from tests.helpers.runtime import OmniServerParams
-from tests.helpers.stage_config import get_deploy_config_path
-
-os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-
-MODEL = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
-
-# Synthetic input for realtime E2E (``generate_synthetic_audio``); distinct cache file per phrase.
-REALTIME_SYNTH_PHRASE_TEXT = "Translate into Chinese: Beijing is the Capital of China"
-
-# The new-schema CI overlay bakes in async_chunk: False and covers CUDA/ROCm/XPU
-# via its ``platforms:`` section, so one path serves all three.
-default_stage_config = get_deploy_config_path("ci/qwen3_omni_moe.yaml")
-
-realtime_server_params = [
- pytest.param(
- OmniServerParams(
- model=MODEL,
- stage_config_path=default_stage_config,
- use_stage_cli=True,
- server_args=["--no-async-chunk"],
- ),
- id="default",
- ),
-]
-
-
-def _pcm16_mono_16k_from_wav_bytes(wav_bytes: bytes) -> bytes:
- with wave.open(io.BytesIO(wav_bytes), "rb") as wf:
- if wf.getnchannels() != 1:
- raise ValueError(f"Expected mono WAV, got {wf.getnchannels()} channels")
- if wf.getsampwidth() != 2:
- raise ValueError(f"Expected 16-bit PCM, sampwidth={wf.getsampwidth()}")
- if wf.getframerate() != 16000:
- raise ValueError(f"Expected 16 kHz input for /v1/realtime, got {wf.getframerate()} Hz")
- if wf.getcomptype() != "NONE":
- raise ValueError(f"Expected uncompressed PCM, comptype={wf.getcomptype()!r}")
- return wf.readframes(wf.getnframes())
-
-
-def _wav_bytes_from_pcm16(pcm: bytes, sample_rate_hz: int) -> bytes:
- buf = io.BytesIO()
- with wave.open(buf, "wb") as wf:
- wf.setnchannels(1)
- wf.setsampwidth(2)
- wf.setframerate(sample_rate_hz)
- wf.writeframes(pcm)
- return buf.getvalue()
-
-
-async def _run_realtime_audio_roundtrip(
- host: str,
- port: int,
- model: str,
- pcm16: bytes,
- *,
- chunk_ms: int = 100,
-) -> dict:
- uri = f"ws://{host}:{port}/v1/realtime"
- incremental: list[bytes] = []
- output_sr = 24000
- text_chunks: list[str] = []
- final_text = ""
- delta_events = 0
-
- bytes_per_ms = 16000 * 2 // 1000
- chunk_bytes = max(bytes_per_ms * chunk_ms, 2)
-
- async with websockets.connect(uri, max_size=64 * 1024 * 1024) as ws:
- await ws.send(json.dumps({"type": "session.update", "model": model}))
- await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": False}))
-
- for i in range(0, len(pcm16), chunk_bytes):
- chunk = pcm16[i : i + chunk_bytes]
- await ws.send(
- json.dumps(
- {
- "type": "input_audio_buffer.append",
- "audio": base64.b64encode(chunk).decode("utf-8"),
- }
- )
- )
-
- await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": True}))
-
- while True:
- message = await asyncio.wait_for(ws.recv(), timeout=600)
- if isinstance(message, bytes):
- continue
-
- event = json.loads(message)
- event_type = event.get("type")
-
- if event_type == "session.created":
- continue
-
- if event_type == "response.audio.delta":
- delta_events += 1
- sr = event.get("sample_rate_hz")
- if isinstance(sr, int) and sr > 0:
- output_sr = sr
- audio_b64 = event.get("audio", "")
- if audio_b64:
- incremental.append(base64.b64decode(audio_b64))
- continue
-
- if event_type == "transcription.delta":
- d = event.get("delta", "")
- if d:
- text_chunks.append(d)
- continue
-
- if event_type == "transcription.done":
- final_text = event.get("text", "") or "".join(text_chunks)
- continue
-
- if event_type == "response.audio.done":
- break
-
- if event_type == "error":
- raise AssertionError(f"WebSocket error: {event}")
-
- raise AssertionError(f"Unexpected WebSocket event: {event}")
-
- out_pcm = b"".join(incremental)
- return {
- "output_pcm": out_pcm,
- "output_sample_rate": output_sr,
- "transcription_text": final_text if final_text else "".join(text_chunks),
- "delta_events": delta_events,
- }
-
-
-class TestQwen3OmniRealtimeWebSocket:
- @pytest.mark.advanced_model
- @pytest.mark.omni
- @hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
- @pytest.mark.parametrize("omni_server", realtime_server_params, indirect=True)
- def test_streaming_audio_input_pcm_output(self, omni_server) -> None:
- """
- Short streamed 16 kHz mono PCM16 input; expect streamed PCM16 audio deltas and
- transcription. Verify Whisper(output audio) aligns with model text (same idea
- as multimodal omni e2e). Input speech is synthesized from
- ``REALTIME_SYNTH_PHRASE_TEXT``.
- """
- syn = generate_synthetic_audio(
- 10,
- 1,
- sample_rate=16000,
- phrase_text=REALTIME_SYNTH_PHRASE_TEXT,
- )
- wav_bytes = base64.b64decode(syn["base64"])
- pcm16 = _pcm16_mono_16k_from_wav_bytes(wav_bytes)
-
- result = asyncio.run(
- _run_realtime_audio_roundtrip(
- omni_server.host,
- omni_server.port,
- omni_server.model,
- pcm16,
- chunk_ms=100,
- )
- )
-
- out_pcm = result["output_pcm"]
- assert result["delta_events"] >= 1
- assert out_pcm, "No output PCM from response.audio.delta"
- assert len(out_pcm) % 2 == 0
- assert len(out_pcm) >= 4096, "Output audio unexpectedly small"
- assert result["output_sample_rate"] > 0
-
- final_text = (result["transcription_text"] or "").strip()
- assert final_text, "Expected non-empty transcription (model text stream)"
-
- wav_out = _wav_bytes_from_pcm16(out_pcm, result["output_sample_rate"])
- whisper_text = convert_audio_bytes_to_text(wav_out).strip()
- assert whisper_text, "Whisper returned empty string for synthesized output audio"
-
- sim = cosine_similarity_text(whisper_text.lower(), final_text.lower())
- assert sim > 0.9, (
- f"Output audio transcript should match model text (sim={sim:.3f}): "
- f"whisper={whisper_text!r}, model_text={final_text!r}"
- )
diff --git a/tests/entrypoints/openai_api/test_serving_chat_metrics.py b/tests/entrypoints/openai_api/test_serving_chat_metrics.py
index 0647c40a33f..d25af6c3843 100644
--- a/tests/entrypoints/openai_api/test_serving_chat_metrics.py
+++ b/tests/entrypoints/openai_api/test_serving_chat_metrics.py
@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for OmniChatCompletionResponse/StreamResponse metrics field."""
-from types import SimpleNamespace
-
import pytest
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
@@ -52,36 +50,3 @@ def test_omni_chat_completion_stream_response_metrics():
)
assert response.modality == "audio"
assert response.metrics == {"stage_latency": 0.5}
-
-
-def test_create_image_choice_exposes_diffusion_metrics():
- """Ensure image chat content exposes profiler metrics for clients."""
- from PIL import Image
-
- from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
-
- stage_durations = {"prefill": 0.12, "diffusion": 1.23}
- peak_memory_mb = 3210.5
- omni_outputs = SimpleNamespace(
- request_output=None,
- stage_durations=stage_durations,
- peak_memory_mb=peak_memory_mb,
- images=[Image.new("RGB", (2, 2), color=(255, 0, 0))],
- )
-
- choices = OmniOpenAIServingChat._create_image_choice( # type: ignore[misc]
- None,
- omni_outputs=omni_outputs,
- role="assistant",
- request=SimpleNamespace(return_token_ids=False),
- )
-
- assert len(choices) == 1
- content = choices[0].message.content
- assert isinstance(content, list)
- assert len(content) == 1
- first_item = content[0]
- assert first_item["type"] == "image_url"
- assert first_item["image_url"]["url"].startswith("data:image/png;base64,")
- assert first_item["stage_durations"] == stage_durations
- assert first_item["peak_memory_mb"] == peak_memory_mb
diff --git a/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py b/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py
deleted file mode 100644
index a9b9f53ba8a..00000000000
--- a/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py
+++ /dev/null
@@ -1,82 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-"""Regression tests for multistage diffusion generation input construction."""
-
-from __future__ import annotations
-
-from types import SimpleNamespace
-
-import pytest
-from PIL import Image
-from vllm.sampling_params import SamplingParams
-
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-@pytest.fixture
-def serving_chat():
- from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
-
- return object.__new__(OmniOpenAIServingChat)
-
-
-def test_build_multistage_generation_inputs_applies_stage_specific_overrides(serving_chat):
- from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
-
- engine = SimpleNamespace(
- stage_configs=[
- SimpleNamespace(stage_type="llm", is_comprehension=True),
- SimpleNamespace(stage_type="diffusion", is_comprehension=False),
- SimpleNamespace(stage_type="diffusion", is_comprehension=False),
- ],
- default_sampling_params_list=[
- SamplingParams(temperature=0.2, seed=11),
- OmniDiffusionSamplingParams(),
- OmniDiffusionSamplingParams(),
- ],
- )
- reference_image = Image.new("RGB", (24, 24), color="green")
- extra_body = {
- "negative_prompt": "blurry",
- "num_inference_steps": 28,
- "guidance_scale": 7.5,
- "true_cfg_scale": 5.0,
- "guidance_scale_2": 1.25,
- "layers": 6,
- "resolution": 1024,
- "lora": {"name": "adapter-a", "path": "/tmp/adapter-a", "scale": 0.6},
- }
- gen_params = OmniDiffusionSamplingParams(height=768, width=1024, seed=0, num_outputs_per_prompt=2)
-
- engine_prompt, sampling_params_list = OmniOpenAIServingChat._build_multistage_generation_inputs(
- serving_chat,
- engine=engine,
- prompt="draw a robot",
- extra_body=extra_body,
- reference_images=[reference_image],
- gen_params=gen_params,
- )
-
- assert engine_prompt["prompt"] == "draw a robot"
- assert engine_prompt["modalities"] == ["img2img"]
- assert engine_prompt["negative_prompt"] == "blurry"
- assert engine_prompt["mm_processor_kwargs"] == {"target_h": 768, "target_w": 1024}
- assert engine_prompt["multi_modal_data"]["img2img"].size == (24, 24)
-
- assert len(sampling_params_list) == 3
- assert sampling_params_list[0].temperature == 0.2
- assert sampling_params_list[0].seed == 0
- assert sampling_params_list[1].height == 768
- assert sampling_params_list[1].width == 1024
- assert sampling_params_list[1].seed == 0
- assert sampling_params_list[1].num_inference_steps == 28
- assert sampling_params_list[1].guidance_scale == 7.5
- assert sampling_params_list[1].num_outputs_per_prompt == 2
- assert sampling_params_list[1].true_cfg_scale == 5.0
- assert sampling_params_list[1].lora_request.name == "adapter-a"
- assert sampling_params_list[2].height == 768
- assert sampling_params_list[2].width == 1024
- assert sampling_params_list[2].num_inference_steps == 28
- assert engine.default_sampling_params_list[1].height is None
- assert engine.default_sampling_params_list[2].resolution == 640
diff --git a/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py b/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py
index c885d907ca4..fa4c1e195db 100644
--- a/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py
+++ b/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py
@@ -38,7 +38,7 @@ def default_comprehension_params():
temperature=0.4,
top_p=0.9,
top_k=1,
- max_tokens=4353,
+ max_tokens=2048,
seed=42,
repetition_penalty=1.05,
)
@@ -100,10 +100,6 @@ def mock_request(mocker: MockerFixture):
request.stop_token_ids = None
request.frequency_penalty = None
request.presence_penalty = None
- # Must be real Python objects (not MagicMock) so the code's explicit-field
- # and extra_body checks work correctly.
- request.model_fields_set = set()
- request.extra_body = {}
return request
@@ -146,7 +142,7 @@ def test_preserves_yaml_defaults_when_no_request_params(serving_chat, mock_reque
assert comprehension_params.temperature == 0.4
assert comprehension_params.top_p == 0.9
assert comprehension_params.top_k == 1 # YAML custom param preserved
- assert comprehension_params.max_tokens == 4353
+ assert comprehension_params.max_tokens == 2048
assert comprehension_params.seed == 42
assert comprehension_params.repetition_penalty == 1.05 # YAML custom param preserved
@@ -154,7 +150,6 @@ def test_preserves_yaml_defaults_when_no_request_params(serving_chat, mock_reque
def test_request_temperature_overrides_yaml_default(serving_chat, mock_request):
"""Test that request temperature overrides YAML default."""
mock_request.temperature = 0.8
- mock_request.model_fields_set = {"temperature"}
result = serving_chat._build_sampling_params_list_from_request(mock_request)
@@ -167,7 +162,6 @@ def test_request_temperature_overrides_yaml_default(serving_chat, mock_request):
def test_request_top_p_overrides_yaml_default(serving_chat, mock_request):
"""Test that request top_p overrides YAML default."""
mock_request.top_p = 0.95
- mock_request.model_fields_set = {"top_p"}
result = serving_chat._build_sampling_params_list_from_request(mock_request)
@@ -179,7 +173,6 @@ def test_request_top_p_overrides_yaml_default(serving_chat, mock_request):
def test_request_max_tokens_overrides_yaml_default(serving_chat, mock_request):
"""Test that request max_tokens overrides YAML default."""
mock_request.max_tokens = 100
- mock_request.model_fields_set = {"max_tokens"}
result = serving_chat._build_sampling_params_list_from_request(mock_request)
@@ -190,13 +183,12 @@ def test_max_tokens_uses_yaml_default_when_not_specified(serving_chat, mock_requ
"""Test that max_tokens falls back to YAML default when not in request."""
result = serving_chat._build_sampling_params_list_from_request(mock_request)
- assert result[0].max_tokens == 4353
+ assert result[0].max_tokens == 2048
def test_request_seed_overrides_yaml_default(serving_chat, mock_request):
"""Test that request seed overrides YAML default."""
mock_request.seed = 123
- mock_request.model_fields_set = {"seed"}
result = serving_chat._build_sampling_params_list_from_request(mock_request)
@@ -208,7 +200,6 @@ def test_request_seed_overrides_yaml_default(serving_chat, mock_request):
def test_request_frequency_penalty_overrides(serving_chat, mock_request):
"""Test that request frequency_penalty is applied."""
mock_request.frequency_penalty = 0.5
- mock_request.model_fields_set = {"frequency_penalty"}
result = serving_chat._build_sampling_params_list_from_request(mock_request)
@@ -218,7 +209,6 @@ def test_request_frequency_penalty_overrides(serving_chat, mock_request):
def test_request_presence_penalty_overrides(serving_chat, mock_request):
"""Test that request presence_penalty is applied."""
mock_request.presence_penalty = 0.3
- mock_request.model_fields_set = {"presence_penalty"}
result = serving_chat._build_sampling_params_list_from_request(mock_request)
@@ -245,7 +235,6 @@ def test_multiple_params_override_together(serving_chat, mock_request):
mock_request.temperature = 0.7
mock_request.top_p = 0.85
mock_request.seed = 999
- mock_request.model_fields_set = {"max_tokens", "temperature", "top_p", "seed"}
result = serving_chat._build_sampling_params_list_from_request(mock_request)
@@ -286,7 +275,6 @@ def test_apply_request_overrides_applies_values(serving_chat, mock_request, defa
"""Test that _apply_request_overrides applies non-None request values."""
mock_request.temperature = 0.8
mock_request.seed = 123
- mock_request.model_fields_set = {"temperature", "seed"}
result = serving_chat._apply_request_overrides(default_comprehension_params, mock_request)
@@ -296,197 +284,6 @@ def test_apply_request_overrides_applies_values(serving_chat, mock_request, defa
assert result.top_k == 1 # YAML custom param preserved
-# =============================================================================
-# Tests for empty-list handling in _apply_request_overrides
-# =============================================================================
-
-
-def test_apply_overrides_empty_stop_list_preserves_default(serving_chat, mocker):
- """Test that request.stop=[] does NOT override YAML default stop words."""
- default_params = SamplingParams(temperature=0.5, stop=["<|im_end|>"])
- request = mocker.MagicMock()
- request.temperature = None
- request.top_p = None
- request.top_k = None
- request.max_tokens = None
- request.min_tokens = None
- request.seed = None
- request.ignore_eos = None
- request.stop = [] # empty list — should be treated as "not set"
- request.stop_token_ids = None
- request.frequency_penalty = None
- request.presence_penalty = None
- request.model_fields_set = {"stop"}
- request.extra_body = {}
-
- result = serving_chat._apply_request_overrides(default_params, request)
-
- assert result.stop == ["<|im_end|>"] # YAML default preserved
-
-
-def test_apply_overrides_nonempty_stop_list_overrides_default(serving_chat, mocker):
- """Test that request.stop=["\\n"] overrides YAML default stop words."""
- default_params = SamplingParams(temperature=0.5, stop=["<|im_end|>"])
- request = mocker.MagicMock()
- request.temperature = None
- request.top_p = None
- request.top_k = None
- request.max_tokens = None
- request.min_tokens = None
- request.seed = None
- request.ignore_eos = None
- request.stop = ["\n"] # non-empty list — should override
- request.stop_token_ids = None
- request.frequency_penalty = None
- request.presence_penalty = None
- request.model_fields_set = {"stop"}
- request.extra_body = {}
-
- result = serving_chat._apply_request_overrides(default_params, request)
-
- assert result.stop == ["\n"] # Overridden by request
-
-
-def test_apply_overrides_empty_stop_token_ids_preserves_default(serving_chat, mocker):
- """Test that request.stop_token_ids=[] does NOT override YAML default."""
- default_params = SamplingParams(temperature=0.5, stop_token_ids=[2, 3])
- request = mocker.MagicMock()
- request.temperature = None
- request.top_p = None
- request.top_k = None
- request.max_tokens = None
- request.min_tokens = None
- request.seed = None
- request.ignore_eos = None
- request.stop = None
- request.stop_token_ids = [] # empty list — should be treated as "not set"
- request.frequency_penalty = None
- request.presence_penalty = None
-
- result = serving_chat._apply_request_overrides(default_params, request)
-
- assert result.stop_token_ids == [2, 3] # YAML default preserved
-
-
-def test_apply_overrides_nonempty_stop_token_ids_overrides_default(serving_chat, mocker):
- """Test that request.stop_token_ids=[100] overrides YAML default."""
- default_params = SamplingParams(temperature=0.5, stop_token_ids=[2, 3])
- request = mocker.MagicMock()
- request.temperature = None
- request.top_p = None
- request.top_k = None
- request.max_tokens = None
- request.min_tokens = None
- request.seed = None
- request.ignore_eos = None
- request.stop = None
- request.stop_token_ids = [100] # non-empty list — should override
- request.frequency_penalty = None
- request.presence_penalty = None
- request.model_fields_set = {"stop_token_ids"}
- request.extra_body = {}
-
- result = serving_chat._apply_request_overrides(default_params, request)
-
- assert result.stop_token_ids == [100] # Overridden by request
-
-
-def test_apply_overrides_mixed_empty_and_nonempty_lists(serving_chat, mocker):
- """Test mixing empty and non-empty list fields with scalar fields."""
- default_params = SamplingParams(
- temperature=0.4,
- stop=["<|end|>"],
- stop_token_ids=[2],
- )
- request = mocker.MagicMock()
- request.temperature = 0.9
- request.top_p = None
- request.top_k = None
- request.max_tokens = None
- request.min_tokens = None
- request.seed = None
- request.ignore_eos = None
- request.stop = [] # empty — should NOT override
- request.stop_token_ids = [100, 200] # non-empty — SHOULD override
- request.frequency_penalty = None
- request.presence_penalty = None
- request.model_fields_set = {"temperature", "stop", "stop_token_ids"}
- request.extra_body = {}
-
- result = serving_chat._apply_request_overrides(default_params, request)
-
- assert result.temperature == 0.9 # Scalar override works
- assert result.stop == ["<|end|>"] # Empty list did NOT override
- assert result.stop_token_ids == [100, 200] # Non-empty list DID override
-
-
-def test_apply_overrides_none_scalar_still_preserves_default(serving_chat, mocker):
- """Regression: ensure None scalar values still don't override defaults."""
- default_params = SamplingParams(temperature=0.5, max_tokens=100, seed=42)
- request = mocker.MagicMock()
- request.temperature = None
- request.top_p = None
- request.top_k = None
- request.max_tokens = None
- request.min_tokens = None
- request.seed = None
- request.ignore_eos = None
- request.stop = None
- request.stop_token_ids = None
- request.frequency_penalty = None
- request.presence_penalty = None
- request.model_fields_set = set()
- request.extra_body = {}
-
- result = serving_chat._apply_request_overrides(default_params, request)
-
- assert result.temperature == 0.5
- assert result.max_tokens == 100
- assert result.seed == 42
-
-
-def test_apply_overrides_both_lists_empty_preserves_defaults(serving_chat, mocker):
- """Test that both stop=[] and stop_token_ids=[] preserve YAML defaults."""
- default_params = SamplingParams(
- temperature=0.5,
- stop=["<|end|>", "\\n"],
- stop_token_ids=[2, 32000],
- )
- request = mocker.MagicMock()
- request.temperature = None
- request.top_p = None
- request.top_k = None
- request.max_tokens = None
- request.min_tokens = None
- request.seed = None
- request.ignore_eos = None
- request.stop = []
- request.stop_token_ids = []
- request.frequency_penalty = None
- request.presence_penalty = None
- request.model_fields_set = {"stop", "stop_token_ids"}
- request.extra_body = {}
-
- result = serving_chat._apply_request_overrides(default_params, request)
-
- assert result.stop == ["<|end|>", "\\n"]
- assert result.stop_token_ids == [2, 32000]
-
-
-def test_build_sampling_params_list_empty_stop_preserves_yaml(serving_chat, mock_request):
- """Test that empty stop list in request preserves YAML defaults via
- _build_sampling_params_list_from_request."""
- mock_request.stop = []
- mock_request.stop_token_ids = []
-
- result = serving_chat._build_sampling_params_list_from_request(mock_request)
-
- comprehension_params = result[0]
- # Empty lists should NOT override — YAML defaults are preserved
- assert comprehension_params.stop == []
- assert comprehension_params.stop_token_ids == []
-
-
# =============================================================================
# Tests for _get_comprehension_stage_index
# =============================================================================
@@ -535,174 +332,3 @@ def test_get_comprehension_stage_index_raises_when_not_found(mocker: MockerFixtu
with pytest.raises(ValueError, match="No comprehension stage"):
instance._get_comprehension_stage_index()
-
-
-# =============================================================================
-# Tests for _resolve_height_width_from_extra_body
-# =============================================================================
-
-
-class TestResolveHeightWidth:
- def test_explicit_height_width(self):
- from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
-
- h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"height": 512, "width": 768})
- assert h == 512
- assert w == 768
-
- def test_size_string(self):
- from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
-
- h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"size": "768x512"})
- assert w == 768
- assert h == 512
-
- def test_size_string_uppercase(self):
- from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
-
- h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"size": "768X512"})
- assert w == 768
- assert h == 512
-
- def test_size_fallback_when_height_missing(self):
- from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
-
- h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"size": "512x512", "width": 1024})
- # height is None -> size fallback fires and sets BOTH width and height
- assert h == 512
- assert w == 512
-
- def test_empty_extra_body(self):
- from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
-
- h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({})
- assert h is None
- assert w is None
-
- def test_invalid_size_format_ignored(self):
- from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
-
- h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"size": "invalid"})
- assert h is None
- assert w is None
-
-
-# =============================================================================
-# Tests for _apply_request_overrides with GLM-Image (target_h/w injection)
-# =============================================================================
-
-
-class TestApplyRequestOverridesGLMImage:
- """Test target_h/w injection for GLM-Image AR stage.
-
- max_tokens is NOT computed dynamically — it comes from the deploy YAML
- default (e.g. 4353). _apply_request_overrides only injects target_h/w
- into extra_args so the model can build M-RoPE position grids.
- """
-
- @pytest.fixture
- def glm_serving_chat(self, mock_engine_client, mocker: MockerFixture):
- from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
-
- instance = object.__new__(OmniOpenAIServingChat)
- instance.engine_client = mock_engine_client
- instance._extract_diffusion_prompt_and_images_from_messages = mocker.MagicMock(return_value=("a cat", []))
- return instance
-
- @pytest.fixture
- def glm_request(self, mocker: MockerFixture):
- req = mocker.MagicMock()
- req.temperature = None
- req.top_p = None
- req.top_k = None
- req.max_tokens = None
- req.min_tokens = None
- req.seed = None
- req.ignore_eos = None
- req.stop = None
- req.stop_token_ids = None
- req.frequency_penalty = None
- req.presence_penalty = None
- req.extra_body = {"height": 1024, "width": 1024}
- req.model_fields_set = set()
- return req
-
- def test_t2i_injects_target_h_w(self, glm_serving_chat, glm_request, default_comprehension_params):
- """t2i mode: target_h/w injected into extra_args, max_tokens unchanged."""
- result = glm_serving_chat._apply_request_overrides(default_comprehension_params, glm_request)
- assert result.extra_args["target_h"] == 1024
- assert result.extra_args["target_w"] == 1024
- # max_tokens stays at YAML default (not dynamically computed)
- assert result.max_tokens == 4353
-
- def test_i2i_injects_target_h_w(
- self, glm_serving_chat, glm_request, default_comprehension_params, mocker: MockerFixture
- ):
- """i2i mode: target_h/w injected, max_tokens unchanged."""
- glm_serving_chat._extract_diffusion_prompt_and_images_from_messages = mocker.MagicMock(
- return_value=("edit this", ["fake_image"])
- )
- result = glm_serving_chat._apply_request_overrides(default_comprehension_params, glm_request)
- assert result.extra_args["target_h"] == 1024
- assert result.extra_args["target_w"] == 1024
- # max_tokens stays at YAML default regardless of t2i/i2i
- assert result.max_tokens == 4353
-
- def test_user_max_tokens_preserved(self, glm_serving_chat, glm_request, default_comprehension_params):
- """User-provided max_tokens is respected (not overridden by dynamic computation)."""
- glm_request.max_tokens = 500
- glm_request.model_fields_set = {"max_tokens"}
-
- result = glm_serving_chat._apply_request_overrides(default_comprehension_params, glm_request)
- assert result.max_tokens == 500
- assert result.extra_args["target_h"] == 1024
- assert result.extra_args["target_w"] == 1024
-
- def test_no_height_width_preserves_default(
- self, glm_serving_chat, mocker: MockerFixture, default_comprehension_params
- ):
- """When no height/width in extra_body, keep YAML default max_tokens, no target_h/w."""
- req = mocker.MagicMock()
- req.temperature = None
- req.top_p = None
- req.top_k = None
- req.max_tokens = None
- req.min_tokens = None
- req.seed = None
- req.ignore_eos = None
- req.stop = None
- req.stop_token_ids = None
- req.frequency_penalty = None
- req.presence_penalty = None
- req.extra_body = {}
- req.model_fields_set = set()
-
- result = glm_serving_chat._apply_request_overrides(default_comprehension_params, req)
- assert result.max_tokens == 4353 # YAML default
- # No target_h/w injected when dimensions not provided
- assert not result.extra_args or "target_h" not in (result.extra_args or {})
-
- def test_size_string_parsed_for_glm_image(
- self, glm_serving_chat, mocker: MockerFixture, default_comprehension_params
- ):
- """'size' in extra_body is parsed as fallback for height/width."""
- req = mocker.MagicMock()
- req.temperature = None
- req.top_p = None
- req.top_k = None
- req.max_tokens = None
- req.min_tokens = None
- req.seed = None
- req.ignore_eos = None
- req.stop = None
- req.stop_token_ids = None
- req.frequency_penalty = None
- req.presence_penalty = None
- req.extra_body = {"size": "512x512"}
- req.model_fields_set = set()
-
- result = glm_serving_chat._apply_request_overrides(default_comprehension_params, req)
- assert result.extra_args["target_h"] == 512
- assert result.extra_args["target_w"] == 512
- # max_tokens stays at YAML default (not dynamically computed)
- assert result.max_tokens == 4353
diff --git a/tests/entrypoints/openai_api/test_serving_chat_speaker.py b/tests/entrypoints/openai_api/test_serving_chat_speaker.py
deleted file mode 100644
index 97c05e45b41..00000000000
--- a/tests/entrypoints/openai_api/test_serving_chat_speaker.py
+++ /dev/null
@@ -1,111 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Tests for chat endpoint speaker validation."""
-
-import asyncio
-from types import SimpleNamespace
-
-import pytest
-from pytest_mock import MockerFixture
-
-from vllm_omni.entrypoints.openai.utils import (
- get_supported_speakers_from_hf_config,
- validate_requested_speaker,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-@pytest.fixture
-def serving_chat():
- from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
-
- instance = object.__new__(OmniOpenAIServingChat)
- instance._supported_speakers = None
- return instance
-
-
-def _make_hf_config(mocker: MockerFixture, *, speaker_id: dict | None = None, spk_id: dict | None = None):
- hf_config = mocker.MagicMock()
- talker_config = mocker.MagicMock()
- talker_config.speaker_id = speaker_id
- talker_config.spk_id = spk_id
- hf_config.talker_config = talker_config
- return hf_config
-
-
-def test_validate_requested_speaker_accepts_case_insensitive_value():
- supported = {"vivian", "ethan"}
- assert validate_requested_speaker("Vivian", supported) == "vivian"
- assert validate_requested_speaker(" vivian ", supported) == "vivian"
-
-
-def test_validate_requested_speaker_rejects_invalid_value_with_supported_list():
- supported = {"vivian", "ethan"}
- with pytest.raises(ValueError, match="Invalid speaker 'uncle_fu'. Supported: ethan, vivian"):
- validate_requested_speaker("uncle_fu", supported)
-
-
-def test_validate_requested_speaker_skips_validation_when_supported_empty():
- assert validate_requested_speaker("anything", set()) == "anything"
- assert validate_requested_speaker(" ", {"vivian"}) is None
-
-
-def test_get_supported_speakers_from_hf_config_uses_spk_id_fallback(mocker: MockerFixture):
- hf_config = _make_hf_config(mocker, speaker_id=None, spk_id={"Serena": 0})
- assert get_supported_speakers_from_hf_config(hf_config) == {"serena"}
-
-
-def test_get_supported_speakers_caches_normalized_keys(mocker: MockerFixture, serving_chat):
- serving_chat.model_config = mocker.MagicMock()
- serving_chat.model_config.hf_config = _make_hf_config(mocker, speaker_id={"Vivian": 0, "Ethan": 1})
-
- assert serving_chat._get_supported_speakers() == {"vivian", "ethan"}
-
- # Cached value should be reused even if the config changes afterwards.
- serving_chat.model_config.hf_config.talker_config.speaker_id = {"Serena": 2}
- assert serving_chat._get_supported_speakers() == {"vivian", "ethan"}
-
-
-def test_create_chat_completion_converts_value_error_to_error_response(mocker: MockerFixture, serving_chat):
- serving_chat._diffusion_mode = False
- serving_chat._check_model = mocker.AsyncMock(return_value=None)
- serving_chat.engine_client = mocker.MagicMock(errored=False)
- serving_chat._maybe_get_adapters = mocker.MagicMock(return_value=None)
- serving_chat.models = mocker.MagicMock()
- serving_chat.models.model_name.return_value = "test-model"
- serving_chat.renderer = mocker.MagicMock()
- serving_chat.renderer.get_tokenizer.return_value = mocker.MagicMock()
- serving_chat.reasoning_parser_cls = None
- serving_chat.tool_parser = None
- serving_chat.use_harmony = False
- serving_chat.enable_auto_tools = False
- serving_chat.exclude_tools_when_tool_choice_none = False
- serving_chat.trust_request_chat_template = False
- serving_chat.chat_template = None
- serving_chat.chat_template_content_format = "string"
- serving_chat.default_chat_template_kwargs = {}
- serving_chat._validate_chat_template = mocker.MagicMock(return_value=None)
- serving_chat._prepare_extra_chat_template_kwargs = mocker.MagicMock(return_value={})
- serving_chat._preprocess_chat = mocker.AsyncMock(
- side_effect=ValueError("Invalid speaker 'uncle_fu'. Supported: ethan, vivian")
- )
- serving_chat.create_error_response = mocker.MagicMock(return_value="error-response")
-
- request = SimpleNamespace(
- tool_choice=None,
- tools=None,
- chat_template=None,
- chat_template_kwargs=None,
- reasoning_effort=None,
- messages=[],
- add_generation_prompt=False,
- continue_final_message=False,
- add_special_tokens=False,
- request_id="speaker-test",
- )
-
- result = asyncio.run(serving_chat.create_chat_completion(request))
-
- assert result == "error-response"
- serving_chat.create_error_response.assert_called_once_with("Invalid speaker 'uncle_fu'. Supported: ethan, vivian")
diff --git a/tests/entrypoints/openai_api/test_serving_speech.py b/tests/entrypoints/openai_api/test_serving_speech.py
index b78d62d9eda..b140b7a0468 100644
--- a/tests/entrypoints/openai_api/test_serving_speech.py
+++ b/tests/entrypoints/openai_api/test_serving_speech.py
@@ -6,6 +6,7 @@
from inspect import Signature, signature
from pathlib import Path
from types import SimpleNamespace
+from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
@@ -18,7 +19,6 @@
from pytest_mock import MockerFixture
from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse
-from vllm_omni.entrypoints.omni_base import OmniEngineDeadError
from vllm_omni.entrypoints.openai import api_server as api_server_module
from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin
from vllm_omni.entrypoints.openai.protocol.audio import (
@@ -63,11 +63,14 @@ def test_stereo_to_mono_conversion(self, audio_mixin, mocker: MockerFixture):
adjusted_tensor = mock_speed.call_args[0][0]
assert len(adjusted_tensor) == 24000
- def test_speed_adjustment(self, audio_mixin):
+ def test_speed_adjustment(self, audio_mixin, mocker: MockerFixture):
+ mock_time_stretch = mocker.patch("librosa.effects.time_stretch")
+ mock_time_stretch.return_value = np.zeros(12000)
audio_tensor = np.random.rand(24000).astype(np.float32)
adjusted_audio, _ = audio_mixin._apply_speed_adjustment(audio_tensor, speed=2.0, sample_rate=24000)
+ mock_time_stretch.assert_called_with(y=audio_tensor, rate=2.0)
assert adjusted_audio.shape == (12000,)
def test_unsupported_format_fallback(self, audio_mixin, caplog, mocker: MockerFixture):
@@ -114,22 +117,30 @@ def test_stereo_audio_preservation(self, audio_mixin, mocker: MockerFixture):
assert np.array_equal(output_tensor, stereo_tensor)
def test_speed_adjustment_bypass(self, audio_mixin, mocker: MockerFixture):
- """Test that speed=1.0 bypasses the expensive torchaudio time stretching."""
+ """Test that speed=1.0 bypasses the expensive librosa time stretching."""
audio_tensor = np.random.rand(24000).astype(np.float32)
- mock_time_stretch = mocker.patch("torchaudio.transforms.TimeStretch")
- # speed=1.0 should return immediately without calling torchaudio
+ mock_time_stretch = mocker.patch("librosa.effects.time_stretch")
+ # speed=1.0 should return immediately without calling librosa
result, _ = audio_mixin._apply_speed_adjustment(audio_tensor, speed=1.0, sample_rate=24000)
mock_time_stretch.assert_not_called()
assert np.array_equal(result, audio_tensor)
- def test_speed_adjustment_stereo_handling(self, audio_mixin):
- """Test that speed adjustment handles stereo (channels-last) input."""
+ def test_speed_adjustment_stereo_handling(self, audio_mixin, mocker: MockerFixture):
+ """Test that speed adjustment is attempted on stereo inputs."""
+ mock_time_stretch = mocker.patch("librosa.effects.time_stretch")
stereo_tensor = np.random.rand(24000, 2).astype(np.float32)
+ # Mock return value representing a sped-up version (half length)
+ mock_time_stretch.return_value = np.zeros((12000, 2), dtype=np.float32)
result, _ = audio_mixin._apply_speed_adjustment(stereo_tensor, speed=2.0, sample_rate=24000)
+ mock_time_stretch.assert_called_once()
+ # Ensure the stereo tensor was passed to librosa
+ call_args = mock_time_stretch.call_args
+ assert np.array_equal(call_args.kwargs["y"], stereo_tensor)
+ assert call_args.kwargs["rate"] == 2.0
assert result.shape == (12000, 2)
@@ -647,13 +658,11 @@ def speech_server(self, mocker: MockerFixture):
mock_engine_client.tts_max_instructions_length = None
mock_models = mocker.MagicMock()
mock_models.is_base_model.return_value = True
- server = OmniOpenAIServingSpeech(
+ return OmniOpenAIServingSpeech(
engine_client=mock_engine_client,
models=mock_models,
request_logger=mocker.MagicMock(),
)
- yield server
- server.shutdown()
def test_is_tts_detection_no_stage(self, speech_server):
"""Test TTS model detection when no TTS stage exists."""
@@ -685,32 +694,6 @@ def test_is_tts_detection_with_tts_stage(self, mocker: MockerFixture):
assert server._is_tts is True
assert server._tts_stage is mock_stage
- def test_prepare_speech_rejects_non_tts_omni_model(self, mocker: MockerFixture):
- """Multi-stage omni models (e.g. Qwen3-Omni) must not use /v1/audio/speech."""
- mock_engine_client = mocker.MagicMock()
- mock_engine_client.errored = False
- mock_engine_client.tts_max_instructions_length = None
-
- # Simulate Qwen3-Omni: multiple stages, none in _TTS_MODEL_STAGES
- thinker = SimpleNamespace(engine_args=SimpleNamespace(model_stage="thinker"), tts_args={})
- talker = SimpleNamespace(engine_args=SimpleNamespace(model_stage="talker"), tts_args={})
- code2wav = SimpleNamespace(engine_args=SimpleNamespace(model_stage="code2wav"), tts_args={})
- mock_engine_client.stage_configs = [thinker, talker, code2wav]
-
- mock_models = mocker.MagicMock()
- mock_models.is_base_model.return_value = True
- server = OmniOpenAIServingSpeech(
- engine_client=mock_engine_client,
- models=mock_models,
- request_logger=mocker.MagicMock(),
- )
- assert server._is_tts is False
-
- request = OpenAICreateSpeechRequest(input="Hello world")
- with pytest.raises(ValueError, match="only supported for dedicated TTS models"):
- asyncio.run(server._prepare_speech_generation(request))
- server.shutdown()
-
def test_estimate_prompt_len_fallback(self, speech_server):
"""Test prompt length estimation falls back to 2048 when model is unavailable."""
tts_params = {"text": ["Hello"], "task_type": ["CustomVoice"]}
@@ -778,26 +761,6 @@ def test_validate_tts_request_base_empty_ref_text(self, speech_server):
)
assert speech_server._validate_tts_request(req) is None
- @pytest.mark.parametrize(
- "ref_text",
- [None, "", " "],
- ids=["none", "empty", "whitespace"],
- )
- def test_validate_base_task_missing_ref_text_returns_400(self, speech_server, ref_text):
- """Regression: Base task without ref_text must return 400, not crash EngineCore.
-
- See https://github.com/vllm-project/vllm-omni/pull/2203
- """
- req = OpenAICreateSpeechRequest(
- input="Hello",
- task_type="Base",
- ref_audio="data:audio/wav;base64,abc",
- ref_text=ref_text,
- )
- result = speech_server._validate_tts_request(req)
- assert result is not None, f"ref_text={ref_text!r} should be rejected"
- assert "ref_text" in result
-
def test_validate_tts_request_customvoice_no_speakers(self, speech_server):
"""CustomVoice on a model with no speakers returns 400 instead of crashing engine."""
req = OpenAICreateSpeechRequest(input="Hello", task_type="CustomVoice")
@@ -927,7 +890,7 @@ def test_load_supported_speakers(self, mocker: MockerFixture):
# Verify speakers are normalized to lowercase
assert server.supported_speakers == {"ryan", "vivian", "aiden"}
- def test_build_tts_params_with_uploaded_voice(self, speech_server, mocker: MockerFixture):
+ def test_build_tts_params_with_uploaded_voice(self, speech_server):
"""Test _build_tts_params auto-sets ref_audio for uploaded voices (x_vector only)."""
speech_server.uploaded_speakers = {
"custom_voice": {
@@ -940,18 +903,18 @@ def test_build_tts_params_with_uploaded_voice(self, speech_server, mocker: Mocke
}
speech_server.supported_speakers = {"ryan", "vivian", "custom_voice"}
- mock_get_audio = mocker.patch.object(speech_server, "_get_uploaded_audio_data")
- mock_get_audio.return_value = "data:audio/wav;base64,ZmFrZWF1ZGlv"
- req = OpenAICreateSpeechRequest(input="Hello", voice="custom_voice")
- params = speech_server._build_tts_params(req)
+ with patch.object(speech_server, "_get_uploaded_audio_data") as mock_get_audio:
+ mock_get_audio.return_value = "data:audio/wav;base64,ZmFrZWF1ZGlv"
+ req = OpenAICreateSpeechRequest(input="Hello", voice="custom_voice")
+ params = speech_server._build_tts_params(req)
- assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZWF1ZGlv"]
- assert params["x_vector_only_mode"] == [True]
- assert params["task_type"] == ["Base"]
- assert params["voice_created_at"] == [1711234567.89]
- assert "ref_text" not in params
+ assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZWF1ZGlv"]
+ assert params["x_vector_only_mode"] == [True]
+ assert params["task_type"] == ["Base"]
+ assert params["voice_created_at"] == [1711234567.89]
+ assert "ref_text" not in params
- def test_build_tts_params_with_uploaded_voice_ref_text(self, speech_server, mocker: MockerFixture):
+ def test_build_tts_params_with_uploaded_voice_ref_text(self, speech_server):
"""Test _build_tts_params enables in-context cloning when ref_text is stored."""
speech_server.uploaded_speakers = {
"custom_voice": {
@@ -964,16 +927,16 @@ def test_build_tts_params_with_uploaded_voice_ref_text(self, speech_server, mock
}
speech_server.supported_speakers = {"ryan", "vivian", "custom_voice"}
- mock_get_audio = mocker.patch.object(speech_server, "_get_uploaded_audio_data")
- mock_get_audio.return_value = "data:audio/wav;base64,ZmFrZWF1ZGlv"
- req = OpenAICreateSpeechRequest(input="Hello", voice="custom_voice")
- params = speech_server._build_tts_params(req)
+ with patch.object(speech_server, "_get_uploaded_audio_data") as mock_get_audio:
+ mock_get_audio.return_value = "data:audio/wav;base64,ZmFrZWF1ZGlv"
+ req = OpenAICreateSpeechRequest(input="Hello", voice="custom_voice")
+ params = speech_server._build_tts_params(req)
- assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZWF1ZGlv"]
- assert params["x_vector_only_mode"] == [False]
- assert params["task_type"] == ["Base"]
- assert params["ref_text"] == ["Hello world transcript"]
- assert params["voice_created_at"] == [1711234567.89]
+ assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZWF1ZGlv"]
+ assert params["x_vector_only_mode"] == [False]
+ assert params["task_type"] == ["Base"]
+ assert params["ref_text"] == ["Hello world transcript"]
+ assert params["voice_created_at"] == [1711234567.89]
def test_build_tts_params_without_uploaded_voice(self, speech_server):
"""Test _build_tts_params does not auto-set ref_audio for non-uploaded voices."""
@@ -1015,43 +978,45 @@ def test_build_tts_params_with_explicit_ref_audio(self, speech_server):
# x_vector_only_mode should not be set when explicit ref_audio is provided
assert "x_vector_only_mode" not in params
- def test_get_uploaded_audio_data(self, speech_server, mocker: MockerFixture):
+ def test_get_uploaded_audio_data(self, speech_server):
"""Test _get_uploaded_audio_data function."""
# Mock file operations
- mock_open = mocker.patch("builtins.open", create=True)
- mock_b64encode = mocker.patch("base64.b64encode")
- mock_exists = mocker.patch("pathlib.Path.exists")
- mock_exists.return_value = True
- mock_b64encode.return_value = b"ZmFrZWF1ZGlv"
-
- # Setup mock file
- mock_file = mocker.MagicMock()
- mock_file.read.return_value = b"fakeaudio"
- mock_open.return_value.__enter__.return_value = mock_file
-
- # Setup uploaded speaker
- speech_server.uploaded_speakers = {
- "test_voice": {"name": "test_voice", "file_path": "/tmp/test.wav", "mime_type": "audio/wav"}
- }
- result = speech_server._get_uploaded_audio_data("test_voice")
+ with (
+ patch("builtins.open", create=True) as mock_open,
+ patch("base64.b64encode") as mock_b64encode,
+ patch("pathlib.Path.exists") as mock_exists,
+ ):
+ mock_exists.return_value = True
+ mock_b64encode.return_value = b"ZmFrZWF1ZGlv"
+
+ # Setup mock file
+ mock_file = MagicMock()
+ mock_file.read.return_value = b"fakeaudio"
+ mock_open.return_value.__enter__.return_value = mock_file
+
+ # Setup uploaded speaker
+ speech_server.uploaded_speakers = {
+ "test_voice": {"name": "test_voice", "file_path": "/tmp/test.wav", "mime_type": "audio/wav"}
+ }
+ result = speech_server._get_uploaded_audio_data("test_voice")
- assert result == "data:audio/wav;base64,ZmFrZWF1ZGlv"
- mock_open.assert_called_once_with(Path("/tmp/test.wav"), "rb")
- mock_b64encode.assert_called_once_with(b"fakeaudio")
+ assert result == "data:audio/wav;base64,ZmFrZWF1ZGlv"
+ mock_open.assert_called_once_with(Path("/tmp/test.wav"), "rb")
+ mock_b64encode.assert_called_once_with(b"fakeaudio")
- def test_get_uploaded_audio_data_missing_file(self, speech_server, mocker: MockerFixture):
+ def test_get_uploaded_audio_data_missing_file(self, speech_server):
"""Test _get_uploaded_audio_data when file is missing."""
- mock_exists = mocker.patch("pathlib.Path.exists")
- mock_exists.return_value = False
+ with patch("pathlib.Path.exists") as mock_exists:
+ mock_exists.return_value = False
- # Setup uploaded speaker
- speech_server.uploaded_speakers = {
- "test_voice": {"name": "test_voice", "file_path": "/tmp/test.wav", "mime_type": "audio/wav"}
- }
+ # Setup uploaded speaker
+ speech_server.uploaded_speakers = {
+ "test_voice": {"name": "test_voice", "file_path": "/tmp/test.wav", "mime_type": "audio/wav"}
+ }
- result = speech_server._get_uploaded_audio_data("test_voice")
+ result = speech_server._get_uploaded_audio_data("test_voice")
- assert result is None
+ assert result is None
def test_get_uploaded_audio_data_voice_not_found(self, speech_server):
"""Test _get_uploaded_audio_data when voice is not in uploaded_speakers."""
@@ -1061,181 +1026,6 @@ def test_get_uploaded_audio_data_voice_not_found(self, speech_server):
assert result is None
- # ── speaker field alias ──
-
- def test_speaker_alias_accepted_as_voice(self):
- """The 'speaker' JSON key should be accepted as an alias for 'voice'."""
- req = OpenAICreateSpeechRequest.model_validate({"input": "Hello", "speaker": "custom_voice"})
- assert req.voice == "custom_voice"
-
- def test_voice_field_still_accepted(self):
- """The canonical 'voice' JSON key should still work."""
- req = OpenAICreateSpeechRequest.model_validate({"input": "Hello", "voice": "custom_voice"})
- assert req.voice == "custom_voice"
-
- def test_speaker_alias_in_base_task_with_uploaded_voice(self, speech_server, mocker: MockerFixture):
- """Using 'speaker' key with an uploaded voice should work for Base task."""
- speech_server.uploaded_speakers = {
- "utesf": {
- "name": "UTESF",
- "file_path": "/tmp/voice_samples/utesf.wav",
- "mime_type": "audio/wav",
- "ref_text": None,
- }
- }
- req = OpenAICreateSpeechRequest.model_validate({"input": "Hello", "speaker": "UTESF", "task_type": "Base"})
- assert req.voice == "UTESF"
- mocker.patch("pathlib.Path.exists", return_value=True)
- result = speech_server._validate_qwen_tts_request(req)
- assert result is None
-
- # ── uploaded voice with embedding ──
-
- def test_build_tts_params_with_uploaded_voice_embedding(self, speech_server, mocker: MockerFixture):
- """Test _build_tts_params loads embedding for embedding-uploaded voices."""
- speech_server.uploaded_speakers = {
- "emb_voice": {
- "name": "emb_voice",
- "file_path": "/tmp/voice_samples/emb_voice.safetensors",
- "mime_type": "application/x-safetensors",
- "embedding_source": "direct",
- "embedding_dim": 1024,
- "cache_status": "ready",
- "cache_file": "/tmp/voice_samples/emb_voice.safetensors",
- }
- }
- speech_server.supported_speakers = {"ryan", "vivian", "emb_voice"}
-
- fake_embedding = [0.1] * 1024
- mock_get_emb = mocker.patch.object(speech_server, "_get_uploaded_speaker_embedding")
- mock_get_emb.return_value = fake_embedding
- req = OpenAICreateSpeechRequest(input="Hello", voice="emb_voice")
- params = speech_server._build_tts_params(req)
-
- assert "voice_clone_prompt" in params
- assert params["voice_clone_prompt"][0]["ref_spk_embedding"] == fake_embedding
- assert params["task_type"] == ["Base"]
- assert params["x_vector_only_mode"] == [True]
- assert "ref_audio" not in params
-
- # ── regression: full flow from issue #1603 ──
-
- def test_regression_1603_speaker_key_with_uploaded_audio_voice(self, speech_server, mocker: MockerFixture):
- """Regression test for #1603: upload audio voice, then invoke TTS with 'speaker' key.
-
- Verifies the full validate → build_params pipeline works end-to-end.
- """
- speech_server.uploaded_speakers = {
- "utesf": {
- "name": "UTESF",
- "file_path": "/tmp/voice_samples/utesf.wav",
- "mime_type": "audio/wav",
- "ref_text": "Hola, esta es una prueba.",
- }
- }
- # Parse with 'speaker' alias (the key users actually send)
- req = OpenAICreateSpeechRequest.model_validate(
- {"input": "Hello world", "speaker": "UTESF", "task_type": "Base"}
- )
- assert req.voice == "UTESF"
-
- # Validation should pass (file exists)
- mocker.patch("pathlib.Path.exists", return_value=True)
- err = speech_server._validate_qwen_tts_request(req)
- assert err is None, f"Validation failed: {err}"
-
- # Build params should auto-set ref_audio from stored file
- mock_audio = mocker.patch.object(speech_server, "_get_uploaded_audio_data")
- mock_audio.return_value = "data:audio/wav;base64,ZmFrZQ=="
- params = speech_server._build_tts_params(req)
-
- assert params["task_type"] == ["Base"]
- assert params["ref_audio"] == ["data:audio/wav;base64,ZmFrZQ=="]
- assert params["ref_text"] == ["Hola, esta es una prueba."]
- assert params["x_vector_only_mode"] == [False]
- assert params["speaker"] == ["utesf"]
-
- def test_regression_1603_speaker_key_with_uploaded_embedding_voice(self, speech_server, mocker: MockerFixture):
- """Regression test for #1603: upload embedding voice, then invoke TTS with 'speaker' key.
-
- Verifies embedding-uploaded voices are loaded as voice_clone_prompt, not as audio.
- """
- speech_server.uploaded_speakers = {
- "myvoice": {
- "name": "myvoice",
- "file_path": "/tmp/voice_samples/myvoice.safetensors",
- "mime_type": "application/x-safetensors",
- "embedding_source": "direct",
- "embedding_dim": 1024,
- "cache_status": "ready",
- "cache_file": "/tmp/voice_samples/myvoice.safetensors",
- }
- }
- # Parse with 'speaker' alias
- req = OpenAICreateSpeechRequest.model_validate(
- {"input": "Hello world", "speaker": "myvoice", "task_type": "Base"}
- )
- assert req.voice == "myvoice"
-
- # Validation should pass
- mocker.patch("pathlib.Path.exists", return_value=True)
- err = speech_server._validate_qwen_tts_request(req)
- assert err is None, f"Validation failed: {err}"
-
- # Build params should use embedding, NOT audio
- fake_emb = [0.1] * 1024
- mock_emb = mocker.patch.object(speech_server, "_get_uploaded_speaker_embedding")
- mock_emb.return_value = fake_emb
- params = speech_server._build_tts_params(req)
-
- assert params["task_type"] == ["Base"]
- assert params["x_vector_only_mode"] == [True]
- assert "voice_clone_prompt" in params
- assert params["voice_clone_prompt"][0]["ref_spk_embedding"] == fake_emb
- # Must NOT have ref_audio — that would fail for safetensors files
- assert "ref_audio" not in params
-
- def test_validate_rejects_embedding_voice_with_pending_cache(self, speech_server, mocker: MockerFixture):
- """Validation should reject embedding voices whose cache is not yet ready."""
- speech_server.uploaded_speakers = {
- "myvoice": {
- "name": "myvoice",
- "file_path": "/tmp/myvoice.safetensors",
- "mime_type": "application/x-safetensors",
- "embedding_source": "direct",
- "cache_status": "pending",
- "cache_file": None,
- }
- }
- req = OpenAICreateSpeechRequest.model_validate({"input": "Hello", "speaker": "myvoice", "task_type": "Base"})
- mocker.patch("pathlib.Path.exists", return_value=True)
- err = speech_server._validate_qwen_tts_request(req)
- assert err is not None
- assert "not yet ready" in err
-
- def test_x_vector_only_mode_not_overwritten_for_uploaded_embedding(self, speech_server, mocker: MockerFixture):
- """x_vector_only_mode set by uploaded embedding must not be overwritten by request field."""
- speech_server.uploaded_speakers = {
- "emb_voice": {
- "name": "emb_voice",
- "file_path": "/tmp/emb_voice.safetensors",
- "mime_type": "application/x-safetensors",
- "embedding_source": "direct",
- "embedding_dim": 1024,
- "cache_status": "ready",
- "cache_file": "/tmp/emb_voice.safetensors",
- }
- }
- fake_emb = [0.1] * 1024
- mock_emb = mocker.patch.object(speech_server, "_get_uploaded_speaker_embedding")
- mock_emb.return_value = fake_emb
- # Client explicitly sends x_vector_only_mode=False, but embedding requires True
- req = OpenAICreateSpeechRequest(input="Hello", voice="emb_voice", x_vector_only_mode=False)
- params = speech_server._build_tts_params(req)
-
- assert params["x_vector_only_mode"] == [True]
- assert "voice_clone_prompt" in params
-
def test_max_instructions_length_default(self, speech_server):
"""Test default max instructions length (500) when no config provided."""
# Fixture creates server with no CLI override and no TTS stage
@@ -1678,9 +1468,9 @@ async def test_omni_model_includes_generate(self):
assert "generate" in tasks
-def test_api_server_create_speech_wraps_error_response_status(mocker: MockerFixture):
- handler = mocker.MagicMock()
- handler.create_speech = mocker.AsyncMock(
+def test_api_server_create_speech_wraps_error_response_status():
+ handler = MagicMock()
+ handler.create_speech = AsyncMock(
return_value=ErrorResponse(
error=ErrorInfo(message="bad request", type="BadRequestError", param=None, code=400),
)
@@ -1708,77 +1498,6 @@ def test_api_server_create_speech_wraps_error_response_status(mocker: MockerFixt
assert response.status_code == 400
-def test_api_server_create_speech_engine_error_response_includes_request_and_stage_id(mocker: MockerFixture):
- handler = mocker.MagicMock()
- handler.create_speech = mocker.AsyncMock(
- side_effect=OmniEngineDeadError(
- "engine dead",
- error_stage_id=1,
- )
- )
-
- terminate_mock = mocker.patch.object(api_server_module, "terminate_if_errored")
-
- app = FastAPI()
- app.state.args = SimpleNamespace(log_error_stack=False)
- app.state.openai_serving_speech = handler
- app.state.engine_client = SimpleNamespace(
- engine=SimpleNamespace(is_alive=lambda: False),
- errored=True,
- )
- app.state.server = SimpleNamespace()
- scope = {
- "type": "http",
- "app": app,
- "method": "POST",
- "path": "/v1/audio/speech",
- "headers": [],
- "query_string": b"",
- "client": ("127.0.0.1", 12345),
- "server": ("testserver", 80),
- "scheme": "http",
- }
- raw_request = Request(scope)
- raw_request.state.request_metadata = SimpleNamespace(request_id="speech-req-1")
- request = OpenAICreateSpeechRequest(input="Hello")
-
- response = asyncio.run(api_server_module.create_speech(request, raw_request))
-
- assert isinstance(response, JSONResponse)
- assert response.status_code == 500
- assert response.body.decode("utf-8") == (
- '{"error":{"message":"engine dead","type":"InternalServerError","param":null,'
- '"code":500,"request_id":"speech-req-1","error_stage_id":1}}'
- )
- terminate_mock.assert_called_once()
-
-
-def test_omni_engine_error_handler_includes_request_and_stage_id(mocker: MockerFixture):
- app = FastAPI()
- app.state.args = SimpleNamespace(log_error_stack=False)
- app.state.engine_client = SimpleNamespace(
- engine=SimpleNamespace(is_alive=lambda: False),
- errored=True,
- )
- app.state.server = SimpleNamespace()
-
- terminate_mock = mocker.patch.object(api_server_module, "terminate_if_errored")
- api_server_module._register_omni_exception_handlers(app)
-
- @app.get("/boom")
- async def boom(request: Request):
- request.state.request_metadata = SimpleNamespace(request_id="speech-req-1")
- exc = OmniEngineDeadError("engine dead", error_stage_id=1)
- raise exc
-
- response = TestClient(app).get("/boom")
-
- assert response.status_code == 500
- assert response.json()["error"]["request_id"] == "speech-req-1"
- assert response.json()["error"]["error_stage_id"] == 1
- terminate_mock.assert_called_once()
-
-
class TestWAVHeaderGeneration:
"""Unit tests for WAV header generation with placeholder values."""
@@ -1920,13 +1639,11 @@ def fish_speech_server(mocker: MockerFixture):
mock_models = mocker.MagicMock()
mock_models.is_base_model.return_value = True
- server = OmniOpenAIServingSpeech(
+ return OmniOpenAIServingSpeech(
engine_client=mock_engine_client,
models=mock_models,
request_logger=mocker.MagicMock(),
)
- yield server
- server.shutdown()
class TestFishSpeechServing:
@@ -1946,9 +1663,9 @@ def test_build_fish_prompt_normalizes_legacy_speaker_tags(self, fish_speech_serv
assert "<|speaker:0|>你好,[laughing]欢迎回来。<|speaker:1|>我也来了。" in encoded_texts
assert all(allowed_special is None for _, _, allowed_special in tokenizer.calls)
- def test_build_fish_clone_prompt_normalizes_text_fields(self, fish_speech_server, mocker: MockerFixture):
+ def test_build_fish_clone_prompt_normalizes_text_fields(self, fish_speech_server):
fish_speech_server._fish_speech_tokenizer = _FakeFishTokenizer()
- fish_speech_server._estimate_fish_prompt_len = mocker.MagicMock(return_value=123)
+ fish_speech_server._estimate_fish_prompt_len = MagicMock(return_value=123)
request = OpenAICreateSpeechRequest(
input="你好,欢迎回来。",
@@ -1965,8 +1682,8 @@ def test_build_fish_clone_prompt_normalizes_text_fields(self, fish_speech_server
assert info["text"] == "<|speaker:1|>你好,欢迎回来。"
assert info["ref_text"] == "<|speaker:0|>参考音频的原始文本。"
assert info["fish_structured_voice_clone"] is True
- assert isinstance(info["ref_audio_wav"], torch.Tensor)
- assert info["ref_audio_wav"].dtype == torch.float32
+ assert os.path.exists(info["ref_audio_path"])
+ os.remove(info["ref_audio_path"])
fish_speech_server._estimate_fish_prompt_len.assert_called_once_with(
"<|speaker:1|>你好,欢迎回来。",
"<|speaker:0|>参考音频的原始文本。",
@@ -1999,10 +1716,8 @@ def test_build_fish_prompt_rejects_unsafe_control_tokens(self, fish_speech_serve
with pytest.raises(ValueError, match="unsupported control token"):
fish_speech_server._build_fish_speech_prompt(request)
- def test_prepare_speech_generation_overrides_fish_default_max_tokens(
- self, fish_speech_server, mocker: MockerFixture
- ):
- fish_speech_server._build_fish_speech_prompt_async = mocker.AsyncMock(
+ def test_prepare_speech_generation_overrides_fish_default_max_tokens(self, fish_speech_server):
+ fish_speech_server._build_fish_speech_prompt = MagicMock(
return_value={
"prompt_token_ids": [1, 2, 3],
"additional_information": {},
@@ -2015,14 +1730,13 @@ def test_prepare_speech_generation_overrides_fish_default_max_tokens(
assert request_id.startswith("speech-")
assert generator == "generator"
- fish_speech_server._build_fish_speech_prompt_async.assert_awaited_once()
fish_speech_server.engine_client.generate.assert_called_once()
sampling_params_list = fish_speech_server.engine_client.generate.call_args.kwargs["sampling_params_list"]
assert sampling_params_list[0].max_tokens == 4096
assert fish_speech_server.engine_client.default_sampling_params_list[0].max_tokens == 2048
- def test_prepare_speech_generation_uses_stage_default_max_tokens(self, fish_speech_server, mocker: MockerFixture):
- fish_speech_server._build_fish_speech_prompt_async = mocker.AsyncMock(
+ def test_prepare_speech_generation_uses_stage_default_max_tokens(self, fish_speech_server):
+ fish_speech_server._build_fish_speech_prompt = MagicMock(
return_value={
"prompt_token_ids": [1, 2, 3],
"additional_information": {},
@@ -2053,9 +1767,9 @@ def test_prepare_speech_generation_rejects_invalid_fish_max_new_tokens(self, fis
fish_speech_server.engine_client.generate.assert_not_called()
- def test_create_speech_batch_allows_fish_text_only_items(self, fish_speech_server, mocker: MockerFixture):
- fish_speech_server._check_model = mocker.AsyncMock(return_value=None)
- fish_speech_server._generate_audio_bytes = mocker.AsyncMock(return_value=("YWJj", "audio/wav"))
+ def test_create_speech_batch_allows_fish_text_only_items(self, fish_speech_server):
+ fish_speech_server._check_model = AsyncMock(return_value=None)
+ fish_speech_server._generate_audio_bytes = AsyncMock(return_value=("YWJj", "audio/wav"))
batch = BatchSpeechRequest(items=[SpeechBatchItem(input="hello fish")])
response = asyncio.run(fish_speech_server.create_speech_batch(batch))
@@ -2251,8 +1965,8 @@ def test_validate_cosyvoice3_max_new_tokens_range(self, cosyvoice3_server):
assert error is not None
assert "max_new_tokens" in error
- def test_prepare_speech_generation_cosyvoice3(self, cosyvoice3_server, mocker: MockerFixture):
- cosyvoice3_server._build_cosyvoice3_prompt = mocker.AsyncMock(
+ def test_prepare_speech_generation_cosyvoice3(self, cosyvoice3_server):
+ cosyvoice3_server._build_cosyvoice3_prompt = AsyncMock(
return_value={
"prompt": "Hello",
"multi_modal_data": {"audio": (np.zeros(24000), 24000)},
@@ -2271,115 +1985,3 @@ def test_prepare_speech_generation_cosyvoice3(self, cosyvoice3_server, mocker: M
assert generator == "generator"
assert tts_params == {}
cosyvoice3_server._build_cosyvoice3_prompt.assert_awaited_once()
-
-
-class TestTTSAsyncOffloading:
- """Tests for event-loop-safe offloading of blocking TTS operations."""
-
- def test_build_voxtral_prompt_is_sync(self):
- """_build_voxtral_prompt should be a regular function, not a coroutine."""
- assert not asyncio.iscoroutinefunction(OmniOpenAIServingSpeech._build_voxtral_prompt)
-
- @pytest.fixture
- def voxtral_server(self, mocker: MockerFixture):
- mocker.patch.object(OmniOpenAIServingSpeech, "_load_supported_speakers", return_value=set())
- mocker.patch.object(OmniOpenAIServingSpeech, "_load_codec_frame_rate", return_value=None)
- mock_engine_client = mocker.MagicMock()
- mock_engine_client.errored = False
- mock_engine_client.model_config = mocker.MagicMock(model="mistralai/Voxtral")
- mock_engine_client.default_sampling_params_list = [SimpleNamespace(max_tokens=2048)]
- mock_engine_client.tts_batch_max_items = 32
- mock_engine_client.generate = mocker.MagicMock(return_value="generator")
- mock_engine_client.stage_configs = [
- SimpleNamespace(
- engine_args=SimpleNamespace(model_stage="audio_generation"),
- tts_args={},
- )
- ]
- mock_models = mocker.MagicMock()
- mock_models.is_base_model.return_value = True
- server = OmniOpenAIServingSpeech(
- engine_client=mock_engine_client,
- models=mock_models,
- request_logger=mocker.MagicMock(),
- )
- yield server
- server.shutdown()
-
- @pytest.fixture
- def qwen3_tts_server(self, mocker: MockerFixture):
- mocker.patch.object(OmniOpenAIServingSpeech, "_load_supported_speakers", return_value=set())
- mocker.patch.object(OmniOpenAIServingSpeech, "_load_codec_frame_rate", return_value=None)
- mock_engine_client = mocker.MagicMock()
- mock_engine_client.errored = False
- mock_engine_client.model_config = mocker.MagicMock(model="Qwen/Qwen3-TTS", hf_config=mocker.MagicMock())
- mock_engine_client.default_sampling_params_list = [SimpleNamespace(max_tokens=2048)]
- mock_engine_client.tts_batch_max_items = 32
- mock_engine_client.generate = mocker.MagicMock(return_value="generator")
- mock_engine_client.tts_max_instructions_length = None
- mock_engine_client.stage_configs = [
- SimpleNamespace(
- engine_args=SimpleNamespace(model_stage="qwen3_tts"),
- tts_args={},
- )
- ]
- mock_models = mocker.MagicMock()
- mock_models.is_base_model.return_value = True
- server = OmniOpenAIServingSpeech(
- engine_client=mock_engine_client,
- models=mock_models,
- request_logger=mocker.MagicMock(),
- )
- yield server
- server.shutdown()
-
- def test_prepare_speech_generation_awaits_voxtral_async(self, voxtral_server, mocker: MockerFixture):
- """Voxtral path in _prepare_speech_generation should call the async wrapper."""
- voxtral_server._build_voxtral_prompt_async = mocker.AsyncMock(
- return_value={
- "prompt_token_ids": [1, 2, 3],
- "additional_information": {"voice": ["test"]},
- }
- )
- request = OpenAICreateSpeechRequest(input="hello", voice="test")
- asyncio.run(voxtral_server._prepare_speech_generation(request))
- voxtral_server._build_voxtral_prompt_async.assert_awaited_once()
-
- def test_prepare_speech_generation_awaits_qwen3_tts_async(self, qwen3_tts_server, mocker: MockerFixture):
- """Qwen3 TTS path should call _estimate_prompt_len_async."""
- qwen3_tts_server._validate_tts_request = mocker.MagicMock(return_value=None)
- qwen3_tts_server._build_tts_params = mocker.MagicMock(
- return_value={"text": ["hello"], "task_type": ["CustomVoice"], "speaker": ["Vivian"]}
- )
- qwen3_tts_server._estimate_prompt_len_async = mocker.AsyncMock(return_value=512)
- request = OpenAICreateSpeechRequest(input="hello")
- asyncio.run(qwen3_tts_server._prepare_speech_generation(request))
- qwen3_tts_server._build_tts_params.assert_called_once()
- qwen3_tts_server._estimate_prompt_len_async.assert_awaited_once()
-
- def test_shutdown_is_idempotent(self, mocker: MockerFixture):
- """Calling shutdown() twice should not raise."""
- mocker.patch.object(OmniOpenAIServingSpeech, "_load_supported_speakers", return_value=set())
- mocker.patch.object(OmniOpenAIServingSpeech, "_load_codec_frame_rate", return_value=None)
- mock_engine_client = mocker.MagicMock()
- mock_engine_client.errored = False
- mock_engine_client.stage_configs = []
- mock_engine_client.tts_max_instructions_length = None
- mock_models = mocker.MagicMock()
- mock_models.is_base_model.return_value = True
- server = OmniOpenAIServingSpeech(
- engine_client=mock_engine_client,
- models=mock_models,
- request_logger=mocker.MagicMock(),
- )
- assert server._tts_executor is not None
- server.shutdown()
- assert server._tts_executor is None
- server.shutdown() # Should not raise
- assert server._tts_executor is None
-
- def test_diffusion_instance_shutdown_safe(self, mocker: MockerFixture):
- """Diffusion instances (created via for_diffusion) should have safe shutdown."""
- server = OmniOpenAIServingSpeech.for_diffusion(diffusion_engine=mocker.MagicMock(), model_name="test-model")
- assert server._tts_executor is None
- server.shutdown() # Should not raise
diff --git a/tests/entrypoints/openai_api/test_serving_speech_stream.py b/tests/entrypoints/openai_api/test_serving_speech_stream.py
index 1b93ef58e24..bd136ac7272 100644
--- a/tests/entrypoints/openai_api/test_serving_speech_stream.py
+++ b/tests/entrypoints/openai_api/test_serving_speech_stream.py
@@ -1,8 +1,8 @@
import asyncio
+from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import FastAPI, WebSocket
-from pytest_mock import MockerFixture
from starlette.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
@@ -13,26 +13,19 @@
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-def _build_test_app(
- speech_service=None,
- *,
- idle_timeout=30.0,
- config_timeout=10.0,
- mocker: MockerFixture | None = None,
-):
+def _build_test_app(speech_service=None, *, idle_timeout=30.0, config_timeout=10.0):
if speech_service is None:
- assert mocker is not None
- speech_service = mocker.MagicMock(spec=OmniOpenAIServingSpeech)
- speech_service._generate_audio_bytes = mocker.AsyncMock(return_value=(b"RIFF" + b"\x00" * 32, "audio/wav"))
- speech_service._prepare_speech_generation = mocker.AsyncMock(return_value=("req-1", object(), {}))
+ speech_service = MagicMock(spec=OmniOpenAIServingSpeech)
+ speech_service._generate_audio_bytes = AsyncMock(return_value=(b"RIFF" + b"\x00" * 32, "audio/wav"))
+ speech_service._prepare_speech_generation = AsyncMock(return_value=("req-1", object(), {}))
async def mock_generate_pcm_chunks(_generator, _request_id):
for chunk in (b"\x01\x02", b"\x03\x04\x05"):
yield chunk
speech_service._generate_pcm_chunks = mock_generate_pcm_chunks
- speech_service.engine_client = mocker.MagicMock()
- speech_service.engine_client.abort = mocker.AsyncMock()
+ speech_service.engine_client = MagicMock()
+ speech_service.engine_client.abort = AsyncMock()
handler = OmniStreamingSpeechHandler(
speech_service=speech_service,
@@ -49,8 +42,8 @@ async def ws_endpoint(websocket: WebSocket):
class TestStreamingSpeechWebSocket:
- def test_non_streaming_single_frame(self, mocker: MockerFixture):
- app, speech_service = _build_test_app(mocker=mocker)
+ def test_non_streaming_single_frame(self):
+ app, speech_service = _build_test_app()
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -75,13 +68,13 @@ def test_non_streaming_single_frame(self, mocker: MockerFixture):
assert speech_service._generate_audio_bytes.await_count == 1
- def test_streaming_multiple_binary_frames(self, mocker: MockerFixture):
+ def test_streaming_multiple_binary_frames(self):
captured_requests = []
- speech_service = mocker.MagicMock(spec=OmniOpenAIServingSpeech)
- speech_service._generate_audio_bytes = mocker.AsyncMock(return_value=(b"", "audio/wav"))
- speech_service.engine_client = mocker.MagicMock()
- speech_service.engine_client.abort = mocker.AsyncMock()
+ speech_service = MagicMock(spec=OmniOpenAIServingSpeech)
+ speech_service._generate_audio_bytes = AsyncMock(return_value=(b"", "audio/wav"))
+ speech_service.engine_client = MagicMock()
+ speech_service.engine_client.abort = AsyncMock()
async def mock_prepare_speech_generation(request):
captured_requests.append(request)
@@ -130,8 +123,8 @@ async def mock_generate_pcm_chunks(_generator, _request_id):
assert captured_requests[0].initial_codec_chunk_frames == 12
assert speech_service._generate_audio_bytes.await_count == 0
- def test_flush_on_input_done(self, mocker: MockerFixture):
- app, _ = _build_test_app(mocker=mocker)
+ def test_flush_on_input_done(self):
+ app, _ = _build_test_app()
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -149,8 +142,8 @@ def test_flush_on_input_done(self, mocker: MockerFixture):
}
assert ws.receive_json() == {"type": "session.done", "total_sentences": 1}
- def test_invalid_streaming_config(self, mocker: MockerFixture):
- app, _ = _build_test_app(mocker=mocker)
+ def test_invalid_streaming_config(self):
+ app, _ = _build_test_app()
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -166,8 +159,8 @@ def test_invalid_streaming_config(self, mocker: MockerFixture):
assert error["type"] == "error"
assert "response_format='pcm'" in error["message"]
- def test_empty_input_text_emits_no_audio(self, mocker: MockerFixture):
- app, speech_service = _build_test_app(mocker=mocker)
+ def test_empty_input_text_emits_no_audio(self):
+ app, speech_service = _build_test_app()
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -179,8 +172,8 @@ def test_empty_input_text_emits_no_audio(self, mocker: MockerFixture):
assert speech_service._generate_audio_bytes.await_count == 0
- def test_multiple_sentences_increment_indices(self, mocker: MockerFixture):
- app, _ = _build_test_app(mocker=mocker)
+ def test_multiple_sentences_increment_indices(self):
+ app, _ = _build_test_app()
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -210,8 +203,8 @@ def test_multiple_sentences_increment_indices(self, mocker: MockerFixture):
ws.send_json({"type": "input.done"})
assert ws.receive_json() == {"type": "session.done", "total_sentences": 2}
- def test_unknown_message_type_keeps_session_open(self, mocker: MockerFixture):
- app, _ = _build_test_app(mocker=mocker)
+ def test_unknown_message_type_keeps_session_open(self):
+ app, _ = _build_test_app()
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -234,21 +227,21 @@ def test_unknown_message_type_keeps_session_open(self, mocker: MockerFixture):
ws.send_json({"type": "input.done"})
assert ws.receive_json() == {"type": "session.done", "total_sentences": 1}
- def test_config_timeout_closes_session(self, mocker: MockerFixture):
- app, _ = _build_test_app(config_timeout=0.01, mocker=mocker)
+ def test_config_timeout_closes_session(self):
+ app, _ = _build_test_app(config_timeout=0.01)
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
error = ws.receive_json()
assert error == {"type": "error", "message": "Timeout waiting for session.config"}
- def test_generation_error_marks_audio_done(self, mocker: MockerFixture):
- speech_service = mocker.MagicMock(spec=OmniOpenAIServingSpeech)
- speech_service._generate_audio_bytes = mocker.AsyncMock(side_effect=RuntimeError("boom"))
- speech_service._prepare_speech_generation = mocker.AsyncMock(return_value=("req-err", object(), {}))
- speech_service._generate_pcm_chunks = mocker.AsyncMock()
- speech_service.engine_client = mocker.MagicMock()
- speech_service.engine_client.abort = mocker.AsyncMock()
+ def test_generation_error_marks_audio_done(self):
+ speech_service = MagicMock(spec=OmniOpenAIServingSpeech)
+ speech_service._generate_audio_bytes = AsyncMock(side_effect=RuntimeError("boom"))
+ speech_service._prepare_speech_generation = AsyncMock(return_value=("req-err", object(), {}))
+ speech_service._generate_pcm_chunks = AsyncMock()
+ speech_service.engine_client = MagicMock()
+ speech_service.engine_client.abort = AsyncMock()
app, _ = _build_test_app(speech_service)
with TestClient(app) as client:
@@ -263,12 +256,12 @@ def test_generation_error_marks_audio_done(self, mocker: MockerFixture):
ws.send_json({"type": "input.done"})
assert ws.receive_json() == {"type": "session.done", "total_sentences": 1}
- def test_streaming_generation_error_marks_audio_done(self, mocker: MockerFixture):
- speech_service = mocker.MagicMock(spec=OmniOpenAIServingSpeech)
- speech_service._generate_audio_bytes = mocker.AsyncMock(return_value=(b"", "audio/wav"))
- speech_service._prepare_speech_generation = mocker.AsyncMock(return_value=("req-stream-err", object(), {}))
- speech_service.engine_client = mocker.MagicMock()
- speech_service.engine_client.abort = mocker.AsyncMock()
+ def test_streaming_generation_error_marks_audio_done(self):
+ speech_service = MagicMock(spec=OmniOpenAIServingSpeech)
+ speech_service._generate_audio_bytes = AsyncMock(return_value=(b"", "audio/wav"))
+ speech_service._prepare_speech_generation = AsyncMock(return_value=("req-stream-err", object(), {}))
+ speech_service.engine_client = MagicMock()
+ speech_service.engine_client.abort = AsyncMock()
async def mock_generate_pcm_chunks(_generator, _request_id):
yield b"\x01\x02"
@@ -305,8 +298,8 @@ async def mock_generate_pcm_chunks(_generator, _request_id):
ws.send_json({"type": "input.done"})
assert ws.receive_json() == {"type": "session.done", "total_sentences": 1}
- def test_invalid_input_text_type_returns_validation_error(self, mocker: MockerFixture):
- app, speech_service = _build_test_app(mocker=mocker)
+ def test_invalid_input_text_type_returns_validation_error(self):
+ app, speech_service = _build_test_app()
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -323,9 +316,9 @@ def test_invalid_input_text_type_returns_validation_error(self, mocker: MockerFi
assert speech_service._generate_audio_bytes.await_count == 0
- def test_input_text_message_too_large(self, monkeypatch, mocker: MockerFixture):
+ def test_input_text_message_too_large(self, monkeypatch):
monkeypatch.setattr(streaming_speech_module, "_MAX_INPUT_TEXT_MESSAGE_SIZE", 32)
- app, speech_service = _build_test_app(mocker=mocker)
+ app, speech_service = _build_test_app()
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -342,9 +335,9 @@ def test_input_text_message_too_large(self, monkeypatch, mocker: MockerFixture):
assert speech_service._generate_audio_bytes.await_count == 0
- def test_session_config_message_too_large(self, monkeypatch, mocker: MockerFixture):
+ def test_session_config_message_too_large(self, monkeypatch):
monkeypatch.setattr(streaming_speech_module, "_MAX_CONFIG_MESSAGE_SIZE", 64)
- app, _ = _build_test_app(mocker=mocker)
+ app, _ = _build_test_app()
with TestClient(app) as client:
with client.websocket_connect("/v1/audio/speech/stream") as ws:
@@ -355,12 +348,12 @@ def test_session_config_message_too_large(self, monkeypatch, mocker: MockerFixtu
"message": "session.config message too large",
}
- def test_disconnect_aborts_streaming_request(self, mocker: MockerFixture):
- speech_service = mocker.MagicMock(spec=OmniOpenAIServingSpeech)
- speech_service._generate_audio_bytes = mocker.AsyncMock(return_value=(b"", "audio/wav"))
- speech_service._prepare_speech_generation = mocker.AsyncMock(return_value=("req-abort", object(), {}))
- speech_service.engine_client = mocker.MagicMock()
- speech_service.engine_client.abort = mocker.AsyncMock()
+ def test_disconnect_aborts_streaming_request(self):
+ speech_service = MagicMock(spec=OmniOpenAIServingSpeech)
+ speech_service._generate_audio_bytes = AsyncMock(return_value=(b"", "audio/wav"))
+ speech_service._prepare_speech_generation = AsyncMock(return_value=("req-abort", object(), {}))
+ speech_service.engine_client = MagicMock()
+ speech_service.engine_client.abort = AsyncMock()
async def mock_generate_pcm_chunks(_generator, _request_id):
yield b"\x01\x02"
@@ -368,11 +361,11 @@ async def mock_generate_pcm_chunks(_generator, _request_id):
speech_service._generate_pcm_chunks = mock_generate_pcm_chunks
handler = OmniStreamingSpeechHandler(speech_service=speech_service)
- websocket = mocker.MagicMock()
- websocket.send_json = mocker.AsyncMock(side_effect=[None, WebSocketDisconnect()])
- websocket.send_bytes = mocker.AsyncMock(side_effect=WebSocketDisconnect())
+ websocket = MagicMock()
+ websocket.send_json = AsyncMock(side_effect=[None, WebSocketDisconnect()])
+ websocket.send_bytes = AsyncMock(side_effect=WebSocketDisconnect())
- config = mocker.MagicMock()
+ config = MagicMock()
config.model = None
config.voice = "Vivian"
config.task_type = None
@@ -392,18 +385,3 @@ async def mock_generate_pcm_chunks(_generator, _request_id):
speech_service.engine_client.abort.assert_awaited_once_with("req-abort")
assert websocket.send_json.await_count == 2
-
-
-class TestGeneratePcmChunksContract:
- """Guard: _generate_pcm_chunks must exist on OmniOpenAIServingSpeech.
-
- The WebSocket handler calls speech_service._generate_pcm_chunks()
- at runtime. If the method is removed, all WS TTS streaming breaks
- with an AttributeError. This test catches that at CI time.
- """
-
- def test_generate_pcm_chunks_defined(self):
- assert hasattr(OmniOpenAIServingSpeech, "_generate_pcm_chunks")
- assert asyncio.iscoroutinefunction(OmniOpenAIServingSpeech._generate_pcm_chunks) or callable(
- OmniOpenAIServingSpeech._generate_pcm_chunks
- )
diff --git a/tests/entrypoints/openai_api/test_serving_speech_voxcpm.py b/tests/entrypoints/openai_api/test_serving_speech_voxcpm.py
deleted file mode 100644
index 48660b6d1cd..00000000000
--- a/tests/entrypoints/openai_api/test_serving_speech_voxcpm.py
+++ /dev/null
@@ -1,143 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""UTs for VoxCPM OpenAI speech serving behavior."""
-
-import asyncio
-from types import SimpleNamespace
-from unittest.mock import AsyncMock
-
-import pytest
-from pytest_mock import MockerFixture
-
-from vllm_omni.entrypoints.openai.protocol.audio import OpenAICreateSpeechRequest
-from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-@pytest.fixture
-def voxcpm_server(mocker: MockerFixture):
- mocker.patch.object(OmniOpenAIServingSpeech, "_load_supported_speakers", return_value=set())
- mocker.patch.object(OmniOpenAIServingSpeech, "_load_codec_frame_rate", return_value=None)
-
- mock_engine_client = mocker.MagicMock()
- mock_engine_client.errored = False
- mock_engine_client.model_config = mocker.MagicMock(model="OpenBMB/VoxCPM1.5")
- mock_engine_client.default_sampling_params_list = [SimpleNamespace(max_tokens=2048)]
- mock_engine_client.tts_batch_max_items = 32
- mock_engine_client.generate = mocker.MagicMock(return_value="generator")
- mock_engine_client.stage_configs = [
- SimpleNamespace(
- engine_args=SimpleNamespace(
- model_stage="latent_generator",
- model_arch="VoxCPMForConditionalGeneration",
- ),
- tts_args={},
- ),
- SimpleNamespace(
- engine_args=SimpleNamespace(model_stage="vae"),
- tts_args={},
- ),
- ]
-
- mock_models = mocker.MagicMock()
- mock_models.is_base_model.return_value = True
-
- return OmniOpenAIServingSpeech(
- engine_client=mock_engine_client,
- models=mock_models,
- request_logger=mocker.MagicMock(),
- )
-
-
-class TestVoxCPMServing:
- def test_voxcpm_model_type_detection(self, voxcpm_server):
- assert voxcpm_server._tts_model_type == "voxcpm"
- assert voxcpm_server._is_tts is True
- assert voxcpm_server.supported_speakers == set()
-
- @pytest.mark.parametrize(
- ("request_kwargs", "expected_substring"),
- [
- ({"voice": "alice"}, "voice"),
- ({"instructions": "whisper"}, "instructions"),
- ({"language": "en"}, "language"),
- ({"task_type": "CustomVoice"}, "plain tts"),
- ({"x_vector_only_mode": True}, "x_vector_only_mode"),
- ({"speaker_embedding": [0.1, 0.2]}, "speaker_embedding"),
- ({"initial_codec_chunk_frames": 4}, "initial_codec_chunk_frames"),
- ({"ref_text": "reference"}, "ref_audio"),
- ],
- )
- def test_validate_voxcpm_rejects_unsupported_fields(self, voxcpm_server, request_kwargs, expected_substring):
- request = OpenAICreateSpeechRequest(input="hello voxcpm", **request_kwargs)
- error = voxcpm_server._validate_voxcpm_request(request)
- assert error is not None
- assert expected_substring in error.lower()
-
- def test_validate_voxcpm_accepts_plain_tts_request(self, voxcpm_server):
- request = OpenAICreateSpeechRequest(input="hello voxcpm", max_new_tokens=256)
- assert voxcpm_server._validate_voxcpm_request(request) is None
-
- def test_validate_voxcpm_accepts_voice_clone_request(self, voxcpm_server):
- request = OpenAICreateSpeechRequest(
- input="clone this voice",
- ref_audio="data:audio/wav;base64,QUJD",
- ref_text="reference transcript",
- max_new_tokens=256,
- )
- assert voxcpm_server._validate_voxcpm_request(request) is None
-
- def test_prepare_speech_generation_voxcpm_text_only(self, voxcpm_server):
- request = OpenAICreateSpeechRequest(input="hello voxcpm", max_new_tokens=321)
-
- request_id, generator, tts_params = asyncio.run(voxcpm_server._prepare_speech_generation(request))
-
- assert request_id.startswith("speech-")
- assert generator == "generator"
- assert tts_params == {
- "text": ["hello voxcpm"],
- "cfg_value": [2.0],
- "inference_timesteps": [10],
- "min_len": [2],
- "max_new_tokens": [321],
- }
-
- voxcpm_server.engine_client.generate.assert_called_once()
- call = voxcpm_server.engine_client.generate.call_args
- assert call.kwargs["prompt"] == {
- "prompt_token_ids": [1],
- "additional_information": tts_params,
- }
- assert call.kwargs["output_modalities"] == ["audio"]
-
- def test_prepare_speech_generation_voxcpm_voice_clone_resolves_ref_audio(self, voxcpm_server):
- voxcpm_server._resolve_ref_audio = AsyncMock(return_value=([0.1, -0.1, 0.2], 16000))
- request = OpenAICreateSpeechRequest(
- input="clone this voice",
- ref_audio="data:audio/wav;base64,QUJD",
- ref_text="reference transcript",
- max_new_tokens=512,
- )
-
- request_id, generator, tts_params = asyncio.run(voxcpm_server._prepare_speech_generation(request))
-
- assert request_id.startswith("speech-")
- assert generator == "generator"
- assert tts_params == {
- "text": ["clone this voice"],
- "cfg_value": [2.0],
- "inference_timesteps": [10],
- "min_len": [2],
- "max_new_tokens": [512],
- "ref_text": ["reference transcript"],
- "ref_audio": [[[0.1, -0.1, 0.2], 16000]],
- }
-
- voxcpm_server._resolve_ref_audio.assert_awaited_once_with("data:audio/wav;base64,QUJD")
- call = voxcpm_server.engine_client.generate.call_args
- assert call.kwargs["prompt"] == {
- "prompt_token_ids": [1],
- "additional_information": tts_params,
- }
diff --git a/tests/entrypoints/openai_api/test_serving_video_stream.py b/tests/entrypoints/openai_api/test_serving_video_stream.py
deleted file mode 100644
index 0787eb8562c..00000000000
--- a/tests/entrypoints/openai_api/test_serving_video_stream.py
+++ /dev/null
@@ -1,650 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Tests for the serving-layer streaming video WebSocket handler."""
-
-from __future__ import annotations
-
-import asyncio
-import base64
-import io
-import json
-import threading
-from typing import Any
-
-import pytest
-from PIL import Image
-
-from vllm_omni.entrypoints.openai import serving_video_stream, video_stream_envs
-from vllm_omni.entrypoints.openai.serving_video_stream import (
- OmniStreamingVideoHandler,
- StreamingVideoSessionConfig,
-)
-from vllm_omni.outputs import OmniRequestOutput
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def _make_jpeg(r: int = 128, g: int = 128, b: int = 128) -> bytes:
- img = Image.new("RGB", (64, 64), (r, g, b))
- buf = io.BytesIO()
- img.save(buf, format="JPEG", quality=95)
- return buf.getvalue()
-
-
-def _b64(data: bytes) -> str:
- return base64.b64encode(data).decode()
-
-
-def _text_result(text: str) -> OmniRequestOutput:
- class Output:
- pass
-
- class RequestOutput:
- pass
-
- output = Output()
- output.text = text
- request_output = RequestOutput()
- request_output.outputs = [output]
- return OmniRequestOutput(final_output_type="text", request_output=request_output)
-
-
-def _audio_result(audio_data: Any) -> OmniRequestOutput:
- class Output:
- pass
-
- class RequestOutput:
- pass
-
- output = Output()
- output.multimodal_output = {"audio": audio_data}
- request_output = RequestOutput()
- request_output.outputs = [output]
- return OmniRequestOutput(final_output_type="audio", request_output=request_output)
-
-
-class MockWebSocket:
- def __init__(self, messages: list[str] | None = None):
- self._messages = list(messages or [])
- self._idx = 0
- self.accepted = False
- self.sent: list[dict[str, Any]] = []
-
- async def accept(self):
- self.accepted = True
-
- async def receive_text(self) -> str:
- if self._idx >= len(self._messages):
- await asyncio.sleep(999)
- msg = self._messages[self._idx]
- self._idx += 1
- return msg
-
- async def send_json(self, data: dict[str, Any]):
- self.sent.append(data)
-
-
-class TimedWebSocket:
- def __init__(self):
- self._q: asyncio.Queue[str] = asyncio.Queue()
- self.accepted = False
- self.sent: list[dict[str, Any]] = []
-
- async def accept(self):
- self.accepted = True
-
- async def receive_text(self) -> str:
- return await self._q.get()
-
- async def send_json(self, data: dict[str, Any]):
- self.sent.append(data)
-
- def put(self, msg: dict[str, Any]):
- self._q.put_nowait(json.dumps(msg))
-
- def sent_types(self) -> list[str]:
- return [m.get("type", "") for m in self.sent]
-
-
-def test_api_server_registers_video_stream_route():
- from vllm_omni.entrypoints.openai.api_server import router
-
- assert any(getattr(route, "path", None) == "/v1/video/chat/stream" for route in router.routes)
-
-
-@pytest.mark.asyncio
-async def test_receive_config_accepts_client_legacy_aliases():
- ws = MockWebSocket(
- [
- json.dumps(
- {
- "type": "session.config",
- "model": "test",
- "num_sample_frames": 7,
- "evs_enabled": False,
- "evs_threshold": 0.87,
- }
- )
- ]
- )
- handler = OmniStreamingVideoHandler(chat_service=object())
-
- config = await handler._receive_config(ws)
-
- assert config is not None
- assert config.num_frames == 7
- assert config.enable_frame_filter is False
- assert config.frame_filter_threshold == 0.87
-
-
-@pytest.mark.asyncio
-async def test_audio_in_video_sets_mm_processor_kwargs():
- captured_requests = []
-
- class EmptyEngine:
- def generate(self, **_kwargs):
- async def _gen():
- if False:
- yield None
-
- return _gen()
-
- class CapturingHandler(OmniStreamingVideoHandler):
- async def _preprocess_to_engine_prompt(self, request):
- captured_requests.append(request)
- return {"prompt": "x"}
-
- ws = MockWebSocket()
- handler = CapturingHandler(chat_service=object(), engine_client=EmptyEngine())
- config = StreamingVideoSessionConfig(model="test", modalities=["text", "audio"], use_audio_in_video=True)
-
- await handler._process_query_engine(
- ws,
- config,
- [_b64(_make_jpeg())],
- bytearray(b"\x00\x00"),
- [],
- "what is happening?",
- "req-1",
- asyncio.Event(),
- {},
- )
-
- assert captured_requests
- assert captured_requests[0].mm_processor_kwargs == {"use_audio_in_video": True}
-
-
-@pytest.mark.asyncio
-async def test_audio_in_video_disabled_omits_mm_processor_kwargs():
- captured_requests = []
-
- class EmptyEngine:
- def generate(self, **_kwargs):
- async def _gen():
- if False:
- yield None
-
- return _gen()
-
- class CapturingHandler(OmniStreamingVideoHandler):
- async def _preprocess_to_engine_prompt(self, request):
- captured_requests.append(request)
- return {"prompt": "x"}
-
- ws = MockWebSocket()
- handler = CapturingHandler(chat_service=object(), engine_client=EmptyEngine())
- config = StreamingVideoSessionConfig(model="test", modalities=["text", "audio"], use_audio_in_video=False)
-
- await handler._process_query_engine(
- ws,
- config,
- [_b64(_make_jpeg())],
- bytearray(b"\x00\x00"),
- [],
- "what is happening?",
- "req-1",
- asyncio.Event(),
- {},
- )
-
- assert captured_requests
- assert captured_requests[0].mm_processor_kwargs is None
-
-
-@pytest.mark.asyncio
-async def test_query_inline_audio_data_sets_mm_processor_kwargs():
- captured_requests = []
-
- class EmptyEngine:
- def generate(self, **_kwargs):
- async def _gen():
- if False:
- yield None
-
- return _gen()
-
- class CapturingHandler(OmniStreamingVideoHandler):
- async def _preprocess_to_engine_prompt(self, request):
- captured_requests.append(request)
- return {"prompt": "x"}
-
- ws = MockWebSocket(
- [
- json.dumps({"type": "session.config", "model": "test"}),
- json.dumps({"type": "video.frame", "data": _b64(_make_jpeg())}),
- json.dumps(
- {
- "type": "video.query",
- "text": "describe",
- "audio_data": _b64(b"\x00\x00"),
- }
- ),
- json.dumps({"type": "video.done"}),
- ]
- )
- handler = CapturingHandler(chat_service=object(), engine_client=EmptyEngine(), idle_timeout=2.0)
-
- await handler.handle_session(ws)
-
- assert captured_requests
- assert captured_requests[0].mm_processor_kwargs == {"use_audio_in_video": True}
- assert "session.done" in [m.get("type") for m in ws.sent]
-
-
-def test_audio_delta_mode_is_read_by_serving_code_at_runtime(monkeypatch):
- handler = OmniStreamingVideoHandler(chat_service=object())
- result = _audio_result([object()])
-
- monkeypatch.setattr(
- OmniStreamingVideoHandler,
- "_delta_fast",
- classmethod(lambda cls, audio_data, chunks_drained: ("fast-path", chunks_drained)),
- )
- monkeypatch.setattr(
- OmniStreamingVideoHandler,
- "_delta_slow",
- classmethod(lambda cls, audio_data, chunks_drained: ("slow-path", chunks_drained)),
- )
-
- monkeypatch.setenv("VLLM_VIDEO_AUDIO_DELTA_MODE", "fast")
- assert handler._extract_audio_delta_b64(result, 0)[0] == "fast-path"
-
- monkeypatch.setenv("VLLM_VIDEO_AUDIO_DELTA_MODE", "slow")
- assert handler._extract_audio_delta_b64(result, 0)[0] == "slow-path"
-
-
-def test_video_stream_envs_strip_and_warn_once_per_invalid_value(monkeypatch):
- warnings = []
-
- video_stream_envs._warned_invalid_envs.clear()
- try:
- monkeypatch.setattr(
- video_stream_envs.logger,
- "warning",
- lambda message, *args, **_kwargs: warnings.append((message, args)),
- )
-
- monkeypatch.setenv("VLLM_VIDEO_ASYNC_CHUNK", " off ")
- assert video_stream_envs.VLLM_VIDEO_ASYNC_CHUNK == "off"
- assert not warnings
-
- monkeypatch.setenv("VLLM_VIDEO_ASYNC_CHUNK", "bad")
- assert video_stream_envs.VLLM_VIDEO_ASYNC_CHUNK == "on"
- assert video_stream_envs.VLLM_VIDEO_ASYNC_CHUNK == "on"
- assert len(warnings) == 1
-
- monkeypatch.setenv("VLLM_VIDEO_ASYNC_CHUNK", "still_bad")
- assert video_stream_envs.VLLM_VIDEO_ASYNC_CHUNK == "on"
- assert len(warnings) == 2
- finally:
- video_stream_envs._warned_invalid_envs.clear()
-
-
-@pytest.mark.asyncio
-async def test_async_chunk_mode_is_read_by_engine_path_at_runtime(monkeypatch):
- class TextEngine:
- def generate(self, **_kwargs):
- async def _gen():
- yield _text_result("hello")
-
- return _gen()
-
- class CapturingHandler(OmniStreamingVideoHandler):
- async def _preprocess_to_engine_prompt(self, request):
- return {"prompt": "x"}
-
- handler = CapturingHandler(chat_service=object(), engine_client=TextEngine())
- config = StreamingVideoSessionConfig(model="test", modalities=["text"])
-
- monkeypatch.setenv("VLLM_VIDEO_ASYNC_CHUNK", "on")
- ws_on = MockWebSocket()
- await handler._process_query_engine(
- ws_on,
- config,
- [_b64(_make_jpeg())],
- bytearray(),
- [],
- "describe",
- "req-on",
- asyncio.Event(),
- {},
- )
- assert {"type": "response.text.delta", "delta": "hello"} in ws_on.sent
-
- monkeypatch.setenv("VLLM_VIDEO_ASYNC_CHUNK", "off")
- ws_off = MockWebSocket()
- await handler._process_query_engine(
- ws_off,
- config,
- [_b64(_make_jpeg())],
- bytearray(),
- [],
- "describe",
- "req-off",
- asyncio.Event(),
- {},
- )
- assert {"type": "response.text.done", "text": "hello"} in ws_off.sent
- assert not any(m.get("type") == "response.text.delta" for m in ws_off.sent)
-
-
-@pytest.mark.asyncio
-async def test_query_without_engine_client_sends_error():
- ws = MockWebSocket()
- handler = OmniStreamingVideoHandler(chat_service=object(), engine_client=None)
-
- await handler._process_query(
- ws,
- StreamingVideoSessionConfig(model="test"),
- [],
- bytearray(),
- [],
- "describe",
- "req-1",
- asyncio.Event(),
- {},
- )
-
- assert {"type": "error", "message": "Streaming video requires an engine client"} in ws.sent
-
-
-@pytest.mark.asyncio
-async def test_new_query_cancels_in_flight_query():
- query_started = asyncio.Event()
- query_cancelled = asyncio.Event()
- calls = 0
-
- class BlockingHandler(OmniStreamingVideoHandler):
- async def _process_query(self, *args, **kwargs):
- nonlocal calls
- calls += 1
- if calls > 1:
- return
- query_started.set()
- try:
- await asyncio.sleep(999)
- except asyncio.CancelledError:
- query_cancelled.set()
- raise
-
- ws = TimedWebSocket()
- handler = BlockingHandler(chat_service=object(), idle_timeout=5.0)
- task = asyncio.create_task(handler.handle_session(ws))
-
- ws.put({"type": "session.config", "model": "test"})
- await asyncio.sleep(0)
- ws.put({"type": "video.frame", "data": _b64(_make_jpeg())})
- await asyncio.sleep(0)
- ws.put({"type": "video.query", "text": "describe"})
- await asyncio.wait_for(query_started.wait(), timeout=2.0)
-
- ws.put({"type": "video.query", "text": "interrupt"})
- await asyncio.wait_for(query_cancelled.wait(), timeout=2.0)
- ws.put({"type": "video.done"})
-
- await asyncio.wait_for(task, timeout=2.0)
- assert "session.done" in ws.sent_types()
-
-
-@pytest.mark.asyncio
-async def test_video_done_waits_for_in_flight_query():
- query_started = asyncio.Event()
- allow_finish = asyncio.Event()
- query_finished = asyncio.Event()
-
- class BlockingHandler(OmniStreamingVideoHandler):
- async def _process_query(self, *args, **kwargs):
- query_started.set()
- await allow_finish.wait()
- query_finished.set()
-
- ws = TimedWebSocket()
- handler = BlockingHandler(chat_service=object(), idle_timeout=5.0)
- task = asyncio.create_task(handler.handle_session(ws))
-
- ws.put({"type": "session.config", "model": "test"})
- await asyncio.sleep(0)
- ws.put({"type": "video.frame", "data": _b64(_make_jpeg())})
- await asyncio.sleep(0)
- ws.put({"type": "video.query", "text": "describe"})
- await asyncio.wait_for(query_started.wait(), timeout=2.0)
-
- ws.put({"type": "video.done"})
- await asyncio.sleep(0.05)
- assert not task.done()
- assert not query_finished.is_set()
-
- allow_finish.set()
- await asyncio.wait_for(task, timeout=2.0)
-
- assert query_finished.is_set()
- assert "session.done" in ws.sent_types()
-
-
-@pytest.mark.asyncio
-async def test_frame_prewarm_does_not_block_following_query(monkeypatch):
- decode_started = threading.Event()
- release_decode = threading.Event()
- query_started = asyncio.Event()
-
- def blocked_decode(raw_bytes: bytes):
- decode_started.set()
- release_decode.wait(timeout=2.0)
- return Image.open(io.BytesIO(raw_bytes)).convert("RGB")
-
- class BlockingHandler(OmniStreamingVideoHandler):
- async def _process_query(self, *args, **kwargs):
- query_started.set()
-
- monkeypatch.setattr(serving_video_stream, "_decode_frame_bytes", blocked_decode)
-
- ws = TimedWebSocket()
- handler = BlockingHandler(chat_service=object(), idle_timeout=5.0)
- task = asyncio.create_task(handler.handle_session(ws))
-
- ws.put({"type": "session.config", "model": "test"})
- await asyncio.sleep(0)
- ws.put({"type": "video.frame", "data": _b64(_make_jpeg())})
-
- for _ in range(100):
- if decode_started.is_set():
- break
- await asyncio.sleep(0.01)
- assert decode_started.is_set()
-
- ws.put({"type": "video.query", "text": "describe"})
- await asyncio.wait_for(query_started.wait(), timeout=2.0)
-
- release_decode.set()
- ws.put({"type": "video.done"})
- await asyncio.wait_for(task, timeout=2.0)
- assert "session.done" in ws.sent_types()
-
-
-@pytest.mark.asyncio
-async def test_client_cannot_send_internal_frame_decode_failed_message():
- captured_frames: list[list[str]] = []
- frame = _b64(_make_jpeg())
-
- class CapturingHandler(OmniStreamingVideoHandler):
- async def _process_query(
- self,
- websocket,
- config,
- frame_buffer,
- audio_buffer,
- message_history,
- query_text,
- request_id,
- interrupt_event,
- prewarmed_frames,
- ):
- captured_frames.append(list(frame_buffer))
-
- ws = TimedWebSocket()
- handler = CapturingHandler(chat_service=object(), idle_timeout=5.0)
- task = asyncio.create_task(handler.handle_session(ws))
-
- ws.put({"type": "session.config", "model": "test"})
- await asyncio.sleep(0)
- ws.put({"type": "video.frame", "data": frame})
- await asyncio.sleep(0)
- ws.put({"type": "_internal.frame_decode_failed", "b64": frame})
- await asyncio.sleep(0)
- ws.put({"type": "video.query", "text": "describe"})
- await asyncio.sleep(0)
- ws.put({"type": "video.done"})
- await asyncio.wait_for(task, timeout=2.0)
-
- assert {"type": "error", "message": "Unknown type: _internal.frame_decode_failed"} in ws.sent
- assert captured_frames == [[frame]]
-
-
-@pytest.mark.asyncio
-async def test_failed_frame_prewarm_removes_frame_before_query():
- ws = TimedWebSocket()
- handler = OmniStreamingVideoHandler(chat_service=object(), idle_timeout=5.0)
- task = asyncio.create_task(handler.handle_session(ws))
-
- ws.put({"type": "session.config", "model": "test", "enable_frame_filter": False})
- await asyncio.sleep(0)
- ws.put({"type": "video.frame", "data": _b64(b"not-a-jpeg")})
-
- for _ in range(100):
- if any(m.get("message") == "Frame decode failed" for m in ws.sent):
- break
- await asyncio.sleep(0.01)
-
- assert {"type": "error", "message": "Frame decode failed"} in ws.sent
-
- ws.put({"type": "video.query", "text": "describe"})
- await asyncio.sleep(0)
- ws.put({"type": "video.done"})
- await asyncio.wait_for(task, timeout=2.0)
-
- assert {"type": "error", "message": "No frames buffered"} in ws.sent
-
-
-@pytest.mark.asyncio
-async def test_frame_filter_error_sends_invalid_image(monkeypatch):
- def fail_should_retain(self, frame_jpeg):
- raise ValueError("decode failed")
-
- monkeypatch.setattr(serving_video_stream.FrameSimilarityFilter, "should_retain", fail_should_retain)
-
- ws = TimedWebSocket()
- handler = OmniStreamingVideoHandler(chat_service=object(), idle_timeout=5.0)
- task = asyncio.create_task(handler.handle_session(ws))
-
- ws.put({"type": "session.config", "model": "test"})
- await asyncio.sleep(0)
- ws.put({"type": "video.frame", "data": _b64(_make_jpeg())})
- await asyncio.sleep(0)
- ws.put({"type": "video.done"})
- await asyncio.wait_for(task, timeout=2.0)
-
- assert {"type": "error", "message": "Invalid image data"} in ws.sent
- assert "session.done" in ws.sent_types()
-
-
-@pytest.mark.asyncio
-async def test_audio_buffer_overflow_clears_buffer_before_query(monkeypatch):
- captured_audio_lengths: list[int] = []
-
- class EmptyEngine:
- def generate(self, **_kwargs):
- async def _gen():
- if False:
- yield None
-
- return _gen()
-
- class CapturingHandler(OmniStreamingVideoHandler):
- async def _process_query_engine(
- self,
- websocket,
- config,
- frame_buffer,
- audio_buffer,
- message_history,
- query_text,
- request_id,
- interrupt_event,
- prewarmed_frames,
- ):
- captured_audio_lengths.append(len(audio_buffer))
-
- monkeypatch.setattr(serving_video_stream, "_MAX_AUDIO_BUFFER_BYTES", 4)
-
- ws = TimedWebSocket()
- handler = CapturingHandler(chat_service=object(), engine_client=EmptyEngine(), idle_timeout=5.0)
- task = asyncio.create_task(handler.handle_session(ws))
-
- ws.put({"type": "session.config", "model": "test"})
- await asyncio.sleep(0)
- ws.put({"type": "audio.chunk", "data": _b64(b"1234")})
- await asyncio.sleep(0)
- ws.put({"type": "audio.chunk", "data": _b64(b"5")})
- await asyncio.sleep(0)
- ws.put({"type": "video.frame", "data": _b64(_make_jpeg())})
- await asyncio.sleep(0)
- ws.put({"type": "video.query", "text": "describe"})
- await asyncio.sleep(0)
- ws.put({"type": "video.done"})
- await asyncio.wait_for(task, timeout=2.0)
-
- assert {"type": "error", "message": "Audio buffer overflow"} in ws.sent
- assert captured_audio_lengths == [0]
-
-
-def test_build_messages_keeps_recent_history_text_only():
- handler = OmniStreamingVideoHandler(chat_service=object())
- old_frame = _b64(_make_jpeg(1, 2, 3))
- current_frame = _b64(_make_jpeg(4, 5, 6))
- history = [
- {"role": "user", "content": [{"type": "text", "text": "old question"}]},
- {"role": "assistant", "content": "old answer"},
- {
- "role": "user",
- "content": [
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{old_frame}"}},
- {"type": "input_audio", "input_audio": {"data": "ignored", "format": "wav"}},
- {"type": "text", "text": "recent question"},
- ],
- },
- {"role": "assistant", "content": "recent answer"},
- ]
-
- messages, user_message = handler._build_messages(
- StreamingVideoSessionConfig(model="test", num_frames=1),
- [current_frame],
- bytearray(),
- history,
- "current question",
- {},
- )
-
- assert messages[0] == {"role": "user", "content": "recent question"}
- assert messages[1] == {"role": "assistant", "content": "recent answer"}
- assert messages[2] == user_message
- assert user_message["content"][-1] == {"type": "text", "text": "current question"}
diff --git a/tests/entrypoints/openai_api/test_text_splitter.py b/tests/entrypoints/openai_api/test_text_splitter.py
index b9022e015dd..23d4d191fc2 100644
--- a/tests/entrypoints/openai_api/test_text_splitter.py
+++ b/tests/entrypoints/openai_api/test_text_splitter.py
@@ -4,7 +4,7 @@
from vllm_omni.entrypoints.openai.text_splitter import SentenceSplitter
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+pytestmark = [pytest.mark.openai, pytest.mark.speech]
class TestSentenceSplitterEnglish:
diff --git a/tests/entrypoints/openai_api/test_video_api_utils.py b/tests/entrypoints/openai_api/test_video_api_utils.py
deleted file mode 100644
index 9e732403fbb..00000000000
--- a/tests/entrypoints/openai_api/test_video_api_utils.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for OpenAI-compatible video API encoding helpers."""
-
-import numpy as np
-import pytest
-import torch
-
-from vllm_omni.diffusion.postprocess import rife_interpolator
-from vllm_omni.entrypoints.openai import video_api_utils
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def _install_fake_video_mux(monkeypatch, mux_calls):
- def _fake_mux_video_audio_bytes(frames, audio, fps, audio_sample_rate, video_codec_options=None):
- mux_calls.append(
- {
- "frames": frames,
- "audio": audio,
- "fps": fps,
- "audio_sample_rate": audio_sample_rate,
- "video_codec_options": video_codec_options,
- }
- )
- return b"fake-video"
-
- monkeypatch.setattr(
- "vllm_omni.diffusion.utils.media_utils.mux_video_audio_bytes",
- _fake_mux_video_audio_bytes,
- )
-
-
-def test_encode_video_bytes_exports_frames_without_interpolation(monkeypatch):
- mux_calls = []
- _install_fake_video_mux(monkeypatch, mux_calls)
-
- frames = [np.full((2, 2, 3), fill_value=i / 5, dtype=np.float32) for i in range(5)]
- video_bytes = video_api_utils._encode_video_bytes(
- frames,
- fps=8,
- )
-
- assert video_bytes == b"fake-video"
- assert mux_calls[0]["frames"].shape == (5, 2, 2, 3)
- assert mux_calls[0]["frames"].dtype == np.uint8
- assert mux_calls[0]["fps"] == 8.0
- assert mux_calls[0]["audio"] is None
-
-
-def test_rife_model_inference_runs_on_dummy_tensors():
- model = rife_interpolator.Model().eval()
- img0 = torch.rand(1, 3, 32, 32)
- img1 = torch.rand(1, 3, 32, 32)
-
- output = model.inference(img0, img1, scale=1.0)
-
- assert output.shape == (1, 3, 32, 32)
- assert torch.isfinite(output).all()
-
-
-def test_frame_interpolator_runs_actual_torch_tensor_path(monkeypatch):
- model = rife_interpolator.Model().eval()
- interpolator = rife_interpolator.FrameInterpolator()
- monkeypatch.setattr(interpolator, "_ensure_model_loaded", lambda preferred_device=None: model)
-
- video = torch.zeros(1, 3, 2, 32, 32)
- output_video, multiplier = interpolator.interpolate_tensor(video, exp=1, scale=1.0)
-
- assert multiplier == 2
- assert output_video.shape == (1, 3, 3, 32, 32)
- assert torch.isfinite(output_video).all()
-
-
-def test_frame_interpolator_uses_platform_device_when_tensor_is_cpu(monkeypatch):
- chosen_devices = []
- model = rife_interpolator.Model().eval()
-
- def _fake_ensure_model_loaded(*, preferred_device=None):
- chosen_devices.append(preferred_device)
- return model
-
- interpolator = rife_interpolator.FrameInterpolator()
- monkeypatch.setattr(interpolator, "_ensure_model_loaded", _fake_ensure_model_loaded)
- monkeypatch.setattr(model.flownet, "to", lambda device: model.flownet)
- monkeypatch.setattr(rife_interpolator, "_select_torch_device", lambda: torch.device("cuda"))
-
- video = torch.zeros(1, 3, 2, 32, 32)
- output_video, multiplier = interpolator.interpolate_tensor(video, exp=1, scale=1.0)
-
- assert chosen_devices == [torch.device("cuda")]
- assert multiplier == 2
- assert output_video.shape == (1, 3, 3, 32, 32)
diff --git a/tests/entrypoints/openai_api/test_video_frame_filter.py b/tests/entrypoints/openai_api/test_video_frame_filter.py
deleted file mode 100644
index a734aec75c0..00000000000
--- a/tests/entrypoints/openai_api/test_video_frame_filter.py
+++ /dev/null
@@ -1,89 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for FrameSimilarityFilter (Phase 2 EVS)."""
-
-from __future__ import annotations
-
-import pytest
-
-from tests.entrypoints.openai_api.conftest_video import (
- make_gradient_jpeg,
- make_jpeg,
-)
-from vllm_omni.entrypoints.openai.video_frame_filter import FrameSimilarityFilter
-
-
-class TestFrameSimilarityFilter:
- def test_first_frame_always_retained(self):
- f = FrameSimilarityFilter(threshold=0.99)
- assert f.should_retain(make_jpeg()) is True
-
- def test_identical_frames_dropped(self):
- f = FrameSimilarityFilter(threshold=0.90)
- frame = make_jpeg(100, 100, 100)
- assert f.should_retain(frame) is True
- assert f.should_retain(frame) is False
-
- def test_very_different_frames_retained(self):
- f = FrameSimilarityFilter(threshold=0.95)
- assert f.should_retain(make_jpeg(255, 255, 255)) is True
- assert f.should_retain(make_jpeg(0, 0, 0)) is True
-
- def test_low_threshold_keeps_slightly_different(self):
- f = FrameSimilarityFilter(threshold=0.50)
- assert f.should_retain(make_jpeg(100, 100, 100)) is True
- assert f.should_retain(make_jpeg(110, 110, 110)) is True
-
- def test_random_frames_all_retained(self):
- f = FrameSimilarityFilter(threshold=0.95)
- for i in range(5):
- assert f.should_retain(make_gradient_jpeg(seed=i)) is True
-
- def test_reset_clears_state(self):
- f = FrameSimilarityFilter(threshold=0.90)
- frame = make_jpeg(128, 128, 128)
- assert f.should_retain(frame) is True
- assert f.should_retain(frame) is False
- f.reset()
- assert f.should_retain(frame) is True
-
- def test_stats_counting(self):
- f = FrameSimilarityFilter(threshold=0.90)
- frame = make_jpeg(50, 50, 50)
- f.should_retain(frame) # retained
- f.should_retain(frame) # dropped
- f.should_retain(frame) # dropped
- stats = f.stats
- assert stats["retained_count"] == 1
- assert stats["dropped_count"] == 2
- assert stats["total_count"] == 3
- assert abs(stats["drop_rate"] - 2.0 / 3.0) < 1e-6
-
- def test_stats_empty(self):
- stats = FrameSimilarityFilter().stats
- assert stats["total_count"] == 0
- assert stats["drop_rate"] == 0.0
-
- def test_stats_reset(self):
- f = FrameSimilarityFilter(threshold=0.90)
- f.should_retain(make_jpeg())
- f.reset()
- assert f.stats["total_count"] == 0
-
- def test_invalid_threshold(self):
- with pytest.raises(ValueError, match="threshold"):
- FrameSimilarityFilter(threshold=1.5)
- with pytest.raises(ValueError, match="threshold"):
- FrameSimilarityFilter(threshold=-0.1)
-
- def test_invalid_thumbnail_size(self):
- with pytest.raises(ValueError, match="thumbnail_size"):
- FrameSimilarityFilter(thumbnail_size=0)
-
- def test_different_image_sizes_same_colour(self):
- """Filter should handle frames of varying resolutions."""
- f = FrameSimilarityFilter(threshold=0.90)
- small = make_jpeg(100, 100, 100, size=32)
- large = make_jpeg(100, 100, 100, size=256)
- assert f.should_retain(small) is True
- assert f.should_retain(large) is False
diff --git a/tests/entrypoints/openai_api/test_video_server.py b/tests/entrypoints/openai_api/test_video_server.py
index a29f4493c28..7200b38abb8 100644
--- a/tests/entrypoints/openai_api/test_video_server.py
+++ b/tests/entrypoints/openai_api/test_video_server.py
@@ -29,39 +29,22 @@
from vllm_omni.entrypoints.openai.serving_video import OmniOpenAIServingVideo
from vllm_omni.entrypoints.openai.storage import LocalStorageManager
from vllm_omni.entrypoints.openai.stores import AsyncDictStore, TaskRegistry
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
class MockVideoResult:
- def __init__(
- self,
- videos,
- audios=None,
- sample_rate=None,
- custom_output=None,
- stage_durations=None,
- peak_memory_mb=0.0,
- ):
+ def __init__(self, videos, audios=None, sample_rate=None):
self.multimodal_output = {"video": videos}
if audios is not None:
self.multimodal_output["audio"] = audios
if sample_rate is not None:
self.multimodal_output["audio_sample_rate"] = sample_rate
- self._custom_output = custom_output or {}
- self.stage_durations = stage_durations or {}
- self.peak_memory_mb = peak_memory_mb
-
- @property
- def custom_output(self):
- return self._custom_output
class FakeAsyncOmni:
def __init__(self):
self.stage_configs = [SimpleNamespace(stage_type="diffusion")]
- self.default_sampling_params_list = [OmniDiffusionSamplingParams()]
self.captured_prompt = None
self.captured_sampling_params_list = None
@@ -84,7 +67,7 @@ def set_stage_configs_if_missing(self, stage_configs):
if self.stage_configs is None:
self.stage_configs = stage_configs
- async def generate_video_bytes(self, request, reference_id, *, reference_image=None):
+ async def generate_videos(self, request, reference_id, *, reference_image=None):
self.started.set()
try:
await asyncio.Future()
@@ -135,6 +118,7 @@ def _wait_for_status(client: TestClient, video_id: str, status: str, timeout_s:
last_payload = None
while time.time() < deadline:
response = client.get(f"/v1/videos/{video_id}")
+ assert response.status_code == 200
last_payload = response.json()
if last_payload["status"] == status:
return last_payload
@@ -151,81 +135,15 @@ def _wait_until(predicate, timeout_s: float = 2.0, interval_s: float = 0.02):
raise AssertionError("Timed out waiting for condition")
-def test_async_video_generation_bypasses_base64(test_client, mocker: MockerFixture):
- """Regression test: Ensure async video generation saves raw bytes directly
- without bouncing through base64 encoding."""
- # We mock _encode_video_bytes (the correct path)
- mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"raw-mp4-bytes",
- )
-
- # We assert that encode_video_base64 is never called
- mock_base64 = mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
- side_effect=RuntimeError("Regression: async video path should not base64 encode"),
- )
-
- response = test_client.post(
- "/v1/videos",
- data={"prompt": "A base64 test."},
- )
- assert response.status_code == 200
- video_id = response.json()["id"]
-
- # Wait for completion. If it used base64, the RuntimeError would fail the task
- _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)
- mock_base64.assert_not_called()
-
-
-def test_async_video_generation_with_audio_bypasses_base64(test_client, mocker: MockerFixture):
- """Regression test: Ensure async video generation passes audio through
- generate_video_bytes without bouncing through base64 encoding."""
- mock_encode = mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"raw-mp4-bytes",
- )
-
- mock_base64 = mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
- side_effect=RuntimeError("Regression: async video path should not base64 encode"),
- )
-
- engine = test_client.app.state.openai_serving_video._engine_client
-
- async def _generate(prompt, request_id, sampling_params_list):
- engine.captured_prompt = prompt
- engine.captured_sampling_params_list = sampling_params_list
- yield MockVideoResult([object()], audios=[object()], sample_rate=48000)
-
- engine.generate = _generate
-
- response = test_client.post(
- "/v1/videos",
- data={"prompt": "A base64 test with audio."},
- )
- assert response.status_code == 200
- video_id = response.json()["id"]
-
- _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)
- mock_base64.assert_not_called()
-
- mock_encode.assert_called_once()
- kwargs = mock_encode.call_args.kwargs
- assert "audio" in kwargs
- assert kwargs["audio"] is not None
- assert kwargs["audio_sample_rate"] == 48000
-
-
def test_t2v_video_generation_form(test_client, mocker: MockerFixture):
fps_values = []
- def _fake_encode(video, fps, audio=None, audio_sample_rate=None, **kwargs):
+ def _fake_encode(video, fps):
fps_values.append(fps)
- return b"fake-video"
+ return "Zg=="
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
side_effect=_fake_encode,
)
response = test_client.post(
@@ -257,8 +175,8 @@ def test_i2v_video_generation_form(test_client, mocker: MockerFixture):
image_bytes = _make_test_image_bytes((48, 32))
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"fake-video",
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ return_value="Zg==",
)
response = test_client.post(
"/v1/videos",
@@ -283,8 +201,8 @@ def test_i2v_video_generation_resizes_input_to_requested_dimensions(test_client,
image_bytes = _make_test_image_bytes((48, 32))
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"fake-video",
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ return_value="Zg==",
)
response = test_client.post(
"/v1/videos",
@@ -309,8 +227,8 @@ def test_i2v_video_generation_resizes_input_to_requested_dimensions(test_client,
def test_i2v_video_generation_with_image_reference_form(test_client, mocker: MockerFixture):
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"fake-video",
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ return_value="Zg==",
)
response = test_client.post(
"/v1/videos",
@@ -334,12 +252,12 @@ def test_i2v_video_generation_with_image_reference_form(test_client, mocker: Moc
def test_seconds_defaults_fps_and_frames(test_client, mocker: MockerFixture):
fps_values = []
- def _fake_encode(video, fps, audio=None, audio_sample_rate=None, **kwargs):
+ def _fake_encode(video, fps):
fps_values.append(fps)
- return b"fake-video"
+ return "Zg=="
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
side_effect=_fake_encode,
)
response = test_client.post(
@@ -363,8 +281,8 @@ def _fake_encode(video, fps, audio=None, audio_sample_rate=None, **kwargs):
def test_size_param_sets_width_height(test_client, mocker: MockerFixture):
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"fake-video",
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ return_value="Zg==",
)
response = test_client.post(
"/v1/videos",
@@ -385,8 +303,8 @@ def test_size_param_sets_width_height(test_client, mocker: MockerFixture):
def test_sampling_params_pass_through(test_client, mocker: MockerFixture):
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"fake-video",
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ return_value="Zg==",
)
response = test_client.post(
"/v1/videos",
@@ -414,155 +332,13 @@ def test_sampling_params_pass_through(test_client, mocker: MockerFixture):
assert captured.extra_args["flow_shift"] == 0.25
-def test_frame_interpolation_params_pass_to_diffusion_sampling_params(test_client, mocker: MockerFixture):
- """Frame interpolation parameters should be forwarded to diffusion worker sampling params."""
- mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"fake-video",
- )
- response = test_client.post(
- "/v1/videos",
- data={
- "prompt": "smooth motion",
- "fps": "8",
- "enable_frame_interpolation": "true",
- "frame_interpolation_exp": "2",
- "frame_interpolation_scale": "0.5",
- "frame_interpolation_model_path": "local-rife",
- },
- )
-
- assert response.status_code == 200
- video_id = response.json()["id"]
- _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)
-
- engine = test_client.app.state.openai_serving_video._engine_client
- captured = engine.captured_sampling_params_list[0]
- assert captured.enable_frame_interpolation is True
- assert captured.frame_interpolation_exp == 2
- assert captured.frame_interpolation_scale == 0.5
- assert captured.frame_interpolation_model_path == "local-rife"
-
-
-def test_default_sampling_params_apply_to_video_requests(test_client, mocker: MockerFixture):
- mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"fake-video",
- )
- engine = test_client.app.state.openai_serving_video._engine_client
- engine.default_sampling_params_list = [
- OmniDiffusionSamplingParams(
- num_inference_steps=4,
- guidance_scale=7.5,
- generator_device="cpu",
- enable_frame_interpolation=True,
- frame_interpolation_exp=2,
- frame_interpolation_scale=0.5,
- frame_interpolation_model_path="default-rife",
- )
- ]
-
- response = test_client.post(
- "/v1/videos",
- data={
- "prompt": "default param pass-through",
- },
- )
-
- assert response.status_code == 200
- video_id = response.json()["id"]
- _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)
-
- captured = engine.captured_sampling_params_list[0]
- assert captured.num_inference_steps == 4
- assert captured.guidance_scale == 7.5
- assert captured.generator_device == "cpu"
- assert captured.enable_frame_interpolation is True
- assert captured.frame_interpolation_exp == 2
- assert captured.frame_interpolation_scale == 0.5
- assert captured.frame_interpolation_model_path == "default-rife"
-
-
-def test_request_params_override_default_video_sampling_params(test_client, mocker: MockerFixture):
- mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"fake-video",
- )
- engine = test_client.app.state.openai_serving_video._engine_client
- engine.default_sampling_params_list = [
- OmniDiffusionSamplingParams(
- num_inference_steps=4,
- guidance_scale=7.5,
- enable_frame_interpolation=True,
- frame_interpolation_exp=2,
- frame_interpolation_scale=0.5,
- frame_interpolation_model_path="default-rife",
- )
- ]
-
- response = test_client.post(
- "/v1/videos",
- data={
- "prompt": "explicit override",
- "num_inference_steps": "8",
- "enable_frame_interpolation": "false",
- "frame_interpolation_exp": "1",
- "frame_interpolation_scale": "1.0",
- "frame_interpolation_model_path": "custom-rife",
- },
- )
-
- assert response.status_code == 200
- video_id = response.json()["id"]
- _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)
-
- captured = engine.captured_sampling_params_list[0]
- assert captured.num_inference_steps == 8
- assert captured.guidance_scale == 7.5
- assert captured.enable_frame_interpolation is False
- assert captured.frame_interpolation_exp == 1
- assert captured.frame_interpolation_scale == 1.0
- assert captured.frame_interpolation_model_path == "custom-rife"
-
-
-def test_worker_fps_multiplier_is_applied_to_async_encoding(test_client, mocker: MockerFixture):
- fps_values = []
- engine = test_client.app.state.openai_serving_video._engine_client
-
- async def _generate(prompt, request_id, sampling_params_list):
- engine.captured_prompt = prompt
- engine.captured_sampling_params_list = sampling_params_list
- import numpy as np
-
- yield MockVideoResult([np.zeros((1, 64, 64, 3), dtype=np.uint8)], custom_output={"video_fps_multiplier": 2})
-
- engine.generate = _generate
-
- def _fake_encode(video, fps, **kwargs):
- del video, kwargs
- fps_values.append(fps)
- return b"fake-video"
-
- mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- side_effect=_fake_encode,
- )
-
- response = test_client.post("/v1/videos", data={"prompt": "fps multiplier", "fps": "8"})
-
- assert response.status_code == 200
- video_id = response.json()["id"]
- _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)
- assert fps_values == [16]
-
-
def test_audio_sample_rate_comes_from_model_config(test_client, mocker: MockerFixture):
audio_sample_rates = []
- def _fake_encode(video, fps, audio=None, audio_sample_rate=None, video_codec_options=None):
- del video, fps, audio, video_codec_options
+ def _fake_encode(video, fps, audio=None, audio_sample_rate=None):
+ del video, fps, audio
audio_sample_rates.append(audio_sample_rate)
- return b"fake-video"
+ return "Zg=="
engine = test_client.app.state.openai_serving_video._engine_client
engine.model_config = SimpleNamespace(
@@ -576,14 +352,12 @@ def _fake_encode(video, fps, audio=None, audio_sample_rate=None, video_codec_opt
async def _generate(prompt, request_id, sampling_params_list):
engine.captured_prompt = prompt
engine.captured_sampling_params_list = sampling_params_list
- import numpy as np
-
- yield MockVideoResult([np.zeros((1, 64, 64, 3), dtype=np.uint8)], audios=[object()])
+ yield MockVideoResult([object()], audios=[object()])
engine.generate = _generate
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
side_effect=_fake_encode,
)
response = test_client.post(
@@ -597,33 +371,6 @@ async def _generate(prompt, request_id, sampling_params_list):
assert audio_sample_rates == [16000]
-def test_video_job_persists_profiler_metadata(test_client, mocker: MockerFixture):
- engine = test_client.app.state.openai_serving_video._engine_client
-
- async def _generate(prompt, request_id, sampling_params_list):
- engine.captured_prompt = prompt
- engine.captured_sampling_params_list = sampling_params_list
- yield MockVideoResult(
- [object()],
- stage_durations={"diffuse": 2.5, "vae.decode": 0.3},
- peak_memory_mb=4096.5,
- )
-
- engine.generate = _generate
- mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"fake-video",
- )
-
- response = test_client.post("/v1/videos", data={"prompt": "profile me"})
- assert response.status_code == 200
- video_id = response.json()["id"]
- completed = _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)
-
- assert completed["stage_durations"] == {"diffuse": 2.5, "vae.decode": 0.3}
- assert completed["peak_memory_mb"] == 4096.5
-
-
def test_missing_handler_returns_503():
app = FastAPI()
app.include_router(router)
@@ -646,18 +393,6 @@ def test_missing_prompt_returns_422(test_client):
assert response.status_code == 422
-def test_video_generation_rejects_model_mismatch(test_client):
- response = test_client.post(
- "/v1/videos",
- data={
- "prompt": "bad model",
- "model": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
- },
- )
- assert response.status_code == 400
- assert "model mismatch" in response.json()["detail"].lower()
-
-
def test_invalid_size_parse_returns_422(test_client):
response = test_client.post(
"/v1/videos",
@@ -693,8 +428,8 @@ def test_invalid_seconds_returns_422(test_client):
def test_negative_prompt_and_seed_pass_through(test_client, mocker: MockerFixture):
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"fake-video",
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ return_value="Zg==",
)
response = test_client.post(
"/v1/videos",
@@ -726,7 +461,7 @@ def test_invalid_lora_returns_400(test_client):
assert response.status_code == 200
video_id = response.json()["id"]
failed = _wait_for_status(test_client, video_id, VideoGenerationStatus.FAILED.value)
- assert failed["error"]["code"] == 400
+ assert failed["error"]["code"] == "HTTPException"
assert "lora object" in failed["error"]["message"].lower()
@@ -763,16 +498,12 @@ def test_video_request_validation():
with pytest.raises(ValueError):
VideoGenerationRequest(prompt="test", image_reference={"file_id": "file-1", "image_url": "https://example.com"})
- with pytest.raises(ValueError):
- VideoGenerationRequest(prompt="test", frame_interpolation_exp=0)
- with pytest.raises(ValueError):
- VideoGenerationRequest(prompt="test", frame_interpolation_scale=0)
def test_list_videos_supports_order_after_and_limit(test_client, mocker: MockerFixture):
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"fake-video",
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ return_value="Zg==",
)
ids = []
for i in range(3):
@@ -840,8 +571,8 @@ def test_list_videos_supports_order_after_and_limit(test_client, mocker: MockerF
def test_delete_completed_job_removes_file_and_metadata(test_client, mocker: MockerFixture):
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"fake-video",
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ return_value="Zg==",
)
create_resp = test_client.post("/v1/videos", data={"prompt": "Delete this video"})
assert create_resp.status_code == 200
@@ -912,8 +643,8 @@ def test_video_response_file_extension_is_robust():
def test_extra_params_merged_into_extra_args(test_client, mocker: MockerFixture):
"""extra_params JSON object is merged into sampling_params.extra_args."""
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"fake-video",
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ return_value="Zg==",
)
extra_params = {
"is_enable_stage2": True,
@@ -943,8 +674,8 @@ def test_extra_params_merged_into_extra_args(test_client, mocker: MockerFixture)
def test_extra_params_none_by_default(test_client, mocker: MockerFixture):
"""When extra_params is omitted, extra_args stays empty."""
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"fake-video",
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ return_value="Zg==",
)
response = test_client.post(
"/v1/videos",
@@ -984,8 +715,8 @@ def test_extra_params_invalid_json(test_client):
def test_extra_params_merged_with_existing_extra_args(test_client, mocker: MockerFixture):
"""extra_params is merged on top of existing extra_args (e.g. flow_shift)."""
mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"fake-video",
+ "vllm_omni.entrypoints.openai.serving_video.encode_video_base64",
+ return_value="Zg==",
)
response = test_client.post(
"/v1/videos",
@@ -1006,28 +737,6 @@ def test_extra_params_merged_with_existing_extra_args(test_client, mocker: Mocke
assert captured.extra_args["zero_steps"] == 2
-def test_sample_solver_forwarded_via_extra_params(test_client, mocker: MockerFixture):
- """sample_solver can be passed through existing extra_params for Wan2.2 online serving."""
- mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- return_value=b"fake-video",
- )
- response = test_client.post(
- "/v1/videos",
- data={
- "prompt": "A fox running through snow.",
- "extra_params": json.dumps({"sample_solver": "euler"}),
- },
- )
-
- assert response.status_code == 200
- video_id = response.json()["id"]
- _wait_for_status(test_client, video_id, VideoGenerationStatus.COMPLETED.value)
- engine = test_client.app.state.openai_serving_video._engine_client
- captured = engine.captured_sampling_params_list[0]
- assert captured.extra_args["sample_solver"] == "euler"
-
-
# ---------------------------------------------------------------------------
# Sync endpoint tests (POST /v1/videos/sync)
# ---------------------------------------------------------------------------
@@ -1061,31 +770,6 @@ def test_sync_t2v_returns_video_bytes(test_client, mocker: MockerFixture):
assert response.headers["x-request-id"].startswith("video_sync-")
assert response.headers["x-model"] == "Wan-AI/Wan2.2-T2V-A14B-Diffusers"
assert float(response.headers["x-inference-time-s"]) >= 0
- assert json.loads(response.headers["x-stage-durations"]) == {}
- assert float(response.headers["x-peak-memory-mb"]) == 0.0
-
-
-def test_sync_t2v_returns_profiler_headers(test_client, mocker: MockerFixture):
- engine = test_client.app.state.openai_serving_video._engine_client
-
- async def _generate(prompt, request_id, sampling_params_list):
- engine.captured_prompt = prompt
- engine.captured_sampling_params_list = sampling_params_list
- yield MockVideoResult(
- [object()],
- stage_durations={"diffuse": 1.75},
- peak_memory_mb=1234.25,
- )
-
- engine.generate = _generate
- _mock_encode_video_bytes(mocker, b"profiled-video")
-
- response = test_client.post("/v1/videos/sync", data={"prompt": "sync profile"})
-
- assert response.status_code == 200
- assert response.content == b"profiled-video"
- assert json.loads(response.headers["x-stage-durations"]) == {"diffuse": 1.75}
- assert float(response.headers["x-peak-memory-mb"]) == pytest.approx(1234.25, rel=0, abs=1e-3)
def test_sync_i2v_returns_video_bytes(test_client, mocker: MockerFixture):
@@ -1204,90 +888,3 @@ def test_sync_sampling_params_pass_through(test_client, mocker: MockerFixture):
assert captured.num_inference_steps == 30
assert captured.guidance_scale == 6.5
assert captured.seed == 42
-
-
-def test_sync_frame_interpolation_params_pass_to_sampling_params(test_client, mocker: MockerFixture):
- """Frame interpolation parameters should be forwarded on the sync path."""
- encode_mock = _mock_encode_video_bytes(mocker)
- response = test_client.post(
- "/v1/videos/sync",
- data={
- "prompt": "smooth sync",
- "fps": "8",
- "enable_frame_interpolation": "true",
- "frame_interpolation_exp": "2",
- "frame_interpolation_scale": "0.5",
- "frame_interpolation_model_path": "local-rife",
- },
- )
-
- assert response.status_code == 200
- engine = test_client.app.state.openai_serving_video._engine_client
- captured = engine.captured_sampling_params_list[0]
- assert captured.enable_frame_interpolation is True
- assert captured.frame_interpolation_exp == 2
- assert captured.frame_interpolation_scale == 0.5
- assert captured.frame_interpolation_model_path == "local-rife"
- _, kwargs = encode_mock.call_args
- assert kwargs["fps"] == 8
-
-
-def test_sync_default_sampling_params_apply_to_video_requests(test_client, mocker: MockerFixture):
- _mock_encode_video_bytes(mocker)
- engine = test_client.app.state.openai_serving_video._engine_client
- engine.default_sampling_params_list = [
- OmniDiffusionSamplingParams(
- num_inference_steps=4,
- guidance_scale=7.5,
- enable_frame_interpolation=True,
- frame_interpolation_exp=2,
- frame_interpolation_scale=0.5,
- frame_interpolation_model_path="default-rife",
- )
- ]
-
- response = test_client.post(
- "/v1/videos/sync",
- data={
- "prompt": "sync default param pass-through",
- "fps": "8",
- },
- )
-
- assert response.status_code == 200
- engine = test_client.app.state.openai_serving_video._engine_client
- captured = engine.captured_sampling_params_list[0]
- assert captured.num_inference_steps == 4
- assert captured.guidance_scale == 7.5
- assert captured.enable_frame_interpolation is True
- assert captured.frame_interpolation_exp == 2
- assert captured.frame_interpolation_scale == 0.5
- assert captured.frame_interpolation_model_path == "default-rife"
-
-
-def test_worker_fps_multiplier_is_applied_to_sync_encoding(test_client, mocker: MockerFixture):
- engine = test_client.app.state.openai_serving_video._engine_client
- fps_values = []
-
- async def _generate(prompt, request_id, sampling_params_list):
- engine.captured_prompt = prompt
- engine.captured_sampling_params_list = sampling_params_list
- yield MockVideoResult([object()], custom_output={"video_fps_multiplier": 2})
-
- engine.generate = _generate
-
- def _fake_encode(video, fps, **kwargs):
- del video, kwargs
- fps_values.append(fps)
- return b"fps-multiplied"
-
- mocker.patch(
- "vllm_omni.entrypoints.openai.serving_video._encode_video_bytes",
- side_effect=_fake_encode,
- )
-
- response = test_client.post("/v1/videos/sync", data={"prompt": "fps multiplier", "fps": "8"})
-
- assert response.status_code == 200
- assert response.content == b"fps-multiplied"
- assert fps_values == [16]
diff --git a/tests/entrypoints/openai_api/test_video_stream_handler.py b/tests/entrypoints/openai_api/test_video_stream_handler.py
deleted file mode 100644
index bee5e0de3e4..00000000000
--- a/tests/entrypoints/openai_api/test_video_stream_handler.py
+++ /dev/null
@@ -1,593 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Runtime behavior tests for VideoStreamHandler.
-
-Every test creates a mock WebSocket, drives handle_session through a
-specific code path, and asserts on the JSON messages actually sent back.
-No inspect.getsource() tricks — these test real async execution.
-"""
-
-from __future__ import annotations
-
-import asyncio
-import base64
-import io
-import json
-from unittest.mock import AsyncMock
-
-import pytest
-from PIL import Image
-
-from vllm_omni.entrypoints.openai.video_stream_session import (
- VideoStreamHandler,
-)
-
-# ---------------------------------------------------------------------------
-# Helpers
-# ---------------------------------------------------------------------------
-
-
-def _make_jpeg(r: int = 128, g: int = 128, b: int = 128) -> bytes:
- img = Image.new("RGB", (64, 64), (r, g, b))
- buf = io.BytesIO()
- img.save(buf, format="JPEG", quality=95)
- return buf.getvalue()
-
-
-def _b64(data: bytes) -> str:
- return base64.b64encode(data).decode()
-
-
-def _config_msg(**overrides) -> str:
- msg = {"type": "session.config", "model": "test", "evs_enabled": False}
- msg.update(overrides)
- return json.dumps(msg)
-
-
-def _frame_msg(jpeg: bytes | None = None) -> str:
- if jpeg is None:
- jpeg = _make_jpeg()
- return json.dumps({"type": "video.frame", "data": _b64(jpeg)})
-
-
-def _query_msg(text: str = "What do you see?") -> str:
- return json.dumps({"type": "video.query", "text": text})
-
-
-def _done_msg() -> str:
- return json.dumps({"type": "video.done"})
-
-
-def _audio_msg(pcm: bytes = b"\x00" * 320) -> str:
- return json.dumps({"type": "audio.chunk", "data": _b64(pcm)})
-
-
-def _make_sse_line(content: str) -> str:
- payload = {"choices": [{"delta": {"content": content}}]}
- return f"data: {json.dumps(payload)}\n\n"
-
-
-class MockWebSocket:
- """Async-compatible mock WebSocket that feeds messages from a list."""
-
- def __init__(self, messages: list[str]):
- self._messages = list(messages)
- self._idx = 0
- self._sent: list[dict | bytes] = []
- self._accepted = False
-
- async def accept(self):
- self._accepted = True
-
- async def receive_text(self) -> str:
- if self._idx >= len(self._messages):
- # Simulate connection close by hanging forever (will be timed out)
- await asyncio.sleep(999)
- msg = self._messages[self._idx]
- self._idx += 1
- return msg
-
- async def send_json(self, data: dict):
- self._sent.append(data)
-
- async def send_bytes(self, data: bytes):
- self._sent.append(data)
-
- @property
- def sent_messages(self) -> list[dict]:
- return [m for m in self._sent if isinstance(m, dict)]
-
- def sent_types(self) -> list[str]:
- return [m.get("type", "") for m in self.sent_messages]
-
-
-class MockChatHandler:
- """Mock OmniOpenAIServingChat for testing _handle_query paths."""
-
- def __init__(self, response=None):
- self._response = response
-
- async def create_chat_completion(self, request, raw_request=None):
- if self._response is not None:
- return self._response
-
- # Default: return a simple streaming generator
- async def _gen():
- yield _make_sse_line("Hello")
- yield _make_sse_line("World")
- yield "data: [DONE]\n\n"
-
- return _gen()
-
-
-# ---------------------------------------------------------------------------
-# _receive_config path tests
-# ---------------------------------------------------------------------------
-
-
-class TestReceiveConfig:
- @pytest.mark.asyncio
- async def test_config_timeout(self):
- """No message within config_timeout → error sent, session ends."""
- ws = MockWebSocket([]) # no messages — will hang
- handler = VideoStreamHandler(chat_handler=MockChatHandler(), config_timeout=0.05)
- await handler.handle_session(ws)
-
- assert ws._accepted
- errors = [m for m in ws.sent_messages if m["type"] == "error"]
- assert len(errors) == 1
- assert "Timeout" in errors[0]["message"]
-
- @pytest.mark.asyncio
- async def test_config_invalid_json(self):
- """Non-JSON config message → error, session ends."""
- ws = MockWebSocket(["not json"])
- handler = VideoStreamHandler(chat_handler=MockChatHandler(), config_timeout=1.0)
- await handler.handle_session(ws)
-
- errors = [m for m in ws.sent_messages if m["type"] == "error"]
- assert any("Invalid JSON" in e["message"] for e in errors)
-
- @pytest.mark.asyncio
- async def test_config_wrong_type(self):
- """Message with wrong type → error, session ends."""
- ws = MockWebSocket([json.dumps({"type": "wrong"})])
- handler = VideoStreamHandler(chat_handler=MockChatHandler(), config_timeout=1.0)
- await handler.handle_session(ws)
-
- errors = [m for m in ws.sent_messages if m["type"] == "error"]
- assert any("Expected session.config" in e["message"] for e in errors)
-
- @pytest.mark.asyncio
- async def test_config_invalid_field_type(self):
- """Config with wrong field type → error, session ends."""
- ws = MockWebSocket(
- [
- json.dumps(
- {
- "type": "session.config",
- "model": "test",
- "max_frames": "potato",
- }
- )
- ]
- )
- handler = VideoStreamHandler(chat_handler=MockChatHandler(), config_timeout=1.0)
- await handler.handle_session(ws)
-
- errors = [m for m in ws.sent_messages if m["type"] == "error"]
- assert any("max_frames" in e["message"] for e in errors)
-
-
-# ---------------------------------------------------------------------------
-# Normal session flow
-# ---------------------------------------------------------------------------
-
-
-class TestNormalFlow:
- @pytest.mark.asyncio
- async def test_frame_query_done(self):
- """Happy path: config → frame → query → done → session.done."""
- ws = MockWebSocket(
- [
- _config_msg(),
- _frame_msg(),
- _query_msg("Describe."),
- _done_msg(),
- ]
- )
- handler = VideoStreamHandler(chat_handler=MockChatHandler(), idle_timeout=2.0)
- await handler.handle_session(ws)
-
- types = ws.sent_types()
- assert "response.start" in types
- assert "response.text.delta" in types
- assert "response.text.done" in types
- assert "session.done" in types
-
- # Check text content was streamed
- deltas = [m["delta"] for m in ws.sent_messages if m.get("type") == "response.text.delta"]
- assert "Hello" in deltas
- assert "World" in deltas
-
- # response.text.done has full text
- done_msg = next(m for m in ws.sent_messages if m.get("type") == "response.text.done")
- assert done_msg["text"] == "HelloWorld"
-
- @pytest.mark.asyncio
- async def test_multiple_frames_before_query(self):
- """Multiple frames accumulate, query uses sampled frames."""
- frames = [_frame_msg(_make_jpeg(r=i * 50)) for i in range(5)]
- ws = MockWebSocket(
- [
- _config_msg(num_sample_frames=3),
- *frames,
- _query_msg(),
- _done_msg(),
- ]
- )
- handler = VideoStreamHandler(chat_handler=MockChatHandler(), idle_timeout=2.0)
- await handler.handle_session(ws)
- assert "session.done" in ws.sent_types()
-
- @pytest.mark.asyncio
- async def test_evs_stats_in_done(self):
- """EVS stats sent before session.done when evs_enabled=True."""
- frame = _frame_msg(_make_jpeg(100, 100, 100))
- ws = MockWebSocket(
- [
- _config_msg(evs_enabled=True, evs_threshold=0.90),
- frame,
- frame, # second should be dropped
- _done_msg(),
- ]
- )
- handler = VideoStreamHandler(chat_handler=MockChatHandler(), idle_timeout=2.0)
- await handler.handle_session(ws)
-
- evs_msgs = [m for m in ws.sent_messages if m.get("type") == "response.evs_stats"]
- assert len(evs_msgs) == 1
- assert evs_msgs[0]["retained_count"] == 1
- assert evs_msgs[0]["dropped_count"] == 1
-
-
-# ---------------------------------------------------------------------------
-# _handle_query error paths
-# ---------------------------------------------------------------------------
-
-
-class TestHandleQueryErrors:
- @pytest.mark.asyncio
- async def test_query_with_no_frames(self):
- """Query before any frames → error, session continues."""
- ws = MockWebSocket(
- [
- _config_msg(),
- _query_msg("Hello?"),
- _done_msg(),
- ]
- )
- handler = VideoStreamHandler(chat_handler=MockChatHandler(), idle_timeout=2.0)
- await handler.handle_session(ws)
-
- errors = [m for m in ws.sent_messages if m["type"] == "error"]
- assert any("No frames" in e["message"] for e in errors)
- # Session still ends cleanly
- assert "session.done" in ws.sent_types()
-
- @pytest.mark.asyncio
- async def test_query_returns_error_response(self):
- """create_chat_completion returns ErrorResponse → error forwarded."""
- from vllm.entrypoints.openai.engine.protocol import ErrorResponse
-
- err = ErrorResponse(
- message="model overloaded",
- type="server_error",
- code=503,
- )
-
- chat = AsyncMock()
- chat.create_chat_completion = AsyncMock(return_value=err)
-
- ws = MockWebSocket(
- [
- _config_msg(),
- _frame_msg(),
- _query_msg(),
- _done_msg(),
- ]
- )
- handler = VideoStreamHandler(chat_handler=chat, idle_timeout=2.0)
- await handler.handle_session(ws)
-
- # Should see response.start, then error, then response.text.done
- assert "response.start" in ws.sent_types()
- assert "response.text.done" in ws.sent_types()
- errors = [m for m in ws.sent_messages if m["type"] == "error"]
- assert any("model overloaded" in e["message"] for e in errors)
-
- @pytest.mark.asyncio
- async def test_query_generator_raises(self):
- """Exception during streaming → error sent, session continues."""
-
- async def _exploding_gen():
- yield _make_sse_line("partial")
- raise RuntimeError("CUDA OOM")
-
- chat = MockChatHandler()
- chat._response = None # override below
-
- class BoomChat:
- async def create_chat_completion(self, req, raw_request=None):
- return _exploding_gen()
-
- ws = MockWebSocket(
- [
- _config_msg(),
- _frame_msg(),
- _query_msg(),
- _done_msg(),
- ]
- )
- handler = VideoStreamHandler(chat_handler=BoomChat(), idle_timeout=2.0)
- await handler.handle_session(ws)
-
- errors = [m for m in ws.sent_messages if m["type"] == "error"]
- assert any("Query processing failed" in e["message"] for e in errors)
- # Crucially: no raw exception message leaked
- assert not any("CUDA OOM" in e["message"] for e in errors)
- # Protocol: response.text.done always sent after response.start
- assert "response.text.done" in ws.sent_types()
- # Session still ends cleanly
- assert "session.done" in ws.sent_types()
-
- @pytest.mark.asyncio
- async def test_empty_query_text_rejected(self):
- """video.query with empty text → error, session continues."""
- ws = MockWebSocket(
- [
- _config_msg(),
- _frame_msg(),
- json.dumps({"type": "video.query", "text": ""}),
- _done_msg(),
- ]
- )
- handler = VideoStreamHandler(chat_handler=MockChatHandler(), idle_timeout=2.0)
- await handler.handle_session(ws)
-
- errors = [m for m in ws.sent_messages if m["type"] == "error"]
- assert any("non-empty" in e["message"] for e in errors)
- assert "session.done" in ws.sent_types()
-
-
-# ---------------------------------------------------------------------------
-# Timeout paths
-# ---------------------------------------------------------------------------
-
-
-class TestTimeouts:
- @pytest.mark.asyncio
- async def test_idle_timeout(self):
- """No message after config within idle_timeout → error, session ends."""
- ws = MockWebSocket([_config_msg()]) # config then nothing
- handler = VideoStreamHandler(chat_handler=MockChatHandler(), idle_timeout=0.05)
- await handler.handle_session(ws)
-
- errors = [m for m in ws.sent_messages if m["type"] == "error"]
- assert any("Idle timeout" in e["message"] for e in errors)
-
-
-# ---------------------------------------------------------------------------
-# Reader error paths
-# ---------------------------------------------------------------------------
-
-
-class TestReaderErrors:
- @pytest.mark.asyncio
- async def test_invalid_json_mid_session(self):
- """Invalid JSON mid-session → error sent, session continues."""
- ws = MockWebSocket(
- [
- _config_msg(),
- "{{bad json",
- _done_msg(),
- ]
- )
- handler = VideoStreamHandler(chat_handler=MockChatHandler(), idle_timeout=2.0)
- await handler.handle_session(ws)
-
- errors = [m for m in ws.sent_messages if m["type"] == "error"]
- assert any("Invalid JSON" in e["message"] for e in errors)
- assert "session.done" in ws.sent_types()
-
- @pytest.mark.asyncio
- async def test_non_dict_message(self):
- """Non-object JSON → error sent, continues."""
- ws = MockWebSocket(
- [
- _config_msg(),
- json.dumps([1, 2, 3]),
- _done_msg(),
- ]
- )
- handler = VideoStreamHandler(chat_handler=MockChatHandler(), idle_timeout=2.0)
- await handler.handle_session(ws)
-
- errors = [m for m in ws.sent_messages if m["type"] == "error"]
- assert any("JSON objects" in e["message"] for e in errors)
-
- @pytest.mark.asyncio
- async def test_unknown_message_type(self):
- """Unknown type → error, session continues."""
- ws = MockWebSocket(
- [
- _config_msg(),
- json.dumps({"type": "teleport.now"}),
- _done_msg(),
- ]
- )
- handler = VideoStreamHandler(chat_handler=MockChatHandler(), idle_timeout=2.0)
- await handler.handle_session(ws)
-
- errors = [m for m in ws.sent_messages if m["type"] == "error"]
- assert any("Unknown message type" in e["message"] for e in errors)
- assert "session.done" in ws.sent_types()
-
- @pytest.mark.asyncio
- async def test_invalid_base64_frame(self):
- """Invalid base64 in video.frame → error, continues."""
- ws = MockWebSocket(
- [
- _config_msg(),
- json.dumps({"type": "video.frame", "data": "!!!not-base64!!!"}),
- _done_msg(),
- ]
- )
- handler = VideoStreamHandler(chat_handler=MockChatHandler(), idle_timeout=2.0)
- await handler.handle_session(ws)
-
- errors = [m for m in ws.sent_messages if m["type"] == "error"]
- assert any("Invalid base64" in e["message"] for e in errors)
- assert "session.done" in ws.sent_types()
-
-
-# ---------------------------------------------------------------------------
-# Concurrency: frames arrive during query processing
-# ---------------------------------------------------------------------------
-
-
-class TestConcurrency:
- @pytest.mark.asyncio
- async def test_frames_during_query(self):
- """Frames sent while a query is processing must be buffered."""
- query_started = asyncio.Event()
- query_can_finish = asyncio.Event()
- captured_frame_count: list[int] = []
-
- class SlowChat:
- async def create_chat_completion(self, req, raw_request=None):
- query_started.set()
- await query_can_finish.wait()
-
- async def _gen():
- yield _make_sse_line("answer")
- yield "data: [DONE]\n\n"
-
- return _gen()
-
- class InspectingSlowChat(SlowChat):
- """On the second query, record how many frames are in the request."""
-
- _call_count = 0
-
- async def create_chat_completion(self, req, raw_request=None):
- self._call_count += 1
- if self._call_count == 2:
- # Count image_url blocks in the request to verify buffering
- for msg in req.messages:
- if isinstance(msg.get("content", msg), list):
- content = msg.get("content", msg)
- else:
- content = getattr(msg, "content", [])
- if isinstance(content, list):
- n = sum(
- 1
- for p in content
- if (isinstance(p, dict) and p.get("type") == "image_url")
- or (hasattr(p, "type") and p.type == "image_url")
- )
- captured_frame_count.append(n)
- return await super().create_chat_completion(req, raw_request)
-
- # Custom WebSocket that can inject frames after query starts
- class TimedWebSocket(MockWebSocket):
- def __init__(self):
- self._q: asyncio.Queue[str] = asyncio.Queue()
- self._sent: list[dict | bytes] = []
- self._accepted = False
-
- async def accept(self):
- self._accepted = True
-
- async def receive_text(self) -> str:
- return await self._q.get()
-
- def put(self, msg: str):
- self._q.put_nowait(msg)
-
- ws = TimedWebSocket()
- chat = InspectingSlowChat()
- handler = VideoStreamHandler(chat_handler=chat, idle_timeout=5.0)
-
- session_task = asyncio.create_task(handler.handle_session(ws))
-
- # 1. Config
- ws.put(_config_msg())
- await asyncio.sleep(0.01)
-
- # 2. Initial frame + query
- ws.put(_frame_msg(_make_jpeg(10, 10, 10)))
- await asyncio.sleep(0.01)
- ws.put(_query_msg("first"))
-
- # 3. Wait for query to start processing
- await asyncio.wait_for(query_started.wait(), timeout=2.0)
-
- # 4. Send more frames WHILE query is running
- ws.put(_frame_msg(_make_jpeg(200, 200, 200)))
- ws.put(_frame_msg(_make_jpeg(50, 50, 50)))
- await asyncio.sleep(0.05)
-
- # 5. Let query finish
- query_can_finish.set()
- await asyncio.sleep(0.1)
-
- # 6. Second query — should see all 3 frames (1 initial + 2 during)
- query_started.clear()
- query_can_finish.clear()
- ws.put(_query_msg("second"))
- await asyncio.wait_for(query_started.wait(), timeout=2.0)
- query_can_finish.set()
- await asyncio.sleep(0.1)
-
- # 7. Done
- ws.put(_done_msg())
-
- await asyncio.wait_for(session_task, timeout=5.0)
-
- assert "session.done" in ws.sent_types()
- # Verify: the 2 frames sent during query 1 were buffered and
- # included in query 2 (total 3 frames).
- assert captured_frame_count == [3]
-
-
-# ---------------------------------------------------------------------------
-# SSE parsing (kept from original, these are valid unit tests)
-# ---------------------------------------------------------------------------
-
-
-class TestParseSSEDeltas:
- def test_single_delta(self):
- assert VideoStreamHandler._parse_sse_deltas(_make_sse_line("hello")) == ["hello"]
-
- def test_multiple_deltas(self):
- chunk = _make_sse_line("a") + _make_sse_line("b")
- assert VideoStreamHandler._parse_sse_deltas(chunk) == ["a", "b"]
-
- def test_done_skipped(self):
- chunk = _make_sse_line("x") + "data: [DONE]\n\n"
- assert VideoStreamHandler._parse_sse_deltas(chunk) == ["x"]
-
- def test_empty_content_skipped(self):
- p = json.dumps({"choices": [{"delta": {"content": ""}}]})
- assert VideoStreamHandler._parse_sse_deltas(f"data: {p}\n\n") == []
-
- def test_malformed_json_skipped(self):
- chunk = "data: {bad}\n\n" + _make_sse_line("ok")
- assert VideoStreamHandler._parse_sse_deltas(chunk) == ["ok"]
-
- def test_empty_string(self):
- assert VideoStreamHandler._parse_sse_deltas("") == []
-
- def test_unicode(self):
- assert VideoStreamHandler._parse_sse_deltas(_make_sse_line("你好🌍")) == ["你好🌍"]
diff --git a/tests/entrypoints/openai_api/test_video_stream_session.py b/tests/entrypoints/openai_api/test_video_stream_session.py
deleted file mode 100644
index b0a12b1e42c..00000000000
--- a/tests/entrypoints/openai_api/test_video_stream_session.py
+++ /dev/null
@@ -1,289 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for VideoStreamSession (Phase 2 + 3)."""
-
-from __future__ import annotations
-
-import pytest
-
-from tests.entrypoints.openai_api.conftest_video import (
- make_gradient_jpeg,
- make_jpeg,
-)
-from vllm_omni.entrypoints.openai.video_stream_session import (
- VideoStreamConfig,
- VideoStreamSession,
-)
-
-# ---------------------------------------------------------------------------
-# VideoStreamConfig
-# ---------------------------------------------------------------------------
-
-
-class TestVideoStreamConfig:
- def test_defaults(self):
- cfg = VideoStreamConfig()
- assert cfg.max_frames == 64
- assert cfg.num_sample_frames == 16
- assert cfg.evs_enabled is True
- assert cfg.evs_threshold == 0.95
-
- def test_from_dict(self):
- cfg = VideoStreamConfig.from_dict(
- {
- "model": "test-model",
- "max_frames": 32,
- "evs_threshold": 0.90,
- "unknown_field": "ignored",
- }
- )
- assert cfg.model == "test-model"
- assert cfg.max_frames == 32
- assert cfg.evs_threshold == 0.90
-
- def test_from_dict_empty(self):
- cfg = VideoStreamConfig.from_dict({})
- assert cfg.model == ""
- assert cfg.max_frames == 64
-
- def test_from_dict_invalid_type(self):
- with pytest.raises(TypeError, match="max_frames.*expected int.*got str"):
- VideoStreamConfig.from_dict({"max_frames": "potato"})
-
- def test_from_dict_invalid_bool(self):
- with pytest.raises(TypeError, match="evs_enabled.*expected bool"):
- VideoStreamConfig.from_dict({"evs_enabled": "yes"})
-
- def test_from_dict_invalid_modalities(self):
- with pytest.raises(TypeError, match="modalities.*expected list"):
- VideoStreamConfig.from_dict({"modalities": 42})
-
- def test_from_dict_evs_threshold_int_accepted(self):
- """JSON doesn't distinguish int/float — int 1 is a valid threshold."""
- cfg = VideoStreamConfig.from_dict({"evs_threshold": 1})
- assert cfg.evs_threshold == 1
-
-
-# ---------------------------------------------------------------------------
-# Frame buffer & sliding window (uses deque)
-# ---------------------------------------------------------------------------
-
-
-class TestFrameBuffer:
- def test_add_frame_basic(self):
- cfg = VideoStreamConfig(evs_enabled=False, max_frames=10)
- session = VideoStreamSession(cfg)
- assert session.add_frame(make_jpeg()) is True
- assert session.frame_count == 1
-
- def test_add_frame_too_large(self):
- cfg = VideoStreamConfig(evs_enabled=False, max_frames=10)
- session = VideoStreamSession(cfg)
- huge = b"\xff" * (10 * 1024 * 1024 + 1) # just over 10 MB
- with pytest.raises(ValueError, match="Frame too large"):
- session.add_frame(huge)
-
- def test_sliding_window(self):
- cfg = VideoStreamConfig(evs_enabled=False, max_frames=3)
- session = VideoStreamSession(cfg)
- for i in range(5):
- session.add_frame(make_jpeg(r=i * 50))
- assert session.frame_count == 3
-
- def test_sliding_window_keeps_newest(self):
- cfg = VideoStreamConfig(evs_enabled=False, max_frames=2)
- session = VideoStreamSession(cfg)
- f1 = make_jpeg(10, 10, 10)
- f2 = make_jpeg(20, 20, 20)
- f3 = make_jpeg(30, 30, 30)
- session.add_frame(f1)
- session.add_frame(f2)
- session.add_frame(f3)
- sampled = session.sample_frames()
- assert len(sampled) == 2
- assert sampled[0] == f2
- assert sampled[1] == f3
-
-
-# ---------------------------------------------------------------------------
-# EVS integration
-# ---------------------------------------------------------------------------
-
-
-class TestEVSIntegration:
- def test_evs_drops_identical_frames(self):
- cfg = VideoStreamConfig(evs_enabled=True, evs_threshold=0.90)
- session = VideoStreamSession(cfg)
- frame = make_jpeg(100, 100, 100)
- assert session.add_frame(frame) is True
- assert session.add_frame(frame) is False
- assert session.frame_count == 1
-
- def test_evs_keeps_different_frames(self):
- cfg = VideoStreamConfig(evs_enabled=True, evs_threshold=0.95)
- session = VideoStreamSession(cfg)
- for i in range(5):
- assert session.add_frame(make_gradient_jpeg(seed=i)) is True
- assert session.frame_count == 5
-
- def test_evs_disabled(self):
- cfg = VideoStreamConfig(evs_enabled=False)
- session = VideoStreamSession(cfg)
- frame = make_jpeg()
- assert session.add_frame(frame) is True
- assert session.add_frame(frame) is True
- assert session.frame_count == 2
-
- def test_evs_stats(self):
- cfg = VideoStreamConfig(evs_enabled=True, evs_threshold=0.90)
- session = VideoStreamSession(cfg)
- frame = make_jpeg()
- session.add_frame(frame)
- session.add_frame(frame)
- stats = session.evs_stats
- assert stats is not None
- assert stats["retained_count"] == 1
- assert stats["dropped_count"] == 1
-
- def test_evs_stats_none_when_disabled(self):
- cfg = VideoStreamConfig(evs_enabled=False)
- session = VideoStreamSession(cfg)
- assert session.evs_stats is None
-
-
-# ---------------------------------------------------------------------------
-# Uniform sampling
-# ---------------------------------------------------------------------------
-
-
-class TestSampling:
- def test_sample_exact(self):
- cfg = VideoStreamConfig(evs_enabled=False, num_sample_frames=3, max_frames=10)
- session = VideoStreamSession(cfg)
- frames = [make_gradient_jpeg(i) for i in range(3)]
- for f in frames:
- session.add_frame(f)
- sampled = session.sample_frames()
- assert len(sampled) == 3
- assert sampled == frames
-
- def test_sample_fewer_than_requested(self):
- cfg = VideoStreamConfig(evs_enabled=False, num_sample_frames=10, max_frames=64)
- session = VideoStreamSession(cfg)
- for i in range(3):
- session.add_frame(make_gradient_jpeg(i))
- assert len(session.sample_frames()) == 3
-
- def test_sample_uniform(self):
- cfg = VideoStreamConfig(evs_enabled=False, num_sample_frames=4, max_frames=64)
- session = VideoStreamSession(cfg)
- for i in range(10):
- session.add_frame(make_gradient_jpeg(i))
- sampled = session.sample_frames()
- assert len(sampled) == 4
- expected = [session._frames[i] for i in [0, 3, 6, 9]]
- assert sampled == expected
-
- def test_sample_empty(self):
- cfg = VideoStreamConfig(evs_enabled=False)
- session = VideoStreamSession(cfg)
- assert session.sample_frames() == []
-
- def test_sample_single_frame_from_multi_frame_buffer(self):
- cfg = VideoStreamConfig(evs_enabled=False, num_sample_frames=1, max_frames=64)
- session = VideoStreamSession(cfg)
- first = make_gradient_jpeg(0)
- second = make_gradient_jpeg(1)
- session.add_frame(first)
- session.add_frame(second)
-
- assert session.sample_frames() == [second]
-
-
-# ---------------------------------------------------------------------------
-# Audio buffer (Phase 3)
-# ---------------------------------------------------------------------------
-
-
-class TestAudioBuffer:
- def test_add_audio_chunk(self):
- session = VideoStreamSession(VideoStreamConfig())
- assert session.has_audio is False
- session.add_audio_chunk(b"\x00" * 100)
- assert session.has_audio is True
-
- def test_clear_audio(self):
- session = VideoStreamSession(VideoStreamConfig())
- session.add_audio_chunk(b"\x00" * 100)
- session.clear_audio()
- assert session.has_audio is False
-
-
-# ---------------------------------------------------------------------------
-# build_chat_request
-# ---------------------------------------------------------------------------
-
-
-class TestBuildChatRequest:
- def test_video_only_request(self):
- cfg = VideoStreamConfig(model="test-model", evs_enabled=False, num_sample_frames=4)
- session = VideoStreamSession(cfg)
- for i in range(4):
- session.add_frame(make_gradient_jpeg(i))
-
- request = session.build_chat_request("Describe this scene.")
- assert request.model == "test-model"
- assert request.stream is True
-
- content = request.messages[0]["content"]
- assert len(content) == 5 # 4 image_url + 1 text
- image_parts = [p for p in content if p["type"] == "image_url"]
- text_parts = [p for p in content if p["type"] == "text"]
- assert len(image_parts) == 4
- assert len(text_parts) == 1
- assert text_parts[0]["text"] == "Describe this scene."
-
- mm_kw = getattr(request, "mm_processor_kwargs", None)
- assert mm_kw is None or not mm_kw.get("use_audio_in_video", False)
-
- def test_video_plus_audio_request(self):
- cfg = VideoStreamConfig(model="test-model", evs_enabled=False, num_sample_frames=2)
- session = VideoStreamSession(cfg)
- session.add_frame(make_gradient_jpeg(0))
- session.add_frame(make_gradient_jpeg(1))
- session.add_audio_chunk(b"\x00" * 3200)
-
- request = session.build_chat_request("What is being said?")
-
- content = request.messages[0]["content"]
- assert len(content) == 4 # 2 image_url + 1 audio_url + 1 text
- audio_parts = [p for p in content if p["type"] == "audio_url"]
- assert len(audio_parts) == 1
- # RFC 3551: audio/L16 for linear 16-bit PCM
- assert audio_parts[0]["audio_url"]["url"].startswith("data:audio/L16;rate=16000;base64,")
-
- mm_kw = getattr(request, "mm_processor_kwargs", None) or {}
- assert mm_kw.get("use_audio_in_video") is True
-
- def test_image_url_is_valid_base64(self):
- cfg = VideoStreamConfig(evs_enabled=False, num_sample_frames=1)
- session = VideoStreamSession(cfg)
- session.add_frame(make_jpeg(200, 100, 50))
- request = session.build_chat_request("test")
- content = request.messages[0]["content"]
- img_url = content[0]["image_url"]["url"]
- assert img_url.startswith("data:image/jpeg;base64,")
- import base64
-
- b64_data = img_url.split(",", 1)[1]
- decoded = base64.b64decode(b64_data)
- assert len(decoded) > 0
-
- def test_clear_audio_after_query(self):
- session = VideoStreamSession(VideoStreamConfig(evs_enabled=False))
- session.add_frame(make_jpeg())
- session.add_audio_chunk(b"\x00" * 100)
- session.build_chat_request("test")
- session.clear_audio()
- assert session.has_audio is False
- assert session.frame_count == 1
diff --git a/tests/entrypoints/test_async_omni.py b/tests/entrypoints/test_async_omni.py
deleted file mode 100644
index c8ef5eddf20..00000000000
--- a/tests/entrypoints/test_async_omni.py
+++ /dev/null
@@ -1,134 +0,0 @@
-import asyncio
-from types import SimpleNamespace
-
-import pytest
-from vllm.sampling_params import RequestOutputKind, SamplingParams
-
-from vllm_omni.entrypoints.async_omni import AsyncOmni
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-async def _noop(**kw):
- pass
-
-
-def get_fake_add_request(submitted_request_ids):
- async def fake_add_request_async(*, request_id, prompt, sampling_params_list, final_stage_id, **kwargs):
- del prompt, sampling_params_list, final_stage_id, kwargs
- submitted_request_ids.append(request_id)
-
- return fake_add_request_async
-
-
-def get_fake_abort(aborted_request_batches):
- async def fake_abort_async(request_ids):
- aborted_request_batches.append(list(request_ids))
-
- return fake_abort_async
-
-
-async def fake_process_results(request_id, metrics, final_stage_id_for_e2e, req_start_ts, wall_start_ts):
- del metrics, final_stage_id_for_e2e, req_start_ts, wall_start_ts
- if request_id.startswith("cancel-"):
- await asyncio.Future()
- return
- yield SimpleNamespace(
- stage_id=0,
- request_output=SimpleNamespace(outputs=[]),
- finished=True,
- )
-
-
-def get_async_omni_instance(fake_add_request=_noop, fake_abort_request=_noop) -> AsyncOmni:
- omni = object.__new__(AsyncOmni)
- omni._pause_cond = asyncio.Condition()
- omni._paused = False
- omni.engine = SimpleNamespace(
- num_stages=1,
- add_request_async=fake_add_request,
- abort_async=fake_abort_request,
- )
- omni.log_stats = False
- omni.request_states = {}
- omni._final_output_handler = lambda: None
- omni.resolve_sampling_params_list = lambda params, allow_delta_coercion: params
- omni._compute_final_stage_id = lambda output_modalities: 0
- omni._process_orchestrator_results = fake_process_results
- omni._log_summary_and_cleanup = lambda request_id: omni.request_states.pop(request_id, None)
- return omni
-
-
-def test_generate_accepts_request_after_repeated_cancellations():
- async def run_test():
- submitted_request_ids = []
- aborted_request_batches = []
-
- async def collect_outputs(request_id):
- outputs = []
- async for output in AsyncOmni.generate(
- omni,
- prompt={"prompt": "prompt"},
- request_id=request_id,
- sampling_params_list=[SimpleNamespace()],
- output_modalities=["image"],
- ):
- outputs.append(output)
- return outputs
-
- omni = get_async_omni_instance(
- fake_add_request=get_fake_add_request(submitted_request_ids),
- fake_abort_request=get_fake_abort(aborted_request_batches),
- )
-
- assert len(await collect_outputs("baseline")) == 1
-
- for idx in range(3):
- task = asyncio.create_task(collect_outputs(f"cancel-{idx}"))
- await asyncio.sleep(0)
- task.cancel()
- with pytest.raises(asyncio.CancelledError):
- await task
-
- assert len(await collect_outputs("after-cancel")) == 1
- assert submitted_request_ids == [
- "baseline",
- "cancel-0",
- "cancel-1",
- "cancel-2",
- "after-cancel",
- ]
- assert aborted_request_batches == [
- ["cancel-0"],
- ["cancel-1"],
- ["cancel-2"],
- ]
-
- asyncio.run(run_test())
-
-
-@pytest.mark.parametrize(
- "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY, RequestOutputKind.CUMULATIVE]
-)
-def test_output_kind_is_preserved_with_explicit_sampling_params(output_kind):
- """Ensure we don't change the output kind in async generate if params are provided directly."""
-
- captured_params = []
-
- async def capturing_add_request(*, request_id, prompt, sampling_params_list, final_stage_id, **kwargs):
- del prompt, final_stage_id, kwargs
- captured_params.extend(sampling_params_list)
-
- async def run():
- omni = get_async_omni_instance(fake_add_request=capturing_add_request)
- sp = SamplingParams(output_kind=output_kind)
- async for _ in omni.generate(
- prompt={"prompt": "test"},
- request_id="test-req",
- sampling_params_list=[sp],
- output_modalities=["text"],
- ):
- pass
-
- asyncio.run(run())
- assert captured_params[0].output_kind == output_kind
diff --git a/tests/entrypoints/test_async_omni_abort.py b/tests/entrypoints/test_async_omni_abort.py
new file mode 100644
index 00000000000..b34652162d0
--- /dev/null
+++ b/tests/entrypoints/test_async_omni_abort.py
@@ -0,0 +1,85 @@
+import asyncio
+from types import SimpleNamespace
+
+import pytest
+
+from vllm_omni.entrypoints.async_omni import AsyncOmni
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+def test_generate_accepts_request_after_repeated_cancellations():
+ async def run_test():
+ submitted_request_ids = []
+ aborted_request_batches = []
+
+ async def fake_add_request_async(*, request_id, prompt, sampling_params_list, final_stage_id, **kwargs):
+ del prompt, sampling_params_list, final_stage_id, kwargs
+ submitted_request_ids.append(request_id)
+
+ async def fake_abort_async(request_ids):
+ aborted_request_batches.append(list(request_ids))
+
+ async def fake_process_results(request_id, metrics, final_stage_id_for_e2e, req_start_ts, wall_start_ts):
+ del metrics, final_stage_id_for_e2e, req_start_ts, wall_start_ts
+ if request_id.startswith("cancel-"):
+ await asyncio.Future()
+ return
+ yield SimpleNamespace(
+ stage_id=0,
+ request_output=SimpleNamespace(outputs=[]),
+ finished=True,
+ )
+
+ async def collect_outputs(request_id):
+ outputs = []
+ async for output in AsyncOmni.generate(
+ omni,
+ prompt={"prompt": "prompt"},
+ request_id=request_id,
+ sampling_params_list=[SimpleNamespace()],
+ output_modalities=["image"],
+ ):
+ outputs.append(output)
+ return outputs
+
+ omni = object.__new__(AsyncOmni)
+ omni._pause_cond = asyncio.Condition()
+ omni._paused = False
+ omni.engine = SimpleNamespace(
+ num_stages=1,
+ add_request_async=fake_add_request_async,
+ abort_async=fake_abort_async,
+ )
+ omni.log_stats = False
+ omni.request_states = {}
+ omni._final_output_handler = lambda: None
+ omni.resolve_sampling_params_list = lambda params: params
+ omni._compute_final_stage_id = lambda output_modalities: 0
+ omni._process_orchestrator_results = fake_process_results
+ omni._log_summary_and_cleanup = lambda request_id: omni.request_states.pop(request_id, None)
+
+ assert len(await collect_outputs("baseline")) == 1
+
+ for idx in range(3):
+ task = asyncio.create_task(collect_outputs(f"cancel-{idx}"))
+ await asyncio.sleep(0)
+ task.cancel()
+ with pytest.raises(asyncio.CancelledError):
+ await task
+
+ assert len(await collect_outputs("after-cancel")) == 1
+ assert submitted_request_ids == [
+ "baseline",
+ "cancel-0",
+ "cancel-1",
+ "cancel-2",
+ "after-cancel",
+ ]
+ assert aborted_request_batches == [
+ ["cancel-0"],
+ ["cancel-1"],
+ ["cancel-2"],
+ ]
+
+ asyncio.run(run_test())
diff --git a/tests/entrypoints/test_async_omni_diffusion_config.py b/tests/entrypoints/test_async_omni_diffusion_config.py
index 7ed8128260e..ca5624f2d4c 100644
--- a/tests/entrypoints/test_async_omni_diffusion_config.py
+++ b/tests/entrypoints/test_async_omni_diffusion_config.py
@@ -69,20 +69,6 @@ def test_default_stage_config_propagates_ulysses_mode():
assert parallel_config.ulysses_mode == "advanced_uaa"
-def test_default_stage_config_includes_default_sampling_params():
- """Ensure default sampling params survive the default diffusion-stage builder."""
- stage_cfg = AsyncOmniEngine._create_default_diffusion_stage_cfg(
- {
- "default_sampling_params": '{"0": {"generator_device":"cpu", "guidance_scale":7.5}}',
- }
- )[0]
-
- assert stage_cfg["default_sampling_params"] == {
- "generator_device": "cpu",
- "guidance_scale": 7.5,
- }
-
-
def test_serve_cli_accepts_ulysses_mode():
"""Ensure diffusion serve CLI exposes ulysses_mode and wires it to parallel_config."""
parser = FlexibleArgumentParser()
@@ -107,24 +93,3 @@ def test_serve_cli_accepts_ulysses_mode():
assert args.ulysses_mode == "advanced_uaa"
assert parallel_config.ulysses_degree == 4
assert parallel_config.ulysses_mode == "advanced_uaa"
-
-
-def test_serve_cli_accepts_diffusion_pipeline_profiler_flag():
- """Ensure diffusion serve CLI exposes the profiler switch."""
- parser = FlexibleArgumentParser()
- subparsers = parser.add_subparsers(dest="command")
- OmniServeCommand().subparser_init(subparsers)
-
- args = parser.parse_args(
- [
- "serve",
- "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
- "--omni",
- "--enable-diffusion-pipeline-profiler",
- ]
- )
-
- stage_cfg = _create_default_diffusion_stage_cfg(args)[0]
-
- assert args.enable_diffusion_pipeline_profiler is True
- assert stage_cfg["engine_args"]["enable_diffusion_pipeline_profiler"] is True
diff --git a/tests/entrypoints/test_cfg_companion_tracker.py b/tests/entrypoints/test_cfg_companion_tracker.py
new file mode 100644
index 00000000000..941ead41ff0
--- /dev/null
+++ b/tests/entrypoints/test_cfg_companion_tracker.py
@@ -0,0 +1,114 @@
+import time
+from types import SimpleNamespace
+
+import pytest
+
+from vllm_omni.entrypoints.cfg_companion_tracker import CfgCompanionTracker
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+def dummy_expand_func(prompt, sp0):
+ if prompt == "expand_me":
+ return [SimpleNamespace(prompt={"prompt": "neg"}, role="cfg_text", request_id_suffix="__cfg_text")]
+ return []
+
+
+@pytest.fixture
+def tracker():
+ sp0 = SimpleNamespace()
+ return CfgCompanionTracker(prompt_expand_func=dummy_expand_func, stage0_sampling_params=sp0, timeout_s=0.1)
+
+
+def test_companion_tracker_initialization(tracker):
+ assert not tracker.is_active
+ assert tracker.num_companions == 0
+
+
+def test_expand_prompts_registers_companions(tracker):
+ request_id_to_prompt = {"req1": "expand_me", "req2": "do_not_expand"}
+
+ pairs = tracker.expand_prompts(request_id_to_prompt)
+
+ assert len(pairs) == 1
+ companion_id, prompt = pairs[0]
+ assert companion_id == "req1__cfg_text"
+ assert prompt == {"prompt": "neg"}
+
+ assert tracker.is_active
+ assert tracker.num_companions == 1
+ assert tracker.is_companion("req1__cfg_text")
+ assert not tracker.is_companion("req2__cfg_text")
+ assert tracker.has_companions("req1")
+ assert not tracker.has_companions("req2")
+
+ comp_map = tracker.get_companion_request_ids("req1")
+ assert comp_map == {"cfg_text": "req1__cfg_text"}
+
+
+def test_companion_lifecycle_success(tracker):
+ request_id_to_prompt = {"req1": "expand_me"}
+ tracker.expand_prompts(request_id_to_prompt)
+
+ # Defer parent
+ engine_outputs = {"out": 123}
+ tracker.defer_parent("req1", engine_outputs, stage_id=0)
+
+ # Initially not done
+ assert not tracker.all_companions_done("req1")
+
+ # Companion completes
+ parent_id = tracker.on_companion_completed("req1__cfg_text")
+
+ # Parent should be returned since all companions are done and it is pending
+ assert parent_id == "req1"
+ assert tracker.all_companions_done("req1")
+
+ # Pop pending parent
+ popped = tracker.pop_pending_parent("req1")
+ assert popped is not None
+ assert popped["engine_outputs"] == engine_outputs
+ assert popped["stage_id"] == 0
+
+
+def test_companion_lifecycle_failure(tracker):
+ request_id_to_prompt = {"req1": "expand_me"}
+ tracker.expand_prompts(request_id_to_prompt)
+
+ tracker.defer_parent("req1", {"out": 123}, stage_id=0)
+
+ # Companion fails
+ parent_id, aborted = tracker.on_companion_error("req1__cfg_text")
+
+ assert parent_id == "req1"
+ assert aborted is True
+ assert tracker.is_parent_failed("req1")
+
+ # Parent should be removed from pending list
+ assert tracker.pop_pending_parent("req1") is None
+
+ # Consume failure
+ tracker.consume_parent_failure("req1")
+ assert not tracker.is_parent_failed("req1")
+
+
+def test_companion_lifecycle_timeout(tracker):
+ request_id_to_prompt = {"req1": "expand_me"}
+ tracker.expand_prompts(request_id_to_prompt)
+
+ tracker.defer_parent("req1", {"out": 123}, stage_id=0)
+
+ # Initially no timeouts
+ timeouts = tracker.check_timeouts()
+ assert len(timeouts) == 0
+
+ # Wait for timeout
+ time.sleep(0.15)
+
+ # Check timeouts again
+ timeouts = tracker.check_timeouts()
+ assert len(timeouts) == 1
+ assert timeouts[0] == "req1"
+
+ # Should be removed from pending
+ assert tracker.pop_pending_parent("req1") is None
diff --git a/tests/entrypoints/test_omni_base_profiler.py b/tests/entrypoints/test_omni_base_profiler.py
index ca10eed91f6..0c1ddc6a5db 100644
--- a/tests/entrypoints/test_omni_base_profiler.py
+++ b/tests/entrypoints/test_omni_base_profiler.py
@@ -1,7 +1,8 @@
"""Unit tests for OmniBase and AsyncOmni profiler methods."""
+from unittest.mock import MagicMock, patch
+
import pytest
-from pytest_mock import MockerFixture
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
@@ -10,12 +11,12 @@ class TestOmniBaseProfiler:
"""Test suite for OmniBase profiler methods (start_profile, stop_profile)."""
@pytest.fixture
- def mock_engine(self, mocker: MockerFixture):
+ def mock_engine(self):
"""Create a mock AsyncOmniEngine for testing."""
- engine = mocker.MagicMock()
+ engine = MagicMock()
engine.num_stages = 3
engine.is_alive.return_value = True
- engine.default_sampling_params_list = [mocker.MagicMock() for _ in range(3)]
+ engine.default_sampling_params_list = [MagicMock() for _ in range(3)]
engine.get_stage_metadata.side_effect = lambda i: {
"final_output_type": "text" if i == 0 else "audio",
"final_output": True,
@@ -24,15 +25,17 @@ def mock_engine(self, mocker: MockerFixture):
return engine
@pytest.fixture
- def omni_base_instance(self, mock_engine, mocker: MockerFixture):
+ def omni_base_instance(self, mock_engine):
"""Create an OmniBase instance with mocked dependencies."""
- mocker.patch("vllm_omni.entrypoints.omni_base.AsyncOmniEngine", return_value=mock_engine)
- mocker.patch("vllm_omni.entrypoints.omni_base.omni_snapshot_download", side_effect=lambda x: x)
- mocker.patch("vllm_omni.entrypoints.omni_base.weakref.finalize")
- from vllm_omni.entrypoints.omni_base import OmniBase
-
- instance = OmniBase(model="test-model")
- return instance
+ with (
+ patch("vllm_omni.entrypoints.omni_base.AsyncOmniEngine", return_value=mock_engine),
+ patch("vllm_omni.entrypoints.omni_base.omni_snapshot_download", side_effect=lambda x: x),
+ patch("vllm_omni.entrypoints.omni_base.weakref.finalize"),
+ ):
+ from vllm_omni.entrypoints.omni_base import OmniBase
+
+ instance = OmniBase(model="test-model")
+ return instance
def test_start_profile_calls_collective_rpc(self, omni_base_instance, mock_engine):
"""Test that start_profile calls collective_rpc with correct arguments."""
diff --git a/tests/entrypoints/test_omni_entrypoints.py b/tests/entrypoints/test_omni_entrypoints.py
index a96b4dd1df6..3cffcd37df4 100644
--- a/tests/entrypoints/test_omni_entrypoints.py
+++ b/tests/entrypoints/test_omni_entrypoints.py
@@ -1,21 +1,17 @@
from __future__ import annotations
-import argparse
import queue
from collections.abc import Callable
from types import SimpleNamespace
from typing import Any
-from unittest.mock import MagicMock
import pytest
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.sampling_params import RequestOutputKind, SamplingParams
-from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.omni import Omni
-from vllm_omni.entrypoints.omni_base import OmniEngineDeadError
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
@@ -167,40 +163,6 @@ def _patch_engine(monkeypatch: pytest.MonkeyPatch, engine: FakeAsyncOmniEngine)
monkeypatch.setattr("vllm_omni.entrypoints.omni_base.omni_snapshot_download", lambda model: model)
-def test_from_cli_args_only_nulls_untyped_override_fields(monkeypatch: pytest.MonkeyPatch):
- from vllm_omni.entrypoints.omni import Omni
-
- captured: dict[str, Any] = {}
-
- def fake_engine(*args: Any, **kwargs: Any) -> FakeAsyncOmniEngine:
- captured.update(kwargs)
- return FakeAsyncOmniEngine()
-
- monkeypatch.setattr("vllm_omni.entrypoints.omni_base.AsyncOmniEngine", fake_engine)
- monkeypatch.setattr("vllm_omni.entrypoints.omni_base.omni_snapshot_download", lambda model: model)
- monkeypatch.setattr("sys.argv", ["prog"])
-
- parser = argparse.ArgumentParser()
- parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
- parser.add_argument("--hsdp-shard-size", type=int, default=-1)
- args = parser.parse_args([])
- args.model = "fake-model"
-
- Omni.from_cli_args(args, parser=parser)
-
- assert captured["gpu_memory_utilization"] is None
- assert captured["hsdp_shard_size"] == -1
-
-
-def _make_base():
- from vllm_omni.entrypoints.omni_base import OmniBase
-
- obj = object.__new__(OmniBase)
- obj.engine = MagicMock()
- obj.request_states = {}
- return obj
-
-
def _stage_spec(
stage_id: int,
*,
@@ -362,18 +324,6 @@ def _enqueue_error_message(engine: FakeAsyncOmniEngine, msg: dict[str, Any]) ->
)
-def _enqueue_fatal_error_message(engine: FakeAsyncOmniEngine, msg: dict[str, Any]) -> None:
- engine.output_q.put_nowait(
- {
- "type": "error",
- "fatal": True,
- "request_id": msg["request_id"],
- "stage_id": 2,
- "error": "engine dead",
- }
- )
-
-
@pytest.mark.asyncio
async def test_get_supported_tasks_returns_engine_supported_tasks():
omni = object.__new__(AsyncOmni)
@@ -585,22 +535,18 @@ async def test_async_omni_abort_forwards_to_engine(monkeypatch: pytest.MonkeyPat
@pytest.mark.asyncio
-async def test_async_omni_propagates_fatal_error_context(monkeypatch: pytest.MonkeyPatch):
- engine = FakeAsyncOmniEngine(stage_metadata=THREE_STAGE_META, on_add_request=_enqueue_fatal_error_message)
+async def test_async_omni_propagates_engine_error(monkeypatch: pytest.MonkeyPatch):
+ engine = FakeAsyncOmniEngine(stage_metadata=THREE_STAGE_META, on_add_request=_enqueue_error_message)
_patch_engine(monkeypatch, engine)
app = AsyncOmni("dummy-model")
try:
- with pytest.raises(EngineDeadError, match="engine dead") as exc_info:
+ with pytest.raises(RuntimeError, match="engine boom"):
async for _ in app.generate(prompt="hello", request_id="req-1"):
pass
finally:
app.shutdown()
- assert isinstance(exc_info.value, OmniEngineDeadError)
- assert str(exc_info.value) == "engine dead"
- assert getattr(exc_info.value, "error_stage_id") == 2
-
def test_omni_generate_py_generator_yields_final_outputs_for_each_request(monkeypatch: pytest.MonkeyPatch):
sampling_params = [SamplingParams(max_tokens=8) for _ in range(3)]
@@ -741,255 +687,3 @@ def test_omni_forces_final_only_on_llm_stages(monkeypatch: pytest.MonkeyPatch):
assert submitted_params[1].output_kind == RequestOutputKind.FINAL_ONLY
assert submitted_params[2].output_kind == original_diffusion_output_kind
assert len(outputs) == 2
-
-
-def test_fatal_error_raises_engine_dead():
- base = _make_base()
- msg = {"type": "error", "error": "orchestrator crashed", "fatal": True}
-
- with pytest.raises(EngineDeadError, match="orchestrator crashed"):
- base._handle_output_message(msg)
-
-
-def test_non_fatal_error_raises_runtime():
- base = _make_base()
- msg = {"type": "error", "error": "something wrong"}
-
- with pytest.raises(RuntimeError, match="something wrong"):
- base._handle_output_message(msg)
-
-
-def test_async_omni_errored_property_alive():
- omni = object.__new__(AsyncOmni)
- omni.engine = SimpleNamespace(
- is_alive=lambda: True,
- stage_clients=[SimpleNamespace(is_comprehension=False)],
- )
-
- assert omni.errored is False
-
-
-def test_async_omni_errored_property_dead_engine():
- omni = object.__new__(AsyncOmni)
- omni.engine = SimpleNamespace(
- is_alive=lambda: False,
- stage_clients=[SimpleNamespace(is_comprehension=False)],
- )
-
- assert omni.errored is True
-
-
-def test_async_omni_errored_property_dead_stage():
- omni = object.__new__(AsyncOmni)
- dead_stage = SimpleNamespace(is_comprehension=False, _engine_dead=True)
- omni.engine = SimpleNamespace(
- is_alive=lambda: True,
- stage_clients=[dead_stage],
- )
-
- assert omni.errored is True
-
-
-def _enqueue_stage_error(
- engine: FakeAsyncOmniEngine,
- msg,
- *,
- error_text: str,
- kill_engine: bool = False,
-):
- """Enqueue a stage error output, optionally killing the engine."""
- if kill_engine:
- engine._alive = False
- engine.output_q.put_nowait(
- {
- "type": "output",
- "request_id": msg["request_id"],
- "stage_id": 0,
- "engine_outputs": SimpleNamespace(
- payload="",
- finished=True,
- images=[],
- stage_durations={},
- error=error_text,
- ),
- "finished": False,
- }
- )
-
-
-@pytest.mark.asyncio
-async def test_async_omni_propagates_engine_dead_error(monkeypatch: pytest.MonkeyPatch):
- """When the engine is dead and an error output arrives, ``generate()``
- must raise ``EngineDeadError`` (not plain ``RuntimeError``)."""
-
- engine = FakeAsyncOmniEngine(
- stage_metadata=THREE_STAGE_META,
- on_add_request=lambda eng, msg: _enqueue_stage_error(eng, msg, error_text="worker OOM", kill_engine=True),
- )
- _patch_engine(monkeypatch, engine)
-
- app = AsyncOmni("dummy-model")
- try:
- with pytest.raises(EngineDeadError, match="worker OOM"):
- async for _ in app.generate(prompt="hello", request_id="req-dead"):
- pass
- finally:
- app.shutdown()
-
-
-@pytest.mark.asyncio
-async def test_async_omni_propagates_engine_generate_error(monkeypatch: pytest.MonkeyPatch):
- """When the engine is alive but a stage error occurs, ``generate()``
- must raise ``EngineGenerateError`` (recoverable, not ``EngineDeadError``)."""
-
- engine = FakeAsyncOmniEngine(
- stage_metadata=THREE_STAGE_META,
- on_add_request=lambda eng, msg: _enqueue_stage_error(eng, msg, error_text="diffusion step failed"),
- )
- _patch_engine(monkeypatch, engine)
-
- app = AsyncOmni("dummy-model")
- try:
- with pytest.raises(EngineGenerateError):
- async for _ in app.generate(prompt="hello", request_id="req-recover"):
- pass
- finally:
- app.shutdown()
-
-
-# ───────── OmniBase.check_health() aggregation ─────────
-
-
-def test_check_health_passes_when_all_healthy():
- base = _make_base()
- healthy_stage = MagicMock()
- healthy_stage.check_health = MagicMock()
- base.engine.is_alive.return_value = True
- base.engine.stage_clients = [healthy_stage]
- base.check_health() # should not raise
-
-
-def test_check_health_raises_when_stage_dead():
- base = _make_base()
- dead_stage = MagicMock()
- dead_stage.check_health = MagicMock(side_effect=EngineDeadError("Stage-1 dead"))
- base.engine.is_alive.return_value = True
- base.engine.stage_clients = [dead_stage]
- with pytest.raises(EngineDeadError, match="Stage-1 dead"):
- base.check_health()
-
-
-def test_check_health_raises_when_orchestrator_dead():
- base = _make_base()
- base.engine.is_alive.return_value = False
- base.engine.stage_clients = []
- with pytest.raises(EngineDeadError, match="not alive"):
- base.check_health()
-
-
-# ───────── OmniBase.errored property ─────────
-
-
-def test_omni_base_errored_false_when_alive():
- base = _make_base()
- base.engine.is_alive.return_value = True
- base.engine.stage_clients = [SimpleNamespace()]
- assert base.errored is False
-
-
-def test_omni_base_errored_true_when_orchestrator_dead():
- base = _make_base()
- base.engine.is_alive.return_value = False
- base.engine.stage_clients = []
- assert base.errored is True
-
-
-def test_omni_base_errored_true_when_stage_engine_dead():
- base = _make_base()
- base.engine.is_alive.return_value = True
- dead_stage = SimpleNamespace(_engine_dead=True)
- base.engine.stage_clients = [dead_stage]
- assert base.errored is True
-
-
-def test_omni_base_errored_true_when_stage_resources_engine_dead():
- base = _make_base()
- base.engine.is_alive.return_value = True
- dead_stage = SimpleNamespace(resources=SimpleNamespace(engine_dead=True))
- base.engine.stage_clients = [dead_stage]
- assert base.errored is True
-
-
-# ───────── Omni (sync) EngineDeadError / EngineGenerateError ─────────
-
-
-def test_omni_propagates_engine_dead_error(monkeypatch: pytest.MonkeyPatch):
- """When the engine is dead and a stage error output arrives,
- ``Omni.generate()`` must raise ``EngineDeadError``."""
- engine = FakeAsyncOmniEngine(
- stage_metadata=THREE_STAGE_META,
- on_add_request=lambda eng, msg: _enqueue_stage_error(eng, msg, error_text="worker OOM", kill_engine=True),
- )
- _patch_engine(monkeypatch, engine)
-
- app = Omni("dummy-model")
- try:
- with pytest.raises(EngineDeadError, match="worker OOM"):
- list(app.generate(["hello"], py_generator=False, use_tqdm=False))
- finally:
- app.shutdown()
-
-
-def test_omni_propagates_engine_generate_error(monkeypatch: pytest.MonkeyPatch):
- """When the engine is alive but a stage error occurs,
- ``Omni.generate()`` must raise ``EngineGenerateError`` (recoverable)."""
- engine = FakeAsyncOmniEngine(
- stage_metadata=THREE_STAGE_META,
- on_add_request=lambda eng, msg: _enqueue_stage_error(eng, msg, error_text="diffusion step failed"),
- )
- _patch_engine(monkeypatch, engine)
-
- app = Omni("dummy-model")
- try:
- with pytest.raises(EngineGenerateError):
- list(app.generate(["hello"], py_generator=False, use_tqdm=False))
- finally:
- app.shutdown()
-
-
-def test_omni_errored_property_alive(monkeypatch: pytest.MonkeyPatch):
- """Omni.errored (inherited from OmniBase) returns False when healthy."""
- engine = FakeAsyncOmniEngine(stage_metadata=THREE_STAGE_META)
- _patch_engine(monkeypatch, engine)
-
- app = Omni("dummy-model")
- try:
- assert app.errored is False
- finally:
- app.shutdown()
-
-
-def test_omni_errored_property_dead_engine(monkeypatch: pytest.MonkeyPatch):
- """Omni.errored returns True when the orchestrator is dead."""
- engine = FakeAsyncOmniEngine(stage_metadata=THREE_STAGE_META)
- _patch_engine(monkeypatch, engine)
-
- app = Omni("dummy-model")
- try:
- engine._alive = False
- assert app.errored is True
- finally:
- app.shutdown()
-
-
-def test_omni_errored_property_dead_stage(monkeypatch: pytest.MonkeyPatch):
- """Omni.errored returns True when a stage client is marked dead."""
- engine = FakeAsyncOmniEngine(stage_metadata=THREE_STAGE_META)
- _patch_engine(monkeypatch, engine)
-
- app = Omni("dummy-model")
- try:
- engine.stage_clients[0]._engine_dead = True
- assert app.errored is True
- finally:
- app.shutdown()
diff --git a/tests/entrypoints/test_omni_sleep_mode.py b/tests/entrypoints/test_omni_sleep_mode.py
deleted file mode 100644
index aa7be1ba0f7..00000000000
--- a/tests/entrypoints/test_omni_sleep_mode.py
+++ /dev/null
@@ -1,336 +0,0 @@
-import asyncio
-import logging
-import os
-
-import pytest
-import torch
-from vllm import SamplingParams
-
-from tests.helpers.mark import hardware_test
-from vllm_omni.entrypoints.async_omni import AsyncOmni
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-from vllm_omni.platforms import current_omni_platform
-
-logging.basicConfig(level=logging.INFO)
-logger = logging.getLogger("OmniTest")
-pytestmark = [pytest.mark.advanced_model]
-
-
-def clean_gpu_envs():
- """clean up GPU environment variables to ensure tests run on all available devices."""
- device_visibility_vars = [
- "CUDA_VISIBLE_DEVICES", # NVIDIA
- "HIP_VISIBLE_DEVICES", # AMD ROCm
- "ZE_AFFINITY_MASK", # Intel XPU
- "ONEAPI_DEVICE_SELECTOR", # Intel OneAPI
- "ASCEND_RT_VISIBLE_DEVICES", # Huawei NPU (CAN)
- ]
- for key in device_visibility_vars:
- os.environ.pop(key, None)
-
-
-def get_vram_info(device_id: int) -> dict:
- """Obtain a snapshot of the specified GPU's memory (GiB)."""
- try:
- if current_omni_platform.is_rocm():
- num_gpus = torch.cuda.device_count()
- safe_id = device_id if device_id < num_gpus else 0
- torch.cuda.synchronize(safe_id)
- return {
- "reserved": torch.cuda.memory_reserved(safe_id) / 1024**3,
- "allocated": torch.cuda.memory_allocated(safe_id) / 1024**3,
- }
- else:
- with torch.cuda.device(device_id):
- torch.cuda.synchronize()
- return {
- "reserved": torch.cuda.memory_reserved() / 1024**3,
- "allocated": torch.cuda.memory_allocated() / 1024**3,
- }
- except Exception as e:
- logger.warning(f"memory skip ({device_id}): {e}")
- return {"reserved": 0.0, "allocated": 0.0}
-
-
-def get_ack_info(ack, key, default=None):
- """
- Since ACKs in a distributed environment can be either objects or dictionaries,
- this tool ensures compatibility.
- """
- if hasattr(ack, key):
- return getattr(ack, key)
- if isinstance(ack, dict):
- return ack.get(key, default)
- return default
-
-
-@pytest.fixture(scope="function")
-async def llm_engine():
- if current_omni_platform.is_rocm():
- clean_gpu_envs()
- model_name = "ByteDance-Seed/BAGEL-7B-MoT"
- common_args = {
- "worker_type": "ar",
- "enable_sleep_mode": True,
- "dtype": "bfloat16",
- "trust_remote_code": True,
- "max_model_len": 2048,
- "max_num_batched_tokens": 8192,
- "enforce_eager": True,
- }
- stages = [
- {
- "stage_id": 0,
- "stage_type": "llm",
- "runtime": {"process": True, "devices": "0", "max_batch_size": 1},
- "engine_args": {**common_args, "model_stage": "thinker", "gpu_memory_utilization": 0.1},
- },
- {
- "stage_id": 1,
- "stage_type": "llm",
- "engine_input_source": [0],
- "runtime": {"process": True, "devices": "1", "max_batch_size": 1, "connector_type": "queue"},
- "engine_args": {**common_args, "model_stage": "talker", "gpu_memory_utilization": 0.1},
- },
- ]
- connectors = [{"src_stage_id": 0, "dst_stage_id": 1, "connector_type": "queue"}]
- engine = AsyncOmni(model=model_name, stages=stages, connectors=connectors, init_timeout=600, enable_sleep_mode=True)
- yield engine
- engine.shutdown()
-
-
-@pytest.fixture(scope="function")
-async def diffusion_engine():
- if current_omni_platform.is_rocm():
- clean_gpu_envs()
- model_name = "ByteDance-Seed/BAGEL-7B-MoT"
- stages = [
- {
- "stage_id": 0,
- "stage_type": "diffusion",
- "runtime": {"process": True, "devices": "0,1", "max_batch_size": 1},
- "engine_args": {
- "model_stage": "base",
- "gpu_memory_utilization": 0.1,
- "model_class_name": "BagelPipeline",
- "enable_sleep_mode": True,
- "enforce_eager": True,
- "max_num_batched_tokens": 8192,
- "parallel_config": {
- "tensor_parallel_size": 2,
- },
- },
- "final_output": True,
- "final_output_type": "image",
- }
- ]
- engine = AsyncOmni(model=model_name, stages=stages, init_timeout=600, enable_sleep_mode=True)
- yield engine
- engine.shutdown()
-
-
-class TestOmniSleepMode:
- @pytest.mark.asyncio
- @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=1)
- async def test_llm_sleep_ack(self, llm_engine: AsyncOmni):
- """LLM Thinker (GPU0) Signal and Physical Recycling Audit"""
- try:
- acks = await llm_engine.sleep(stage_ids=[0], level=2)
- # Verification signal successful
- assert all(get_ack_info(ack, "status") == "SUCCESS" for ack in acks)
- # Verify physical recycling volume
- total_freed_bytes = sum(get_ack_info(ack, "freed_bytes", 0) for ack in acks)
- freed_gib = total_freed_bytes / 1024**3
- logger.info(f"Thinker VRAM physically reclaimed: {freed_gib:.2f} GiB")
- assert freed_gib > 5.0
- finally:
- llm_engine.shutdown()
-
- @pytest.mark.asyncio
- @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=2)
- async def test_diffusion_sleep_handshake(self, diffusion_engine: AsyncOmni):
- """Diffusion Worker stage signal loop"""
- try:
- logger.info("Starting Diffusion Worker Handshake Test")
- acks = await diffusion_engine.sleep(stage_ids=[0], level=2)
-
- def _get_status(ack):
- return ack.status if hasattr(ack, "status") else ack.get("status")
-
- assert len(acks) >= 1, "Expected at least 1 ACK from Diffusion Workers"
- assert all(_get_status(ack) == "SUCCESS" for ack in acks)
- logger.info(f"Success: Received {len(acks)} Diffusion Worker ACKs")
- logger.info("Testing auto-wakeup before test end...")
- await diffusion_engine.wake_up(stage_ids=[0])
- logger.info("Test logic finished, triggering manual shutdown...")
- finally:
- diffusion_engine.shutdown()
- logger.info("Manual shutdown executed. Test should exit now.")
-
- @pytest.mark.asyncio
- @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=2)
- async def test_cross_device_cleanup(self, diffusion_engine: AsyncOmni):
- """Physical recycling audit: leveraging deterministic data returned by Workers"""
- try:
- acks = await diffusion_engine.sleep(stage_ids=[0], level=1)
- # Sum up the release amounts reported by all Workers.
- total_freed_bytes = sum(get_ack_info(ack, "freed_bytes", 0) for ack in acks)
- freed_gb = total_freed_bytes / 1024**3
- logger.info("Physical reclamation summary from workers:")
- logger.info(f"- Total Workers: {len(acks)}")
- logger.info(f"- Total Freed: {freed_gb:.2f} GiB")
- assert freed_gb > 14.0
- logger.info("SUCCESS: 100% weights offloaded.")
- finally:
- diffusion_engine.shutdown()
-
- @pytest.mark.asyncio
- @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=2)
- async def test_diffusion_integrity_bit_level(self, diffusion_engine: AsyncOmni):
- """Bit-level consistency after Diffusion wake-up (prevent image corruption)"""
- try:
- prompt = "A huge swimming pool, with many people swimming."
- sp = OmniDiffusionSamplingParams(num_inference_steps=4, height=512, width=512, seed=42)
- llm_sp = SamplingParams()
-
- # Baseline Generation
- logger.info("Running Baseline Generation...")
- base_output = None
- async for output in diffusion_engine.generate(prompt, request_id="base", sampling_params_list=[llm_sp, sp]):
- base_output = output
- assert base_output is not None and len(base_output.images) > 0
- logger.info("Baseline Generation successful.")
- # Sleep Level 2
- logger.info("Entering Deep Sleep (VRAM Scavenging)...")
- await diffusion_engine.sleep(stage_ids=[0], level=2)
- # Wake-up
- logger.info("Waking up (Reloading Weights)...")
- await diffusion_engine.wake_up(stage_ids=[0])
-
- await asyncio.sleep(2.0)
- import gc
-
- gc.collect()
-
- logger.info("Running Post-Wakeup Generation...")
- post_output = None
- async for output in diffusion_engine.generate(prompt, request_id="post", sampling_params_list=[llm_sp, sp]):
- post_output = output
- # Assert result consistency
- assert post_output is not None
- assert len(base_output.images) == len(post_output.images)
- assert post_output.images[0] is not None
- logger.info("SUCCESS: Diffusion integrity verified after Sleep/Wake cycle.")
- except Exception as e:
- logger.error(f"Integrity test failed: {e}")
- raise e
- finally:
- logger.info("Triggering mandatory cleanup...")
- diffusion_engine.shutdown()
- logger.info("Cleanup complete, test exiting.")
-
- @pytest.mark.asyncio
- @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=2)
- async def test_coordinated_cross_device(self, llm_engine: AsyncOmni, diffusion_engine: AsyncOmni):
- """Heterogeneous Coordinated Cleanup Test (Talker and Diffusion on GPU 1)"""
- device_id = 1
- try:
- logger.info(f"Waking up both engines on GPU {device_id}...")
- await llm_engine.wake_up(stage_ids=[1])
- await diffusion_engine.wake_up(stage_ids=[0])
-
- get_vram_info(device_id)
- torch.cuda.empty_cache()
- await asyncio.sleep(2)
-
- initial_vram = get_vram_info(device_id)["reserved"]
- logger.info(f"GPU {device_id} Peak Pressure: {initial_vram:.2f} GiB")
-
- # coordinated sleep
- logger.info("Issuing concurrent SLEEP commands...")
- await llm_engine.sleep(stage_ids=[1], level=2)
- await asyncio.sleep(1.0)
- await diffusion_engine.sleep(stage_ids=[0], level=2)
-
- await asyncio.sleep(3.0)
- torch.cuda.empty_cache()
-
- final_vram = get_vram_info(device_id)["reserved"]
- logger.info(f"GPU {device_id} Final VRAM after coordinated sleep: {final_vram:.2f} GiB")
-
- assert initial_vram - final_vram > 15.0 or final_vram < 8.0
- logger.info(f"SUCCESS: Heterogeneous VRAM drop verified on GPU {device_id}.")
- except Exception as e:
- logger.error(f"Coordinated test failed: {e}")
- raise e
- finally:
- logger.info("Triggering mandatory cleanup for both engines...")
- llm_engine.shutdown()
- diffusion_engine.shutdown()
- logger.info("All engines scavenged. Ready for next test.")
-
- @pytest.mark.asyncio
- @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=2)
- async def test_diffusion_vram_lifecycle_audit(self, diffusion_engine: AsyncOmni):
- """Diffusion memory loop: Active -> Deep Sleep -> Active -> inference sanity check"""
- device_id = 1
- try:
- get_vram_info(device_id)
- torch.cuda.empty_cache()
- vram_initial = get_vram_info(device_id)["reserved"]
- logger.info(f"Diffusion Initial VRAM: {vram_initial:.2f} GiB")
-
- # Sleep
- logger.info("Triggering Level 2 Deep Sleep (Full Weight Offloading)...")
- acks = await diffusion_engine.sleep(stage_ids=[0], level=2)
-
- reported_freed_bytes = sum(getattr(ack, "freed_bytes", 0) for ack in acks)
- reported_freed_gib = reported_freed_bytes / 1024**3
- logger.info(f"Worker internally reported freed: {reported_freed_gib:.2f} GiB")
-
- await asyncio.sleep(2)
- get_vram_info(device_id)
- torch.cuda.empty_cache()
-
- vram_sleeping = get_vram_info(device_id)["reserved"]
- logger.info(f"External VRAM measurement during Sleep: {vram_sleeping:.2f} GiB")
-
- assert reported_freed_gib > 14.0 or vram_sleeping < 5.0, (
- f"Reclamation failed. Reported: {reported_freed_gib:.2f}G, Measured: {vram_sleeping:.2f}G"
- )
-
- # wake-up
- logger.info("Triggering Wake-up (Reloading weights to GPU)...")
- await diffusion_engine.wake_up(stage_ids=[0])
-
- await asyncio.sleep(2)
- get_vram_info(device_id)
- torch.cuda.empty_cache()
- vram_restored = get_vram_info(device_id)["reserved"]
- logger.info(f"VRAM after Wake-up: {vram_restored:.2f} GiB")
-
- assert abs(vram_restored - vram_initial) < 3.0, "VRAM failed to restore to initial levels"
-
- # inference sanity check
- logger.info("Running post-lifecycle inference smoke test...")
- prompt = "A futuristic lab with glowing lights, high quality."
- sp = OmniDiffusionSamplingParams(num_inference_steps=2, height=512, width=512, seed=42)
- llm_sp = SamplingParams()
-
- base_img_found = False
- async for output in diffusion_engine.generate(
- prompt, request_id="lifecycle-check", sampling_params_list=[llm_sp, sp]
- ):
- if output.images and output.images[0] is not None:
- base_img_found = True
-
- assert base_img_found, "Inference failed after Wake-up cycle!"
- logger.info("SUCCESS: Full Diffusion Lifecycle (Active -> Sleep -> Active -> Generate) audited.")
-
- except Exception as e:
- logger.error(f"Lifecycle audit failed: {e}")
- raise e
- finally:
- logger.info("Cleaning up engine and scavenging processes...")
- diffusion_engine.shutdown()
- await asyncio.sleep(1)
diff --git a/tests/entrypoints/test_pd_disaggregation.py b/tests/entrypoints/test_pd_disaggregation.py
deleted file mode 100644
index 5ffabfbf2af..00000000000
--- a/tests/entrypoints/test_pd_disaggregation.py
+++ /dev/null
@@ -1,1222 +0,0 @@
-"""Unit tests for PD (Prefill-Decode) disaggregation in the Omni orchestrator.
-
-Tests the PD detection, validation, config parsing, sampling param
-preparation, and routing logic added by the PD disaggregation feature
-(issue #1188). All tests run without GPU.
-
-NOTE (v1908 adaptation): Tests that relied on the old OmniStage / stage_list
-architecture (removed in PR #1908) are marked xfail with
-``reason="Requires migration to v1908 Orchestrator architecture"``.
-The remaining tests exercise PDDisaggregationMixin directly and work
-without spinning up a real engine.
-"""
-
-import uuid
-import warnings
-from queue import Empty, Queue
-from types import SimpleNamespace
-from typing import Any
-
-import pytest
-from vllm import SamplingParams
-
-from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin
-
-pytestmark = pytest.mark.skip(reason="Temporarily skip PD entrypoint tests while PD config is being removed.")
-
-# Suppress noisy DeprecationWarnings from optional Swig bindings imported by vLLM dependencies.
-warnings.filterwarnings(
- "ignore",
- message=r"builtin type SwigPy.*has no __module__ attribute",
- category=DeprecationWarning,
-)
-
-
-def _ns(**kwargs):
- """Create a lightweight attribute object for tests."""
- return SimpleNamespace(**kwargs)
-
-
-# ---------------------------------------------------------------------------
-# Fake helpers (same pattern as test_omni_llm.py)
-# ---------------------------------------------------------------------------
-
-
-class _FakeEngineArgs(dict):
- """Fake engine args that supports both attribute and dict access."""
-
- def __init__(self, args_dict: dict[str, Any]):
- super().__init__(args_dict)
- if "model_stage" not in self:
- self["model_stage"] = None
- if "engine_output_type" not in self:
- self["engine_output_type"] = None
- for key, value in self.items():
- setattr(self, key, value)
-
-
-class _FakeStageConfig:
- def __init__(self, config_dict: dict[str, Any]):
- engine_args_dict = config_dict.get("engine_args", {})
- self.engine_args = _FakeEngineArgs(engine_args_dict)
- self.final_output = config_dict.get("final_output", False)
- self.final_output_type = config_dict.get("final_output_type", None)
- self.stage_id = config_dict.get("stage_id", 0)
- self.is_prefill_only = config_dict.get("is_prefill_only", False)
- self.is_decode_only = config_dict.get("is_decode_only", False)
- self.engine_input_source = config_dict.get("engine_input_source", [])
- self.is_comprehension = config_dict.get("is_comprehension", False)
- self._config_dict = config_dict
-
-
-class _FakeQueue:
- def __init__(self, maxsize=0):
- self._queue = Queue(maxsize=maxsize)
-
- def put(self, item):
- self._queue.put(item)
-
- def put_nowait(self, item):
- self._queue.put_nowait(item)
-
- def get(self):
- return self._queue.get()
-
- def get_nowait(self):
- return self._queue.get_nowait()
-
- def empty(self):
- return self._queue.empty()
-
-
-class _FakeStage:
- """Lightweight stage stub with PD disaggregation flag support."""
-
- def __init__(self, config, stage_init_timeout: int = 300):
- if isinstance(config, dict):
- config = _FakeStageConfig(config)
- self.config = config
- self.stage_config = config
- self.engine = None
- self.engine_outputs = None
- self.stage_id = getattr(config, "stage_id", 0)
- self.engine_args = config.engine_args
- self.model_stage = getattr(config.engine_args, "model_stage", None)
- self.stage_type = "llm"
- self.default_sampling_params = SamplingParams(temperature=1.0)
- self.final_output = config.final_output if hasattr(config, "final_output") else False
- self.final_output_type = getattr(config, "final_output_type", None)
- self.is_prefill_only = getattr(config, "is_prefill_only", False)
- self.is_decode_only = getattr(config, "is_decode_only", False)
- self.engine_input_source = getattr(config, "engine_input_source", [])
- self.is_comprehension = getattr(config, "is_comprehension", False)
- processed_input = getattr(config, "_config_dict", {}).get("processed_input", ["processed"])
- self._processed_input = processed_input
- self._in_q = None
- self._out_q = None
- self._proc = None
- self._stage_init_timeout = max(0, int(stage_init_timeout))
-
- def attach_queues(self, in_q, out_q):
- self._in_q = in_q
- self._out_q = out_q
-
- def init_stage_worker(
- self, model: str, *, is_async=False, shm_threshold_bytes=65536, ctx=None, batch_timeout=10, **kwargs
- ):
- self._proc = _ns(
- start=lambda: None,
- join=lambda timeout=None: None,
- is_alive=lambda: False,
- terminate=lambda: None,
- )
- if self._out_q is not None:
- try:
- self._out_q.put_nowait({"type": "stage_ready", "stage_id": self.stage_id})
- except Exception:
- pass
-
- def stop_stage_worker(self):
- if self._in_q is not None:
- try:
- self._in_q.put_nowait({"type": "shutdown"})
- except Exception:
- pass
-
- def submit(self, payload: dict[str, Any]):
- if self._in_q is not None:
- self._in_q.put(payload)
-
- def try_collect(self) -> Any:
- if self._out_q is None:
- return None
- try:
- return self._out_q.get_nowait()
- except Empty:
- return None
-
- def set_engine_outputs(self, outputs):
- self.engine_outputs = outputs
-
- def process_engine_inputs(self, stage_list, prompts):
- return self._processed_input
-
-
-# ---------------------------------------------------------------------------
-# Shared mock setup helpers
-# ---------------------------------------------------------------------------
-
-
-def _setup_engine_mocks(monkeypatch):
- fake_engine = _ns()
- fake_engine.tokenizer = _ns()
- fake_engine.log_stats = False
- fake_engine.vllm_config = _ns()
- fake_engine.vllm_config.model_config = _ns()
- fake_engine.vllm_config.model_config.io_processor_plugin = None
- fake_engine.get_supported_tasks = lambda: []
- fake_engine.model_config = _ns()
- fake_engine.model_config.io_processor_plugin = None
- fake_registry = _ns()
- fake_registry.resolve_model_cls = lambda *args, **kwargs: (_ns(), "test_arch")
- fake_engine.model_config.registry = fake_registry
- fake_engine.vllm_config.model_config.registry = fake_registry
-
- monkeypatch.setattr(
- "vllm.v1.engine.llm_engine.LLMEngine.from_engine_args",
- lambda **kw: fake_engine,
- raising=False,
- )
-
- class FakeModelClass:
- pass
-
- monkeypatch.setattr(
- "vllm.model_executor.model_loader.utils.get_model_architecture",
- lambda model_config: (FakeModelClass, "test_arch"),
- raising=False,
- )
- monkeypatch.setattr(
- "vllm.model_executor.model_loader.utils._get_model_architecture",
- lambda model_config: (FakeModelClass, "test_arch"),
- raising=False,
- )
- monkeypatch.setattr(
- "vllm.model_executor.models.adapters.try_create_mm_pooling_model_cls",
- lambda model_cls: model_cls,
- raising=False,
- )
- monkeypatch.setattr(
- "vllm.multimodal.cache._enable_processor_cache",
- lambda model_config, mm_registry: False,
- raising=False,
- )
- monkeypatch.setattr(
- "vllm.plugins.io_processors.get_io_processor",
- lambda vllm_config, io_processor_plugin: None,
- raising=False,
- )
-
-
-def _setup_multiprocessing_mocks(monkeypatch):
- import multiprocessing as mp
-
- fake_process_instance = _ns(
- start=lambda: None,
- join=lambda timeout=None: None,
- is_alive=lambda: False,
- terminate=lambda: None,
- )
-
- def fake_process_class(*args, **kwargs):
- return fake_process_instance
-
- fake_ctx = _ns()
- fake_ctx.Queue = lambda maxsize=0: _FakeQueue(maxsize=maxsize)
- fake_ctx.Process = fake_process_class
-
- monkeypatch.setattr(mp, "get_context", lambda method: fake_ctx, raising=False)
- monkeypatch.setattr(mp, "Process", fake_process_class, raising=False)
-
-
-def _setup_ipc_mocks(monkeypatch):
- # These IPC helpers existed in the old architecture; no-op in new arch.
- pass
-
-
-def _setup_log_mocks(monkeypatch):
- class _FakeOrchestratorAggregator:
- def __init__(self, num_stages, enable_stats, wall_start_ts, final_stage_id_for_e2e=None):
- self.num_stages = num_stages
- self.enable_stats = enable_stats
- self.stage_first_ts = [None] * num_stages
- self.stage_last_ts = [None] * num_stages
- self.stage_total_tokens = [0] * num_stages
- self.accumulated_gen_time_ms = {}
- self.e2e_done = set()
- self.e2e_count = 0
- self.e2e_total_ms = 0.0
-
- def on_stage_metrics(self, stage_id, req_id, metrics, final_output_type=None):
- pass
-
- def on_finalize_request(self, stage_id, req_id, start_ts):
- self.e2e_done.add(req_id)
-
- def on_forward(self, from_stage, to_stage, req_id, size_bytes, tx_ms, use_shm):
- pass
-
- def accumulate_diffusion_metrics(self, stage_type, req_id, engine_outputs):
- pass
-
- def record_audio_generated_frames(self, output, stage_id, req_id):
- pass
-
- def stage_postprocess_timer(self, stage_id, req_id):
- from contextlib import contextmanager
-
- @contextmanager
- def _noop():
- yield
-
- return _noop()
-
- def build_and_log_summary(self):
- return "Fake summary"
-
- monkeypatch.setattr(
- "vllm_omni.entrypoints.omni.OrchestratorAggregator",
- _FakeOrchestratorAggregator,
- raising=False,
- )
-
-
-def _clear_modules():
- import sys
-
- for module_name in [
- "vllm_omni.entrypoints.utils",
- "vllm_omni.entrypoints.omni",
- ]:
- if module_name in sys.modules:
- del sys.modules[module_name]
-
-
-@pytest.fixture(autouse=True)
-def mock_get_config(monkeypatch):
- """Auto-mock get_config and related model loading functions."""
- import sys
-
- fake_tokenizer = _ns()
- fake_tokenizer.encode = lambda *args, **kwargs: [1, 2, 3]
- fake_tokenizer.decode = lambda *args, **kwargs: "test"
-
- def _mock_init_tokenizer_from_configs(model_config=None, **kwargs):
- return fake_tokenizer
-
- monkeypatch.setattr(
- "vllm.transformers_utils.tokenizer.init_tokenizer_from_configs",
- _mock_init_tokenizer_from_configs,
- raising=False,
- )
- tokenizer_module_path = "vllm.transformers_utils.tokenizer"
- if tokenizer_module_path in sys.modules:
- setattr(sys.modules[tokenizer_module_path], "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs)
-
- def _mock_length_from_prompt_token_ids_or_embeds(prompt_token_ids=None, prompt_embeds=None):
- if prompt_token_ids is not None:
- if isinstance(prompt_token_ids, list):
- return len(prompt_token_ids)
- return 10
-
- monkeypatch.setattr(
- "vllm.utils.length_from_prompt_token_ids_or_embeds", _mock_length_from_prompt_token_ids_or_embeds, raising=False
- )
- monkeypatch.setattr(
- "vllm_omni.engine.input_processor.length_from_prompt_token_ids_or_embeds",
- _mock_length_from_prompt_token_ids_or_embeds,
- raising=False,
- )
-
- processor_module_path = "vllm_omni.engine.input_processor"
- if processor_module_path in sys.modules:
- setattr(
- sys.modules[processor_module_path],
- "length_from_prompt_token_ids_or_embeds",
- _mock_length_from_prompt_token_ids_or_embeds,
- )
-
- monkeypatch.setattr(
- "vllm_omni.entrypoints.async_omni.init_tokenizer_from_configs", _mock_init_tokenizer_from_configs, raising=False
- )
- async_omni_path = "vllm_omni.entrypoints.async_omni"
- if async_omni_path in sys.modules:
- setattr(sys.modules[async_omni_path], "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs)
-
- fake_hf_config = _ns()
- fake_hf_config.model_type = "qwen2_5_omni"
-
- monkeypatch.setattr(
- "vllm.transformers_utils.config.get_config", lambda model, **kwargs: fake_hf_config, raising=False
- )
- monkeypatch.setattr("vllm_omni.entrypoints.utils.get_config", lambda model, **kwargs: fake_hf_config, raising=False)
-
- def _mock_cached_file(path_or_repo_id, *args, **kwargs):
- import os
- import tempfile
-
- fake_config_file = os.path.join(tempfile.gettempdir(), "fake_config.json")
- if not os.path.exists(fake_config_file):
- with open(fake_config_file, "w") as f:
- f.write('{"model_type": "qwen2_5_omni"}')
- return fake_config_file
-
- monkeypatch.setattr("transformers.utils.hub.cached_file", _mock_cached_file, raising=False)
- monkeypatch.setattr(
- "transformers.utils.hub.cached_files",
- lambda path_or_repo_id, filenames, **kwargs: (
- [_mock_cached_file(path_or_repo_id, filenames[0])] if filenames else None
- ),
- raising=False,
- )
-
-
-# ---------------------------------------------------------------------------
-# Helper to build an Omni instance with PD stage configs
-# ---------------------------------------------------------------------------
-
-
-def _make_pd_omni(monkeypatch, stage_configs, *, extra_setup=None):
- """Create a lightweight PDDisaggregationMixin instance for unit tests.
-
- Bypasses the full OmniBase / AsyncOmniEngine init chain so tests run
- without GPU. Returns an object that has all PDDisaggregationMixin
- methods and state (``_pd_separation_pair``, ``_pd_kv_params_by_req``,
- etc.) initialised from *stage_configs*.
-
- Tests that need the full ``Omni.generate()`` loop (old stage_list / queue
- infrastructure) are marked ``xfail`` and not covered here.
- """
- configs = [_FakeStageConfig(c) for c in stage_configs]
-
- class _LightweightOmni(PDDisaggregationMixin):
- """Minimal shim: exposes stage_configs so PDDisaggregationMixin works."""
-
- def __init__(self):
- self._name = "Omni"
- self._stage_configs = configs
- self._init_pd_state()
-
- @property
- def stage_configs(self):
- return self._stage_configs
-
- if extra_setup:
- import vllm_omni.entrypoints.omni as omni_module
-
- extra_setup(monkeypatch, omni_module)
-
- return _LightweightOmni()
-
-
-# ---------------------------------------------------------------------------
-# Stage config templates
-# ---------------------------------------------------------------------------
-
-
-def _prefill_stage_cfg(stage_id=0, **overrides):
- cfg = {
- "stage_id": stage_id,
- "engine_args": {
- "model_stage": "thinker",
- "kv_transfer_config": {
- "kv_connector": "MooncakeConnector",
- "kv_role": "kv_producer",
- "kv_rank": 0,
- "kv_parallel_size": 2,
- "kv_connector_extra_config": {"mooncake_bootstrap_port": 25201},
- },
- },
- "is_prefill_only": True,
- "final_output": False,
- "is_comprehension": True,
- }
- cfg.update(overrides)
- return cfg
-
-
-def _decode_stage_cfg(stage_id=1, engine_input_source=None, **overrides):
- cfg = {
- "stage_id": stage_id,
- "engine_args": {
- "model_stage": "thinker",
- "kv_transfer_config": {
- "kv_connector": "MooncakeConnector",
- "kv_role": "kv_consumer",
- "kv_rank": 1,
- "kv_parallel_size": 2,
- "kv_connector_extra_config": {"mooncake_bootstrap_port": 25202},
- },
- },
- "is_decode_only": True,
- "engine_input_source": engine_input_source if engine_input_source is not None else [0],
- "final_output": True,
- "final_output_type": "text",
- "is_comprehension": True,
- }
- cfg.update(overrides)
- return cfg
-
-
-def _talker_stage_cfg(stage_id=2, engine_input_source=None, **overrides):
- cfg = {
- "stage_id": stage_id,
- "engine_args": {"model_stage": "talker"},
- "engine_input_source": engine_input_source if engine_input_source is not None else [1],
- "final_output": False,
- }
- cfg.update(overrides)
- return cfg
-
-
-def _code2wav_stage_cfg(stage_id=3, engine_input_source=None, **overrides):
- cfg = {
- "stage_id": stage_id,
- "engine_args": {"model_stage": "code2wav"},
- "engine_input_source": engine_input_source if engine_input_source is not None else [2],
- "final_output": True,
- "final_output_type": "audio",
- }
- cfg.update(overrides)
- return cfg
-
-
-# ===================================================================
-# Tests: PD pair detection
-# ===================================================================
-
-
-class TestDetectPDSeparation:
- """Tests for Omni._detect_pd_separation()."""
-
- def test_detects_pd_pair(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(stage_id=0),
- _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
- ],
- )
- assert omni._pd_separation_pair == (0, 1)
-
- def test_no_pd_pair_without_flags(self, monkeypatch):
- """Normal (non-PD) pipeline has no PD pair."""
- omni = _make_pd_omni(
- monkeypatch,
- [
- {
- "stage_id": 0,
- "engine_args": {"model_stage": "thinker"},
- "final_output": True,
- "final_output_type": "text",
- },
- {
- "stage_id": 1,
- "engine_args": {"model_stage": "talker"},
- "engine_input_source": [0],
- "final_output": True,
- "final_output_type": "audio",
- },
- ],
- )
- assert omni._pd_separation_pair is None
-
- def test_detects_pd_pair_in_4_stage_pipeline(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(stage_id=0),
- _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
- _talker_stage_cfg(stage_id=2, engine_input_source=[1]),
- _code2wav_stage_cfg(stage_id=3, engine_input_source=[2]),
- ],
- )
- assert omni._pd_separation_pair == (0, 1)
-
- def test_pd_pair_uses_stage_id_for_input_source(self, monkeypatch):
- """engine_input_source references stage_id, not list index."""
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(stage_id=10),
- _decode_stage_cfg(stage_id=20, engine_input_source=[10]),
- ],
- )
- assert omni._pd_separation_pair == (0, 1)
-
-
-# ===================================================================
-# Tests: PD config validation
-# ===================================================================
-
-
-class TestValidatePDConfig:
- """Tests for Omni._validate_pd_separation_config()."""
-
- def test_valid_config_passes(self, monkeypatch):
- """Valid PD config should not raise."""
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- # If we got here without error, validation passed
- assert omni._pd_separation_pair == (0, 1)
-
- def test_mismatched_connector_raises(self, monkeypatch):
- """Different kv_connector types should raise ValueError."""
- decode_cfg = _decode_stage_cfg(engine_input_source=[0])
- decode_cfg["engine_args"]["kv_transfer_config"]["kv_connector"] = "NixlConnector"
-
- with pytest.raises(ValueError, match="connector mismatch"):
- _make_pd_omni(monkeypatch, [_prefill_stage_cfg(), decode_cfg])
-
- def test_wrong_prefill_role_raises(self, monkeypatch):
- """Prefill with kv_consumer role should raise."""
- prefill_cfg = _prefill_stage_cfg()
- prefill_cfg["engine_args"]["kv_transfer_config"]["kv_role"] = "kv_consumer"
-
- with pytest.raises(ValueError, match="kv_role must be"):
- _make_pd_omni(monkeypatch, [prefill_cfg, _decode_stage_cfg(engine_input_source=[0])])
-
- def test_wrong_decode_role_raises(self, monkeypatch):
- """Decode with kv_producer role should raise."""
- decode_cfg = _decode_stage_cfg(engine_input_source=[0])
- decode_cfg["engine_args"]["kv_transfer_config"]["kv_role"] = "kv_producer"
-
- with pytest.raises(ValueError, match="kv_role must be"):
- _make_pd_omni(monkeypatch, [_prefill_stage_cfg(), decode_cfg])
-
- def test_missing_kv_transfer_config_raises(self, monkeypatch):
- """Missing kv_transfer_config should raise."""
- prefill_cfg = _prefill_stage_cfg()
- del prefill_cfg["engine_args"]["kv_transfer_config"]
-
- with pytest.raises(ValueError, match="kv_transfer_config"):
- _make_pd_omni(monkeypatch, [prefill_cfg, _decode_stage_cfg(engine_input_source=[0])])
-
- def test_mismatched_buffer_device_raises(self, monkeypatch):
- """Mismatched kv_buffer_device should raise."""
- prefill_cfg = _prefill_stage_cfg()
- prefill_cfg["engine_args"]["kv_transfer_config"]["kv_buffer_device"] = "cuda"
- decode_cfg = _decode_stage_cfg(engine_input_source=[0])
- decode_cfg["engine_args"]["kv_transfer_config"]["kv_buffer_device"] = "cpu"
-
- with pytest.raises(ValueError, match="kv_buffer_device mismatch"):
- _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg])
-
-
-# ===================================================================
-# Tests: Connector info extraction
-# ===================================================================
-
-
-class TestGetPDConnectorInfo:
- """Tests for Omni._get_pd_connector_info()."""
-
- def test_extracts_bootstrap_addr_for_mooncake(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- info = omni._pd_connector_info
- assert "prefill_bootstrap_addr" in info
- assert info["prefill_bootstrap_addr"] == "127.0.0.1:25201"
-
- def test_none_for_non_pd_pipeline(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- {"stage_id": 0, "engine_args": {}, "final_output": True, "final_output_type": "text"},
- ],
- )
- assert omni._pd_connector_info is None
-
-
-# ===================================================================
-# Tests: Prefill sampling params preparation
-# ===================================================================
-
-
-class TestPreparePrefillSamplingParams:
- """Tests for Omni._prepare_prefill_sampling_params()."""
-
- def test_sets_max_tokens_to_1(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- sp = SamplingParams(max_tokens=2048)
- result = omni._prepare_prefill_sampling_params("req-1", sp)
-
- assert result.max_tokens == 1
- assert result is not sp # should be cloned
-
- def test_injects_kv_transfer_params(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- sp = SamplingParams(max_tokens=2048)
- result = omni._prepare_prefill_sampling_params("req-1", sp)
-
- kv_params = result.extra_args["kv_transfer_params"]
- assert kv_params["do_remote_decode"] is True
- assert kv_params["do_remote_prefill"] is False
- assert kv_params["transfer_id"] == "xfer-req-1"
-
- def test_preserves_existing_extra_args(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- sp = SamplingParams(max_tokens=2048, extra_args={"custom_key": "value"})
- result = omni._prepare_prefill_sampling_params("req-1", sp)
-
- assert result.extra_args["custom_key"] == "value"
- assert "kv_transfer_params" in result.extra_args
-
- def test_does_not_mutate_original(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- sp = SamplingParams(max_tokens=2048)
- _ = omni._prepare_prefill_sampling_params("req-1", sp)
-
- assert sp.max_tokens == 2048
- assert sp.extra_args is None
-
-
-# ===================================================================
-# Tests: Sampling params auto-duplication for PD split
-# ===================================================================
-
-
-@pytest.mark.xfail(reason="Requires migration to v1908 Orchestrator architecture (no stage_list / OmniStage)")
-class TestSamplingParamsAutoDuplication:
- """When user provides N-1 sampling params (for logical stages), the
- orchestrator should auto-duplicate the thinker params for the decode stage.
- """
-
- def test_auto_duplicates_for_4_stage_pipeline(self, monkeypatch):
- """User provides 3 params for 4 physical stages -> auto-insert decode params."""
- test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000001")
-
- def _extra_setup(mp, omni_module):
- mp.setattr(uuid, "uuid4", lambda: test_uuid)
- mp.setattr(omni_module, "uuid", uuid)
-
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(stage_id=0),
- _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
- _talker_stage_cfg(stage_id=2, engine_input_source=[1]),
- _code2wav_stage_cfg(stage_id=3, engine_input_source=[2]),
- ],
- extra_setup=_extra_setup,
- )
-
- assert omni._pd_separation_pair == (0, 1)
- assert len(omni.stage_list) == 4
-
- # Simulate outputs for all stages
- expected_rid = f"0_{test_uuid}"
- for i in range(4):
- omni.stage_list[i]._out_q.put_nowait(
- {
- "request_id": expected_rid,
- "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1, 2])])],
- "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
- }
- )
-
- # Provide 3 params (one less than 4 stages) - should auto-duplicate
- sp_thinker = SamplingParams(temperature=0.4, max_tokens=2048)
- sp_talker = SamplingParams(temperature=0.9, max_tokens=4096)
- sp_code2wav = SamplingParams(temperature=0.0, max_tokens=65536)
-
- # This should NOT raise ValueError about param count mismatch
- outputs = omni.generate(
- prompts=["hello"],
- sampling_params_list=[sp_thinker, sp_talker, sp_code2wav],
- )
- assert isinstance(outputs, list)
-
-
-# ===================================================================
-# Tests: KV transfer params normalization
-# ===================================================================
-
-
-class TestNormalizeKVTransferParams:
- def test_dict_passthrough(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- d = {"transfer_id": "test", "do_remote_decode": True}
- assert omni._normalize_kv_transfer_params(d) is d
-
- def test_none_returns_none(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- assert omni._normalize_kv_transfer_params(None) is None
-
- def test_dataclass_to_dict(self, monkeypatch):
- from dataclasses import dataclass
-
- @dataclass
- class FakeKVParams:
- transfer_id: str = "test"
- do_remote_decode: bool = True
-
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- result = omni._normalize_kv_transfer_params(FakeKVParams())
- assert isinstance(result, dict)
- assert result["transfer_id"] == "test"
-
-
-# ===================================================================
-# Tests: _kv_cfg_to_dict
-# ===================================================================
-
-
-class TestKvCfgToDict:
- def test_dict_passthrough(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- d = {"kv_connector": "MooncakeConnector"}
- assert omni._kv_cfg_to_dict(d) is d
-
- def test_none_returns_empty(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- assert omni._kv_cfg_to_dict(None) == {}
-
- def test_dataclass_converted(self, monkeypatch):
- from dataclasses import dataclass
-
- @dataclass
- class FakeCfg:
- kv_connector: str = "TestConnector"
- kv_role: str = "kv_producer"
-
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- result = omni._kv_cfg_to_dict(FakeCfg())
- assert result["kv_connector"] == "TestConnector"
- assert result["kv_role"] == "kv_producer"
-
-
-# ===================================================================
-# Tests: PD routing in scheduling loop
-# ===================================================================
-
-
-@pytest.mark.xfail(reason="Requires migration to v1908 Orchestrator architecture (no stage_list / OmniStage)")
-class TestPDRouting:
- """Test that the scheduling loop correctly routes requests from
- prefill to decode stage with proper kv_transfer_params.
- """
-
- def test_prefill_stage_receives_max_tokens_1(self, monkeypatch):
- """Stage 0 (prefill) should receive max_tokens=1."""
- test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000002")
-
- def _extra_setup(mp, omni_module):
- mp.setattr(uuid, "uuid4", lambda: test_uuid)
- mp.setattr(omni_module, "uuid", uuid)
-
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(stage_id=0),
- _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
- ],
- extra_setup=_extra_setup,
- )
-
- expected_rid = f"0_{test_uuid}"
-
- # Put stage outputs in both queues
- omni.stage_list[0]._out_q.put_nowait(
- {
- "request_id": expected_rid,
- "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1])])],
- "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
- }
- )
- omni.stage_list[1]._out_q.put_nowait(
- {
- "request_id": expected_rid,
- "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1, 2, 3])])],
- "metrics": {"num_tokens_out": 3, "stage_gen_time_ms": 50.0},
- }
- )
-
- sp_list = [SamplingParams(max_tokens=2048), SamplingParams(max_tokens=2048)]
- omni.generate(prompts=["hello"], sampling_params_list=sp_list)
-
- # Check what was submitted to stage 0's input queue
- # (skip the stage_ready message first)
- task = omni.stage_list[0]._in_q.get_nowait()
- assert task["sampling_params"].max_tokens == 1
- kv_params = task["sampling_params"].extra_args["kv_transfer_params"]
- assert kv_params["do_remote_decode"] is True
- assert kv_params["do_remote_prefill"] is False
- assert kv_params["transfer_id"] == f"xfer-{expected_rid}"
-
- def test_decode_stage_receives_original_prompt(self, monkeypatch):
- """Decode stage should get the original prompt (not processed outputs)."""
- test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000003")
-
- def _extra_setup(mp, omni_module):
- mp.setattr(uuid, "uuid4", lambda: test_uuid)
- mp.setattr(omni_module, "uuid", uuid)
-
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(stage_id=0),
- _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
- ],
- extra_setup=_extra_setup,
- )
-
- expected_rid = f"0_{test_uuid}"
- original_prompt = "test prompt for PD"
-
- omni.stage_list[0]._out_q.put_nowait(
- {
- "request_id": expected_rid,
- "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1])])],
- "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
- }
- )
- omni.stage_list[1]._out_q.put_nowait(
- {
- "request_id": expected_rid,
- "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1, 2, 3])])],
- "metrics": {"num_tokens_out": 3, "stage_gen_time_ms": 50.0},
- }
- )
-
- sp_list = [SamplingParams(max_tokens=2048), SamplingParams(max_tokens=2048)]
- omni.generate(prompts=[original_prompt], sampling_params_list=sp_list)
-
- # Check what was forwarded to stage 1 (decode)
- # The connector sends tasks to stage 1's input queue
- task = omni.stage_list[1]._in_q.get_nowait()
- # The engine_inputs should contain the original prompt
- engine_inputs = task.get("engine_inputs")
- # For PD routing, the original prompt is wrapped in a list
- if isinstance(engine_inputs, list):
- assert original_prompt in engine_inputs
- else:
- assert engine_inputs == original_prompt
-
- def test_decode_kv_params_have_correct_flags(self, monkeypatch):
- """Decode stage kv_transfer_params should have correct role flags."""
- test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000004")
-
- def _extra_setup(mp, omni_module):
- mp.setattr(uuid, "uuid4", lambda: test_uuid)
- mp.setattr(omni_module, "uuid", uuid)
-
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(stage_id=0),
- _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
- ],
- extra_setup=_extra_setup,
- )
-
- expected_rid = f"0_{test_uuid}"
-
- omni.stage_list[0]._out_q.put_nowait(
- {
- "request_id": expected_rid,
- "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1])])],
- "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
- }
- )
- omni.stage_list[1]._out_q.put_nowait(
- {
- "request_id": expected_rid,
- "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1, 2, 3])])],
- "metrics": {"num_tokens_out": 3, "stage_gen_time_ms": 50.0},
- }
- )
-
- sp_list = [SamplingParams(max_tokens=2048), SamplingParams(max_tokens=2048)]
- omni.generate(prompts=["hello"], sampling_params_list=sp_list)
-
- # Check decode task's kv_transfer_params
- task = omni.stage_list[1]._in_q.get_nowait()
- kv_params = task["sampling_params"].extra_args["kv_transfer_params"]
- assert kv_params["do_remote_prefill"] is True
- assert kv_params["do_remote_decode"] is False
- assert kv_params["transfer_id"] == f"xfer-{expected_rid}"
- assert kv_params["remote_bootstrap_addr"] == "127.0.0.1:25201"
-
-
-# ===================================================================
-# Tests: KV params cleanup
-# ===================================================================
-
-
-class TestKVParamsCleanup:
- def test_drop_cleans_up(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- omni._pd_kv_params_by_req["req-1"] = {"transfer_id": "xfer-1"}
- omni._drop_pd_kv_params("req-1")
- assert "req-1" not in omni._pd_kv_params_by_req
-
- def test_drop_nonexistent_is_noop(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- omni._drop_pd_kv_params("nonexistent") # should not raise
-
- def test_pop_returns_stored_params(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- stored = {"transfer_id": "xfer-1", "extra_field": "value"}
- omni._pd_kv_params_by_req["req-1"] = stored
-
- result = omni._pop_pd_kv_params("req-1")
- assert result == stored
- assert "req-1" not in omni._pd_kv_params_by_req
-
- def test_pop_uses_fallback_when_no_stored(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- fallback = {"transfer_id": "xfer-fallback"}
- result = omni._pop_pd_kv_params("req-1", fallback=fallback)
- assert result == fallback
-
-
-# ===================================================================
-# Tests: Config YAML loads without error
-# ===================================================================
-
-
-class TestPDYAMLConfig:
- def test_pd_yaml_loads(self):
- """The PD separation YAML config should load without errors."""
- import os
-
- yaml_path = os.path.join(
- os.path.dirname(__file__),
- "../../vllm_omni/model_executor/stage_configs/qwen3_omni_moe_pd_separation.yaml",
- )
- yaml_path = os.path.abspath(yaml_path)
- if not os.path.exists(yaml_path):
- pytest.skip("PD separation YAML not found")
-
- from omegaconf import OmegaConf
-
- cfg = OmegaConf.load(yaml_path)
- stages = cfg.stage_args
- assert len(stages) == 4
-
- # Prefill stage
- assert stages[0].is_prefill_only is True
- assert stages[0].final_output is False
- assert stages[0].is_comprehension is True
-
- # Decode stage
- assert stages[1].is_decode_only is True
- assert stages[1].final_output is True
- assert stages[1].final_output_type == "text"
- assert stages[1].is_comprehension is True
- assert 0 in stages[1].engine_input_source
-
- # KV transfer configs
- assert stages[0].engine_args.kv_transfer_config.kv_role == "kv_producer"
- assert stages[1].engine_args.kv_transfer_config.kv_role == "kv_consumer"
- assert stages[0].engine_args.kv_transfer_config.kv_connector == "MooncakeConnector"
- assert stages[1].engine_args.kv_transfer_config.kv_connector == "MooncakeConnector"
-
-
-class TestPrefillStopNeutralization:
- """Tests that _prepare_prefill_sampling_params neutralizes stop
- conditions to ensure finish_reason='length'.
- """
-
- def test_clears_stop_strings(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- sp = SamplingParams(max_tokens=2048, stop=[" ", "STOP"])
- result = omni._prepare_prefill_sampling_params("req-1", sp)
- assert result.stop == []
-
- def test_clears_stop_token_ids(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- sp = SamplingParams(max_tokens=2048, stop_token_ids=[151643, 151644])
- result = omni._prepare_prefill_sampling_params("req-1", sp)
- assert result.stop_token_ids == []
-
- def test_clears_include_stop_str_in_output(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- sp = SamplingParams(max_tokens=2048, include_stop_str_in_output=True)
- result = omni._prepare_prefill_sampling_params("req-1", sp)
- assert result.include_stop_str_in_output is False
-
- def test_original_sp_unchanged(self, monkeypatch):
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- sp = SamplingParams(max_tokens=2048, stop=[" "], stop_token_ids=[151643])
- _ = omni._prepare_prefill_sampling_params("req-1", sp)
- assert sp.stop == [""]
- assert sp.stop_token_ids == [151643]
-
-
-# ===================================================================
-# Tests: Failure mode & memory leak prevention
-# ===================================================================
-# NOTE: Full generate()-level failure mode tests are removed for now.
-# The _run_generation error handler (line 1344-1350 in omni.py) calls
-# _drop_pd_kv_params but does not increment completed_requests, causing
-# the while-loop to hang. These tests need to be revisited once the
-# production error-handling path is fixed to properly terminate on
-# stage errors.
-
-
-# ===================================================================
-# Tests: TP size validation
-# ===================================================================
-
-
-class TestTPSizeValidation:
- """Tests that _validate_pd_separation_config checks tensor_parallel_size."""
-
- def test_matching_tp_passes(self, monkeypatch):
- """Same TP size should not raise."""
- prefill_cfg = _prefill_stage_cfg()
- prefill_cfg["engine_args"]["tensor_parallel_size"] = 2
- decode_cfg = _decode_stage_cfg(engine_input_source=[0])
- decode_cfg["engine_args"]["tensor_parallel_size"] = 2
- omni = _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg])
- assert omni._pd_separation_pair == (0, 1)
-
- def test_mismatched_tp_raises(self, monkeypatch):
- """Different TP sizes should raise ValueError."""
- prefill_cfg = _prefill_stage_cfg()
- prefill_cfg["engine_args"]["tensor_parallel_size"] = 2
- decode_cfg = _decode_stage_cfg(engine_input_source=[0])
- decode_cfg["engine_args"]["tensor_parallel_size"] = 4
- with pytest.raises(ValueError, match="tensor_parallel_size"):
- _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg])
-
- def test_default_tp_no_error(self, monkeypatch):
- """Stages without explicit TP (defaults to 1) should pass."""
- omni = _make_pd_omni(
- monkeypatch,
- [
- _prefill_stage_cfg(),
- _decode_stage_cfg(engine_input_source=[0]),
- ],
- )
- assert omni._pd_separation_pair == (0, 1)
diff --git a/tests/entrypoints/test_realtime_connection_helpers.py b/tests/entrypoints/test_realtime_connection_helpers.py
deleted file mode 100644
index e795aa92d0f..00000000000
--- a/tests/entrypoints/test_realtime_connection_helpers.py
+++ /dev/null
@@ -1,86 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for realtime streaming helpers (PR #2581 /v1/realtime path)."""
-
-from __future__ import annotations
-
-import base64
-
-import numpy as np
-import pytest
-import torch
-from vllm.sampling_params import RequestOutputKind, SamplingParams
-
-from vllm_omni.entrypoints.async_omni import AsyncOmni
-from vllm_omni.entrypoints.openai.realtime_connection import RealtimeConnection
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-@pytest.fixture
-def realtime_conn() -> RealtimeConnection:
- return RealtimeConnection.__new__(RealtimeConnection)
-
-
-class TestRealtimeConnectionTensorAndPcm:
- def test_tensor_to_numpy_none(self) -> None:
- assert RealtimeConnection._tensor_to_numpy(None) is None
-
- def test_tensor_to_numpy_1d_numpy(self) -> None:
- arr = np.array([1.0, 2.0], dtype=np.float64)
- out = RealtimeConnection._tensor_to_numpy(arr)
- assert out is not None
- assert out.dtype == np.float32
- assert out.shape == (2,)
-
- def test_tensor_to_numpy_2d_numpy_flattened(self) -> None:
- arr = np.array([[0.5], [-0.5]], dtype=np.float32)
- out = RealtimeConnection._tensor_to_numpy(arr)
- assert out is not None
- assert out.shape == (2,)
-
- def test_tensor_to_numpy_torch(self) -> None:
- t = torch.tensor([[0.25, -0.25]], dtype=torch.float32)
- out = RealtimeConnection._tensor_to_numpy(t)
- assert out is not None
- assert out.shape == (2,)
- np.testing.assert_allclose(out, [0.25, -0.25], rtol=1e-5)
-
- def test_pcm16_b64_roundtrip(self) -> None:
- audio = np.array([0.0, 1.0, -1.0], dtype=np.float32)
- b64 = RealtimeConnection._pcm16_b64(audio)
- raw = base64.b64decode(b64)
- assert len(raw) == 6
- pcm = np.frombuffer(raw, dtype=np.int16)
- assert pcm[0] == 0
- assert pcm[1] == 32767
- assert pcm[2] == -32767
-
-
-class TestAsyncOmniStreamingParamsValidation:
- def test_accepts_streaming_friendly_params(self) -> None:
- p = SamplingParams(
- n=1,
- stop=[],
- output_kind=RequestOutputKind.DELTA,
- )
- AsyncOmni._validate_streaming_input_sampling_params(p)
-
- def test_rejects_non_sampling_params(self) -> None:
- with pytest.raises(ValueError, match="Input streaming"):
- AsyncOmni._validate_streaming_input_sampling_params(object()) # type: ignore[arg-type]
-
- def test_rejects_n_greater_than_one(self) -> None:
- p = SamplingParams(n=2, stop=[], output_kind=RequestOutputKind.DELTA)
- with pytest.raises(ValueError, match="Input streaming"):
- AsyncOmni._validate_streaming_input_sampling_params(p)
-
- def test_rejects_final_only(self) -> None:
- p = SamplingParams(n=1, stop=[], output_kind=RequestOutputKind.FINAL_ONLY)
- with pytest.raises(ValueError, match="Input streaming"):
- AsyncOmni._validate_streaming_input_sampling_params(p)
-
- def test_rejects_stop_strings(self) -> None:
- p = SamplingParams(n=1, stop=["\n"], output_kind=RequestOutputKind.DELTA)
- with pytest.raises(ValueError, match="Input streaming"):
- AsyncOmni._validate_streaming_input_sampling_params(p)
diff --git a/tests/entrypoints/test_serve.py b/tests/entrypoints/test_serve.py
deleted file mode 100644
index e60afc9cd7b..00000000000
--- a/tests/entrypoints/test_serve.py
+++ /dev/null
@@ -1,211 +0,0 @@
-"""Unit tests for the Omni serve CLI helpers."""
-
-from __future__ import annotations
-
-import argparse
-
-import pytest
-from pytest_mock import MockerFixture
-
-from vllm_omni.entrypoints.cli.serve import OmniServeCommand, run_headless
-from vllm_omni.entrypoints.utils import detect_explicit_cli_keys
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def test_serve_parser_accepts_no_async_chunk_and_marks_it_explicit() -> None:
- """``--no-async-chunk`` should parse to ``async_chunk=False`` and mark the
- shared deploy-level dest as explicitly provided by the user."""
- try:
- from vllm.utils.argparse_utils import FlexibleArgumentParser
- except Exception as exc:
- pytest.skip(f"Cannot build parser in this environment: {exc}")
-
- root = FlexibleArgumentParser()
- subparsers = root.add_subparsers(dest="subcommand")
- cmd = OmniServeCommand()
- serve_parser = cmd.subparser_init(subparsers)
-
- argv = ["serve", "fake-model", "--omni", "--no-async-chunk"]
- args = root.parse_args(argv)
-
- assert args.async_chunk is False
- explicit = detect_explicit_cli_keys(argv, serve_parser)
- assert "async_chunk" in explicit
-
-
-def _make_headless_args() -> argparse.Namespace:
- return argparse.Namespace(
- model="fake-model",
- stage_id=3,
- omni_master_address="127.0.0.1",
- omni_master_port=26000,
- api_server_count=0,
- worker_backend="multi_process",
- stage_configs_path=None,
- log_stats=False,
- disable_log_stats=False,
- )
-
-
-def test_run_headless_registers_stage_once_and_launches_all_local_engines(mocker: MockerFixture) -> None:
- args = _make_headless_args()
- stage_cfg = mocker.Mock(stage_id=3)
- stage_cfgs = [stage_cfg]
- parallel_config = mocker.Mock(
- data_parallel_size_local=2,
- data_parallel_rank=4,
- data_parallel_rank_local=1,
- node_rank_within_dp=0,
- )
- vllm_config = mocker.Mock(parallel_config=parallel_config)
- executor_class = mocker.Mock()
- engine_manager = mocker.Mock()
-
- mocker.patch(
- "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs",
- return_value=("/fake/stages.yaml", stage_cfgs),
- )
- mocker.patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment")
- mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=mocker.Mock())
- mocker.patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={})
- mocker.patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={})
- mocker.patch(
- "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mock_build_vllm_config = mocker.patch(
- "vllm_omni.engine.stage_init_utils.build_vllm_config",
- return_value=(vllm_config, executor_class),
- )
- mock_register = mocker.patch(
- "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
- return_value="tcp://127.0.0.1:26001",
- )
- mock_manager_cls = mocker.patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager)
- mocker.patch("signal.signal")
- run_headless(args)
-
- mock_build_vllm_config.assert_called_once_with(
- stage_cfg,
- "fake-model",
- stage_connector_spec={},
- engine_args_dict={},
- headless=True,
- )
- mock_register.assert_called_once_with(
- omni_master_address="127.0.0.1",
- omni_master_port=26000,
- omni_stage_id=3,
- omni_stage_config=stage_cfg,
- coordinator=None,
- )
- mock_manager_cls.assert_called_once()
- manager_kwargs = mock_manager_cls.call_args.kwargs
- assert manager_kwargs["local_engine_count"] == 2
- assert manager_kwargs["start_index"] == 4
- assert manager_kwargs["local_start_index"] == 0
- assert manager_kwargs["local_client"] is False
- assert manager_kwargs["handshake_address"] == "tcp://127.0.0.1:26001"
- assert manager_kwargs["log_stats"] is False
- engine_manager.join_first.assert_called_once_with()
- engine_manager.shutdown.assert_called_once_with()
-
-
-def test_run_headless_honors_explicit_log_stats_flag(mocker: MockerFixture) -> None:
- args = _make_headless_args()
- args.log_stats = True
- stage_cfg = mocker.Mock(stage_id=3)
- stage_cfgs = [stage_cfg]
- parallel_config = mocker.Mock(
- data_parallel_size_local=2,
- data_parallel_rank=4,
- data_parallel_rank_local=1,
- node_rank_within_dp=0,
- )
- vllm_config = mocker.Mock(parallel_config=parallel_config)
- executor_class = mocker.Mock()
- engine_manager = mocker.Mock()
-
- mocker.patch(
- "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs",
- return_value=("/fake/stages.yaml", stage_cfgs),
- )
- mocker.patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment")
- mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=mocker.Mock())
- mocker.patch("vllm_omni.engine.stage_init_utils.get_stage_connector_spec", return_value={})
- mocker.patch("vllm_omni.engine.stage_init_utils.build_engine_args_dict", return_value={})
- mocker.patch(
- "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.stage_init_utils.build_vllm_config",
- return_value=(vllm_config, executor_class),
- )
- mocker.patch(
- "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
- return_value="tcp://127.0.0.1:26001",
- )
- mock_manager_cls = mocker.patch("vllm.v1.engine.utils.CoreEngineProcManager", return_value=engine_manager)
- mocker.patch("signal.signal")
- run_headless(args)
-
- manager_kwargs = mock_manager_cls.call_args.kwargs
- assert manager_kwargs["log_stats"] is True
-
-
-def test_run_headless_launches_diffusion_stage_via_omni_master(mocker: MockerFixture) -> None:
- args = _make_headless_args()
- stage_cfg = mocker.Mock(stage_id=3, stage_type="diffusion")
- stage_cfg.engine_args = mocker.Mock()
- stage_cfg.engine_input_source = []
- stage_cfgs = [stage_cfg]
- metadata = mocker.Mock(stage_id=3)
- od_config = mocker.Mock()
- proc = mocker.Mock()
- proc.exitcode = 0
- proc.is_alive.return_value = False
-
- mocker.patch(
- "vllm_omni.entrypoints.utils.load_and_resolve_stage_configs",
- return_value=("/fake/stages.yaml", stage_cfgs),
- )
- mocker.patch("vllm_omni.engine.stage_init_utils.prepare_engine_environment")
- mocker.patch("vllm_omni.engine.stage_init_utils.load_omni_transfer_config_for_model", return_value=mocker.Mock())
- mocker.patch(
- "vllm_omni.distributed.omni_connectors.utils.initialization.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch("vllm_omni.engine.stage_init_utils.extract_stage_metadata", return_value=metadata)
- mock_inject_stage_info = mocker.patch("vllm_omni.engine.stage_init_utils.inject_kv_stage_info")
- mocker.patch("vllm_omni.engine.stage_init_utils.build_diffusion_config", return_value=od_config)
- mock_register = mocker.patch(
- "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
- return_value=("tcp://127.0.0.1:26001", "tcp://127.0.0.1:26002", "tcp://127.0.0.1:26003"),
- )
- mock_spawn = mocker.patch(
- "vllm_omni.diffusion.stage_diffusion_proc.spawn_diffusion_proc",
- return_value=(proc, "tcp://127.0.0.1:26001", "tcp://127.0.0.1:26002", "tcp://127.0.0.1:26003"),
- )
- mock_handshake = mocker.patch("vllm_omni.diffusion.stage_diffusion_proc.complete_diffusion_handshake")
- mocker.patch("signal.signal")
- run_headless(args)
-
- mock_inject_stage_info.assert_called_once_with(stage_cfg, 3)
- mock_register.assert_called_once_with(
- omni_master_address="127.0.0.1",
- omni_master_port=26000,
- omni_stage_id=3,
- omni_stage_config=stage_cfg,
- return_addresses=True,
- )
- mock_spawn.assert_called_once_with(
- "fake-model",
- od_config,
- handshake_address="tcp://127.0.0.1:26001",
- request_address="tcp://127.0.0.1:26002",
- response_address="tcp://127.0.0.1:26003",
- )
- mock_handshake.assert_called_once_with(proc, "tcp://127.0.0.1:26001")
- proc.join.assert_called_once_with()
diff --git a/tests/entrypoints/test_stage_utils.py b/tests/entrypoints/test_stage_utils.py
index 15ee9c32a4e..2bb2231ccb8 100644
--- a/tests/entrypoints/test_stage_utils.py
+++ b/tests/entrypoints/test_stage_utils.py
@@ -6,6 +6,8 @@
from vllm_omni.entrypoints.stage_utils import set_stage_devices
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
def _make_dummy_torch(call_log):
class _Props:
@@ -53,8 +55,6 @@ def _make_mock_platform(mocker, device_type: str = "cuda", env_var: str = "CUDA_
return mock_platform
-@pytest.mark.core_model
-@pytest.mark.cpu
@pytest.mark.usefixtures("clean_gpu_memory_between_tests")
def test_set_stage_devices_respects_logical_ids(mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch):
# Preserve an existing logical mapping and ensure devices "0,1" map through it.
@@ -75,8 +75,6 @@ def test_set_stage_devices_respects_logical_ids(mocker: MockerFixture, monkeypat
assert os.environ["CUDA_VISIBLE_DEVICES"] == "6,7"
-@pytest.mark.core_model
-@pytest.mark.cpu
@pytest.mark.usefixtures("clean_gpu_memory_between_tests")
def test_set_stage_devices_handles_not_enough_devices(mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch):
# Preserve an existing logical mapping and ensure devices "0,1" map through it.
@@ -92,10 +90,9 @@ def test_set_stage_devices_handles_not_enough_devices(mocker: MockerFixture, mon
mock_platform,
)
- # Keep the logical mapping and resolve to the visible subset.
- set_stage_devices(stage_id=0, devices="0,1,2,3")
-
- assert os.environ["CUDA_VISIBLE_DEVICES"] == "6,7"
+ # Raise since we need 4 GPUs, but we only have 2 visible
+ with pytest.raises(ValueError):
+ set_stage_devices(stage_id=0, devices="0,1,2,3")
@pytest.mark.usefixtures("clean_gpu_memory_between_tests")
diff --git a/tests/entrypoints/test_utils.py b/tests/entrypoints/test_utils.py
index 6e52e4c6c0c..352ed2aad9b 100644
--- a/tests/entrypoints/test_utils.py
+++ b/tests/entrypoints/test_utils.py
@@ -5,21 +5,14 @@
from dataclasses import dataclass
import pytest
-import torch
from pytest_mock import MockerFixture
-from vllm.sampling_params import RequestOutputKind, SamplingParams
-from vllm_omni.config.yaml_util import create_config
from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.engine.arg_utils import OmniEngineArgs
-from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
from vllm_omni.entrypoints.utils import (
_convert_dataclasses_to_dict,
_filter_dict_like_object,
- coerce_param_message_types,
filter_dataclass_kwargs,
- load_and_resolve_stage_configs,
- load_stage_configs_from_yaml,
resolve_model_config_path,
)
@@ -311,159 +304,3 @@ def mock_exists(path):
assert result is not None
assert "glm_image.yaml" in result
-
- def test_voxcpm_transformers_format_resolution(self, mocker: MockerFixture):
- """Test VoxCPM transformers config resolves to the voxcpm stage config."""
- mocker.patch(
- "vllm_omni.entrypoints.utils.get_config",
- side_effect=ValueError("missing transformers config"),
- )
- mocker.patch(
- "vllm_omni.entrypoints.utils.file_or_path_exists",
- side_effect=lambda _model, filename, revision=None: filename == "config.json",
- )
- mocker.patch(
- "vllm_omni.entrypoints.utils.get_hf_file_to_dict",
- return_value={"model_type": "voxcpm"},
- )
- mocker.patch(
- "vllm_omni.entrypoints.utils.current_omni_platform.get_default_stage_config_path",
- return_value="vllm_omni/model_executor/stage_configs",
- )
-
- original_exists = os.path.exists
-
- def mock_exists(path):
- if "voxcpm.yaml" in str(path):
- return True
- return original_exists(path)
-
- mocker.patch("os.path.exists", side_effect=mock_exists)
-
- result = resolve_model_config_path("OpenBMB/VoxCPM1.5")
-
- assert result is not None
- assert "voxcpm.yaml" in result
-
-
-class TestLoadAndResolveStageConfigs:
- def test_load_and_resolve_with_kwargs(self):
- """Ensure that dtype survives default stage creation."""
- kwargs = {"dtype": torch.float32}
- config_path, stage_configs = load_and_resolve_stage_configs(
- model="black-forest-labs/FLUX.2-klein-4B",
- stage_configs_path=None,
- kwargs=kwargs,
- default_stage_cfg_factory=lambda: AsyncOmniEngine._create_default_diffusion_stage_cfg(kwargs),
- )
- assert config_path is None
- assert len(stage_configs) == 1
- assert "dtype" in stage_configs[0]["engine_args"]
-
-
-class TestLoadStageConfigsFromYaml:
- """Regression tests for stage-config loading and merging."""
-
- def test_deep_merges_stage_engine_args(self, mocker: MockerFixture):
- yaml_config = create_config(
- {
- "async_chunk": True,
- "stage_args": [
- {
- "stage_id": 0,
- "runtime": {"device": 0},
- "engine_args": {
- "parallel_config": {"tensor_parallel_size": 4},
- },
- }
- ],
- }
- )
- mocker.patch(
- "vllm_omni.entrypoints.utils.load_yaml_config",
- return_value=yaml_config,
- )
-
- stages = load_stage_configs_from_yaml(
- "fake.yaml",
- base_engine_args={
- "parallel_config": {
- "tensor_parallel_size": 1,
- "pipeline_parallel_size": 2,
- },
- "model": "base-model",
- },
- )
-
- merged_engine_args = stages[0]["engine_args"]
- assert merged_engine_args["parallel_config"]["tensor_parallel_size"] == 4
- assert merged_engine_args["parallel_config"]["pipeline_parallel_size"] == 2
- assert merged_engine_args["model"] == "base-model"
- assert merged_engine_args["async_chunk"] is True
-
- def test_merges_nested_stage_engine_args(self, mocker: MockerFixture):
- yaml_config = create_config(
- {
- "stage_args": [
- {
- "stage_id": 0,
- "engine_args": {
- "nested": {"override": 2},
- },
- }
- ],
- }
- )
- mocker.patch(
- "vllm_omni.entrypoints.utils.load_yaml_config",
- return_value=yaml_config,
- )
-
- stages = load_stage_configs_from_yaml(
- "fake.yaml",
- base_engine_args={"nested": {"base": 1}},
- )
-
- assert stages[0]["engine_args"]["nested"]["base"] == 1
- assert stages[0]["engine_args"]["nested"]["override"] == 2
-
-
-class TestCumulativeStreamingCoercion:
- @pytest.mark.parametrize("skip_clone", [True, False])
- def test_cumulative_default_becomes_delta_if_stream(self, skip_clone):
- """Ensure cumulative messages are coercible to delta if streaming."""
- sp = SamplingParams(output_kind=RequestOutputKind.CUMULATIVE)
- sp.skip_clone = skip_clone
- result = coerce_param_message_types([sp], is_streaming=True)[0]
- assert isinstance(result, SamplingParams)
- assert result.output_kind == RequestOutputKind.DELTA
- assert (skip_clone and sp is result) or (not skip_clone and sp is not result)
-
- @pytest.mark.parametrize("skip_clone", [True, False])
- def test_cumulative_default_becomes_final_only_if_not_stream(self, skip_clone):
- """Ensure cumulative messages are coercible to final only if not streaming."""
- sp = SamplingParams(output_kind=RequestOutputKind.CUMULATIVE)
- sp.skip_clone = skip_clone
- result = coerce_param_message_types([sp], is_streaming=False)[0]
- assert isinstance(result, SamplingParams)
- assert result.output_kind == RequestOutputKind.FINAL_ONLY
- assert (skip_clone and sp is result) or (not skip_clone and sp is not result)
-
- @pytest.mark.parametrize("is_streaming", [True, False])
- @pytest.mark.parametrize("output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
- def test_non_cumulative_are_coerced(self, output_kind, is_streaming):
- """Ensure non-cumulative params are coerced to the target type."""
- sp = SamplingParams(output_kind=output_kind)
- expected = RequestOutputKind.DELTA if is_streaming else RequestOutputKind.FINAL_ONLY
- result = coerce_param_message_types([sp], is_streaming=is_streaming)[0]
- assert isinstance(result, SamplingParams)
- assert result.output_kind == expected
-
- def test_coercion_applies_to_all_stages(self):
- """Ensure all stages are coerced to DELTA for streaming."""
- sp0 = SamplingParams(output_kind=RequestOutputKind.CUMULATIVE)
- sp1 = SamplingParams(output_kind=RequestOutputKind.CUMULATIVE)
- result = coerce_param_message_types([sp0, sp1], is_streaming=True)
- assert all([isinstance(r, SamplingParams) for r in result])
- assert result[0].output_kind == RequestOutputKind.DELTA
- assert result[1].output_kind == RequestOutputKind.DELTA
diff --git a/tests/examples/conftest.py b/tests/examples/conftest.py
index 867731b21f9..137d15f163f 100644
--- a/tests/examples/conftest.py
+++ b/tests/examples/conftest.py
@@ -1,3 +1,353 @@
-"""Pytest fixtures for tests/examples."""
+"""
+Shared fixtures, helpers, and path constants for tests/examples/.
+"""
-from tests.examples.helpers import example_runner # noqa: F401
+import json
+import os
+import re
+import shlex
+import subprocess
+import sys
+import tempfile
+from collections import defaultdict
+from collections.abc import Callable
+from pathlib import Path
+from typing import Any, NamedTuple, cast
+
+import pytest
+import torch
+from safetensors.torch import save_file
+
+# ---------------------------------------------------------------------------
+# Path constants and fixtures
+# ---------------------------------------------------------------------------
+
+REPO_ROOT = Path(__file__).resolve().parents[2]
+EXAMPLES = REPO_ROOT / "examples"
+
+# Use Python tempfile instead of pytest's tmp_path_factory because
+# OUTPUT_DIR is needed in test collection time, but tmp_path_factory is only available in test running time.
+# It is needed during test collection because extract_readme_snippets replaces LoRA path with a generated one under OUTPUT_DIR,
+# and extract_readme_snippets is called at collection time to generate separate test cases for each README code block.
+OUTPUT_DIR = (
+ REPO_ROOT / prefix
+ if (prefix := os.environ.get("OUTPUT_DIR"))
+ else Path(tempfile.mkdtemp(prefix="vllm_omni_test_examples_"))
+)
+
+
+# ---------------------------------------------------------------------------
+# Code snippet extraction and asset file helpers
+# ---------------------------------------------------------------------------
+
+# parameters: language, code, h2_title
+ReadmeSnippetExtractionSkipPredicate = Callable[[str, str, str], tuple[bool, str]]
+
+
+class ReadmeSnippet(NamedTuple):
+ language: str
+ code: str
+ h2_title: str
+ index_in_section: int
+ output_file_path: Path | None = None
+ skip: tuple[bool, str] = (False, "")
+
+ @property
+ def test_id(self) -> str:
+ return f"{ReadmeSnippet._slug(self.h2_title)}_{self.index_in_section:03d}"
+
+ @staticmethod
+ def extract_readme_snippets(
+ readme_path: Path,
+ skipif: ReadmeSnippetExtractionSkipPredicate | None = None,
+ ) -> list["ReadmeSnippet"]:
+ import mistune
+
+ markdown = mistune.create_markdown(renderer="ast")
+ tokens = markdown(readme_path.read_text(encoding="utf-8"))
+ tokens = cast(list[dict[str, Any]], tokens) # mistune's AST renderer always produces a list, not a str
+
+ h2_title = ""
+ section_counts: defaultdict[str, int] = defaultdict(int)
+ snippets: list[ReadmeSnippet] = []
+
+ for token in tokens:
+ token_type = token.get("type")
+
+ if token_type == "heading":
+ level = (token.get("attrs") or {}).get("level")
+ title = ReadmeSnippet._heading_text(token)
+ if level == 2:
+ h2_title = title
+ continue
+
+ if token_type != "block_code":
+ continue
+
+ try:
+ info = token.get("attrs").get("info") # type: ignore[reportOptionalMemberAccess]
+ language = info.strip().split()[0].lower() # type: ignore[reportOptionalMemberAccess]
+
+ # Common shell aliases to "bash" in several markdown renderers.
+ if language in {"shell", "sh", "ksh", "zsh"}:
+ language = "bash"
+
+ if language not in {"bash", "python"}:
+ continue
+ except AttributeError:
+ # The fence is missing explicit language info; skip it.
+ continue
+
+ key = h2_title
+ section_counts[key] += 1
+ code = token.get("raw", "")
+ output_file_path = None
+ if language == "bash":
+ argv = ReadmeSnippet._normalize_bash_command(code, Path(readme_path.parent))
+ code = shlex.join(argv)
+ output_file_path = ReadmeSnippet._output_file_path_from_argv(argv)
+ if skipif is not None:
+ skip_config = skipif(language, code, h2_title)
+ else:
+ skip_config = (False, "")
+ snippet = ReadmeSnippet(
+ language=language,
+ code=code,
+ h2_title=h2_title,
+ index_in_section=section_counts[key],
+ output_file_path=output_file_path,
+ skip=skip_config,
+ )
+ snippets.append(snippet)
+
+ return snippets
+
+ @staticmethod
+ def _normalize_bash_command(command: str, readme_dir: Path) -> list[str]:
+ line_joined_command = re.sub(r"\\\s*\n", " ", command).strip()
+ argv = shlex.split(line_joined_command, comments=True)
+ assert argv, "README bash fence produced an empty command"
+
+ # Normalize python directory and example script location
+ if argv[0] in {"python", "python3"}:
+ argv[0] = sys.executable
+ if len(argv) > 1 and argv[1].endswith(".py"):
+ script_arg = argv[1]
+ script_path = Path(script_arg)
+ if script_path.is_absolute():
+ resolved_script = script_path
+ else:
+ # Take the file name only, and append script_dir to its front
+ resolved_script = readme_dir / script_path.name
+ assert resolved_script.exists(), (
+ f"README bash snippet references a script that does not exist: {script_arg} (resolved to {resolved_script})"
+ )
+ argv[1] = str(resolved_script)
+
+ # Normalize LoRA adapter path and ensure README LoRA assets exist.
+ try:
+ lora_arg_idx = argv.index("--lora-path") # Raise ValueError if not found
+ assert len(argv) > lora_arg_idx + 1, "README bash snippet uses --lora-path without a following value"
+
+ lora_dir = OUTPUT_DIR / "lora"
+ adapter_model = lora_dir / "adapter_model.safetensors"
+ adapter_config = lora_dir / "adapter_config.json"
+ if not adapter_model.exists() or not adapter_config.exists():
+ write_zimage_lora(lora_dir, v_scale=8.0)
+
+ argv[lora_arg_idx + 1] = str(lora_dir)
+ except ValueError:
+ pass
+
+ return argv
+
+ @staticmethod
+ def _output_file_path_from_argv(argv: list[str]) -> Path | None:
+ if "--output" not in argv:
+ return None
+ output_param_idx = argv.index("--output")
+ assert len(argv) > output_param_idx + 1, "README bash snippet uses --output without a following value"
+ output_arg = argv[output_param_idx + 1]
+ return Path(output_arg)
+
+ @staticmethod
+ def _slug(text: str) -> str:
+ return "".join(ch.lower() if ch.isalnum() else "_" for ch in text).strip("_")
+
+ @staticmethod
+ def _heading_text(token: dict) -> str:
+ return "".join(child.get("raw", "") for child in token.get("children", [])).strip()
+
+
+# [TODO] Duplicate `_write_zimage_lora` in tests/e2e/online_serving/test_images_generations_lora.py. Combine these helpers and tests/e2e/offline_inference/test_diffusion_lora.py to test/utils later
+def write_zimage_lora(adapter_dir: Path, *, q_scale: float = 0.0, k_scale: float = 0.0, v_scale: float = 0.0):
+ adapter_dir.mkdir(parents=True, exist_ok=True)
+
+ # Z-Image transformer uses dim=3840 by default.
+ dim = 3840
+ module_name = "transformer.layers.0.attention.to_qkv"
+ rank = 1
+
+ lora_a = torch.zeros((rank, dim), dtype=torch.float32)
+ lora_a[0, 0] = 1.0
+
+ # QKVParallelLinear packs (Q, K, V) => out dim is 3 * dim (tp=1).
+ lora_b = torch.zeros((3 * dim, rank), dtype=torch.float32)
+ if q_scale:
+ lora_b[:dim, 0] = q_scale
+ if k_scale:
+ lora_b[dim : 2 * dim, 0] = k_scale
+ if v_scale:
+ lora_b[2 * dim :, 0] = v_scale
+
+ save_file(
+ {
+ f"base_model.model.{module_name}.lora_A.weight": lora_a,
+ f"base_model.model.{module_name}.lora_B.weight": lora_b,
+ },
+ str(adapter_dir / "adapter_model.safetensors"),
+ )
+ (adapter_dir / "adapter_config.json").write_text(
+ json.dumps(
+ {
+ "r": rank,
+ "lora_alpha": rank,
+ "target_modules": [module_name],
+ }
+ ),
+ encoding="utf-8",
+ )
+
+
+# ---------------------------------------------------------------------------
+# Code runner and subprocess helpers
+# ---------------------------------------------------------------------------
+
+
+class ExampleRunResult(NamedTuple):
+ run_dir: Path
+ assets: list[Path]
+
+
+class ExampleRunner:
+ """Run extracted README snippets and return generated assets.
+
+ The output materials are organized in a three-level directory structure:
+ - Set at init: `self.output_root` for all tests (from env OUTPUT_DIR)
+ - Set at `self.run(...)`: `output_subfolder` for a specific example page (e.g., `example_offline_t2i`)
+ - Generated by `extract_readme_snippets`: `snippet.test_id` for a specific code block (matching H2 titles, e.g., `basic_usage_001`)
+ """
+
+ IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp"}
+
+ def __init__(self, output_root: Path) -> None:
+ self.output_root = output_root
+
+ def run(
+ self, snippet: ReadmeSnippet, *, output_subfolder: Path = Path("."), env: dict[str, str] | None = None
+ ) -> ExampleRunResult:
+ run_dir = self.output_root / output_subfolder / snippet.test_id
+ run_dir.mkdir(parents=True, exist_ok=True)
+
+ if snippet.language == "python":
+ assets = self._run_python_snippet(snippet, run_dir, env)
+ return ExampleRunResult(run_dir=run_dir, assets=assets)
+
+ if snippet.language == "bash":
+ asset = self._run_bash_snippet(snippet, run_dir, env)
+ return ExampleRunResult(run_dir=run_dir, assets=[asset])
+
+ raise AssertionError(f"Unsupported snippet language: {snippet.language}")
+
+ def _run_python_snippet(
+ self, snippet: ReadmeSnippet, run_dir: Path, env: dict[str, str] | None = None
+ ) -> list[Path]:
+ # Saving the script to a temporary file and `run_cmd` it.
+ # Not using `exec(snippet.code)` because the output is lost.
+ script_path = run_dir / "snippet.py"
+ script_path.write_text(snippet.code, encoding="utf-8")
+
+ before = self._collect_images(run_dir)
+ run_cmd([sys.executable, str(script_path)], cwd=run_dir, env=env)
+ after = self._collect_images(run_dir)
+
+ assets = sorted(after - before)
+ return assets
+
+ def _run_bash_snippet(self, snippet: ReadmeSnippet, run_dir: Path, env: dict[str, str] | None = None) -> Path:
+ run_cmd(snippet.code, shell=True, cwd=run_dir, env=env)
+
+ assert snippet.output_file_path is not None, (
+ f"README bash snippet is missing --output argument: {snippet.test_id}. "
+ "The test script cannot guess the output file path."
+ )
+
+ # If the code snippet declares a relative path for the output file, append this path to the parent output collection directory.
+ # If the code snippet declares an absolute path (not likely but just in case), the return value resolution removes `run_dir`, also correctly pointing to this file.
+ return run_dir / snippet.output_file_path
+
+ def _collect_images(self, root: Path) -> set[Path]:
+ return {path for path in root.rglob("*") if path.suffix.lower() in self.IMAGE_SUFFIXES}
+
+
+@pytest.fixture
+def example_runner() -> ExampleRunner:
+ return ExampleRunner(output_root=OUTPUT_DIR)
+
+
+def run_cmd(
+ command: list[str] | str,
+ *,
+ shell: bool = False,
+ env: dict[str, str] | None = None,
+ cwd: Path | str | None = None,
+) -> str:
+ """Run a command as a subprocess; assert zero exit code and return stdout.
+
+ Output is fully captured and returned as a string so callers can parse it
+ (e.g. with :func:`extract_content_after_keyword`).
+ Use this for scripts whose printed output is part of the test assertion.
+ """
+ if env is not None:
+ env = {**os.environ.copy(), **env}
+ result = subprocess.run(command, capture_output=True, text=True, shell=shell, env=env, cwd=cwd)
+
+ if result.returncode != 0:
+ print(f"STDERR: {result.stderr}")
+ raise subprocess.CalledProcessError(result.returncode, command)
+
+ all_output = result.stdout
+ print(f"All output:\n{all_output}")
+ return all_output
+
+
+# ---------------------------------------------------------------------------
+# Output validation helpers
+# ---------------------------------------------------------------------------
+
+
+def extract_content_after_keyword(keywords: str, text: str) -> str:
+ """Return the text that follows *keywords* in *text* (regex match).
+
+ Raises ``AssertionError`` if the keyword is not found, so test failures
+ produce a clear message pointing at the missing keyword.
+ """
+ matches = re.findall(rf"{keywords}\s*(.+)", text, re.DOTALL)
+
+ if not matches:
+ raise AssertionError(f"Keywords {keywords} not found in provided text output")
+ return matches[0]
+
+
+def strip_trailing_audio_saved_line(text: str) -> str:
+ """Drop trailing ``Audio saved to ...`` lines from captured client stdout.
+
+ ``openai_chat_completion_client_for_multimodal_generation.py`` may print
+ ``Chat completion output from text:`` for one choice and ``Audio saved to``
+ for another; :func:`extract_content_after_keyword` uses ``re.DOTALL`` and
+ would otherwise keep the audio progress line inside the *text* segment.
+ """
+ lines = text.splitlines()
+ while lines and lines[-1].strip().startswith("Audio saved to"):
+ lines.pop()
+ return "\n".join(lines).strip()
diff --git a/tests/examples/helpers.py b/tests/examples/helpers.py
deleted file mode 100644
index 137d15f163f..00000000000
--- a/tests/examples/helpers.py
+++ /dev/null
@@ -1,353 +0,0 @@
-"""
-Shared fixtures, helpers, and path constants for tests/examples/.
-"""
-
-import json
-import os
-import re
-import shlex
-import subprocess
-import sys
-import tempfile
-from collections import defaultdict
-from collections.abc import Callable
-from pathlib import Path
-from typing import Any, NamedTuple, cast
-
-import pytest
-import torch
-from safetensors.torch import save_file
-
-# ---------------------------------------------------------------------------
-# Path constants and fixtures
-# ---------------------------------------------------------------------------
-
-REPO_ROOT = Path(__file__).resolve().parents[2]
-EXAMPLES = REPO_ROOT / "examples"
-
-# Use Python tempfile instead of pytest's tmp_path_factory because
-# OUTPUT_DIR is needed in test collection time, but tmp_path_factory is only available in test running time.
-# It is needed during test collection because extract_readme_snippets replaces LoRA path with a generated one under OUTPUT_DIR,
-# and extract_readme_snippets is called at collection time to generate separate test cases for each README code block.
-OUTPUT_DIR = (
- REPO_ROOT / prefix
- if (prefix := os.environ.get("OUTPUT_DIR"))
- else Path(tempfile.mkdtemp(prefix="vllm_omni_test_examples_"))
-)
-
-
-# ---------------------------------------------------------------------------
-# Code snippet extraction and asset file helpers
-# ---------------------------------------------------------------------------
-
-# parameters: language, code, h2_title
-ReadmeSnippetExtractionSkipPredicate = Callable[[str, str, str], tuple[bool, str]]
-
-
-class ReadmeSnippet(NamedTuple):
- language: str
- code: str
- h2_title: str
- index_in_section: int
- output_file_path: Path | None = None
- skip: tuple[bool, str] = (False, "")
-
- @property
- def test_id(self) -> str:
- return f"{ReadmeSnippet._slug(self.h2_title)}_{self.index_in_section:03d}"
-
- @staticmethod
- def extract_readme_snippets(
- readme_path: Path,
- skipif: ReadmeSnippetExtractionSkipPredicate | None = None,
- ) -> list["ReadmeSnippet"]:
- import mistune
-
- markdown = mistune.create_markdown(renderer="ast")
- tokens = markdown(readme_path.read_text(encoding="utf-8"))
- tokens = cast(list[dict[str, Any]], tokens) # mistune's AST renderer always produces a list, not a str
-
- h2_title = ""
- section_counts: defaultdict[str, int] = defaultdict(int)
- snippets: list[ReadmeSnippet] = []
-
- for token in tokens:
- token_type = token.get("type")
-
- if token_type == "heading":
- level = (token.get("attrs") or {}).get("level")
- title = ReadmeSnippet._heading_text(token)
- if level == 2:
- h2_title = title
- continue
-
- if token_type != "block_code":
- continue
-
- try:
- info = token.get("attrs").get("info") # type: ignore[reportOptionalMemberAccess]
- language = info.strip().split()[0].lower() # type: ignore[reportOptionalMemberAccess]
-
- # Common shell aliases to "bash" in several markdown renderers.
- if language in {"shell", "sh", "ksh", "zsh"}:
- language = "bash"
-
- if language not in {"bash", "python"}:
- continue
- except AttributeError:
- # The fence is missing explicit language info; skip it.
- continue
-
- key = h2_title
- section_counts[key] += 1
- code = token.get("raw", "")
- output_file_path = None
- if language == "bash":
- argv = ReadmeSnippet._normalize_bash_command(code, Path(readme_path.parent))
- code = shlex.join(argv)
- output_file_path = ReadmeSnippet._output_file_path_from_argv(argv)
- if skipif is not None:
- skip_config = skipif(language, code, h2_title)
- else:
- skip_config = (False, "")
- snippet = ReadmeSnippet(
- language=language,
- code=code,
- h2_title=h2_title,
- index_in_section=section_counts[key],
- output_file_path=output_file_path,
- skip=skip_config,
- )
- snippets.append(snippet)
-
- return snippets
-
- @staticmethod
- def _normalize_bash_command(command: str, readme_dir: Path) -> list[str]:
- line_joined_command = re.sub(r"\\\s*\n", " ", command).strip()
- argv = shlex.split(line_joined_command, comments=True)
- assert argv, "README bash fence produced an empty command"
-
- # Normalize python directory and example script location
- if argv[0] in {"python", "python3"}:
- argv[0] = sys.executable
- if len(argv) > 1 and argv[1].endswith(".py"):
- script_arg = argv[1]
- script_path = Path(script_arg)
- if script_path.is_absolute():
- resolved_script = script_path
- else:
- # Take the file name only, and append script_dir to its front
- resolved_script = readme_dir / script_path.name
- assert resolved_script.exists(), (
- f"README bash snippet references a script that does not exist: {script_arg} (resolved to {resolved_script})"
- )
- argv[1] = str(resolved_script)
-
- # Normalize LoRA adapter path and ensure README LoRA assets exist.
- try:
- lora_arg_idx = argv.index("--lora-path") # Raise ValueError if not found
- assert len(argv) > lora_arg_idx + 1, "README bash snippet uses --lora-path without a following value"
-
- lora_dir = OUTPUT_DIR / "lora"
- adapter_model = lora_dir / "adapter_model.safetensors"
- adapter_config = lora_dir / "adapter_config.json"
- if not adapter_model.exists() or not adapter_config.exists():
- write_zimage_lora(lora_dir, v_scale=8.0)
-
- argv[lora_arg_idx + 1] = str(lora_dir)
- except ValueError:
- pass
-
- return argv
-
- @staticmethod
- def _output_file_path_from_argv(argv: list[str]) -> Path | None:
- if "--output" not in argv:
- return None
- output_param_idx = argv.index("--output")
- assert len(argv) > output_param_idx + 1, "README bash snippet uses --output without a following value"
- output_arg = argv[output_param_idx + 1]
- return Path(output_arg)
-
- @staticmethod
- def _slug(text: str) -> str:
- return "".join(ch.lower() if ch.isalnum() else "_" for ch in text).strip("_")
-
- @staticmethod
- def _heading_text(token: dict) -> str:
- return "".join(child.get("raw", "") for child in token.get("children", [])).strip()
-
-
-# [TODO] Duplicate `_write_zimage_lora` in tests/e2e/online_serving/test_images_generations_lora.py. Combine these helpers and tests/e2e/offline_inference/test_diffusion_lora.py to test/utils later
-def write_zimage_lora(adapter_dir: Path, *, q_scale: float = 0.0, k_scale: float = 0.0, v_scale: float = 0.0):
- adapter_dir.mkdir(parents=True, exist_ok=True)
-
- # Z-Image transformer uses dim=3840 by default.
- dim = 3840
- module_name = "transformer.layers.0.attention.to_qkv"
- rank = 1
-
- lora_a = torch.zeros((rank, dim), dtype=torch.float32)
- lora_a[0, 0] = 1.0
-
- # QKVParallelLinear packs (Q, K, V) => out dim is 3 * dim (tp=1).
- lora_b = torch.zeros((3 * dim, rank), dtype=torch.float32)
- if q_scale:
- lora_b[:dim, 0] = q_scale
- if k_scale:
- lora_b[dim : 2 * dim, 0] = k_scale
- if v_scale:
- lora_b[2 * dim :, 0] = v_scale
-
- save_file(
- {
- f"base_model.model.{module_name}.lora_A.weight": lora_a,
- f"base_model.model.{module_name}.lora_B.weight": lora_b,
- },
- str(adapter_dir / "adapter_model.safetensors"),
- )
- (adapter_dir / "adapter_config.json").write_text(
- json.dumps(
- {
- "r": rank,
- "lora_alpha": rank,
- "target_modules": [module_name],
- }
- ),
- encoding="utf-8",
- )
-
-
-# ---------------------------------------------------------------------------
-# Code runner and subprocess helpers
-# ---------------------------------------------------------------------------
-
-
-class ExampleRunResult(NamedTuple):
- run_dir: Path
- assets: list[Path]
-
-
-class ExampleRunner:
- """Run extracted README snippets and return generated assets.
-
- The output materials are organized in a three-level directory structure:
- - Set at init: `self.output_root` for all tests (from env OUTPUT_DIR)
- - Set at `self.run(...)`: `output_subfolder` for a specific example page (e.g., `example_offline_t2i`)
- - Generated by `extract_readme_snippets`: `snippet.test_id` for a specific code block (matching H2 titles, e.g., `basic_usage_001`)
- """
-
- IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp"}
-
- def __init__(self, output_root: Path) -> None:
- self.output_root = output_root
-
- def run(
- self, snippet: ReadmeSnippet, *, output_subfolder: Path = Path("."), env: dict[str, str] | None = None
- ) -> ExampleRunResult:
- run_dir = self.output_root / output_subfolder / snippet.test_id
- run_dir.mkdir(parents=True, exist_ok=True)
-
- if snippet.language == "python":
- assets = self._run_python_snippet(snippet, run_dir, env)
- return ExampleRunResult(run_dir=run_dir, assets=assets)
-
- if snippet.language == "bash":
- asset = self._run_bash_snippet(snippet, run_dir, env)
- return ExampleRunResult(run_dir=run_dir, assets=[asset])
-
- raise AssertionError(f"Unsupported snippet language: {snippet.language}")
-
- def _run_python_snippet(
- self, snippet: ReadmeSnippet, run_dir: Path, env: dict[str, str] | None = None
- ) -> list[Path]:
- # Saving the script to a temporary file and `run_cmd` it.
- # Not using `exec(snippet.code)` because the output is lost.
- script_path = run_dir / "snippet.py"
- script_path.write_text(snippet.code, encoding="utf-8")
-
- before = self._collect_images(run_dir)
- run_cmd([sys.executable, str(script_path)], cwd=run_dir, env=env)
- after = self._collect_images(run_dir)
-
- assets = sorted(after - before)
- return assets
-
- def _run_bash_snippet(self, snippet: ReadmeSnippet, run_dir: Path, env: dict[str, str] | None = None) -> Path:
- run_cmd(snippet.code, shell=True, cwd=run_dir, env=env)
-
- assert snippet.output_file_path is not None, (
- f"README bash snippet is missing --output argument: {snippet.test_id}. "
- "The test script cannot guess the output file path."
- )
-
- # If the code snippet declares a relative path for the output file, append this path to the parent output collection directory.
- # If the code snippet declares an absolute path (not likely but just in case), the return value resolution removes `run_dir`, also correctly pointing to this file.
- return run_dir / snippet.output_file_path
-
- def _collect_images(self, root: Path) -> set[Path]:
- return {path for path in root.rglob("*") if path.suffix.lower() in self.IMAGE_SUFFIXES}
-
-
-@pytest.fixture
-def example_runner() -> ExampleRunner:
- return ExampleRunner(output_root=OUTPUT_DIR)
-
-
-def run_cmd(
- command: list[str] | str,
- *,
- shell: bool = False,
- env: dict[str, str] | None = None,
- cwd: Path | str | None = None,
-) -> str:
- """Run a command as a subprocess; assert zero exit code and return stdout.
-
- Output is fully captured and returned as a string so callers can parse it
- (e.g. with :func:`extract_content_after_keyword`).
- Use this for scripts whose printed output is part of the test assertion.
- """
- if env is not None:
- env = {**os.environ.copy(), **env}
- result = subprocess.run(command, capture_output=True, text=True, shell=shell, env=env, cwd=cwd)
-
- if result.returncode != 0:
- print(f"STDERR: {result.stderr}")
- raise subprocess.CalledProcessError(result.returncode, command)
-
- all_output = result.stdout
- print(f"All output:\n{all_output}")
- return all_output
-
-
-# ---------------------------------------------------------------------------
-# Output validation helpers
-# ---------------------------------------------------------------------------
-
-
-def extract_content_after_keyword(keywords: str, text: str) -> str:
- """Return the text that follows *keywords* in *text* (regex match).
-
- Raises ``AssertionError`` if the keyword is not found, so test failures
- produce a clear message pointing at the missing keyword.
- """
- matches = re.findall(rf"{keywords}\s*(.+)", text, re.DOTALL)
-
- if not matches:
- raise AssertionError(f"Keywords {keywords} not found in provided text output")
- return matches[0]
-
-
-def strip_trailing_audio_saved_line(text: str) -> str:
- """Drop trailing ``Audio saved to ...`` lines from captured client stdout.
-
- ``openai_chat_completion_client_for_multimodal_generation.py`` may print
- ``Chat completion output from text:`` for one choice and ``Audio saved to``
- for another; :func:`extract_content_after_keyword` uses ``re.DOTALL`` and
- would otherwise keep the audio progress line inside the *text* segment.
- """
- lines = text.splitlines()
- while lines and lines[-1].strip().startswith("Audio saved to"):
- lines.pop()
- return "\n".join(lines).strip()
diff --git a/tests/examples/offline_inference/test_text_to_image.py b/tests/examples/offline_inference/test_text_to_image.py
index 041c32dc4ef..a08d16f1614 100644
--- a/tests/examples/offline_inference/test_text_to_image.py
+++ b/tests/examples/offline_inference/test_text_to_image.py
@@ -7,12 +7,11 @@
import pytest
-from tests.examples.helpers import EXAMPLES, ExampleRunner, ReadmeSnippet
-from tests.helpers.assertions import assert_image_valid
-from tests.helpers.mark import hardware_marks
-
-pytestmark = [pytest.mark.full_model, pytest.mark.example, *hardware_marks(res={"cuda": "H100"})]
+from tests.conftest import assert_image_valid
+from tests.examples.conftest import EXAMPLES, ExampleRunner, ReadmeSnippet
+from tests.utils import hardware_marks
+pytestmark = [pytest.mark.advanced_model, pytest.mark.example, *hardware_marks(res={"cuda": "H100"})]
T2I_SCRIPT = EXAMPLES / "offline_inference" / "text_to_image" / "text_to_image.py"
README_PATH = T2I_SCRIPT.with_name("README.md")
diff --git a/tests/examples/online_serving/test_qwen2_5_omni.py b/tests/examples/online_serving/test_qwen2_5_omni.py
index b3e49b8d9ad..a78ccf5924a 100644
--- a/tests/examples/online_serving/test_qwen2_5_omni.py
+++ b/tests/examples/online_serving/test_qwen2_5_omni.py
@@ -4,29 +4,34 @@
"""
import os
+
+from vllm_omni.platforms import current_omni_platform
+
+os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+
from pathlib import Path
import pytest
-from tests.examples.helpers import (
+from tests.conftest import OmniServerParams, convert_audio_file_to_text, cosine_similarity_text
+from tests.examples.conftest import (
extract_content_after_keyword,
run_cmd,
strip_trailing_audio_saved_line,
)
-from tests.helpers.mark import hardware_test
-from tests.helpers.media import convert_audio_file_to_text, cosine_similarity_text
-from tests.helpers.runtime import OmniServerParams
-from tests.helpers.stage_config import get_deploy_config_path
+from tests.utils import hardware_test
-pytestmark = [pytest.mark.full_model, pytest.mark.example, pytest.mark.omni]
-
-os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+pytestmark = [pytest.mark.advanced_model, pytest.mark.example]
models = ["Qwen/Qwen2.5-Omni-7B"]
-# Single CI deploy YAML; rocm/xpu deltas are picked automatically via the
-# platforms: section in vllm_omni/deploy/ci/qwen2_5_omni.yaml.
-stage_configs = [get_deploy_config_path("ci/qwen2_5_omni.yaml")]
+
+stage_configs = [str(Path(__file__).parent.parent.parent / "e2e" / "stage_configs" / "qwen2_5_omni_ci.yaml")]
+
+if current_omni_platform.is_xpu():
+ stage_configs = [
+ str(Path(__file__).parent.parent.parent / "e2e" / "stage_configs" / "xpu" / "qwen2_5_omni_ci.yaml")
+ ]
example_dir = str(Path(__file__).parent.parent.parent.parent / "examples" / "online_serving")
# Create parameter combinations for model and stage config
@@ -39,6 +44,8 @@
common_args = ["python", os.path.join(example_dir, "openai_chat_completion_client_for_multimodal_generation.py")]
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards={"cuda": 4, "rocm": 2})
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_send_multimodal_request_001(omni_server) -> None:
@@ -74,6 +81,8 @@ def test_send_multimodal_request_001(omni_server) -> None:
# TODO: Verify the E2E latency after confirmation baseline.
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards={"cuda": 4, "rocm": 2})
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_send_multimodal_request_002(omni_server) -> None:
@@ -109,6 +118,8 @@ def test_send_multimodal_request_002(omni_server) -> None:
# TODO: Verify the E2E latency after confirmation baseline.
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards={"cuda": 4, "rocm": 2})
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_send_multimodal_request_003(omni_server) -> None:
@@ -134,6 +145,8 @@ def test_send_multimodal_request_003(omni_server) -> None:
# TODO: Verify the E2E latency after confirmation baseline.
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards={"cuda": 4, "rocm": 2})
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_modality_control_001(omni_server) -> None:
@@ -162,6 +175,8 @@ def test_modality_control_001(omni_server) -> None:
# TODO: Verify the E2E latency after confirmation baseline.
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards={"cuda": 4, "rocm": 2})
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_modality_control_002(omni_server) -> None:
@@ -189,6 +204,8 @@ def test_modality_control_002(omni_server) -> None:
# TODO: Verify the E2E latency after confirmation baseline.
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards={"cuda": 4, "rocm": 2})
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_modality_control_003(omni_server) -> None:
@@ -225,6 +242,8 @@ def test_modality_control_003(omni_server) -> None:
# TODO: Verify the E2E latency after confirmation baseline.
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards={"cuda": 4, "rocm": 2})
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_stream_001(omni_server) -> None:
diff --git a/tests/examples/online_serving/test_qwen3_omni.py b/tests/examples/online_serving/test_qwen3_omni.py
index e52a2bf5a67..65f99d7bf28 100644
--- a/tests/examples/online_serving/test_qwen3_omni.py
+++ b/tests/examples/online_serving/test_qwen3_omni.py
@@ -4,28 +4,32 @@
"""
import os
+
+from vllm_omni.platforms import current_omni_platform
+
+os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+
from pathlib import Path
import pytest
-from tests.examples.helpers import (
+from tests.conftest import OmniServerParams, convert_audio_file_to_text, cosine_similarity_text
+from tests.examples.conftest import (
extract_content_after_keyword,
run_cmd,
strip_trailing_audio_saved_line,
)
-from tests.helpers.mark import hardware_test
-from tests.helpers.media import convert_audio_file_to_text, cosine_similarity_text
-from tests.helpers.runtime import OmniServerParams
-from tests.helpers.stage_config import get_deploy_config_path
+from tests.utils import hardware_test
-pytestmark = [pytest.mark.full_model, pytest.mark.example, pytest.mark.omni]
-
-os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+pytestmark = [pytest.mark.advanced_model, pytest.mark.example]
models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
-stage_configs = [get_deploy_config_path("ci/qwen3_omni_moe.yaml")]
+stage_configs = [str(Path(__file__).parent.parent.parent / "e2e" / "stage_configs" / "qwen3_omni_ci.yaml")]
+
+if current_omni_platform.is_xpu():
+ stage_configs = [str(Path(__file__).parent.parent.parent / "e2e" / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml")]
example_dir = str(Path(__file__).parent.parent.parent.parent / "examples" / "online_serving")
@@ -38,6 +42,8 @@
common_args = ["python", os.path.join(example_dir, "openai_chat_completion_client_for_multimodal_generation.py")]
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_send_multimodal_request_001(omni_server) -> None:
@@ -66,6 +72,8 @@ def test_send_multimodal_request_001(omni_server) -> None:
# TODO: Verify the E2E latency after confirmation baseline.
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_send_multimodal_request_002(omni_server) -> None:
@@ -97,6 +105,8 @@ def test_send_multimodal_request_002(omni_server) -> None:
# TODO: Verify the E2E latency after confirmation baseline.
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_send_multimodal_request_003(omni_server) -> None:
@@ -112,6 +122,8 @@ def test_send_multimodal_request_003(omni_server) -> None:
# TODO: Verify the E2E latency after confirmation baseline.
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_modality_control_001(omni_server) -> None:
@@ -134,6 +146,8 @@ def test_modality_control_001(omni_server) -> None:
# TODO: Verify the E2E latency after confirmation baseline.
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_modality_control_002(omni_server) -> None:
@@ -156,6 +170,8 @@ def test_modality_control_002(omni_server) -> None:
# TODO: Verify the E2E latency after confirmation baseline.
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_modality_control_003(omni_server) -> None:
@@ -186,6 +202,8 @@ def test_modality_control_003(omni_server) -> None:
# TODO: Verify the E2E latency after confirmation baseline.
+@pytest.mark.advanced_model
+@pytest.mark.omni
@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_stream_001(omni_server) -> None:
diff --git a/tests/examples/online_serving/test_text_to_image.py b/tests/examples/online_serving/test_text_to_image.py
index ee0a1fedba7..51b7ff61bc9 100644
--- a/tests/examples/online_serving/test_text_to_image.py
+++ b/tests/examples/online_serving/test_text_to_image.py
@@ -13,12 +13,11 @@
import pytest
-from tests.examples.helpers import EXAMPLES, OUTPUT_DIR, run_cmd, write_zimage_lora
-from tests.helpers.assertions import assert_image_valid
-from tests.helpers.mark import hardware_marks
-from tests.helpers.runtime import OmniServer, OmniServerParams
+from tests.conftest import OmniServer, OmniServerParams, assert_image_valid
+from tests.examples.conftest import EXAMPLES, OUTPUT_DIR, run_cmd, write_zimage_lora
+from tests.utils import hardware_marks
-pytestmark = [pytest.mark.full_model, pytest.mark.example, *hardware_marks(res={"cuda": "H100"})]
+pytestmark = [pytest.mark.advanced_model, pytest.mark.example, *hardware_marks(res={"cuda": "H100"})]
T2I_ONLINE_CLIENT = EXAMPLES / "online_serving" / "text_to_image" / "openai_chat_client.py"
EXAMPLE_OUTPUT_SUBFOLDER = "example_online_t2i"
diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py
deleted file mode 100644
index a3348b07fe0..00000000000
--- a/tests/helpers/__init__.py
+++ /dev/null
@@ -1,8 +0,0 @@
-"""Shared, importable test helper utilities.
-
-Submodules (``assertions``, ``env``, ``media``, ``runtime``, …) are imported
-explicitly by callers. Avoid star-importing everything here: that ran before
-refactor only inside the old monolithic ``conftest``; a greedy ``__init__``
-changes import order and can affect in-process Omni (``OmniRunner`` / offline
-e2e) vs subprocess-based ``OmniServer`` tests.
-"""
diff --git a/tests/helpers/assertions.py b/tests/helpers/assertions.py
deleted file mode 100644
index 604b76b62ec..00000000000
--- a/tests/helpers/assertions.py
+++ /dev/null
@@ -1,522 +0,0 @@
-"""Assertion and response validation helpers for tests."""
-
-import io
-import tempfile
-import threading
-from io import BytesIO
-from pathlib import Path
-from typing import Any
-
-import numpy as np
-import soundfile as sf
-from PIL import Image
-
-from tests.helpers.media import (
- cosine_similarity_text,
-)
-
-_GENDER_PIPELINE = None
-_GENDER_PIPELINE_LOCK = threading.Lock()
-_PCM_SPEECH_SAMPLE_RATE_HZ = 24_000
-_MIN_PCM_SPEECH_HNR_DB = 1.0
-_PRESET_VOICE_GENDER_MAP: dict[str, str] = {
- "serena": "female",
- "uncle_fu": "male",
- "chelsie": "female",
- "clone": "female",
- "ethan": "male",
-}
-
-
-def assert_image_diffusion_response(
- response,
- request_config: dict[str, Any],
- run_level: str = None,
-) -> None:
- """
- Validate image diffusion response.
-
- Expected request_config schema:
- {
- "request_type": "image",
- "extra_body": {
- "num_outputs_per_prompt": 1,
- "width": ...,
- "height": ...,
- ...
- }
- }
- """
- assert response.images is not None, "Image response is None"
- assert len(response.images) > 0, "No images in response"
-
- extra_body = request_config.get("extra_body") or {}
-
- num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt")
- if num_outputs_per_prompt is not None:
- assert len(response.images) == num_outputs_per_prompt, (
- f"Expected {num_outputs_per_prompt} images, got {len(response.images)}"
- )
-
- if run_level in {"advanced_model", "full_model"}:
- width = extra_body.get("width")
- height = extra_body.get("height")
-
- if width is not None or height is not None:
- for img in response.images:
- assert_image_valid(img, width=width, height=height)
-
-
-def assert_video_diffusion_response(
- response,
- request_config: dict[str, Any],
- run_level: str = None,
-) -> None:
- """
- Validate video diffusion response.
-
- Expected request_config schema:
- {
- "request_type": "video",
- "form_data": {
- "prompt": "...",
- "num_frames": ...,
- "width": ...,
- "height": ...,
- "fps": ...,
- ...
- }
- }
- """
- form_data = request_config.get("form_data", {})
-
- assert response.videos is not None, "Video response is None"
- assert len(response.videos) > 0, "No videos in response"
-
- expected_frames = _maybe_int(form_data.get("num_frames"))
- expected_width = _maybe_int(form_data.get("width"))
- expected_height = _maybe_int(form_data.get("height"))
- expected_fps = _maybe_int(form_data.get("fps"))
-
- for vid_bytes in response.videos:
- assert_video_valid(
- vid_bytes,
- num_frames=expected_frames,
- width=expected_width,
- height=expected_height,
- fps=expected_fps,
- )
-
-
-def assert_audio_diffusion_response(
- response,
- request_config: dict[str, Any],
- run_level: str = None,
-) -> None:
- """
- Validate audio diffusion response.
- """
- raise NotImplementedError("Audio validation is not implemented yet")
-
-
-def _maybe_int(value: Any) -> int | None:
- if value is None:
- return None
- return int(value)
-
-
-def assert_image_valid(image: Path | Image.Image, *, width: int | None = None, height: int | None = None):
- """Assert the file is a loadable image with optional exact dimensions."""
- if isinstance(image, Path):
- assert image.exists(), f"Image not found: {image}"
- image = Image.open(image)
- image.load()
- assert image.width > 0 and image.height > 0
- if width is not None:
- assert image.width == width, f"Expected width={width}, got {image.width}"
- if height is not None:
- assert image.height == height, f"Expected height={height}, got {image.height}"
- return image
-
-
-def assert_video_valid(
- video: Path | bytes | BytesIO,
- *,
- num_frames: int | None = None,
- width: int | None = None,
- height: int | None = None,
- fps: float | None = None,
-) -> dict[str, int | float]:
- """Assert the MP4 has the expected resolution and frame count.
-
- For several diffusion backends, encoded MP4 frame count follows a codec-aligned
- convention (e.g. request `num_frames=8` can produce 9 encoded frames). Keep
- this compatibility behavior to avoid false negatives in online-serving tests.
- """
- temp_path = None
- cap = None
- try:
- import cv2
-
- if isinstance(video, Path):
- if not video.exists():
- raise AssertionError(f"Video file not found: {video}")
- video_path = str(video)
- else:
- suffix = ".mp4"
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix, mode="wb") as tmp:
- if isinstance(video, bytes):
- tmp.write(video)
- elif isinstance(video, BytesIO):
- tmp.write(video.getvalue())
- else:
- raise TypeError(f"Unsupported video type: {type(video)}")
- temp_path = Path(tmp.name)
- video_path = str(temp_path)
-
- cap = cv2.VideoCapture(video_path)
- if not cap.isOpened():
- raise AssertionError("Failed to open video capture")
-
- actual_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- actual_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- actual_fps = float(cap.get(cv2.CAP_PROP_FPS))
- actual_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
-
- if width is not None:
- assert actual_width == width, f"Expected width={width}, got {actual_width}"
- if height is not None:
- assert actual_height == height, f"Expected height={height}, got {actual_height}"
- if fps is not None and actual_fps:
- assert abs(actual_fps - float(fps)) < 1.0, f"Expected fps~={fps}, got {actual_fps}"
- if num_frames is not None:
- expected_frames = (int(num_frames) // 4) * 4 + 1
- assert actual_frames == expected_frames, f"Expected frames={expected_frames}, got {actual_frames}"
-
- return {
- "width": actual_width,
- "height": actual_height,
- "fps": actual_fps,
- "num_frames": actual_frames,
- }
- except Exception as e:
- print(f"ERROR: {type(e).__name__}: {e}", flush=True)
- raise
- finally:
- if cap is not None:
- cap.release()
- if temp_path and temp_path.exists():
- try:
- temp_path.unlink()
- except OSError:
- pass
-
-
-def assert_audio_valid(
- audio_or_path: Path | np.ndarray,
- *,
- sample_rate: int,
- channels: int,
- duration_s: float,
-) -> None:
- """Assert WAV file or (batch, channels, samples) ndarray matches expected audio format."""
- expected_samples = int(duration_s * sample_rate)
- if isinstance(audio_or_path, np.ndarray):
- audio = audio_or_path
- assert audio.ndim == 3, f"Expected audio ndim=3 (batch, channels, samples), got shape {audio.shape}"
- assert audio.shape[0] == 1, f"Expected batch size 1, got {audio.shape[0]}"
- assert audio.shape[1] == channels, f"Expected {channels} channels, got {audio.shape[1]}"
- assert audio.shape[2] == expected_samples, (
- f"Expected {expected_samples} samples ({duration_s}s @ {sample_rate} Hz), got {audio.shape[2]}"
- )
- return
-
- path = audio_or_path
- assert path.exists(), f"Audio not found: {path}"
- info = sf.info(str(path))
- assert info.samplerate == sample_rate, f"Expected sample_rate={sample_rate}, got {info.samplerate}"
- assert info.channels == channels, f"Expected {channels} channel(s), got {info.channels}"
- assert info.frames == expected_samples, (
- f"Expected {expected_samples} frames ({duration_s}s @ {sample_rate} Hz), got {info.frames}"
- )
-
-
-def _load_gender_pipeline():
- global _GENDER_PIPELINE
- if _GENDER_PIPELINE is not None:
- return _GENDER_PIPELINE
- model_name = "7wolf/wav2vec2-base-gender-classification"
- try:
- from transformers import pipeline
-
- _GENDER_PIPELINE = pipeline(task="audio-classification", model=model_name, device=-1)
- return _GENDER_PIPELINE
- except Exception as exc: # pragma: no cover
- print(f"Warning: failed to create gender pipeline '{model_name}': {exc}")
- _GENDER_PIPELINE = None
- return None
-
-
-def _median_pitch_hz_from_autocorr(mono: np.ndarray, sr: int) -> float | None:
- x = np.asarray(mono, dtype=np.float64)
- x = x - np.mean(x)
- if x.size < int(0.15 * sr):
- return None
- frame_len = int(0.04 * sr)
- hop = max(frame_len // 2, 1)
- f0_min_hz, f0_max_hz = 70.0, 400.0
- lag_min = max(1, int(sr / f0_max_hz))
- lag_max = min(frame_len - 2, int(sr / f0_min_hz))
- if lag_max <= lag_min:
- return None
- win = np.hamming(frame_len)
- pitches: list[float] = []
- for start in range(0, int(x.shape[0]) - frame_len, hop):
- frame = x[start : start + frame_len] * win
- frame = frame - np.mean(frame)
- if float(np.sqrt(np.mean(frame**2))) < 1e-4:
- continue
- ac = np.correlate(frame, frame, mode="full")[frame_len - 1 :]
- ac = ac / (float(ac[0]) + 1e-12)
- region = ac[lag_min : lag_max + 1]
- peak_rel = int(np.argmax(region))
- peak_lag = peak_rel + lag_min
- if peak_lag <= 0:
- continue
- f0 = float(sr) / float(peak_lag)
- if f0_min_hz <= f0 <= f0_max_hz:
- pitches.append(f0)
- if len(pitches) < 4:
- return None
- return float(np.median(np.asarray(pitches, dtype=np.float64)))
-
-
-def _estimate_voice_gender_from_audio(audio_bytes: bytes) -> str:
- data, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True)
- if data.size == 0:
- raise ValueError("Empty audio")
- mono = np.mean(data, axis=1)
- try:
- target_sr = 16000
- if int(sr) != target_sr and mono.size > 1:
- src_len = int(mono.shape[0])
- dst_len = max(1, int(round(src_len * float(target_sr) / float(sr))))
- src_idx = np.arange(src_len, dtype=np.float32)
- dst_idx = np.linspace(0, src_len - 1, dst_len, dtype=np.float32)
- mono = np.interp(dst_idx, src_idx, mono.astype(np.float32, copy=False)).astype(np.float32)
- sr = target_sr
-
- median_f0 = _median_pitch_hz_from_autocorr(mono, sr)
- clf = _load_gender_pipeline()
- if clf is None:
- print("gender model not available, returning 'unknown'")
- return "unknown"
- with _GENDER_PIPELINE_LOCK:
- outputs = clf(mono, sampling_rate=sr)
- if not outputs:
- return "unknown"
- top = outputs[0]
- label = str(top.get("label", "")).lower()
- conf = float(top.get("score", 0.0))
- if conf < 0.5:
- gender = "unknown"
- elif ("female" in label) or ("жен" in label):
- gender = "female"
- elif ("male" in label) or ("муж" in label):
- gender = "male"
- else:
- gender = "unknown"
-
- if gender == "female" and median_f0 is not None and median_f0 < 165.0 and conf < 0.88:
- print(f"gender pitch assist: reclassifying female->male (median_f0={median_f0:.1f} Hz, conf={conf:.3f})")
- gender = "male"
- elif gender == "male" and median_f0 is not None and median_f0 > 230.0 and conf < 0.88:
- print(f"gender pitch assist: reclassifying male->female (median_f0={median_f0:.1f} Hz, conf={conf:.3f})")
- gender = "female"
- print(
- f"gender classifier: label={label}, conf={conf:.3f}, gender={gender}"
- + (f", median_f0={median_f0:.1f}Hz" if median_f0 is not None else "")
- )
- return gender
- except Exception as exc: # pragma: no cover
- print(f"Warning: gender classification failed, returning 'unknown': {exc}")
- return "unknown"
-
-
-def _assert_preset_voice_gender_from_audio(audio_bytes: bytes | None, voice_name: str | None) -> None:
- """If ``voice_name`` matches a known preset, assert classifier gender matches (skip when unknown)."""
- if not voice_name or not audio_bytes:
- return
- key = str(voice_name).lower()
- expected_gender = _PRESET_VOICE_GENDER_MAP.get(key)
- if expected_gender is None:
- return
- estimated_gender = _estimate_voice_gender_from_audio(audio_bytes)
- print(f"Preset voice gender check: preset={key!r}, estimated={estimated_gender!r}, expected={expected_gender!r}")
- if estimated_gender != "unknown":
- assert estimated_gender == expected_gender, (
- f"{voice_name!r} is expected {expected_gender}, but estimated gender is {estimated_gender!r}"
- )
-
-
-def _compute_pcm_hnr_db(pcm_samples: np.ndarray, sr: int = _PCM_SPEECH_SAMPLE_RATE_HZ) -> float:
- frame_len = int(0.03 * sr)
- hop = frame_len // 2
- hnr_values: list[float] = []
- for start in range(0, len(pcm_samples) - frame_len, hop):
- frame = pcm_samples[start : start + frame_len].astype(np.float32, copy=False)
- frame = frame - np.mean(frame)
- if np.max(np.abs(frame)) < 0.01:
- continue
- ac = np.correlate(frame, frame, mode="full")[len(frame) - 1 :]
- ac = ac / (ac[0] + 1e-10)
- min_lag = int(sr / 400)
- max_lag = min(int(sr / 80), len(ac))
- if min_lag >= max_lag:
- continue
- peak = float(np.max(ac[min_lag:max_lag]))
- if 0 < peak < 1:
- hnr_values.append(10 * np.log10(peak / (1 - peak + 1e-10)))
- return float(np.mean(hnr_values)) if hnr_values else 0.0
-
-
-def _assert_pcm_int16_speech_hnr(audio_bytes: bytes) -> None:
- assert audio_bytes is not None and len(audio_bytes) >= 2, "missing PCM bytes"
- assert len(audio_bytes) % 2 == 0, "PCM byte length must be aligned to int16"
- pcm_samples = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
- hnr = _compute_pcm_hnr_db(pcm_samples)
- print(f"PCM speech HNR: {hnr:.2f} dB (threshold: {_MIN_PCM_SPEECH_HNR_DB} dB)")
- assert hnr >= _MIN_PCM_SPEECH_HNR_DB, (
- f"Audio distortion detected: HNR={hnr:.2f} dB < {_MIN_PCM_SPEECH_HNR_DB} dB. "
- "Voice clone decoder may be losing ref_code speaker context on later chunks."
- )
-
-
-def assert_omni_response(response: Any, request_config: dict[str, Any], run_level):
- """
- Validate response results.
-
- Args:
- response: OmniResponse object
-
- Raises:
- AssertionError: When the response does not meet validation criteria
- """
- assert response.success, "The request failed."
- e2e_latency = response.e2e_latency
- if e2e_latency is not None:
- print(f"the e2e latency is: {e2e_latency}")
-
- modalities = request_config.get("modalities", ["text", "audio"])
-
- if run_level in {"advanced_model", "full_model"}:
- # Verify output success
- if "audio" in modalities:
- assert response.audio_content is not None, "No audio output is generated"
- print(f"audio content is: {response.audio_content}")
- speaker = request_config.get("speaker")
- if speaker:
- _assert_preset_voice_gender_from_audio(
- response.audio_bytes,
- speaker,
- )
- if "text" in modalities:
- assert response.text_content is not None, "No text output is generated"
- print(f"text content is: {response.text_content}")
-
- # Verify keywords in output
- word_types = ["text", "image", "audio", "video"]
- keywords_dict = request_config.get("key_words", {})
- for word_type in word_types:
- keywords = keywords_dict.get(word_type)
- if "text" in modalities:
- if keywords:
- text_lower = response.text_content.lower()
- assert any(str(kw).lower() in text_lower for kw in keywords), (
- "The output does not contain any of the keywords."
- )
- else:
- if keywords:
- audio_lower = response.audio_content.lower()
- assert any(str(kw).lower() in audio_lower for kw in keywords), (
- "The output does not contain any of the keywords."
- )
-
- # Verify similarity (Whisper transcript vs streamed/detokenized text)
- if "audio" in modalities:
- audio_ref_text = request_config.get("audio_ref_text")
- if "text" in modalities:
- transcript = (response.audio_content or "").strip()
- text_output = (response.text_content or "").strip()
- similarity = cosine_similarity_text(
- transcript.lower(),
- text_output.lower(),
- )
- assert similarity > 0.9, "The audio content is not same as the text"
- print(f"similarity is: {similarity}")
- if audio_ref_text:
- audio_similarity = cosine_similarity_text(
- response.audio_content.lower(),
- str(audio_ref_text).lower(),
- )
- assert audio_similarity > 0.9, (
- f"The audio content does not match reference text: similarity={audio_similarity:.3f}"
- )
-
-
-def assert_audio_speech_response(response: Any, request_config: dict[str, Any], run_level: str) -> None:
- assert response.success, "The request failed."
- e2e_latency = getattr(response, "e2e_latency", None)
- if e2e_latency is not None:
- print(f"the avg e2e latency is: {e2e_latency}")
-
- req_fmt = request_config.get("response_format")
- if req_fmt == "pcm" and response.audio_bytes:
- _assert_pcm_int16_speech_hnr(response.audio_bytes)
- if response.audio_format:
- assert "pcm" in response.audio_format.lower(), (
- f"Expected audio/pcm content-type, got {response.audio_format!r}"
- )
- elif req_fmt == "wav" and response.audio_format:
- assert req_fmt in response.audio_format
-
- if run_level in {"advanced_model", "full_model"} and req_fmt != "pcm":
- expected_text = request_config.get("input")
- if expected_text:
- transcript = (response.audio_content or "").strip()
- print(f"audio content is: {transcript}")
- print(f"input text is: {expected_text}")
- similarity = cosine_similarity_text(transcript.lower(), expected_text.lower())
- print(f"Cosine similarity: {similarity:.3f}")
- assert similarity > 0.9, (
- f"Transcript doesn't match input: similarity={similarity:.2f}, transcript='{transcript}'"
- )
- _assert_preset_voice_gender_from_audio(response.audio_bytes, request_config.get("voice"))
-
-
-def assert_diffusion_response(response: Any, request_config: dict[str, Any], run_level: str = None):
- assert response.success, "The request failed."
- e2e_latency = getattr(response, "e2e_latency", None)
- if e2e_latency is not None:
- print(f"the avg e2e is: {e2e_latency}")
- has_any_content = any(content is not None for content in (response.images, response.videos, response.audios))
- assert has_any_content, "Response contains no images, videos, or audios"
- if response.images is not None:
- assert_image_diffusion_response(response=response, request_config=request_config, run_level=run_level)
- if response.videos is not None:
- assert_video_diffusion_response(response=response, request_config=request_config, run_level=run_level)
- if response.audios is not None:
- assert_audio_diffusion_response(response=response, request_config=request_config, run_level=run_level)
-
-
-__all__ = [
- "assert_audio_diffusion_response",
- "assert_audio_speech_response",
- "assert_diffusion_response",
- "assert_image_diffusion_response",
- "assert_image_valid",
- "assert_omni_response",
- "assert_video_diffusion_response",
- "assert_video_valid",
- "assert_audio_valid",
-]
diff --git a/tests/helpers/env.py b/tests/helpers/env.py
deleted file mode 100644
index 22ec9a78626..00000000000
--- a/tests/helpers/env.py
+++ /dev/null
@@ -1,280 +0,0 @@
-"""Test environment / lifecycle helpers (GPU cleanup hooks and memory monitoring for tests).
-
-``vllm.platforms`` / ``vllm_omni.platforms`` are imported only inside functions that need them
-so importing this module at pytest plugin load does not run before session autouse fixtures
-"""
-
-from __future__ import annotations
-
-import gc
-import os
-import subprocess
-import threading
-import time
-from contextlib import contextmanager
-
-import torch
-
-
-def run_forced_gpu_cleanup_round() -> None:
- run_pre_test_cleanup(enable_force=True)
- run_post_test_cleanup(enable_force=True)
-
-
-def get_physical_device_indices(devices):
- visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
- if visible_devices is None:
- return devices
- visible_indices = [int(x) for x in visible_devices.split(",")]
- index_mapping = {i: physical for i, physical in enumerate(visible_indices)}
- return [index_mapping[i] for i in devices if i in index_mapping]
-
-
-def wait_for_gpu_memory_to_clear(
- *,
- devices: list[int],
- threshold_bytes: int | None = None,
- threshold_ratio: float | None = None,
- timeout_s: float = 120,
-) -> None:
- from vllm.platforms import current_platform
-
- assert threshold_bytes is not None or threshold_ratio is not None
- devices = get_physical_device_indices(devices)
- start_time = time.time()
-
- device_list = ", ".join(str(d) for d in devices)
- if threshold_bytes is not None:
- threshold_str = f"{threshold_bytes / 2**30:.2f} GiB"
- condition_str = f"Memory usage ≤ {threshold_str}"
- else:
- threshold_percent = threshold_ratio * 100
- threshold_str = f"{threshold_percent:.1f}%"
- condition_str = f"Memory usage ratio ≤ {threshold_str}"
-
- print(f"[GPU Memory Monitor] Waiting for GPU {device_list} to free memory, Condition: {condition_str}")
-
- if threshold_bytes is not None:
-
- def is_free(used, total):
- return used <= threshold_bytes / 2**30
- else:
-
- def is_free(used, total):
- return used / total <= threshold_ratio
-
- @contextmanager
- def nvml_scope():
- if current_platform.is_rocm():
- from amdsmi import amdsmi_init, amdsmi_shut_down
-
- amdsmi_init()
- try:
- yield
- finally:
- amdsmi_shut_down()
- elif current_platform.is_cuda():
- from vllm.third_party.pynvml import nvmlInit, nvmlShutdown
-
- nvmlInit()
- try:
- yield
- finally:
- nvmlShutdown()
- else:
- yield
-
- is_rocm = current_platform.is_rocm()
-
- with nvml_scope():
- if is_rocm:
- from amdsmi import amdsmi_get_gpu_vram_usage, amdsmi_get_processor_handles
- elif current_platform.is_cuda():
- from vllm.third_party.pynvml import nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
-
- while True:
- output: dict[int, str] = {}
- output_raw: dict[int, tuple[float, float]] = {}
- for device in devices:
- if is_rocm:
- dev_handle = amdsmi_get_processor_handles()[device]
- mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
- gb_used = mem_info["vram_used"] / 2**10
- gb_total = mem_info["vram_total"] / 2**10
- else:
- dev_handle = nvmlDeviceGetHandleByIndex(device)
- mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
- gb_used = mem_info.used / 2**30
- gb_total = mem_info.total / 2**30
- output_raw[device] = (gb_used, gb_total)
- usage_percent = (gb_used / gb_total) * 100 if gb_total > 0 else 0
- output[device] = f"{gb_used:.1f}GiB/{gb_total:.1f}GiB ({usage_percent:.1f}%)"
-
- print("[GPU Memory Status] Current usage:")
- for device_id, mem_info in output.items():
- print(f" GPU {device_id}: {mem_info}")
-
- dur_s = time.time() - start_time
- elapsed_minutes = dur_s / 60
- if all(is_free(used, total) for used, total in output_raw.values()):
- print(f"[GPU Memory Freed] Devices {device_list} meet memory condition")
- print(f" Condition: {condition_str}")
- print(f" Wait time: {dur_s:.1f} seconds ({elapsed_minutes:.1f} minutes)")
- break
-
- if dur_s >= timeout_s:
- raise ValueError(
- f"[GPU Memory Timeout] Devices {device_list} still don't meet memory condition after {dur_s:.1f} seconds\n"
- f"Condition: {condition_str}\n"
- f"Current status:\n" + "\n".join(f" GPU {device}: {output[device]}" for device in devices)
- )
-
- gc.collect()
- torch.cuda.empty_cache()
- time.sleep(5)
-
-
-def _print_gpu_processes() -> None:
- """Print GPU information including nvidia-smi and system processes."""
-
- print("\n" + "=" * 80)
- print("NVIDIA GPU Information (nvidia-smi)")
- print("=" * 80)
-
- try:
- nvidia_result = subprocess.run(
- ["nvidia-smi"],
- capture_output=True,
- text=True,
- timeout=5,
- )
-
- if nvidia_result.returncode == 0:
- lines = nvidia_result.stdout.strip().split("\n")
- for line in lines[:20]:
- print(line)
-
- if len(lines) > 20:
- print(f"... (showing first 20 of {len(lines)} lines)")
- else:
- print("nvidia-smi command failed")
-
- except (subprocess.TimeoutExpired, FileNotFoundError):
- print("nvidia-smi not available or timed out")
- except Exception as e:
- print(f"Error running nvidia-smi: {e}")
-
- print("\n" + "=" * 80)
- print("Detailed GPU Processes (nvidia-smi pmon)")
- print("=" * 80)
-
- try:
- pmon_result = subprocess.run(
- ["nvidia-smi", "pmon", "-c", "1"],
- capture_output=True,
- text=True,
- timeout=3,
- )
-
- if pmon_result.returncode == 0 and pmon_result.stdout.strip():
- print(pmon_result.stdout)
- else:
- print("No active GPU processes found via nvidia-smi pmon")
-
- except Exception:
- print("nvidia-smi pmon not available")
-
- print("\n" + "=" * 80)
- print("System Processes with GPU keywords")
- print("=" * 80)
-
-
-_SKIPPED_GPU_CLEANUP_MSG = (
- "\nSkipping GPU memory cleanup check (typically: instance already up; no check needed between tests)\n"
-)
-
-
-def run_pre_test_cleanup(enable_force: bool = False) -> None:
- if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force:
- print(_SKIPPED_GPU_CLEANUP_MSG)
- return
-
- print("Pre-test GPU status:")
-
- num_gpus = torch.cuda.device_count()
- if num_gpus > 0:
- try:
- wait_for_gpu_memory_to_clear(
- devices=list(range(num_gpus)),
- threshold_ratio=0.05,
- )
- except Exception as e:
- print(f"Pre-test cleanup note: {e}")
-
-
-def run_post_test_cleanup(enable_force: bool = False) -> None:
- if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force:
- print(_SKIPPED_GPU_CLEANUP_MSG)
- return
-
- if torch.cuda.is_available():
- gc.collect()
- torch.cuda.empty_cache()
-
- print("Post-test GPU status:")
- _print_gpu_processes()
-
-
-class DeviceMemoryMonitor:
- """Poll global device memory usage."""
-
- def __init__(self, device_index: int, interval: float = 0.05):
- self.device_index = device_index
- self.interval = interval
- self._peak_used_mb = 0.0
- self._stop_event = threading.Event()
- self._thread: threading.Thread | None = None
-
- def start(self) -> None:
- from vllm_omni.platforms import current_omni_platform
-
- def monitor_loop() -> None:
- while not self._stop_event.is_set():
- try:
- with current_omni_platform.device(self.device_index):
- free_bytes, total_bytes = current_omni_platform.mem_get_info()
- used_mb = (total_bytes - free_bytes) / (1024**2)
- self._peak_used_mb = max(self._peak_used_mb, used_mb)
- except Exception:
- pass
- time.sleep(self.interval)
-
- self._thread = threading.Thread(target=monitor_loop, daemon=False)
- self._thread.start()
-
- def stop(self) -> None:
- if self._thread is None:
- return
- self._stop_event.set()
- self._thread.join(timeout=2.0)
-
- @property
- def peak_used_mb(self) -> float:
- from vllm_omni.platforms import current_omni_platform
-
- fallback_alloc = current_omni_platform.max_memory_allocated(device=self.device_index) / (1024**2)
- fallback_reserved = current_omni_platform.max_memory_reserved(device=self.device_index) / (1024**2)
- return max(self._peak_used_mb, fallback_alloc, fallback_reserved)
-
- def __del__(self):
- self.stop()
-
-
-__all__ = [
- "DeviceMemoryMonitor",
- "get_physical_device_indices",
- "run_post_test_cleanup",
- "run_pre_test_cleanup",
- "run_forced_gpu_cleanup_round",
- "wait_for_gpu_memory_to_clear",
-]
diff --git a/tests/helpers/fixtures/__init__.py b/tests/helpers/fixtures/__init__.py
deleted file mode 100644
index 8bd090b7824..00000000000
--- a/tests/helpers/fixtures/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Pytest fixture modules under tests.helpers."""
diff --git a/tests/helpers/fixtures/env.py b/tests/helpers/fixtures/env.py
deleted file mode 100644
index 939bad02ca4..00000000000
--- a/tests/helpers/fixtures/env.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import os
-
-import pytest
-import torch
-
-
-@pytest.fixture(scope="session", autouse=True)
-def default_env():
- # Keep behavior but avoid import-time side effects (RFC #2299).
- keys = ("VLLM_WORKER_MULTIPROC_METHOD", "VLLM_TARGET_DEVICE")
- previous = {key: os.environ.get(key) for key in keys}
- os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = previous["VLLM_WORKER_MULTIPROC_METHOD"] or "spawn"
- os.environ["VLLM_TARGET_DEVICE"] = previous["VLLM_TARGET_DEVICE"] or (
- "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu"
- )
- yield
- for key, value in previous.items():
- if value is None:
- os.environ.pop(key, None)
- else:
- os.environ[key] = value
-
-
-@pytest.fixture(scope="session")
-def model_prefix() -> str:
- prefix = os.environ.get("MODEL_PREFIX", "")
- return f"{prefix.rstrip('/')}/" if prefix else ""
-
-
-@pytest.fixture(autouse=True)
-def clean_gpu_memory_between_tests():
- # Import here so ``tests.helpers.env`` (and vLLM platform modules) load only
- # after session autouse fixtures like ``default_env`` have run (RFC #2299).
- from tests.helpers.env import run_post_test_cleanup, run_pre_test_cleanup
-
- print("\n=== PRE-TEST GPU CLEANUP ===")
- run_pre_test_cleanup()
- yield
- run_post_test_cleanup()
-
-
-@pytest.fixture(scope="session", autouse=True)
-def default_vllm_config():
- """Set a default VllmConfig for the whole test session.
-
- Session scope ensures module-scoped fixtures (e.g. ``omni_runner``) and
- deferred imports of ``tests.helpers.runtime`` both see the same context.
- Function-scoped autouse ran too late for ``OmniRunner`` setup and could
- desynchronize vLLM init vs request preprocessing (e.g. renderer state).
- """
- from vllm.config import DeviceConfig, VllmConfig, set_current_vllm_config
-
- # Use CPU device if no GPU is available (e.g., in CI environments)
- has_gpu = torch.cuda.is_available() and torch.cuda.device_count() > 0
- device = "cuda" if has_gpu else "cpu"
- device_config = DeviceConfig(device=device)
-
- with set_current_vllm_config(VllmConfig(device_config=device_config)):
- yield
diff --git a/tests/helpers/fixtures/log.py b/tests/helpers/fixtures/log.py
deleted file mode 100644
index 798fa4ae6c7..00000000000
--- a/tests/helpers/fixtures/log.py
+++ /dev/null
@@ -1,7 +0,0 @@
-import pytest
-
-
-@pytest.fixture(autouse=True)
-def log_test_name_before_test(request: pytest.FixtureRequest):
- print(f"--- Running test: {request.node.name}")
- yield
diff --git a/tests/helpers/fixtures/run_args.py b/tests/helpers/fixtures/run_args.py
deleted file mode 100644
index 975584d206b..00000000000
--- a/tests/helpers/fixtures/run_args.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import pytest
-
-
-def pytest_addoption(parser):
- parser.addoption(
- "--run-level",
- action="store",
- default="core_model",
- choices=["core_model", "advanced_model", "full_model"],
- help="Test level to run: L2, L3, L4",
- )
-
-
-@pytest.fixture(scope="session")
-def run_level(request) -> str:
- """Session test level from ``--run-level`` (see CI five-level docs)."""
- return request.config.getoption("--run-level")
diff --git a/tests/helpers/fixtures/runtime.py b/tests/helpers/fixtures/runtime.py
deleted file mode 100644
index 4cae13cd6eb..00000000000
--- a/tests/helpers/fixtures/runtime.py
+++ /dev/null
@@ -1,137 +0,0 @@
-"""Runtime fixtures (OmniRunner / OmniServer). Imports are deferred to fixture time.
-
-Loading ``tests.helpers.runtime`` at plugin import time (before session fixtures)
-pulls in vLLM/vllm_omni too early and breaks initialization order vs the legacy
-monolithic conftest. Defer imports until fixtures run so ``default_env`` /
-``default_vllm_config`` run first.
-"""
-
-from __future__ import annotations
-
-import threading
-from collections.abc import Generator
-from typing import Any
-
-import pytest
-import yaml
-
-from tests.helpers.runtime import OmniServer
-from tests.helpers.stage_config import modify_stage_config
-
-omni_fixture_lock = threading.Lock()
-
-
-@pytest.fixture(scope="module")
-def omni_server(request: pytest.FixtureRequest, run_level: str, model_prefix: str) -> Generator[OmniServer, Any, None]:
- """Start vLLM-Omni through the standard or stage-CLI launcher.
-
- The fixture stays module-scoped because multi-stage initialization is costly.
- The ``use_stage_cli`` flag on ``OmniServerParams`` routes the setup through the
- stage-CLI harness while still reusing the same fixture grouping semantics.
- """
- with omni_fixture_lock:
- from tests.helpers.runtime import OmniServer, OmniServerParams, OmniServerStageCli
-
- params: OmniServerParams = request.param
- model = model_prefix + params.model
- port = params.port
- stage_config_path = params.stage_config_path
- if run_level in {"advanced_model", "full_model"} and stage_config_path is not None:
- with open(stage_config_path, encoding="utf-8") as f:
- cfg = yaml.safe_load(f) or {}
- # Strip ``load_format: dummy`` (CI overlay default) so advanced_model
- # tests use real weights. New schema (``stages:``) writes the field
- # flat at stage level; legacy schema (``stage_args:``) nests it as
- # ``engine_args.load_format``. Handle both.
- new_schema_stages = cfg.get("stages")
- stage_key = "stages" if new_schema_stages is not None else "stage_args"
- delete_path = "load_format" if new_schema_stages is not None else "engine_args.load_format"
- stage_entries = cfg.get(stage_key, [])
- stage_ids = [stage["stage_id"] for stage in stage_entries if "stage_id" in stage]
- stage_config_path = modify_stage_config(
- stage_config_path,
- deletes={stage_key: {stage_id: [delete_path] for stage_id in stage_ids}},
- )
-
- server_args = params.server_args or []
- if params.use_omni and params.stage_init_timeout is not None:
- server_args = [*server_args, "--stage-init-timeout", str(params.stage_init_timeout)]
- else:
- server_args = [*server_args, "--stage-init-timeout", "600"]
- if params.init_timeout is not None:
- server_args = [*server_args, "--init-timeout", str(params.init_timeout)]
- else:
- server_args = [*server_args, "--init-timeout", "900"]
- if params.use_stage_cli:
- if not params.use_omni:
- raise ValueError("omni_server with use_stage_cli=True requires use_omni=True")
- if stage_config_path is None:
- raise ValueError("omni_server with use_stage_cli=True requires a stage_config_path")
- server_args += ["--stage-configs-path", stage_config_path]
-
- with OmniServerStageCli(
- model,
- stage_config_path,
- server_args,
- port=port,
- env_dict=params.env_dict,
- ) as server:
- print("OmniServer started successfully")
- yield server
- print("OmniServer stopping...")
- else:
- if stage_config_path is not None:
- server_args += ["--stage-configs-path", stage_config_path]
-
- with (
- OmniServer(
- model,
- server_args,
- port=port,
- env_dict=params.env_dict,
- use_omni=params.use_omni,
- )
- if port
- else OmniServer(
- model,
- server_args,
- env_dict=params.env_dict,
- use_omni=params.use_omni,
- )
- ) as server:
- print("OmniServer started successfully")
- yield server
- print("OmniServer stopping...")
-
- print("OmniServer stopped")
-
-
-@pytest.fixture
-def openai_client(request: pytest.FixtureRequest, run_level: str):
- """Resolve ``omni_server`` lazily so parametrized server fixtures work like upstream."""
- from tests.helpers.runtime import OpenAIClientHandler
-
- server = request.getfixturevalue("omni_server")
- return OpenAIClientHandler(host=server.host, port=server.port, api_key="EMPTY", run_level=run_level)
-
-
-@pytest.fixture(scope="module")
-def omni_runner(request: pytest.FixtureRequest, model_prefix: str):
- from tests.helpers.runtime import OmniRunner
-
- with omni_fixture_lock:
- model, stage_config_path = request.param
- model = model_prefix + model
- with OmniRunner(model, seed=42, stage_configs_path=stage_config_path) as runner:
- print("OmniRunner started successfully")
- yield runner
- print("OmniRunner stopping...")
-
- print("OmniRunner stopped")
-
-
-@pytest.fixture
-def omni_runner_handler(omni_runner: Any):
- from tests.helpers.runtime import OmniRunnerHandler
-
- return OmniRunnerHandler(omni_runner)
diff --git a/tests/helpers/mark.py b/tests/helpers/mark.py
deleted file mode 100644
index ed45dd7e9a1..00000000000
--- a/tests/helpers/mark.py
+++ /dev/null
@@ -1,135 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Pytest marks and decorators for hardware / resource selection (CUDA, ROCm, …)."""
-
-import pytest
-from vllm.utils.torch_utils import cuda_device_count_stateless
-
-# Re-exported from tests.helpers.env (GPU wait + DeviceMemoryMonitor).
-
-
-def cuda_marks(*, res: str, num_cards: int):
- test_platform_detail = pytest.mark.cuda
- if res == "L4":
- test_resource = pytest.mark.L4
- elif res == "H100":
- test_resource = pytest.mark.H100
- else:
- raise ValueError(f"Invalid CUDA resource type: {res}. Supported: L4, H100")
- marks = [test_resource, test_platform_detail]
- if num_cards == 1:
- return marks
- test_distributed = pytest.mark.distributed_cuda(num_cards=num_cards)
- test_skipif = pytest.mark.skipif_cuda(
- cuda_device_count_stateless() < num_cards,
- reason=f"Need at least {num_cards} CUDA GPUs to run the test.",
- )
- return marks + [test_distributed, test_skipif]
-
-
-def rocm_marks(*, res: str, num_cards: int):
- test_platform_detail = pytest.mark.rocm
- if res == "MI325":
- test_resource = pytest.mark.MI325
- else:
- raise ValueError(f"Invalid ROCm resource type: {res}. Supported: MI325")
- marks = [test_resource, test_platform_detail]
- if num_cards == 1:
- return marks
- test_distributed = pytest.mark.distributed_rocm(num_cards=num_cards)
- return marks + [test_distributed]
-
-
-def xpu_marks(*, res: str, num_cards: int):
- test_platform_detail = pytest.mark.xpu
- if res == "B60":
- test_resource = pytest.mark.B60
- else:
- raise ValueError(f"Invalid XPU resource type: {res}. Supported: B60")
- marks = [test_resource, test_platform_detail]
- if num_cards == 1:
- return marks
- test_distributed = pytest.mark.distributed_rocm(num_cards=num_cards)
- return marks + [test_distributed]
-
-
-def musa_marks(*, res: str, num_cards: int):
- test_platform_detail = pytest.mark.musa
- if res == "S5000":
- test_resource = pytest.mark.S5000
- else:
- raise ValueError(f"Invalid MUSA resource type: {res}. Supported: S5000")
- marks = [test_resource, test_platform_detail]
- if num_cards == 1:
- return marks
- test_distributed = pytest.mark.distributed_musa(num_cards=num_cards)
- return marks + [test_distributed]
-
-
-def gpu_marks(*, res: str, num_cards: int):
- test_platform = pytest.mark.gpu
- if res in ("L4", "H100"):
- return [test_platform] + cuda_marks(res=res, num_cards=num_cards)
- if res == "MI325":
- return [test_platform] + rocm_marks(res=res, num_cards=num_cards)
- if res == "B60":
- return [test_platform] + xpu_marks(res=res, num_cards=num_cards)
- if res == "S5000":
- return [test_platform] + musa_marks(res=res, num_cards=num_cards)
- raise ValueError(f"Invalid resource type: {res}. Supported: L4, H100, MI325, B60, S5000")
-
-
-def npu_marks(*, res: str, num_cards: int):
- test_platform = pytest.mark.npu
- if res == "A2":
- test_resource = pytest.mark.A2
- elif res == "A3":
- test_resource = pytest.mark.A3
- else:
- test_resource = None
- if num_cards == 1:
- return [mark for mark in [test_platform, test_resource] if mark is not None]
- test_distributed = pytest.mark.distributed_npu(num_cards=num_cards)
- return [mark for mark in [test_platform, test_resource, test_distributed] if mark is not None]
-
-
-def hardware_marks(*, res: dict[str, str], num_cards: int | dict[str, int] = 1):
- for platform, _ in res.items():
- if platform not in ("cuda", "rocm", "xpu", "npu", "musa"):
- raise ValueError(f"Unsupported platform: {platform}")
- if isinstance(num_cards, int):
- num_cards_dict = {platform: num_cards for platform in res.keys()}
- else:
- num_cards_dict = num_cards
- for platform in num_cards_dict.keys():
- if platform not in res:
- raise ValueError(f"Platform '{platform}' in num_cards but not in res.")
- for platform in res.keys():
- if platform not in num_cards_dict:
- num_cards_dict[platform] = 1
-
- all_marks: list[pytest.MarkDecorator] = []
- for platform, resource in res.items():
- cards = num_cards_dict[platform]
- if platform in ("cuda", "rocm", "xpu"):
- marks = gpu_marks(res=resource, num_cards=cards)
- elif platform == "musa":
- marks = musa_marks(res=resource, num_cards=cards)
- elif platform == "npu":
- marks = npu_marks(res=resource, num_cards=cards)
- else:
- raise ValueError(f"Unsupported platform: {platform}")
- all_marks.extend(marks)
- return all_marks
-
-
-def hardware_test(*, res: dict[str, str], num_cards: int | dict[str, int] = 1):
- all_marks = hardware_marks(res=res, num_cards=num_cards)
-
- def wrapper(f):
- func = f
- for mark in reversed(all_marks):
- func = mark(func)
- return func
-
- return wrapper
diff --git a/tests/helpers/media.py b/tests/helpers/media.py
deleted file mode 100644
index c0fb9717140..00000000000
--- a/tests/helpers/media.py
+++ /dev/null
@@ -1,657 +0,0 @@
-"""Synthetic media generation and media/text utilities for tests."""
-
-import base64
-import concurrent.futures
-import gc
-import hashlib
-import io
-import logging
-import math
-import multiprocessing
-import os
-import random
-import re
-import subprocess
-import tempfile
-import time
-import uuid
-from contextlib import contextmanager
-from pathlib import Path
-from typing import Any
-
-import numpy as np
-import soundfile as sf
-from PIL import Image
-
-logger = logging.getLogger(__name__)
-
-
-def _resolve_synthetic_media_cache_dir(cache_dir: Path | str | None) -> Path:
- if cache_dir is not None:
- return Path(cache_dir).expanduser().resolve()
- return Path(tempfile.gettempdir()) / "vllm_omni_test_synthetic_media"
-
-
-def _np_array_from_mp4_bytes(video_bytes: bytes) -> np.ndarray:
- """Decode MP4 bytes to a (T, H, W, 3) uint8 RGB stack (matches in-memory synthetic frames)."""
- import cv2
-
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
- tmp.write(video_bytes)
- path = tmp.name
- cap = None
- try:
- cap = cv2.VideoCapture(path)
- if not cap.isOpened():
- raise RuntimeError("Failed to open cached synthetic video for decode")
- frames: list[np.ndarray] = []
- while True:
- ok, frame_bgr = cap.read()
- if not ok:
- break
- frames.append(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB))
- if not frames:
- raise RuntimeError("Cached synthetic video has no decodable frames")
- return np.stack(frames, axis=0)
- finally:
- if cap is not None:
- cap.release()
- try:
- os.unlink(path)
- except OSError:
- pass
-
-
-def generate_synthetic_audio(
- duration: int,
- num_channels: int,
- sample_rate: int = 48000,
- *,
- phrase_text: str = "test",
- force_regenerate: bool = False,
- cache_dir: Path | str | None = None,
-) -> dict[str, Any]:
- """
- Generate TTS speech with pyttsx3 and return base64 string.
-
- Caches the WAV under ``cache_dir`` when given, else under the default temp
- subdirectory. Reuses the file when the same
- ``duration`` / ``num_channels`` / ``sample_rate`` / ``phrase_text`` are
- requested unless ``force_regenerate`` is true.
-
- The cache filename includes a SHA-256 digest of ``phrase_text`` so different
- phrases never share a WAV cache entry.
- """
- root = _resolve_synthetic_media_cache_dir(cache_dir)
- root.mkdir(parents=True, exist_ok=True)
- phrase_key = hashlib.sha256(phrase_text.encode("utf-8")).hexdigest()
- cache_path = root / f"synth_audio_d{duration}_ch{num_channels}_sr{sample_rate}_pt{phrase_key}.wav"
-
- if not force_regenerate and cache_path.is_file():
- data, _sr = sf.read(str(cache_path), dtype="float32", always_2d=True)
- audio_bytes = cache_path.read_bytes()
- return {
- "np_array": np.asarray(data, dtype=np.float32),
- "base64": base64.b64encode(audio_bytes).decode("utf-8"),
- "file_path": str(cache_path.resolve()),
- }
-
- import pyttsx3
-
- def _pick_voice(engine: pyttsx3.Engine) -> str | None:
- voices = engine.getProperty("voices")
- if not voices:
- return None
-
- preferred_tokens = (
- "natural",
- "jenny",
- "sonia",
- "susan",
- "zira",
- "aria",
- "hazel",
- "samantha",
- "ava",
- "allison",
- "female",
- "woman",
- "english-us",
- "en-us",
- "english",
- )
- discouraged_tokens = (
- "espeak",
- "robot",
- "mbrola",
- "microsoft david",
- "male",
- "man",
- )
-
- best_voice = voices[0]
- best_score = float("-inf")
- for voice in voices:
- voice_text = f"{getattr(voice, 'id', '')} {getattr(voice, 'name', '')}".lower()
- voice_languages = " ".join(
- lang.decode(errors="ignore") if isinstance(lang, bytes) else str(lang)
- for lang in getattr(voice, "languages", [])
- ).lower()
- combined_text = f"{voice_text} {voice_languages}"
- score = 0
- for idx, token in enumerate(preferred_tokens):
- if token in combined_text:
- score += 20 - idx
- for token in discouraged_tokens:
- if token in combined_text:
- score -= 10
- if "english" in combined_text or "en_" in combined_text or "en-" in combined_text:
- score += 4
- if "en-us" in combined_text or "english-us" in combined_text:
- score += 4
- if score > best_score:
- best_score = score
- best_voice = voice
-
- return best_voice.id
-
- def _resample_audio(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
- if src_sr == dst_sr or len(audio) == 0:
- return audio.astype(np.float32)
- src_len = audio.shape[0]
- dst_len = max(1, int(round(src_len * float(dst_sr) / float(src_sr))))
- src_idx = np.arange(src_len, dtype=np.float32)
- dst_idx = np.linspace(0, src_len - 1, dst_len, dtype=np.float32)
- resampled_channels: list[np.ndarray] = []
- for ch in range(audio.shape[1]):
- resampled_channels.append(np.interp(dst_idx, src_idx, audio[:, ch]).astype(np.float32))
- return np.stack(resampled_channels, axis=1)
-
- def _match_channels(audio: np.ndarray, target_channels: int) -> np.ndarray:
- current_channels = audio.shape[1]
- if current_channels == target_channels:
- return audio.astype(np.float32)
- if target_channels == 1:
- return np.mean(audio, axis=1, keepdims=True, dtype=np.float32)
- if current_channels == 1:
- return np.repeat(audio, target_channels, axis=1).astype(np.float32)
- collapsed = np.mean(audio, axis=1, keepdims=True, dtype=np.float32)
- return np.repeat(collapsed, target_channels, axis=1).astype(np.float32)
-
- def _trim_silence(audio: np.ndarray, threshold: float = 0.01) -> np.ndarray:
- if len(audio) == 0:
- return audio
- energy = np.max(np.abs(audio), axis=1)
- voiced = np.where(energy > threshold)[0]
- if len(voiced) == 0:
- return audio
- start = max(0, int(voiced[0]) - int(sample_rate * 0.02))
- end = min(len(audio), int(voiced[-1]) + int(sample_rate * 0.04) + 1)
- return audio[start:end]
-
- def _enhance_speech(audio: np.ndarray) -> np.ndarray:
- if len(audio) == 0:
- return audio.astype(np.float32)
- enhanced = audio.astype(np.float32).copy()
- enhanced -= np.mean(enhanced, axis=0, keepdims=True, dtype=np.float32)
- if len(enhanced) > 1:
- preemphasis = enhanced.copy()
- preemphasis[1:] = enhanced[1:] - 0.94 * enhanced[:-1]
- enhanced = 0.7 * enhanced + 0.3 * preemphasis
- enhanced = np.sign(enhanced) * np.sqrt(np.abs(enhanced))
- fade = min(len(enhanced) // 4, max(1, int(sample_rate * 0.01)))
- if fade > 1:
- ramp_in = np.linspace(0.0, 1.0, fade, dtype=np.float32)
- ramp_out = np.linspace(1.0, 0.0, fade, dtype=np.float32)
- enhanced[:fade] *= ramp_in[:, None]
- enhanced[-fade:] *= ramp_out[:, None]
- peak = float(np.max(np.abs(enhanced)))
- if peak > 1e-8:
- enhanced = enhanced / peak * 0.95
- return enhanced.astype(np.float32)
-
- num_samples = int(sample_rate * max(1, duration))
- audio_data = np.zeros((num_samples, num_channels), dtype=np.float32)
-
- engine = pyttsx3.init()
- engine.setProperty("rate", 112)
- engine.setProperty("volume", 1.0)
- selected_voice = _pick_voice(engine)
- if selected_voice is not None:
- engine.setProperty("voice", selected_voice)
-
- temp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
- temp_wav.close()
- try:
- engine.save_to_file(phrase_text, temp_wav.name)
- engine.runAndWait()
- engine.stop()
-
- ready = False
- for _ in range(50):
- if os.path.exists(temp_wav.name) and os.path.getsize(temp_wav.name) > 44:
- ready = True
- break
- time.sleep(0.1)
- if not ready:
- raise RuntimeError("pyttsx3 did not produce a WAV file in time.")
-
- tts_audio, tts_sr = sf.read(temp_wav.name, dtype="float32", always_2d=True)
- finally:
- if os.path.exists(temp_wav.name):
- os.unlink(temp_wav.name)
-
- if len(tts_audio) == 0:
- raise RuntimeError("pyttsx3 produced an empty WAV file.")
-
- tts_audio = _resample_audio(tts_audio, tts_sr, sample_rate)
- tts_audio = _match_channels(tts_audio, num_channels)
- tts_audio = _trim_silence(tts_audio, threshold=0.012)
- tts_audio = _enhance_speech(tts_audio)
-
- lead_silence = min(int(sample_rate * 0.02), num_samples // 8)
- pause_samples = int(sample_rate * 0.18)
- start = lead_silence
- phrase_len = tts_audio.shape[0]
- while start < num_samples:
- take = min(phrase_len, num_samples - start)
- audio_data[start : start + take] = tts_audio[:take]
- start += phrase_len + pause_samples
-
- max_amp = float(np.max(np.abs(audio_data)))
- if max_amp > 0:
- audio_data = audio_data / max_amp * 0.95
-
- sf.write(str(cache_path), audio_data, sample_rate, format="WAV", subtype="PCM_16")
- audio_bytes = cache_path.read_bytes()
-
- return {
- "np_array": audio_data.copy(),
- "base64": base64.b64encode(audio_bytes).decode("utf-8"),
- "file_path": str(cache_path.resolve()),
- }
-
-
-def _mux_mp4_bytes_with_synthetic_audio(
- video_mp4_bytes: bytes,
- *,
- num_frames: int,
- fps: float = 30.0,
- sample_rate: int = 48000,
-) -> bytes:
- duration_sec = num_frames / fps if fps > 0 else 0.0
- duration_int = max(1, int(math.ceil(duration_sec)))
-
- try:
- audio_result = generate_synthetic_audio(
- duration=duration_int,
- num_channels=1,
- sample_rate=sample_rate,
- )
- audio_pcm = audio_result["np_array"]
- except Exception as e:
- logger.warning("Synthetic video: generate_synthetic_audio failed (%s); using video-only MP4.", e)
- return video_mp4_bytes
-
- try:
- import imageio_ffmpeg
-
- ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe()
- except Exception:
- ffmpeg_exe = "ffmpeg"
-
- try:
- with tempfile.TemporaryDirectory(prefix="syn_vid_mux_") as tmp:
- vid_path = os.path.join(tmp, "video.mp4")
- wav_path = os.path.join(tmp, "audio.wav")
- out_path = os.path.join(tmp, "out.mp4")
- with open(vid_path, "wb") as f:
- f.write(video_mp4_bytes)
- sf.write(wav_path, audio_pcm, sample_rate, format="WAV", subtype="PCM_16")
- cmd = [
- ffmpeg_exe,
- "-y",
- "-nostdin",
- "-hide_banner",
- "-loglevel",
- "error",
- "-i",
- vid_path,
- "-i",
- wav_path,
- "-c:v",
- "copy",
- "-c:a",
- "aac",
- "-b:a",
- "128k",
- "-shortest",
- "-movflags",
- "+faststart",
- out_path,
- ]
- subprocess.run(cmd, check=True, stdin=subprocess.DEVNULL, timeout=300)
- with open(out_path, "rb") as f:
- return f.read()
- except (
- FileNotFoundError,
- subprocess.CalledProcessError,
- subprocess.TimeoutExpired,
- OSError,
- ) as e:
- logger.warning("Synthetic video: audio mux failed (%s); using video-only MP4.", e)
- return video_mp4_bytes
-
-
-def generate_synthetic_video(
- width: int,
- height: int,
- num_frames: int,
- *,
- embed_audio: bool = False,
- force_regenerate: bool = False,
- cache_dir: Path | str | None = None,
-) -> dict[str, Any]:
- """
- Generate synthetic MP4 (optional AAC audio). Caches final bytes by
- ``width`` / ``height`` / ``num_frames`` / ``embed_audio`` unless
- ``force_regenerate`` is true. Cache root: ``cache_dir`` if given, else the
- default temp subdirectory.
- """
- root = _resolve_synthetic_media_cache_dir(cache_dir)
- root.mkdir(parents=True, exist_ok=True)
- cache_path = root / f"synth_video_w{width}_h{height}_nf{num_frames}_ea{int(embed_audio)}.mp4"
-
- if not force_regenerate and cache_path.is_file():
- video_bytes = cache_path.read_bytes()
- return {
- "np_array": _np_array_from_mp4_bytes(video_bytes),
- "base64": base64.b64encode(video_bytes).decode("utf-8"),
- "file_path": str(cache_path.resolve()),
- }
-
- import cv2
- import imageio
-
- num_balls = random.randint(3, 8)
- balls = []
- for _ in range(num_balls):
- radius = min(width, height) // 8
- if radius < 1:
- raise ValueError(f"Video dimensions ({width}x{height}) too small")
- x = random.randint(radius, width - radius)
- y = random.randint(radius, height - radius)
- speed = random.uniform(3.0, 8.0)
- angle = random.uniform(0, 2 * math.pi)
- vx = speed * math.cos(angle)
- vy = speed * math.sin(angle)
- color_bgr = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255))
- balls.append({"x": x, "y": y, "vx": vx, "vy": vy, "radius": radius, "color_bgr": color_bgr})
-
- video_frames = []
- for _ in range(num_frames):
- frame_bgr = np.zeros((height, width, 3), dtype=np.uint8)
- for ball in balls:
- ball["x"] += ball["vx"]
- ball["y"] += ball["vy"]
- if ball["x"] - ball["radius"] <= 0 or ball["x"] + ball["radius"] >= width:
- ball["vx"] = -ball["vx"]
- ball["x"] = max(ball["radius"], min(width - ball["radius"], ball["x"]))
- if ball["y"] - ball["radius"] <= 0 or ball["y"] + ball["radius"] >= height:
- ball["vy"] = -ball["vy"]
- ball["y"] = max(ball["radius"], min(height - ball["radius"], ball["y"]))
- x, y = int(ball["x"]), int(ball["y"])
- radius = int(ball["radius"])
- cv2.circle(frame_bgr, (x, y), radius, ball["color_bgr"], -1)
- frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
- video_frames.append(frame_rgb)
-
- fps = 30
- buffer = io.BytesIO()
- writer_kwargs = {
- "format": "mp4",
- "fps": fps,
- "codec": "libx264",
- "quality": 7,
- "pixelformat": "yuv420p",
- "macro_block_size": 16,
- "ffmpeg_params": ["-preset", "medium", "-crf", "23", "-movflags", "+faststart", "-pix_fmt", "yuv420p"],
- }
- try:
- with imageio.get_writer(buffer, **writer_kwargs) as writer:
- for frame in video_frames:
- writer.append_data(frame)
- buffer.seek(0)
- video_only_bytes = buffer.read()
- except Exception as e:
- print(f"Warning: Failed to encode synthetic video: {e}")
- raise
- video_bytes = (
- _mux_mp4_bytes_with_synthetic_audio(video_only_bytes, num_frames=num_frames, fps=float(fps))
- if embed_audio
- else video_only_bytes
- )
-
- cache_path.write_bytes(video_bytes)
-
- return {
- "np_array": np.array(video_frames),
- "base64": base64.b64encode(video_bytes).decode("utf-8"),
- "file_path": str(cache_path.resolve()),
- }
-
-
-def generate_synthetic_image(
- width: int,
- height: int,
- *,
- force_regenerate: bool = False,
- cache_dir: Path | str | None = None,
- seed: int | None = None,
-) -> dict[str, Any]:
- """
- Random colored squares on white background. Caches JPEG by ``width`` /
- ``height`` unless ``force_regenerate`` is true. Cache root: ``cache_dir``
- if given, else the default temp subdirectory.
- """
- if seed is not None:
- random.seed(seed)
-
- root = _resolve_synthetic_media_cache_dir(cache_dir)
- root.mkdir(parents=True, exist_ok=True)
- cache_path = root / f"synth_image_w{width}_h{height}.jpg"
-
- if not force_regenerate and cache_path.is_file():
- from PIL import Image as PILImage
-
- image = PILImage.open(cache_path)
- image.load()
- image_bytes = cache_path.read_bytes()
- return {
- "np_array": np.array(image).copy(),
- "base64": base64.b64encode(image_bytes).decode("utf-8"),
- "file_path": str(cache_path.resolve()),
- }
-
- from PIL import ImageDraw
-
- image = Image.new("RGB", (width, height), (255, 255, 255))
- draw = ImageDraw.Draw(image)
- num_squares = random.randint(3, 8)
- for _ in range(num_squares):
- square_size = random.randint(max(1, min(width, height) // 8), max(2, min(width, height) // 4))
- x = random.randint(0, max(0, width - square_size - 1))
- y = random.randint(0, max(0, height - square_size - 1))
- color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
- border_width = random.randint(1, 5)
- draw.rectangle([x, y, x + square_size, y + square_size], fill=color, outline=(0, 0, 0), width=border_width)
-
- image.save(str(cache_path), format="JPEG", quality=85, optimize=True)
- image_bytes = cache_path.read_bytes()
-
- return {
- "np_array": np.array(image).copy(),
- "base64": base64.b64encode(image_bytes).decode("utf-8"),
- "file_path": str(cache_path.resolve()),
- }
-
-
-def decode_b64_image(b64: str):
- img = Image.open(io.BytesIO(base64.b64decode(b64)))
- img.load()
- return img
-
-
-def preprocess_text(text):
- import opencc
-
- word_to_num = {
- "zero": "0",
- "one": "1",
- "two": "2",
- "three": "3",
- "four": "4",
- "five": "5",
- "six": "6",
- "seven": "7",
- "eight": "8",
- "nine": "9",
- "ten": "10",
- }
- for word, num in word_to_num.items():
- pattern = r"\b" + re.escape(word) + r"\b"
- text = re.sub(pattern, num, text, flags=re.IGNORECASE)
-
- text = re.sub(r"[^\w\s]", "", text)
- text = re.sub(r"\s+", " ", text)
- cc = opencc.OpenCC("t2s")
- text = cc.convert(text)
- text = re.sub(r"(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])", "", text)
- return text.lower().strip()
-
-
-def cosine_similarity_text(text1, text2, n: int = 3):
- from collections import Counter
-
- if not text1 or not text2:
- return 0.0
-
- text1 = preprocess_text(text1)
- text2 = preprocess_text(text2)
- print(f"cosine similarity text1 is: {text1}, text2 is: {text2}")
-
- ngrams1 = [text1[i : i + n] for i in range(len(text1) - n + 1)]
- ngrams2 = [text2[i : i + n] for i in range(len(text2) - n + 1)]
- counter1 = Counter(ngrams1)
- counter2 = Counter(ngrams2)
-
- all_ngrams = set(counter1.keys()) | set(counter2.keys())
- vec1 = [counter1.get(ng, 0) for ng in all_ngrams]
- vec2 = [counter2.get(ng, 0) for ng in all_ngrams]
- dot_product = sum(a * b for a, b in zip(vec1, vec2))
- norm1 = sum(a * a for a in vec1) ** 0.5
- norm2 = sum(b * b for b in vec2) ** 0.5
- if norm1 == 0 or norm2 == 0:
- return 0.0
- cosine = dot_product / (norm1 * norm2)
- # Down-weight when lengths differ: repeated/hallucinated transcripts stay
- # high in bag-of-ngrams cosine (e.g. ABCABCABC vs ABC) but should score low.
- len1, len2 = len(text1), len(text2)
- length_harmony = (2.0 * min(len1, len2)) / (len1 + len2)
- return cosine * length_harmony
-
-
-def _merge_base64_audio_to_segment(base64_list: list[str]):
- from pydub import AudioSegment
-
- merged = None
- for b64 in base64_list:
- raw = base64.b64decode(b64.split(",", 1)[-1])
- seg = AudioSegment.from_file(io.BytesIO(raw))
- merged = seg if merged is None else merged + seg
- return merged
-
-
-@contextmanager
-def _serialize_whisper_small_model_download():
- """Serialize Whisper ``small`` cache writes across processes (Linux/Unix)."""
- import fcntl
-
- lock_path = Path.home() / ".cache" / "whisper" / ".small_model_download.lock"
- lock_path.parent.mkdir(parents=True, exist_ok=True)
- f = open(lock_path, "a+b")
- try:
- fcntl.flock(f.fileno(), fcntl.LOCK_EX)
- yield
- finally:
- fcntl.flock(f.fileno(), fcntl.LOCK_UN)
- f.close()
-
-
-def _whisper_transcribe_in_current_process(output_path: str) -> str:
- import whisper
-
- device_index = None
- from vllm_omni.platforms import current_omni_platform
-
- if current_omni_platform.is_available():
- n = current_omni_platform.get_device_count()
- if n == 1:
- device_index = 0
- elif n > 1:
- device_index = n - 1
-
- if device_index is not None:
- torch_device = current_omni_platform.get_torch_device(device_index)
- current_omni_platform.set_device(torch_device)
- device = str(torch_device)
- use_accelerator = True
- else:
- use_accelerator = False
- device = "cpu"
-
- with _serialize_whisper_small_model_download():
- model = whisper.load_model("small", device=device)
- try:
- text = model.transcribe(
- output_path,
- temperature=0.0,
- word_timestamps=True,
- condition_on_previous_text=False,
- )["text"]
- finally:
- del model
- gc.collect()
- if use_accelerator:
- current_omni_platform.synchronize()
- current_omni_platform.empty_cache()
- return text or ""
-
-
-def convert_audio_file_to_text(output_path: str) -> str:
- """Convert an audio file to text in an isolated subprocess."""
- ctx = multiprocessing.get_context("spawn")
- with concurrent.futures.ProcessPoolExecutor(max_workers=1, mp_context=ctx) as executor:
- future = executor.submit(_whisper_transcribe_in_current_process, output_path)
- return future.result()
-
-
-def convert_audio_bytes_to_text(raw_bytes: bytes) -> str:
- output_path = f"./test_{uuid.uuid4().hex}.wav"
- data, samplerate = sf.read(io.BytesIO(raw_bytes))
- sf.write(output_path, data, samplerate, format="WAV", subtype="PCM_16")
- print(f"audio data is saved: {output_path}")
- return convert_audio_file_to_text(output_path)
-
-
-__all__ = [
- "_merge_base64_audio_to_segment",
- "convert_audio_bytes_to_text",
- "convert_audio_file_to_text",
- "cosine_similarity_text",
- "decode_b64_image",
- "generate_synthetic_audio",
- "generate_synthetic_image",
- "generate_synthetic_video",
- "preprocess_text",
-]
diff --git a/tests/helpers/runtime.py b/tests/helpers/runtime.py
deleted file mode 100644
index 0cf0f9e480d..00000000000
--- a/tests/helpers/runtime.py
+++ /dev/null
@@ -1,1406 +0,0 @@
-"""Server/client/runner runtime primitives for tests."""
-
-import base64
-import concurrent.futures
-import io
-import json
-import os
-import socket
-import subprocess
-import sys
-import tempfile
-import time
-from dataclasses import dataclass
-from io import BytesIO
-from pathlib import Path
-from typing import Any, NamedTuple
-
-import psutil
-import requests
-import soundfile as sf
-import torch
-import yaml
-from openai import OpenAI, omit
-from PIL import Image
-from vllm import TextPrompt
-from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
-from vllm.logger import init_logger
-
-from tests.helpers.assertions import (
- assert_audio_speech_response,
- assert_diffusion_response,
- assert_omni_response,
-)
-from tests.helpers.env import run_forced_gpu_cleanup_round
-from tests.helpers.media import (
- _merge_base64_audio_to_segment,
- convert_audio_bytes_to_text,
- decode_b64_image,
-)
-from vllm_omni.config.stage_config import resolve_deploy_yaml
-from vllm_omni.platforms import current_omni_platform
-
-logger = init_logger(__name__)
-
-PromptAudioInput = list[tuple[Any, int]] | tuple[Any, int] | None
-PromptImageInput = list[Any] | Any | None
-PromptVideoInput = list[Any] | Any | None
-
-try:
- from vllm.distributed.parallel_state import cleanup_dist_env_and_memory # type: ignore
-except Exception: # pragma: no cover
-
- def cleanup_dist_env_and_memory() -> None:
- return None
-
-
-def get_open_port() -> int:
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(("127.0.0.1", 0))
- return int(s.getsockname()[1])
-
-
-def dummy_messages_from_mix_data(
- system_prompt: dict[str, Any] = None,
- video_data_url: Any = None,
- audio_data_url: Any = None,
- image_data_url: Any = None,
- content_text: str = None,
-):
- """Create messages with video、image、audio data URL for OpenAI API."""
- if content_text is not None:
- content = [{"type": "text", "text": content_text}]
- else:
- content = []
-
- media_items = []
- if isinstance(video_data_url, list):
- for video_url in video_data_url:
- media_items.append((video_url, "video"))
- else:
- media_items.append((video_data_url, "video"))
-
- if isinstance(image_data_url, list):
- for url in image_data_url:
- media_items.append((url, "image"))
- else:
- media_items.append((image_data_url, "image"))
-
- if isinstance(audio_data_url, list):
- for url in audio_data_url:
- media_items.append((url, "audio"))
- else:
- media_items.append((audio_data_url, "audio"))
-
- content.extend(
- {"type": f"{media_type}_url", f"{media_type}_url": {"url": url}}
- for url, media_type in media_items
- if url is not None
- )
- messages = [{"role": "user", "content": content}]
- if system_prompt is not None:
- messages = [system_prompt] + messages
- return messages
-
-
-def _omni_subprocess_cwd() -> str:
- """Repo root for ``python -m vllm_omni...`` (legacy conftest lived under ``tests/``; helpers under ``tests/helpers/``)."""
- return os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
-
-
-class OmniServerParams(NamedTuple):
- model: str
- port: int | None = None
- stage_config_path: str | None = None
- server_args: list[str] | None = None
- env_dict: dict[str, str] | None = None
- use_omni: bool = True
- use_stage_cli: bool = False
- init_timeout: int | None = None
- stage_init_timeout: int | None = None # None: fixture supplies default (600 s)
-
-
-class OmniServer:
- """Omniserver for vLLM-Omni tests."""
-
- def __init__(
- self,
- model: str,
- serve_args: list[str],
- *,
- port: int | None = None,
- env_dict: dict[str, str] | None = None,
- use_omni: bool = True,
- ) -> None:
- run_forced_gpu_cleanup_round()
- cleanup_dist_env_and_memory()
- self.model = model
- self.serve_args = serve_args
- self.env_dict = env_dict
- self.use_omni = use_omni
- self.proc: subprocess.Popen | None = None
- self.host = "127.0.0.1"
- self.port = get_open_port() if port is None else port
-
- def _start_server(self) -> None:
- env = os.environ.copy()
- env.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
- if self.env_dict is not None:
- env.update(self.env_dict)
-
- cmd = [
- sys.executable,
- "-m",
- "vllm_omni.entrypoints.cli.main",
- "serve",
- self.model,
- "--host",
- self.host,
- "--port",
- str(self.port),
- ]
- if self.use_omni:
- cmd.append("--omni")
- cmd += self.serve_args
-
- print(f"Launching OmniServer with: {' '.join(cmd)}")
- self.proc = subprocess.Popen(
- cmd,
- env=env,
- cwd=_omni_subprocess_cwd(),
- )
-
- max_wait = 1200
- start_time = time.time()
- while time.time() - start_time < max_wait:
- ret = self.proc.poll()
- if ret is not None:
- raise RuntimeError(f"Server processes exited with code {ret} before becoming ready.")
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
- sock.settimeout(1)
- if sock.connect_ex((self.host, self.port)) == 0:
- print(f"Server ready on {self.host}:{self.port}")
- return
- time.sleep(2)
- raise RuntimeError(f"Server failed to start within {max_wait} seconds")
-
- def _kill_process_tree(self, pid):
- try:
- parent = psutil.Process(pid)
- children = parent.children(recursive=True)
- all_pids = [pid] + [child.pid for child in children]
-
- for child in children:
- try:
- child.terminate()
- except psutil.NoSuchProcess:
- pass
-
- _, still_alive = psutil.wait_procs(children, timeout=10)
-
- for child in still_alive:
- try:
- child.kill()
- except psutil.NoSuchProcess:
- pass
-
- try:
- parent.terminate()
- parent.wait(timeout=10)
- except (psutil.NoSuchProcess, psutil.TimeoutExpired):
- try:
- parent.kill()
- except psutil.NoSuchProcess:
- pass
-
- time.sleep(1)
- alive_processes = []
- for check_pid in all_pids:
- if psutil.pid_exists(check_pid):
- alive_processes.append(check_pid)
-
- if alive_processes:
- print(f"Warning: Processes still alive: {alive_processes}")
- for alive_pid in alive_processes:
- try:
- subprocess.run(["kill", "-9", str(alive_pid)], timeout=2)
- except Exception as e:
- print(f"Cleanup failed: {e}")
-
- except psutil.NoSuchProcess:
- pass
-
- def __enter__(self):
- self._start_server()
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- if self.proc:
- self._kill_process_tree(self.proc.pid)
- run_forced_gpu_cleanup_round()
- cleanup_dist_env_and_memory()
-
-
-class OmniServerStageCli(OmniServer):
- """Omni server harness that exercises the stage CLI flow."""
-
- def __init__(
- self,
- model: str,
- stage_config_path: str,
- serve_args: list[str] | None = None,
- *,
- stage_ids: list[int] | None = None,
- port: int | None = None,
- env_dict: dict[str, str] | None = None,
- ) -> None:
- super().__init__(model, serve_args or [], port=port, env_dict=env_dict, use_omni=True)
- self.stage_config_path = stage_config_path
- self.master_port = get_open_port()
- self.visible_device_list = self._load_visible_device_list(env_dict)
- resolved_cfg = resolve_deploy_yaml(stage_config_path)
- # Dump the resolved deploy config so CI logs show each stage's
- # gpu_memory_utilization / max_model_len / max_num_seqs after
- # base_config inheritance and overlay merge — essential when
- # diagnosing OOMs that depend on the merged values.
- print(
- f"[OmniServerStageCli] Resolved deploy config from {stage_config_path}:\n"
- f"{yaml.safe_dump(resolved_cfg, sort_keys=False, default_flow_style=False)}",
- flush=True,
- )
- self.stage_runtime_devices = self._load_stage_runtime_devices(resolved_cfg)
- self.stage_ids = stage_ids or self._load_stage_ids(resolved_cfg)
- if 0 not in self.stage_ids:
- raise ValueError(f"Stage CLI test requires stage_id=0 in config: {stage_config_path}")
- self.stage_procs: dict[int, subprocess.Popen] = {}
- self.proc = None
-
- @staticmethod
- def _stage_entries(cfg: dict) -> list[dict]:
- """Return the list of stage entries from either legacy (``stage_args``)
- or new-schema (``stages``) deploy YAMLs."""
- return cfg.get("stage_args") or cfg.get("stages") or []
-
- @staticmethod
- def _load_stage_ids(resolved_config: dict) -> list[int]:
- stage_ids = [
- stage["stage_id"] for stage in OmniServerStageCli._stage_entries(resolved_config) if "stage_id" in stage
- ]
- if not stage_ids:
- raise ValueError("No stage IDs found in resolved config")
- return stage_ids
-
- @staticmethod
- def _load_stage_runtime_devices(resolved_config: dict) -> dict[int, str]:
- runtime_devices: dict[int, str] = {}
- for stage in OmniServerStageCli._stage_entries(resolved_config):
- stage_id = stage.get("stage_id")
- # New schema: stage.devices is flat at stage level.
- # Legacy schema: stage.runtime.devices is nested.
- devices = stage.get("devices") or stage.get("runtime", {}).get("devices")
- if stage_id is not None and devices:
- runtime_devices[int(stage_id)] = str(devices)
- return runtime_devices
-
- @classmethod
- def _parse_device_list(cls, devices: str | int) -> list[str]:
- if isinstance(devices, int):
- if devices < 0:
- raise ValueError("Device IDs must be non-negative integers")
- return [str(devices)]
- return [token.strip() for token in str(devices).split(",") if token.strip()]
-
- @classmethod
- def _load_visible_device_list(cls, env_dict: dict[str, str] | None) -> list[str] | None:
- env = os.environ.copy()
- if env_dict is not None:
- env.update(env_dict)
-
- env_var = getattr(current_omni_platform, "device_control_env_var", None)
- if env_var and env_var in env:
- return [token.strip() for token in env[env_var].split(",") if token.strip()]
- return None
-
- @classmethod
- def _map_stage_devices(cls, stage_id: int, visible_device_list: list[str] | None, devices: str) -> str:
- device_list = cls._parse_device_list(devices)
-
- if visible_device_list is None:
- return ",".join(device_list)
-
- if not all(device.isdigit() for device in device_list):
- raise ValueError("Logical devices must be non-negative integers")
-
- logical_ids = [int(device) for device in device_list]
- if logical_ids and max(logical_ids) >= len(visible_device_list):
- raise ValueError(
- f"Stage {stage_id} has logical IDs {device_list}, one or more of which exceed the number of visible devices"
- )
-
- return ",".join(visible_device_list[idx] for idx in logical_ids)
-
- def _set_stage_device_env(self, stage_id: int, env: dict[str, str], devices: str) -> None:
- mapped_devices = self._map_stage_devices(stage_id, self.visible_device_list, devices)
- env_var = getattr(current_omni_platform, "device_control_env_var", None)
- if env_var:
- env[env_var] = mapped_devices
-
- def _build_stage_cmd(self, stage_id: int, *, headless: bool) -> list[str]:
- cmd = [
- sys.executable,
- "-m",
- "vllm_omni.entrypoints.cli.main",
- "serve",
- self.model,
- "--omni",
- "--stage-configs-path",
- self.stage_config_path,
- "--stage-id",
- str(stage_id),
- "--omni-master-address",
- self.host,
- "--omni-master-port",
- str(self.master_port),
- ]
-
- if headless:
- cmd.append("--headless")
- else:
- cmd += ["--host", self.host, "--port", str(self.port)]
-
- cmd += self.serve_args
- return cmd
-
- def _launch_stage(self, stage_id: int, *, headless: bool) -> None:
- env = os.environ.copy()
- env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
- if self.env_dict is not None:
- env.update(self.env_dict)
-
- devices = self.stage_runtime_devices.get(stage_id)
- if devices:
- self._set_stage_device_env(stage_id, env, devices)
-
- cmd = self._build_stage_cmd(stage_id, headless=headless)
- print(f"Launching OmniServerStageCli stage {stage_id}: {' '.join(cmd)}")
- # Capture each subprocess's stdout+stderr to a per-stage log file so
- # debugging "Stage N exited before API server ready" doesn't rely on
- # guessing; the file is surfaced in the RuntimeError message.
- log_path = Path(tempfile.gettempdir()) / f"omni_stage_{stage_id}_{self.master_port}.log"
- self._stage_log_paths = getattr(self, "_stage_log_paths", {})
- self._stage_log_paths[stage_id] = log_path
- log_fh = open(log_path, "w", buffering=1) # noqa: SIM115 - closed in __exit__
- self._stage_log_files = getattr(self, "_stage_log_files", {})
- self._stage_log_files[stage_id] = log_fh
- proc = subprocess.Popen(
- cmd,
- env=env,
- cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
- stdout=log_fh,
- stderr=subprocess.STDOUT,
- )
- self.stage_procs[stage_id] = proc
- if stage_id == 0:
- self.proc = proc
-
- def _ensure_stage_processes_alive(self) -> None:
- for stage_id, proc in self.stage_procs.items():
- ret = proc.poll()
- if ret is not None:
- log_path = getattr(self, "_stage_log_paths", {}).get(stage_id)
- tail = ""
- if log_path and log_path.exists():
- try:
- with open(log_path, encoding="utf-8", errors="replace") as f:
- lines = f.readlines()
- tail = "\n=== Last 60 lines of stage {} log ({}) ===\n{}".format(
- stage_id, log_path, "".join(lines[-60:]) or ""
- )
- except Exception as exc: # pragma: no cover - diagnostic only
- tail = f"\n"
- raise RuntimeError(f"Stage {stage_id} exited with code {ret} before API server became ready.{tail}")
-
- def _start_server(self) -> None:
- ordered_stage_ids = [0, *[stage_id for stage_id in self.stage_ids if stage_id != 0]]
-
- self._launch_stage(0, headless=False)
- time.sleep(2)
- self._ensure_stage_processes_alive()
-
- for stage_id in ordered_stage_ids[1:]:
- self._launch_stage(stage_id, headless=True)
-
- max_wait = 1200
- start_time = time.time()
- while time.time() - start_time < max_wait:
- self._ensure_stage_processes_alive()
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
- sock.settimeout(1)
- result = sock.connect_ex((self.host, self.port))
- if result == 0:
- print(f"OmniServerStageCli ready on {self.host}:{self.port}")
- return
- time.sleep(2)
-
- raise RuntimeError(f"OmniServerStageCli failed to start within {max_wait} seconds")
-
- def _dump_stage_logs_for_debug(self, head_lines: int = 300, tail_lines: int = 500) -> None:
- """Tail each stage's subprocess log back to stdout on teardown.
-
- Stage subprocesses redirect stdout/stderr to ``/tmp/omni_stage_*.log``
- so we don't spam the main CI stream while tests run; but that also
- hides engine init (KV cache size, Available KV cache memory, vLLM
- engine config) when things go wrong. Dump them here so buildkite
- captures them post-run. Head covers engine init; tail covers
- whatever state the stage was in when it was torn down.
- """
- log_paths = getattr(self, "_stage_log_paths", {}) or {}
- for stage_id in sorted(log_paths):
- log_path = log_paths[stage_id]
- if not log_path or not log_path.exists():
- continue
- try:
- with open(log_path, encoding="utf-8", errors="replace") as f:
- lines = f.readlines()
- except Exception as exc: # pragma: no cover - diagnostic only
- print(f"[OmniServerStageCli] stage {stage_id} log read failed: {exc}", flush=True)
- continue
- total = len(lines)
- if total <= head_lines + tail_lines:
- head_chunk = lines
- tail_chunk = []
- elided = 0
- else:
- head_chunk = lines[:head_lines]
- tail_chunk = lines[-tail_lines:]
- elided = total - head_lines - tail_lines
- print(f"\n=== stage {stage_id} log HEAD ({log_path}) ===", flush=True)
- print("".join(head_chunk).rstrip("\n"), flush=True)
- if tail_chunk:
- print(f"\n... [{elided} lines elided] ...", flush=True)
- print(f"\n=== stage {stage_id} log TAIL ({log_path}) ===", flush=True)
- print("".join(tail_chunk).rstrip("\n"), flush=True)
- print(f"=== end stage {stage_id} log ===\n", flush=True)
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self._dump_stage_logs_for_debug()
- for stage_id in sorted(self.stage_procs, reverse=True):
- proc = self.stage_procs[stage_id]
- if proc.poll() is None:
- self._kill_process_tree(proc.pid)
- run_forced_gpu_cleanup_round()
- cleanup_dist_env_and_memory()
-
-
-@dataclass
-class OmniResponse:
- text_content: str | None = None
- audio_data: list[str] | None = None
- audio_content: str | None = None
- audio_format: str | None = None
- audio_bytes: bytes | None = None
- e2e_latency: float | None = None
- success: bool = False
- error_message: str | None = None
- cached_tokens: int | None = None
-
-
-@dataclass
-class DiffusionResponse:
- text_content: str | None = None
- images: list[Image.Image] | None = None
- audios: list[Any] | None = None
- videos: list[Any] | None = None
- e2e_latency: float | None = None
- success: bool = False
- error_message: str | None = None
-
-
-class OpenAIClientHandler:
- def __init__(self, host: str = "127.0.0.1", port: int = None, api_key: str = "EMPTY", run_level: str = None):
- if port is None:
- port = get_open_port()
- self.base_url = f"http://{host}:{port}"
- self.client = OpenAI(base_url=f"http://{host}:{port}/v1", api_key=api_key)
- self.run_level = run_level
-
- def _process_stream_omni_response(self, chat_completion) -> OmniResponse:
- result = OmniResponse()
- start_time = time.perf_counter()
- try:
- text_content = ""
- audio_data = []
- for chunk in chat_completion:
- for choice in chunk.choices:
- content = getattr(getattr(choice, "delta", None), "content", None)
- modality = getattr(chunk, "modality", None)
- if modality == "audio" and content:
- audio_data.append(content)
- elif modality == "text" and content:
- text_content += content
- result.e2e_latency = time.perf_counter() - start_time
- audio_content = None
- if audio_data:
- merged_seg = _merge_base64_audio_to_segment(audio_data)
- wav_buf = BytesIO()
- merged_seg.export(wav_buf, format="wav")
- result.audio_bytes = wav_buf.getvalue()
- audio_content = convert_audio_bytes_to_text(result.audio_bytes)
- result.text_content = text_content
- result.audio_data = audio_data
- result.audio_content = audio_content
- result.success = True
- except Exception as e:
- result.error_message = f"Stream processing error: {str(e)}"
- print(f"Error: {result.error_message}")
- return result
-
- def _process_non_stream_omni_response(self, chat_completion) -> OmniResponse:
- result = OmniResponse()
- start_time = time.perf_counter()
- try:
- audio_data = None
- text_content = None
- for choice in chat_completion.choices:
- if hasattr(choice.message, "audio") and choice.message.audio is not None:
- audio_data = choice.message.audio.data
- if hasattr(choice.message, "content") and choice.message.content is not None:
- text_content = choice.message.content
- # Extract cached_tokens for prefix caching tests
- usage = getattr(chat_completion, "usage", None)
- if usage and (details := getattr(usage, "prompt_tokens_details", None)):
- result.cached_tokens = details.cached_tokens
- result.e2e_latency = time.perf_counter() - start_time
- audio_content = None
- if audio_data:
- result.audio_bytes = base64.b64decode(audio_data)
- audio_content = convert_audio_bytes_to_text(result.audio_bytes)
- result.text_content = text_content
- result.audio_content = audio_content
- result.success = True
- except Exception as e:
- result.error_message = f"Non-stream processing error: {str(e)}"
- print(f"Error: {result.error_message}")
- return result
-
- def _process_diffusion_response(self, chat_completion) -> DiffusionResponse:
- result = DiffusionResponse()
- start_time = time.perf_counter()
- try:
- images = []
- for choice in chat_completion.choices:
- content = getattr(choice.message, "content", None)
- if isinstance(content, list):
- for item in content:
- image_url = None
- if isinstance(item, dict):
- image_url = item.get("image_url", {}).get("url")
- else:
- image_url_obj = getattr(item, "image_url", None)
- image_url = getattr(image_url_obj, "url", None) if image_url_obj else None
- if image_url and image_url.startswith("data:image"):
- b64_data = image_url.split(",", 1)[1]
- images.append(decode_b64_image(b64_data))
- result.e2e_latency = time.perf_counter() - start_time
- result.images = images if images else None
- result.success = True
- except Exception as e:
- result.error_message = f"Diffusion response processing error: {str(e)}"
- print(f"Error: {result.error_message}")
- return result
-
- def send_omni_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]:
- responses: list[OmniResponse] = []
- stream = request_config.get("stream", False)
- modalities = request_config.get("modalities", ["text", "audio"])
- extra_body: dict[str, Any] = {}
- if "speaker" in request_config:
- extra_body["speaker"] = request_config["speaker"]
- if request_config.get("use_audio_in_video"):
- mm = dict(extra_body.get("mm_processor_kwargs") or {})
- mm["use_audio_in_video"] = True
- extra_body["mm_processor_kwargs"] = mm
- if "sampling_params_list" in request_config:
- extra_body["sampling_params_list"] = request_config["sampling_params_list"]
-
- create_kwargs: dict[str, Any] = {
- "model": request_config.get("model"),
- "messages": request_config.get("messages"),
- "stream": stream,
- "modalities": modalities,
- }
- if extra_body:
- create_kwargs["extra_body"] = extra_body
-
- if request_num == 1:
- chat_completion = self.client.chat.completions.create(**create_kwargs)
- resp = (
- self._process_stream_omni_response(chat_completion)
- if stream
- else self._process_non_stream_omni_response(chat_completion)
- )
- assert_omni_response(resp, request_config, run_level=self.run_level)
- responses.append(resp)
- return responses
-
- def _one():
- chat_completion = self.client.chat.completions.create(**create_kwargs)
- return (
- self._process_stream_omni_response(chat_completion)
- if stream
- else self._process_non_stream_omni_response(chat_completion)
- )
-
- with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor:
- futures = [executor.submit(_one) for _ in range(request_num)]
- for future in concurrent.futures.as_completed(futures):
- resp = future.result()
- assert_omni_response(resp, request_config, run_level=self.run_level)
- responses.append(resp)
- return responses
-
- def _process_stream_audio_speech_response(self, response, *, response_format: str | None = None) -> OmniResponse:
- """
- Process streaming /v1/audio/speech responses into an OmniResponse.
-
- This mirrors _process_stream_omni_response but operates on low-level
- audio bytes and produces an OmniResponse with audio_content filled
- from Whisper transcription.
- """
- result = OmniResponse()
- start_time = time.perf_counter()
-
- try:
- # Aggregate all audio bytes from the streaming response.
- data = bytearray()
-
- # Preferred OpenAI helper.
- if hasattr(response, "iter_bytes") and callable(getattr(response, "iter_bytes")):
- for chunk in response.iter_bytes():
- if chunk:
- data.extend(chunk)
- else:
- # Generic iterable-of-bytes fallback (e.g., generator or list of chunks).
- try:
- iterator = iter(response)
- except TypeError:
- iterator = None
-
- if iterator is not None:
- for chunk in iterator:
- if not chunk:
- continue
- if isinstance(chunk, (bytes, bytearray)):
- data.extend(chunk)
- elif hasattr(chunk, "data"):
- data.extend(chunk.data) # type: ignore[arg-type]
- elif hasattr(chunk, "content"):
- data.extend(chunk.content) # type: ignore[arg-type]
- else:
- raise TypeError(f"Unsupported stream chunk type: {type(chunk)}")
- else:
- raise TypeError(f"Unsupported audio speech streaming response type: {type(response)}")
-
- raw_bytes = bytes(data)
- if response_format == "pcm":
- transcript = None
- else:
- transcript = convert_audio_bytes_to_text(raw_bytes)
-
- # Populate OmniResponse.
- result.audio_bytes = raw_bytes
- result.audio_content = transcript
- result.e2e_latency = time.perf_counter() - start_time
- result.success = True
- result.audio_format = getattr(response, "response", None)
- if result.audio_format is not None:
- result.audio_format = result.audio_format.headers.get("content-type", "")
-
- except Exception as e:
- result.error_message = f"Audio speech stream processing error: {str(e)}"
- print(f"Error: {result.error_message}")
-
- return result
-
- def _process_non_stream_audio_speech_response(
- self, response, *, response_format: str | None = None
- ) -> OmniResponse:
- """
- Process non-streaming /v1/audio/speech responses into an OmniResponse.
-
- This mirrors _process_non_stream_omni_response but for the binary
- audio payload returned by audio.speech.create.
- """
- result = OmniResponse()
- start_time = time.perf_counter()
-
- try:
- # OpenAI non-streaming audio.speech.create returns HttpxBinaryResponseContent (.read() or .content)
- if hasattr(response, "read") and callable(getattr(response, "read")):
- raw_bytes = response.read()
- elif hasattr(response, "content"):
- raw_bytes = response.content # type: ignore[assignment]
- else:
- raise TypeError(f"Unsupported audio speech response type: {type(response)}")
-
- if response_format == "pcm":
- transcript = None
- else:
- transcript = convert_audio_bytes_to_text(raw_bytes)
-
- result.audio_bytes = raw_bytes
- result.audio_content = transcript
- result.e2e_latency = time.perf_counter() - start_time
- result.success = True
- result.audio_format = getattr(response, "response", None)
- if result.audio_format is not None:
- result.audio_format = result.audio_format.headers.get("content-type", "")
-
- except Exception as e:
- result.error_message = f"Audio speech non-stream processing error: {str(e)}"
- print(f"Error: {result.error_message}")
-
- return result
-
- def send_audio_speech_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]:
- """
- Call the /v1/audio/speech endpoint using the same configuration-dict
- style as send_omni_request, but via the OpenAI Python client's
- audio.speech APIs.
-
- Expected keys in request_config:
- - model: model name/path (required)
- - input: text to synthesize (required)
- - response_format: audio format such as "wav" or "pcm" (optional)
- - task_type, ref_text, ref_audio: TTS-specific extras (optional, passed via extra_body)
- - timeout: request timeout in seconds (float, optional, default 120.0)
- - stream: whether to use streaming API (bool, optional, default False)
- """
- timeout = float(request_config.get("timeout", 120.0))
-
- model = request_config["model"]
- text_input = request_config["input"]
- stream = bool(request_config.get("stream", False))
- voice = request_config.get("voice", None)
-
- # Standard OpenAI param: use omit when not provided to keep default behavior.
- response_format = request_config.get("response_format", omit)
-
- # Qwen3-TTS custom fields, forwarded via extra_body.
- extra_body: dict[str, Any] = {}
- # Keep this list aligned with vllm_omni.entrypoints.openai.protocol.audio params.
- for key in ("task_type", "ref_text", "ref_audio", "language", "max_new_tokens"):
- if key in request_config:
- extra_body[key] = request_config[key]
-
- responses: list[OmniResponse] = []
-
- speech_fmt: str | None = None if response_format is omit else str(response_format).lower()
-
- if request_num == 1:
- if stream:
- # Use streaming response helper.
- with self.client.audio.speech.with_streaming_response.create(
- model=model,
- input=text_input,
- response_format=response_format,
- extra_body=extra_body or None,
- timeout=timeout,
- voice=voice,
- ) as resp:
- omni_resp = self._process_stream_audio_speech_response(resp, response_format=speech_fmt)
- else:
- # Non-streaming response.
- resp = self.client.audio.speech.create(
- model=model,
- input=text_input,
- response_format=response_format,
- extra_body=extra_body or None,
- timeout=timeout,
- voice=voice,
- )
- omni_resp = self._process_non_stream_audio_speech_response(resp, response_format=speech_fmt)
-
- assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level)
- responses.append(omni_resp)
- return responses
- else:
- # request_num > 1: concurrent requests (use same params as single-request path)
-
- if stream:
-
- def _stream_task():
- with self.client.audio.speech.with_streaming_response.create(
- model=model,
- input=text_input,
- response_format=response_format,
- extra_body=extra_body or None,
- timeout=timeout,
- voice=voice,
- ) as resp:
- return self._process_stream_audio_speech_response(resp, response_format=speech_fmt)
-
- with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor:
- futures = [executor.submit(_stream_task) for _ in range(request_num)]
- for future in concurrent.futures.as_completed(futures):
- omni_resp = future.result()
- assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level)
- responses.append(omni_resp)
- else:
- with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor:
- futures = []
- for _ in range(request_num):
- future = executor.submit(
- self.client.audio.speech.create,
- model=model,
- input=text_input,
- response_format=response_format,
- extra_body=extra_body or None,
- timeout=timeout,
- voice=voice,
- )
- futures.append(future)
-
- for future in concurrent.futures.as_completed(futures):
- resp = future.result()
- omni_resp = self._process_non_stream_audio_speech_response(resp, response_format=speech_fmt)
- assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level)
- responses.append(omni_resp)
-
- return responses
-
- def send_diffusion_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[DiffusionResponse]:
- """
- Send OpenAI requests for diffusion models.
- Args:
- request_config: Request configuration dictionary containing parameters like model, messages
- request_num: Number of requests to send concurrently, defaults to 1 (single request)
- Returns:
- list[DiffusionResponse]: List of DiffusionResponse objects containing the response data
- """
- responses: list[DiffusionResponse] = []
- stream = request_config.get("stream", False)
- modalities = request_config.get("modalities", omit) # Most diffusion models don't require modalities param
- extra_body = request_config.get("extra_body", None)
- if stream:
- raise NotImplementedError("Streaming is not currently implemented for diffusion model e2e test")
- if request_num == 1:
- # Send single request
- chat_completion = self.client.chat.completions.create(
- model=request_config.get("model"),
- messages=request_config.get("messages"),
- extra_body=extra_body,
- modalities=modalities,
- )
- response = self._process_diffusion_response(chat_completion)
- assert_diffusion_response(response, request_config, run_level=self.run_level)
- responses.append(response)
- else:
- # Send concurrent requests
- with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor:
- futures = []
- # Submit all request tasks
- for _ in range(request_num):
- future = executor.submit(
- self.client.chat.completions.create,
- model=request_config.get("model"),
- messages=request_config.get("messages"),
- modalities=modalities,
- extra_body=extra_body,
- )
- futures.append(future)
- # Process completed tasks
- for future in concurrent.futures.as_completed(futures):
- chat_completion = future.result()
- response = self._process_diffusion_response(chat_completion)
- assert_diffusion_response(response, request_config, run_level=self.run_level)
- responses.append(response)
- return responses
-
- def send_video_diffusion_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]:
- """
- Send native /v1/videos requests.
- """
- if request_num != 1:
- raise NotImplementedError("Concurrent video diffusion requests are not currently implemented")
- form_data = request_config.get("form_data")
- if not isinstance(form_data, dict):
- raise ValueError("Video request_config must contain 'form_data'")
- normalized_form_data = {key: str(value) for key, value in form_data.items() if value is not None}
- files: dict[str, tuple[str, BytesIO, str]] = {}
- image_reference = request_config.get("image_reference")
- if image_reference:
- if image_reference.startswith("data:image"):
- header, encoded = image_reference.split(",", 1)
- content_type = header.split(";")[0].removeprefix("data:")
- extension = content_type.split("/")[-1]
- file_data = base64.b64decode(encoded)
- files["input_reference"] = (f"reference.{extension}", BytesIO(file_data), content_type)
- else:
- normalized_form_data["image_reference"] = json.dumps({"image_url": image_reference})
-
- result = DiffusionResponse()
- start_time = time.perf_counter()
- create_url = self._build_url("/v1/videos")
- response = requests.post(
- create_url,
- data=normalized_form_data,
- files=files,
- headers={"Accept": "application/json"},
- timeout=60,
- )
- response.raise_for_status()
- job_data = response.json()
- video_id = job_data["id"]
- self._wait_until_video_completed(video_id)
- video_content = self._download_video_content(video_id)
- result.success = True
- result.videos = [video_content]
- result.e2e_latency = time.perf_counter() - start_time
- assert_diffusion_response(result, request_config, run_level=self.run_level)
- return [result]
-
- def _wait_until_video_completed(
- self, video_id: str, poll_interval_seconds: int = 2, timeout_seconds: int = 300
- ) -> None:
- status_url = self._build_url(f"/v1/videos/{video_id}")
- deadline = time.monotonic() + timeout_seconds
- while time.monotonic() < deadline:
- status_resp = requests.get(status_url, headers={"Accept": "application/json"}, timeout=30)
- status_resp.raise_for_status()
- status_data = status_resp.json()
- current_status = status_data["status"]
- if current_status == "completed":
- return
- if current_status == "failed":
- error_msg = status_data.get("last_error", "Unknown error")
- raise RuntimeError(f"Job failed: {error_msg}")
- time.sleep(poll_interval_seconds)
- raise TimeoutError(f"Video job {video_id} did not complete within {timeout_seconds}s")
-
- def _download_video_content(self, video_id: str) -> bytes:
- download_url = self._build_url(f"/v1/videos/{video_id}/content")
- video_resp = requests.get(download_url, stream=True, timeout=60)
- video_resp.raise_for_status()
- video_bytes = BytesIO()
- for chunk in video_resp.iter_content(chunk_size=8192):
- if chunk:
- video_bytes.write(chunk)
- return video_bytes.getvalue()
-
- def _build_url(self, path: str) -> str:
- return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
-
-
-class OmniRunner:
- def __init__(
- self,
- model_name: str,
- seed: int = 42,
- stage_init_timeout: int = 600,
- batch_timeout: int = 10,
- init_timeout: int = 900,
- shm_threshold_bytes: int = 65536,
- log_stats: bool = False,
- stage_configs_path: str | None = None,
- **kwargs,
- ) -> None:
- cleanup_dist_env_and_memory()
- run_forced_gpu_cleanup_round()
- self.model_name = model_name
- self.seed = seed
- self._prompt_len_estimate_cache: dict[str, Any] = {}
- from vllm_omni.entrypoints.omni import Omni
-
- self.omni = Omni(
- model=model_name,
- log_stats=log_stats,
- stage_init_timeout=stage_init_timeout,
- batch_timeout=batch_timeout,
- init_timeout=init_timeout,
- shm_threshold_bytes=shm_threshold_bytes,
- stage_configs_path=stage_configs_path,
- **kwargs,
- )
-
- def get_default_sampling_params_list(self) -> list[Any]:
- if not hasattr(self.omni, "default_sampling_params_list"):
- raise AttributeError("Omni.default_sampling_params_list is not available")
- return list(self.omni.default_sampling_params_list)
-
- def _estimate_prompt_len(
- self,
- additional_information: dict[str, Any],
- model_name: str,
- ) -> int:
- """Estimate prompt_token_ids placeholder length for the Talker stage.
-
- The AR Talker replaces all input embeddings via ``preprocess``, so the
- placeholder values are irrelevant but the **length** must match the
- embeddings that ``preprocess`` will produce.
- """
- _cache = self._prompt_len_estimate_cache
- try:
- from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import Qwen3TTSConfig
- from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import (
- Qwen3TTSTalkerForConditionalGeneration,
- )
-
- if model_name not in _cache:
- from transformers import AutoTokenizer
-
- tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left")
- cfg = Qwen3TTSConfig.from_pretrained(model_name, trust_remote_code=True)
- _cache[model_name] = (tok, getattr(cfg, "talker_config", None))
-
- tok, tcfg = _cache[model_name]
- task_type = (additional_information.get("task_type") or ["CustomVoice"])[0]
- return Qwen3TTSTalkerForConditionalGeneration.estimate_prompt_len_from_additional_information(
- additional_information=additional_information,
- task_type=task_type,
- tokenize_prompt=lambda t: tok(t, padding=False)["input_ids"],
- codec_language_id=getattr(tcfg, "codec_language_id", None),
- spk_is_dialect=getattr(tcfg, "spk_is_dialect", None),
- )
- except Exception as exc:
- logger.warning("Failed to estimate prompt length, using fallback 2048: %s", exc)
- return 2048
-
- def get_omni_inputs(
- self,
- prompts: list[str] | str,
- system_prompt: str | None = None,
- audios: PromptAudioInput = None,
- images: PromptImageInput = None,
- videos: PromptVideoInput = None,
- mm_processor_kwargs: dict[str, Any] | None = None,
- modalities: list[str] | None = None,
- ) -> list[TextPrompt]:
- if system_prompt is None:
- system_prompt = (
- "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
- "Group, capable of perceiving auditory and visual inputs, as well as "
- "generating text and speech."
- )
- video_padding_token = "<|VIDEO|>"
- image_padding_token = "<|IMAGE|>"
- audio_padding_token = "<|AUDIO|>"
- if "Qwen3-Omni-30B-A3B-Instruct" in self.model_name:
- video_padding_token = "<|video_pad|>"
- image_padding_token = "<|image_pad|>"
- audio_padding_token = "<|audio_pad|>"
- elif "Ming-flash-omni" in self.model_name:
- video_padding_token = ""
- image_padding_token = ""
- audio_padding_token = ""
- if isinstance(prompts, str):
- prompts = [prompts]
-
- # Qwen-TTS: follow examples/offline_inference/qwen3_tts/end2end.py style.
- # Stage 0 expects token placeholders + additional_information (text/speaker/task_type/...),
- # and Talker replaces embeddings in preprocess based on additional_information only.
- is_tts_model = "Qwen3-TTS" in self.model_name or "qwen3_tts" in self.model_name.lower()
- if is_tts_model and modalities == ["audio"]:
- tts_kw = mm_processor_kwargs or {}
- task_type = tts_kw.get("task_type", "CustomVoice")
- speaker = tts_kw.get("speaker", "Vivian")
- language = tts_kw.get("language", "Auto")
- max_new_tokens = int(tts_kw.get("max_new_tokens", 2048))
- ref_audio = tts_kw.get("ref_audio", None)
- ref_text = tts_kw.get("ref_text", None)
-
- omni_inputs: list[TextPrompt] = []
- for prompt_text in prompts:
- text_str = str(prompt_text).strip() or " "
- additional_information: dict[str, Any] = {
- "task_type": [task_type],
- "text": [text_str],
- "language": [language],
- "speaker": [speaker],
- "max_new_tokens": [max_new_tokens],
- }
- if ref_audio is not None:
- additional_information["ref_audio"] = [ref_audio]
- if ref_text is not None:
- additional_information["ref_text"] = [ref_text]
- plen = self._estimate_prompt_len(additional_information, self.model_name)
- input_dict: TextPrompt = {
- "prompt_token_ids": [0] * plen,
- "additional_information": additional_information,
- }
- omni_inputs.append(input_dict)
- return omni_inputs
-
- def _normalize(mm_input, num_prompts):
- if mm_input is None:
- return [None] * num_prompts
- if isinstance(mm_input, list):
- if len(mm_input) != num_prompts:
- raise ValueError("Multimodal input list length must match prompts length")
- return mm_input
- return [mm_input] * num_prompts
-
- num_prompts = len(prompts)
- audios_list = _normalize(audios, num_prompts)
- images_list = _normalize(images, num_prompts)
- videos_list = _normalize(videos, num_prompts)
-
- omni_inputs = []
- for i, prompt_text in enumerate(prompts):
- user_content = ""
- multi_modal_data = {}
- audio = audios_list[i]
- if audio is not None:
- if isinstance(audio, list):
- for _ in audio:
- user_content += f"<|audio_bos|>{audio_padding_token}<|audio_eos|>"
- multi_modal_data["audio"] = audio
- else:
- user_content += f"<|audio_bos|>{audio_padding_token}<|audio_eos|>"
- multi_modal_data["audio"] = audio
- image = images_list[i]
- if image is not None:
- if isinstance(image, list):
- for _ in image:
- user_content += f"<|vision_bos|>{image_padding_token}<|vision_eos|>"
- multi_modal_data["image"] = image
- else:
- user_content += f"<|vision_bos|>{image_padding_token}<|vision_eos|>"
- multi_modal_data["image"] = image
- video = videos_list[i]
- if video is not None:
- if isinstance(video, list):
- for _ in video:
- user_content += f"<|vision_bos|>{video_padding_token}<|vision_eos|>"
- multi_modal_data["video"] = video
- else:
- user_content += f"<|vision_bos|>{video_padding_token}<|vision_eos|>"
- multi_modal_data["video"] = video
- user_content += prompt_text
-
- full_prompt = (
- f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
- f"<|im_start|>user\n{user_content}<|im_end|>\n"
- f"<|im_start|>assistant\n"
- )
- input_dict: dict[str, Any] = {"prompt": full_prompt}
- if multi_modal_data:
- input_dict["multi_modal_data"] = multi_modal_data
- if modalities:
- input_dict["modalities"] = modalities
- if mm_processor_kwargs:
- input_dict["mm_processor_kwargs"] = mm_processor_kwargs
- omni_inputs.append(input_dict)
- return omni_inputs
-
- def generate(
- self,
- prompts: list[Any],
- sampling_params_list: list[Any] | None = None,
- ) -> list[Any]:
- if sampling_params_list is None:
- sampling_params_list = self.get_default_sampling_params_list()
- return self.omni.generate(prompts, sampling_params_list)
-
- def generate_multimodal(
- self,
- prompts: list[str] | str,
- sampling_params_list: list[Any] | None = None,
- system_prompt: str | None = None,
- audios: PromptAudioInput = None,
- images: PromptImageInput = None,
- videos: PromptVideoInput = None,
- mm_processor_kwargs: dict[str, Any] | None = None,
- modalities: list[str] | None = None,
- ) -> list[Any]:
- omni_inputs = self.get_omni_inputs(
- prompts=prompts,
- system_prompt=system_prompt,
- audios=audios,
- images=images,
- videos=videos,
- mm_processor_kwargs=mm_processor_kwargs,
- modalities=modalities,
- )
- return self.generate(omni_inputs, sampling_params_list)
-
- def start_profile(self, profile_prefix: str | None = None, stages: list[int] | None = None) -> list[Any]:
- return self.omni.start_profile(profile_prefix=profile_prefix, stages=stages)
-
- def stop_profile(self, stages: list[int] | None = None) -> list[Any]:
- return self.omni.stop_profile(stages=stages)
-
- def _cleanup_process(self):
- try:
- keywords = ["enginecore"]
- matched = []
- for proc in psutil.process_iter(["pid", "name", "cmdline", "username"]):
- try:
- cmdline = " ".join(proc.cmdline()).lower() if proc.cmdline() else ""
- name = proc.name().lower()
- if any(k in cmdline for k in keywords) or any(k in name for k in keywords):
- print(f"Found vllm process: PID={proc.pid}, cmd={cmdline[:100]}")
- matched.append(proc)
- except (psutil.NoSuchProcess, psutil.AccessDenied):
- pass
- for proc in matched:
- try:
- proc.terminate()
- except (psutil.NoSuchProcess, psutil.AccessDenied):
- pass
- _, still_alive = psutil.wait_procs(matched, timeout=5)
- for proc in still_alive:
- try:
- proc.kill()
- except (psutil.NoSuchProcess, psutil.AccessDenied):
- pass
- if still_alive:
- _, stubborn = psutil.wait_procs(still_alive, timeout=3)
- if stubborn:
- print(f"Warning: failed to kill residual vllm pids: {[p.pid for p in stubborn]}")
- else:
- print(f"Force-killed residual vllm pids: {[p.pid for p in still_alive]}")
- elif matched:
- print(f"Terminated vllm pids: {[p.pid for p in matched]}")
- except Exception as e:
- print(f"Error in psutil vllm cleanup: {e}")
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- if hasattr(self.omni, "close"):
- self.omni.close()
- self._cleanup_process()
- run_forced_gpu_cleanup_round()
- cleanup_dist_env_and_memory()
-
-
-class OmniRunnerHandler:
- def __init__(self, omni_runner):
- self.runner = omni_runner
-
- def _process_output(self, outputs: list[Any]) -> OmniResponse:
- result = OmniResponse()
- try:
- text_content = None
- audio_content = None
- for stage_output in outputs:
- if getattr(stage_output, "final_output_type", None) == "text":
- text_content = stage_output.request_output.outputs[0].text
- if getattr(stage_output, "final_output_type", None) == "audio":
- audio_content = stage_output.request_output.outputs[0].multimodal_output["audio"]
- result.audio_content = audio_content
- result.text_content = text_content
- result.success = True
- except Exception as e:
- result.error_message = f"Output processing error: {str(e)}"
- result.success = False
- print(f"Error: {result.error_message}")
- return result
-
- def send_request(self, request_config: dict[str, Any] | None = None) -> OmniResponse:
- if request_config is None:
- request_config = {}
- prompts = request_config.get("prompts")
- videos = request_config.get("videos")
- images = request_config.get("images")
- audios = request_config.get("audios")
- modalities = request_config.get("modalities", ["text", "audio"])
- outputs = self.runner.generate_multimodal(
- prompts=prompts, videos=videos, images=images, audios=audios, modalities=modalities
- )
- response = self._process_output(outputs)
- assert_omni_response(response, request_config, run_level="core_model")
- return response
-
- def send_audio_speech_request(self, request_config: dict[str, Any]) -> OmniResponse:
- """
- Offline TTS: text -> audio via generate_multimodal, then validate with assert_audio_speech_response.
-
- request_config must contain:
- - 'input' or 'prompts': text to synthesize.
- Optional keys:
- - 'voice' -> speaker (CustomVoice)
- - 'task_type' -> task_type in additional_information (default: "CustomVoice")
- - 'language' -> language in additional_information (default: "Auto")
- - 'max_new_tokens' -> max_new_tokens in additional_information (default: 2048)
- - 'response_format' -> desired audio format (used only for assertion)
- """
- input_text = request_config.get("input") or request_config.get("prompts")
- if input_text is None:
- raise ValueError("request_config must contain 'input' or 'prompts' for TTS")
- if isinstance(input_text, list):
- input_text = input_text[0] if input_text else ""
-
- mm_processor_kwargs: dict[str, Any] = {}
- if "voice" in request_config:
- mm_processor_kwargs["speaker"] = request_config["voice"]
- if "task_type" in request_config:
- mm_processor_kwargs["task_type"] = request_config["task_type"]
- if "ref_audio" in request_config:
- mm_processor_kwargs["ref_audio"] = request_config["ref_audio"]
- if "ref_text" in request_config:
- mm_processor_kwargs["ref_text"] = request_config["ref_text"]
- if "language" in request_config:
- mm_processor_kwargs["language"] = request_config["language"]
- if "max_new_tokens" in request_config:
- mm_processor_kwargs["max_new_tokens"] = request_config["max_new_tokens"]
-
- outputs = self.runner.generate_multimodal(
- prompts=input_text,
- modalities=["audio"],
- mm_processor_kwargs=mm_processor_kwargs or None,
- )
- mm_out: dict[str, Any] | None = None
- for stage_out in outputs:
- if getattr(stage_out, "final_output_type", None) == "audio":
- mm_out = stage_out.request_output.outputs[0].multimodal_output
- break
- if mm_out is None:
- result = OmniResponse(success=False, error_message="No audio output from pipeline")
- assert result.success, result.error_message
- return result
-
- audio_data = mm_out.get("audio")
- if audio_data is None:
- result = OmniResponse(success=False, error_message="No audio tensor in multimodal output")
- assert result.success, result.error_message
- return result
-
- sr_raw = mm_out.get("sr")
- sr_val = sr_raw[-1] if isinstance(sr_raw, list) and sr_raw else sr_raw
- sr = int(sr_val.item() if hasattr(sr_val, "item") else sr_val)
- wav_tensor = torch.cat(audio_data, dim=-1) if isinstance(audio_data, list) else audio_data
- wav_buf = io.BytesIO()
- sf.write(
- wav_buf,
- wav_tensor.float().cpu().numpy().reshape(-1),
- samplerate=sr,
- format="WAV",
- subtype="PCM_16",
- )
- result = OmniResponse(success=True, audio_bytes=wav_buf.getvalue(), audio_format="audio/wav")
- assert_audio_speech_response(result, request_config, run_level="core_model")
- return result
-
- def start_profile(self, profile_prefix: str | None = None, stages: list[int] | None = None) -> list[Any]:
- return self.runner.start_profile(profile_prefix=profile_prefix, stages=stages)
-
- def stop_profile(self, stages: list[int] | None = None) -> list[Any]:
- return self.runner.stop_profile(stages=stages)
-
-
-__all__ = [
- "DiffusionResponse",
- "OmniResponse",
- "OmniRunner",
- "OmniRunnerHandler",
- "OmniServer",
- "OmniServerParams",
- "OmniServerStageCli",
- "OpenAIClientHandler",
- "get_open_port",
- "run_forced_gpu_cleanup_round",
- "dummy_messages_from_mix_data",
-]
diff --git a/tests/helpers/stage_config.py b/tests/helpers/stage_config.py
deleted file mode 100644
index 2bb017b811f..00000000000
--- a/tests/helpers/stage_config.py
+++ /dev/null
@@ -1,548 +0,0 @@
-"""Config/message construction helpers used by tests."""
-
-import atexit
-import os
-import tempfile
-from pathlib import Path
-from typing import Any
-
-import yaml
-
-
-def modify_stage_config(
- yaml_path: str,
- updates: dict[str, Any] = None,
- deletes: dict[str, Any] = None,
-) -> str:
- """
- Modify configurations in a YAML file, supporting both top-level and stage-specific modifications,
- including addition, modification, and deletion of configurations.
-
- Args:
- yaml_path: Path to the YAML configuration file.
- updates: Dictionary containing both top-level and stage-specific modifications to add or update.
- Format: {
- 'async_chunk': True,
- 'stage_args': {
- 0: {'engine_args.max_model_len': 5800},
- 1: {'engine_args.max_num_seqs': 2}
- }
- }
- deletes: Dictionary containing configurations to delete.
- Format: {
- 'old_config': None, # Delete entire key
- 'stage_args': {
- 0: ['engine_args.old_param'],
- 1: ['runtime.unused_setting']
- }
- }
-
- Returns:
- str: Path to the newly created modified YAML file with timestamp suffix.
- """
- path = Path(yaml_path)
- if not path.exists():
- raise FileNotFoundError(f"yaml does not exist: {path}")
-
- try:
- with open(yaml_path, encoding="utf-8") as f:
- config = yaml.safe_load(f) or {}
- except Exception as e:
- raise ValueError(f"Cannot parse YAML file: {e}")
-
- # Helper function to apply update
- def apply_update(config_dict: dict, key_path: str, value: Any) -> None:
- """Apply update to dictionary using dot-separated path."""
- # Handle direct list assignment (e.g., engine_input_source: [1, 2])
- if "." not in key_path:
- # Simple key, set directly
- config_dict[key_path] = value
- return
-
- current = config_dict
- keys = key_path.split(".")
-
- for i in range(len(keys) - 1):
- key = keys[i]
-
- # Handle list indices
- if key.isdigit() and isinstance(current, list):
- index = int(key)
- if index < 0:
- raise ValueError(f"Negative list index not allowed: {index}")
- if index >= len(current):
- # Expand list if needed
- while len(current) <= index:
- # If we need to go deeper (more keys after this), create a dict
- # Otherwise, create None placeholder
- current.append({} if i < len(keys) - 2 else None)
- current = current[index]
- elif isinstance(current, dict):
- # Handle dictionary keys
- if key not in current:
- # If there are more keys after this, create appropriate structure
- if i < len(keys) - 1:
- # Check if next key is a digit (list index) or string (dict key)
- if keys[i + 1].isdigit():
- current[key] = []
- else:
- current[key] = {}
- else:
- # This is the last key, create based on value type
- current[key] = [] if isinstance(value, list) else {}
- elif not isinstance(current[key], (dict, list)) and i < len(keys) - 1:
- # If current value is not dict/list but we need to go deeper, replace it
- if keys[i + 1].isdigit():
- current[key] = []
- else:
- current[key] = {}
- current = current[key]
- else:
- # Current is not a dict or list, cannot traverse further
- raise TypeError(
- f"Cannot access {'.'.join(keys[: i + 1])} as a dict/list. It's a {type(current).__name__}"
- )
-
- # Set the final value
- last_key = keys[-1]
- if isinstance(current, list) and last_key.isdigit():
- # Setting a value in a list by index
- index = int(last_key)
- if index < 0:
- raise ValueError(f"Negative list index not allowed: {index}")
- if index >= len(current):
- # Expand list if needed
- while len(current) <= index:
- current.append(None)
- current[index] = value
- elif isinstance(current, dict):
- # Special case: if the value is a list and we're setting a top-level key
- # Example: updating engine_input_source with [1, 2]
- current[last_key] = value
- else:
- # Current is not a dict, cannot set key
- raise TypeError(f"Cannot set value at {key_path}. Current type is {type(current).__name__}, expected dict.")
-
- # Helper function to delete by path
- def delete_by_path(config_dict: dict, path: str) -> None:
- """Delete configuration by dot-separated path."""
- if not path:
- return
-
- current = config_dict
- keys = path.split(".")
-
- # Traverse to the parent
- for i in range(len(keys) - 1):
- key = keys[i]
-
- # Handle list indices
- if key.isdigit() and isinstance(current, list):
- index = int(key)
- if index < 0 or index >= len(current):
- raise KeyError(f"List index {index} out of bounds")
- current = current[index]
- elif isinstance(current, dict):
- if key not in current:
- raise KeyError(f"Path {'.'.join(keys[: i + 1])} does not exist")
- current = current[key]
- else:
- raise TypeError(
- f"Cannot access {'.'.join(keys[: i + 1])} as a dict/list. It's a {type(current).__name__}"
- )
-
- # Delete the item
- last_key = keys[-1]
-
- if isinstance(current, list) and last_key.isdigit():
- index = int(last_key)
- if index < 0 or index >= len(current):
- raise KeyError(f"List index {index} out of bounds")
- del current[index]
- elif isinstance(current, dict) and last_key in current:
- del current[last_key]
- else:
- print(f"Path {path} does not exist")
-
- _stage_key = "stages" if "stages" in config else "stage_args"
-
- # Apply deletions first
- if deletes:
- for key, value in deletes.items():
- if key in ("stage_args", "stages"):
- if value and isinstance(value, dict):
- stage_args = config.get(_stage_key, [])
- if not stage_args:
- raise ValueError("stage_args does not exist in config")
-
- for stage_id, delete_paths in value.items():
- if not delete_paths:
- continue
-
- # Find stage by ID
- target_stage = None
- for stage in stage_args:
- if stage.get("stage_id") == int(stage_id):
- target_stage = stage
- break
-
- if target_stage is None:
- continue
-
- # Delete specified paths in this stage
- # Avoid shadowing the original YAML Path used for the output filename below.
- for delete_path in delete_paths:
- if delete_path: # Skip empty paths
- delete_by_path(target_stage, delete_path)
- elif "." in key:
- # Delete using dot-separated path
- delete_by_path(config, key)
- elif value is None and key in config:
- # Delete entire key
- del config[key]
-
- # Apply updates
- if updates:
- for key, value in updates.items():
- if key in ("stage_args", "stages"):
- if value and isinstance(value, dict):
- stage_args = config.get(_stage_key, [])
- if not stage_args:
- raise ValueError("stage_args does not exist in config")
-
- for stage_id, stage_updates in value.items():
- # Find stage by ID
- target_stage = None
- for stage in stage_args:
- if stage.get("stage_id") == int(stage_id):
- target_stage = stage
- break
-
- if target_stage is None:
- available_ids = [s.get("stage_id") for s in stage_args if "stage_id" in s]
- raise KeyError(f"Stage ID {stage_id} not found, available: {available_ids}")
-
- # Apply updates to this stage
- for update_path, val in stage_updates.items():
- # Check if this is a simple key (not dot-separated)
- # Example: 'engine_input_source' vs 'engine_args.max_model_len'
- if "." not in update_path:
- # Direct key assignment (e.g., updating a list value)
- target_stage[update_path] = val
- else:
- # Dot-separated path (e.g., nested dict access)
- apply_update(target_stage, update_path, val)
- elif "." in key:
- # Apply using dot-separated path
- apply_update(config, key, value)
- else:
- # Direct top-level key
- config[key] = value
-
- # Unique suffix: multiple modify_stage_config calls in one process often run
- # within the same second (e.g. test_qwen3_omni_expansion imports both
- # get_chunk_config and get_batch_token_config). int(time.time()) would collide
- # and the later write would overwrite the earlier YAML on disk.
- # Keep generated configs outside the repo and delete them when pytest exits.
- output_fd, output_path = tempfile.mkstemp(prefix=f"{path.stem}_", suffix=".yaml")
- atexit.register(Path(output_path).unlink, missing_ok=True)
-
- with os.fdopen(output_fd, "w", encoding="utf-8") as f:
- yaml.dump(config, f, default_flow_style=None, sort_keys=False, allow_unicode=True, indent=2)
-
- return str(output_path)
-
-
-# ``stage_config.py`` lives under ``tests/helpers/``; repo root is three parents up.
-_REPO_ROOT = Path(__file__).resolve().parent.parent.parent
-_DEPLOY_DIR = _REPO_ROOT / "vllm_omni" / "deploy"
-_CI_GENERATED_DIR = _REPO_ROOT / "tests" / ".ci_generated"
-
-
-# CI overlays as Python dicts (LSP-friendly). Materialized on demand to
-# tests/.ci_generated/.yaml via get_deploy_config_path("ci/.yaml").
-_CI_OVERLAYS: dict[str, dict[str, Any]] = {
- "qwen2_5_omni": {
- "base_config": "qwen2_5_omni.yaml",
- "async_chunk": False,
- "stages": [
- {
- "stage_id": 0,
- "max_model_len": 16384,
- "max_num_batched_tokens": 16384,
- "max_num_seqs": 1,
- "gpu_memory_utilization": 0.9,
- "skip_mm_profiling": True,
- "load_format": "dummy",
- "default_sampling_params": {"max_tokens": 128},
- },
- {
- "stage_id": 1,
- "max_model_len": 16384,
- "max_num_batched_tokens": 16384,
- "max_num_seqs": 1,
- "gpu_memory_utilization": 0.4,
- "skip_mm_profiling": True,
- "load_format": "dummy",
- "default_sampling_params": {"max_tokens": 4096},
- },
- {
- "stage_id": 2,
- "max_num_seqs": 1,
- "gpu_memory_utilization": 0.5,
- "max_num_batched_tokens": 8192,
- "max_model_len": 8192,
- "load_format": "dummy",
- "devices": "2",
- "default_sampling_params": {"max_tokens": 8192},
- },
- ],
- "platforms": {
- "rocm": {
- "stages": [
- {"stage_id": 0, "gpu_memory_utilization": 0.9},
- {"stage_id": 1, "gpu_memory_utilization": 0.4},
- {"stage_id": 2, "gpu_memory_utilization": 0.5, "devices": "2"},
- ],
- },
- "xpu": {
- "stages": [
- {
- "stage_id": 0,
- "gpu_memory_utilization": 0.9,
- "max_num_batched_tokens": 16384,
- "max_model_len": 16384,
- },
- {"stage_id": 1, "gpu_memory_utilization": 0.5},
- {
- "stage_id": 2,
- "gpu_memory_utilization": 0.3,
- "max_num_batched_tokens": 4096,
- "max_model_len": 4096,
- "devices": "2",
- },
- ],
- },
- },
- },
- "qwen3_omni_moe": {
- "base_config": "qwen3_omni_moe.yaml",
- "async_chunk": False,
- "stages": [
- {
- "stage_id": 0,
- "max_num_seqs": 5,
- "max_model_len": 32768,
- "mm_processor_cache_gb": 0,
- "load_format": "dummy",
- "default_sampling_params": {"max_tokens": 150, "ignore_eos": False},
- },
- {
- "stage_id": 1,
- "gpu_memory_utilization": 0.5,
- "max_num_seqs": 5,
- "max_model_len": 32768,
- "load_format": "dummy",
- "default_sampling_params": {"max_tokens": 1000},
- },
- {
- "stage_id": 2,
- "max_num_seqs": 5,
- "max_num_batched_tokens": 100000,
- "load_format": "dummy",
- "default_sampling_params": {"max_tokens": 2000},
- },
- ],
- "platforms": {
- "rocm": {
- "stages": [
- {"stage_id": 0, "max_num_seqs": 1, "default_sampling_params": {"max_tokens": 100}},
- {
- "stage_id": 1,
- "max_num_seqs": 1,
- "enforce_eager": True,
- "default_sampling_params": {"max_tokens": 100},
- },
- {
- "stage_id": 2,
- "max_num_seqs": 1,
- "max_num_batched_tokens": 1000000,
- "default_sampling_params": {"max_tokens": 200},
- },
- ],
- },
- "xpu": {
- "stages": [
- {
- "stage_id": 0,
- "gpu_memory_utilization": 0.85,
- "max_num_seqs": 1,
- "tensor_parallel_size": 4,
- "enforce_eager": True,
- "max_num_batched_tokens": 4096,
- "max_model_len": 4096,
- "max_cudagraph_capture_size": 0,
- "skip_mm_profiling": True,
- "devices": "0,1,2,3",
- "default_sampling_params": {"max_tokens": 100, "ignore_eos": False},
- },
- {
- "stage_id": 1,
- "gpu_memory_utilization": 0.6,
- "max_num_seqs": 1,
- "enforce_eager": True,
- "max_num_batched_tokens": 4096,
- "max_model_len": 4096,
- "max_cudagraph_capture_size": 0,
- "skip_mm_profiling": True,
- "devices": "4",
- },
- {
- "stage_id": 2,
- "gpu_memory_utilization": 0.3,
- "max_num_seqs": 1,
- "max_num_batched_tokens": 100000,
- "max_cudagraph_capture_size": 0,
- "skip_mm_profiling": True,
- "devices": "5",
- "default_sampling_params": {"max_tokens": 2000},
- },
- ],
- },
- },
- },
- "bagel": {
- "base_config": "bagel.yaml",
- "stages": [
- {
- "stage_id": 0,
- "max_num_seqs": 3,
- "gpu_memory_utilization": 0.45,
- "load_format": "dummy",
- },
- {
- "stage_id": 1,
- "max_num_seqs": 1,
- "load_format": "dummy",
- },
- ],
- },
- "bagel_think": {
- "base_config": "bagel_think.yaml",
- "stages": [
- {
- "stage_id": 0,
- "max_num_seqs": 3,
- "gpu_memory_utilization": 0.45,
- "load_format": "dummy",
- },
- {
- "stage_id": 1,
- "max_num_seqs": 1,
- "load_format": "dummy",
- },
- ],
- },
- "bagel_single_stage": {
- "base_config": "bagel_single_stage.yaml",
- "stages": [
- {
- "stage_id": 0,
- "max_num_seqs": 1,
- "load_format": "dummy",
- },
- ],
- },
- "bagel_mooncake": {
- "base_config": "bagel.yaml",
- "stages": [
- {
- "stage_id": 0,
- "max_num_seqs": 1,
- "gpu_memory_utilization": 0.45,
- "load_format": "dummy",
- "output_connectors": {"to_stage_1": "mooncake_connector"},
- },
- {
- "stage_id": 1,
- "max_num_seqs": 1,
- "load_format": "dummy",
- "input_connectors": {"from_stage_0": "mooncake_connector"},
- },
- ],
- "connectors": {
- "mooncake_connector": {
- "name": "MooncakeConnector",
- "extra": {
- "host": "${MOONCAKE_HOST}",
- "metadata_server": "http://${MOONCAKE_HOST}:${MOONCAKE_HTTP_PORT}/metadata",
- "master": "${MOONCAKE_HOST}:${MOONCAKE_RPC_PORT}",
- "segment": 64000000,
- "localbuf": 64000000,
- "proto": "tcp",
- },
- },
- },
- },
- # Single-stage thinker-only topology for the abort test.
- "qwen2_5_omni_thinker_only": {
- "async_chunk": False,
- "pipeline": "qwen2_5_omni_thinker_only",
- "stages": [
- {
- "stage_id": 0,
- "max_num_seqs": 1,
- "gpu_memory_utilization": 0.9,
- "enforce_eager": True,
- "max_num_batched_tokens": 16384,
- "max_model_len": 16384,
- "skip_mm_profiling": True,
- "mm_processor_cache_gb": 0,
- "load_format": "dummy",
- "devices": "0",
- "default_sampling_params": {
- "temperature": 0.0,
- "top_p": 1.0,
- "top_k": -1,
- "max_tokens": 128,
- "seed": 42,
- "repetition_penalty": 1.1,
- },
- },
- ],
- },
-}
-
-
-def _materialize_ci_overlay(model_type: str) -> Path:
- import yaml
-
- if model_type not in _CI_OVERLAYS:
- raise KeyError(f"No CI overlay registered for {model_type!r}. Available: {sorted(_CI_OVERLAYS)}")
-
- _CI_GENERATED_DIR.mkdir(parents=True, exist_ok=True)
- out = _CI_GENERATED_DIR / f"{model_type}.yaml"
-
- overlay = {**_CI_OVERLAYS[model_type]}
- base = overlay.get("base_config")
- if base:
- overlay["base_config"] = str(_DEPLOY_DIR / base)
-
- with open(out, "w", encoding="utf-8") as f:
- yaml.safe_dump(overlay, f, sort_keys=False)
- return out
-
-
-def get_deploy_config_path(rel_path: str) -> str:
- """Resolve a deploy yaml; ``ci/.yaml`` materializes from ``_CI_OVERLAYS``."""
- if rel_path.startswith("ci/") and rel_path.endswith(".yaml"):
- model_type = rel_path[len("ci/") : -len(".yaml")]
- if model_type in _CI_OVERLAYS:
- return str(_materialize_ci_overlay(model_type))
- return str(_DEPLOY_DIR / rel_path)
-
-
-__all__ = [
- "modify_stage_config",
- "get_deploy_config_path",
-]
diff --git a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py
index 0e071f724e5..3b1471365d5 100644
--- a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py
+++ b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py
@@ -2,14 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for CosyVoice3 components."""
-from types import SimpleNamespace
-
import pytest
import torch
import torch.nn as nn
-from tests.helpers.mark import hardware_test
-
class TestPreLookaheadLayer:
"""Tests for PreLookaheadLayer."""
@@ -20,8 +16,6 @@ def layer(self):
return PreLookaheadLayer(in_channels=512, channels=512, pre_lookahead_len=3)
- @pytest.mark.core_model
- @pytest.mark.cpu
def test_forward_shape(self, layer):
"""Test that output shape matches input shape."""
batch, seq_len, channels = 2, 10, 512
@@ -31,8 +25,6 @@ def test_forward_shape(self, layer):
assert out.shape == x.shape
- @pytest.mark.core_model
- @pytest.mark.cpu
def test_forward_with_context(self, layer):
"""Test forward with context for streaming."""
batch, seq_len, channels = 1, 10, 512
@@ -44,8 +36,6 @@ def test_forward_with_context(self, layer):
assert out.shape == x.shape
- @pytest.mark.core_model
- @pytest.mark.cpu
def test_residual_connection(self, layer):
"""Test that residual connection is applied."""
batch, seq_len, channels = 1, 5, 512
@@ -67,8 +57,6 @@ def attention(self):
return DiTAttention(dim=512, heads=8, dim_head=64, dropout=0.0)
- @pytest.mark.core_model
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_forward_shape(self, attention):
"""Test attention output shape."""
batch, seq_len, dim = 2, 16, 512
@@ -78,8 +66,6 @@ def test_forward_shape(self, attention):
assert out.shape == x.shape
- @pytest.mark.core_model
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_forward_with_mask(self, attention):
"""Test attention with mask."""
batch, seq_len, dim = 2, 16, 512
@@ -93,8 +79,6 @@ def test_forward_with_mask(self, attention):
# Masked positions should be zero
assert torch.allclose(out[:, -3:], torch.zeros_like(out[:, -3:]))
- @pytest.mark.core_model
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_qkv_projections(self, attention):
"""Test that Q/K/V projections exist and have correct dimensions."""
assert hasattr(attention, "to_q")
@@ -114,8 +98,6 @@ def block(self):
return DiTBlock(dim=512, heads=8, dim_head=64, ff_mult=4, dropout=0.0)
- @pytest.mark.core_model
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_forward_shape(self, block):
"""Test block output shape."""
batch, seq_len, dim = 2, 16, 512
@@ -126,8 +108,6 @@ def test_forward_shape(self, block):
assert out.shape == x.shape
- @pytest.mark.core_model
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_adalayernorm_modulation(self, block):
"""Test that AdaLayerNorm modulates based on timestep."""
batch, seq_len, dim = 1, 8, 512
@@ -162,8 +142,6 @@ def dit(self):
long_skip_connection=True,
)
- @pytest.mark.core_model
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_forward_shape(self, dit):
"""Test DiT forward output shape."""
batch, mel_dim, seq_len = 1, 80, 32
@@ -178,8 +156,6 @@ def test_forward_shape(self, dit):
assert out.shape == (batch, mel_dim, seq_len)
- @pytest.mark.core_model
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_timestep_embedding(self, dit):
"""Test that different timesteps produce different outputs."""
batch, mel_dim, seq_len = 1, 80, 16
@@ -212,8 +188,6 @@ def forward(self, x, mask, mu, t, spks=None, cond=None):
return DummyEstimator()
- @pytest.mark.core_model
- @pytest.mark.cpu
def test_causal_conditional_cfm_forward(self, dummy_estimator):
"""Test CausalConditionalCFM forward pass."""
from omegaconf import DictConfig
@@ -252,8 +226,6 @@ def test_causal_conditional_cfm_forward(self, dummy_estimator):
class TestSDPAFallback:
"""Test SDPA fallback for float32 inputs."""
- @pytest.mark.core_model
- @hardware_test(res={"cuda": "L4"}, num_cards=1)
def test_float32_uses_sdpa(self):
"""Test that float32 inputs use SDPA fallback."""
from vllm_omni.diffusion.attention.layer import Attention
@@ -275,32 +247,3 @@ def test_float32_uses_sdpa(self):
assert out.shape == (batch, seq_len, heads, dim)
assert out.dtype == torch.float32
-
-
-def test_code2wav_forward_finalizes_hift_tail():
- from vllm_omni.model_executor.models.cosyvoice3.cosyvoice3_code2wav import CosyVoice3Code2Wav
-
- class DummyHiFT(nn.Module):
- def __init__(self):
- super().__init__()
- self.m_source = SimpleNamespace(l_linear=SimpleNamespace(weight=torch.ones(1, dtype=torch.float32)))
- self.finalize_calls: list[bool] = []
-
- def inference(self, speech_feat, finalize=True):
- self.finalize_calls.append(bool(finalize))
- return torch.zeros((speech_feat.shape[0], 1, speech_feat.shape[-1]), dtype=speech_feat.dtype), None
-
- model = object.__new__(CosyVoice3Code2Wav)
- nn.Module.__init__(model)
- model.hift = DummyHiFT()
- model._forward_mel = lambda **_: torch.ones((1, 80, 8), dtype=torch.float32)
-
- out = model.forward(
- token=torch.tensor([[1, 2, 3]], dtype=torch.int32),
- prompt_token=torch.tensor([[4, 5]], dtype=torch.int32),
- prompt_feat=torch.ones((1, 4, 80), dtype=torch.float32),
- embedding=torch.ones((1, 192), dtype=torch.float32),
- )
-
- assert out.shape == (1, 1, 8)
- assert model.hift.finalize_calls == [True]
diff --git a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py
deleted file mode 100644
index 9a78c54de65..00000000000
--- a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_model_helpers.py
+++ /dev/null
@@ -1,463 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from threading import Lock
-from types import SimpleNamespace
-
-import pytest
-import torch
-import torch.nn as nn
-from vllm.v1.outputs import SamplerOutput
-from vllm.v1.sample.logits_processor.state import LogitsProcessors
-from vllm.v1.sample.metadata import SamplingMetadata
-
-from vllm_omni.model_executor.models.cosyvoice3.cosyvoice3 import CosyVoice3Model
-from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-class _DummyCode2Wav:
- def __init__(
- self,
- vocab_size: int,
- num_samples: int = 32,
- outputs: list[tuple[torch.Tensor, dict[str, object] | None]] | None = None,
- ):
- self.input_embedding = SimpleNamespace(num_embeddings=vocab_size)
- self.num_samples = num_samples
- self.outputs = list(outputs or [])
- self.forward_calls: list[dict[str, object]] = []
- self.forward_streaming_calls: list[dict[str, object]] = []
-
- def forward(self, **kwargs):
- self.forward_calls.append(kwargs)
- token = kwargs["token"]
- num_samples = int(token.shape[-1])
- return torch.linspace(-1.0, 1.0, max(num_samples, 1), dtype=torch.float32).reshape(1, 1, -1)
-
- def forward_streaming(self, **kwargs):
- self.forward_streaming_calls.append(kwargs)
- if self.outputs:
- return self.outputs.pop(0)
-
- token = kwargs["token"]
- num_samples = int(token.shape[-1])
- audio = torch.linspace(-1.0, 1.0, max(num_samples, 1), dtype=torch.float32).reshape(1, 1, -1)
- new_state = None
- if not kwargs.get("finalize", False):
- new_state = {
- "mel": torch.ones((1, 80, max(num_samples, 1)), dtype=torch.float32),
- "speech_offset": audio.shape[-1],
- }
- return audio, new_state
-
-
-def _make_code2wav_model(
- *,
- with_stride_cfg: bool = False,
- num_samples: int = 32,
- outputs: list[tuple[torch.Tensor, dict[str, object] | None]] | None = None,
-) -> CosyVoice3Model:
- model = object.__new__(CosyVoice3Model)
- nn.Module.__init__(model)
- model.model_stage = "cosyvoice3_code2wav"
- hift_cfg = {} if not with_stride_cfg else {"upsample_rates": [8, 5, 3], "istft_params": {"hop_len": 4}}
- model.config = SimpleNamespace(
- sample_rate=24000,
- hift=hift_cfg,
- token_frame_rate=25 if with_stride_cfg else 0,
- token_mel_ratio=2 if with_stride_cfg else 0,
- )
- model.code2wav = _DummyCode2Wav(vocab_size=4, num_samples=num_samples, outputs=outputs)
- model.source_cache_len = 4
- model.speech_window = torch.hamming_window(8, periodic=False)
- model._stream_audio_cache_by_req = {}
- model._stream_audio_cache_lock = Lock()
- model._stream_vocoder_cache_by_req = {}
- return model
-
-
-def _make_talker_model() -> CosyVoice3Model:
- model = object.__new__(CosyVoice3Model)
- nn.Module.__init__(model)
- model.model_stage = "cosyvoice3_talker"
- model.config = SimpleNamespace(
- llm={
- "speech_token_size": 6561,
- "eos_token_id": 6562,
- "sampling": {
- "top_p": 0.8,
- "top_k": 25,
- "win_size": 10,
- "tau_r": 0.1,
- },
- },
- vocab_size=151923,
- )
- return model
-
-
-def _make_sampling_metadata(
- *,
- output_token_ids: list[list[int]],
- repetition_penalty: float = 2.0,
-) -> SamplingMetadata:
- return SamplingMetadata(
- temperature=torch.tensor([1.0], dtype=torch.float32),
- all_greedy=False,
- all_random=True,
- top_p=torch.tensor([0.8], dtype=torch.float32),
- top_k=torch.tensor([25], dtype=torch.int32),
- generators={},
- max_num_logprobs=None,
- no_penalties=False,
- prompt_token_ids=None,
- frequency_penalties=torch.zeros(1, dtype=torch.float32),
- presence_penalties=torch.zeros(1, dtype=torch.float32),
- repetition_penalties=torch.tensor([repetition_penalty], dtype=torch.float32),
- output_token_ids=output_token_ids,
- allowed_token_ids_mask=None,
- bad_words_token_ids={},
- logitsprocs=LogitsProcessors(),
- )
-
-
-def test_split_request_ids_uses_seq_token_counts():
- ids = torch.tensor([10, 11, 12, 13, 14], dtype=torch.long)
- chunks = CosyVoice3Model._split_request_ids(ids, [2, 2, 2])
- assert [c.tolist() for c in chunks] == [[10, 11], [12, 13], [14]]
-
-
-def test_split_request_ids_honors_single_request_seq_token_counts():
- ids = torch.tensor([10, 11, 12, 13, 14], dtype=torch.long)
- chunks = CosyVoice3Model._split_request_ids(ids, [3])
- assert [c.tolist() for c in chunks] == [[10, 11, 12]]
-
-
-def test_sanitize_codec_tokens_filters_out_of_range():
- model = _make_code2wav_model()
- raw = torch.tensor([-1, 0, 3, 4, 99], dtype=torch.long)
- clean = model._sanitize_codec_tokens(raw)
- assert clean.tolist() == [0, 3]
-
-
-def test_forward_prefers_token_offset_when_present():
- model = _make_code2wav_model()
-
- runtime_info = [
- {
- "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long),
- "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32),
- "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32),
- "token_offset": 2,
- "left_context_size": 1,
- }
- ]
-
- out = model.forward(
- input_ids=torch.tensor([0, 1, 2], dtype=torch.long),
- positions=torch.tensor([0, 1, 2], dtype=torch.long),
- model_intermediate_buffer=runtime_info,
- seq_token_counts=[3],
- )
-
- assert len(out.multimodal_outputs["audio"]) == 1
- assert out.multimodal_outputs["audio"][0].numel() > 0
- assert len(model.code2wav.forward_streaming_calls) == 1
- call = model.code2wav.forward_streaming_calls[0]
- assert call["token"].shape == (1, 3)
- assert call["token_offset_tokens"] == 2
- assert call["finalize"] is False
-
-
-def test_forward_falls_back_to_left_context_size_for_backward_compat():
- model = _make_code2wav_model()
-
- runtime_info = [
- {
- "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long),
- "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32),
- "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32),
- "left_context_size": 2,
- }
- ]
-
- model.forward(
- input_ids=torch.tensor([0, 1, 2], dtype=torch.long),
- positions=torch.tensor([0, 1, 2], dtype=torch.long),
- model_intermediate_buffer=runtime_info,
- seq_token_counts=[3],
- )
-
- assert model.code2wav.forward_streaming_calls[0]["token_offset_tokens"] == 2
-
-
-def test_forward_ignores_single_request_padded_tail_tokens():
- model = _make_code2wav_model(with_stride_cfg=True)
- runtime_info = [
- {
- "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long),
- "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32),
- "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32),
- "token_offset": 0,
- }
- ]
-
- out = model.forward(
- input_ids=torch.tensor([0, 1, 2, 3, 3], dtype=torch.long),
- positions=torch.tensor([0, 1, 2, 3, 4], dtype=torch.long),
- model_intermediate_buffer=runtime_info,
- seq_token_counts=[3],
- )
-
- # The padded tail must not contribute to code2wav length.
- assert out.multimodal_outputs["audio"][0].numel() == 3
- assert model.code2wav.forward_streaming_calls[0]["token"].tolist() == [[0, 1, 2]]
-
-
-def test_forward_uses_non_stream_decode_without_chunk_metadata():
- model = _make_code2wav_model()
-
- runtime_info = [
- {
- "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long),
- "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32),
- "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32),
- "prefix_ids": [101, 102],
- "generated_len": 3,
- }
- ]
-
- out = model.forward(
- input_ids=torch.tensor([0, 1, 2], dtype=torch.long),
- positions=torch.tensor([0, 1, 2], dtype=torch.long),
- model_intermediate_buffer=runtime_info,
- seq_token_counts=[3],
- )
-
- assert out.multimodal_outputs["audio"][0].numel() == 3
- assert len(model.code2wav.forward_calls) == 1
- assert len(model.code2wav.forward_streaming_calls) == 0
- call = model.code2wav.forward_calls[0]
- assert call["token"].tolist() == [[0, 1, 2]]
-
-
-def test_forward_reuses_streaming_cache_state_between_chunks():
- model = _make_code2wav_model(
- outputs=[
- (
- torch.arange(4, dtype=torch.float32).reshape(1, 1, -1),
- {"mel": torch.ones((1, 80, 3), dtype=torch.float32), "speech_offset": 4},
- ),
- (
- torch.full((1, 1, 2), 9.0, dtype=torch.float32),
- {"mel": torch.ones((1, 80, 5), dtype=torch.float32), "speech_offset": 6},
- ),
- ]
- )
- runtime_info = [
- {
- "req_id": ["rid-stream"],
- "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long),
- "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32),
- "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32),
- "token_offset": 0,
- "stream_finished": torch.tensor(False),
- }
- ]
-
- out1 = model.forward(
- input_ids=torch.tensor([0, 1, 2], dtype=torch.long),
- positions=torch.tensor([0, 1, 2], dtype=torch.long),
- model_intermediate_buffer=runtime_info,
- seq_token_counts=[3],
- )
- assert out1.multimodal_outputs["audio"][0].tolist() == [0.0, 1.0, 2.0, 3.0]
- assert model.code2wav.forward_streaming_calls[0]["cache_state"] is None
-
- out2 = model.forward(
- input_ids=torch.tensor([0, 1, 2], dtype=torch.long),
- positions=torch.tensor([0, 1, 2], dtype=torch.long),
- model_intermediate_buffer=runtime_info,
- seq_token_counts=[3],
- )
- assert out2.multimodal_outputs["audio"][0].tolist() == [9.0, 9.0]
- cache_state = model.code2wav.forward_streaming_calls[1]["cache_state"]
- assert cache_state is not None
- assert cache_state["speech_offset"] == 4
- assert "rid-stream" in model._stream_vocoder_cache_by_req
-
-
-def test_forward_clears_streaming_cache_on_terminal_chunk():
- model = _make_code2wav_model(
- outputs=[
- (
- torch.arange(4, dtype=torch.float32).reshape(1, 1, -1),
- {"mel": torch.ones((1, 80, 3), dtype=torch.float32), "speech_offset": 4},
- ),
- (
- torch.full((1, 1, 1), 7.0, dtype=torch.float32),
- None,
- ),
- ]
- )
- runtime_info = [
- {
- "req_id": ["rid-stream"],
- "speech_token": torch.tensor([[1, 2, 3]], dtype=torch.long),
- "speech_feat": torch.tensor([[[0.1, 0.2], [0.3, 0.4]]], dtype=torch.float32),
- "embedding": torch.tensor([[0.5, 0.6]], dtype=torch.float32),
- "token_offset": 0,
- "stream_finished": torch.tensor(False),
- }
- ]
-
- model.forward(
- input_ids=torch.tensor([0, 1, 2], dtype=torch.long),
- positions=torch.tensor([0, 1, 2], dtype=torch.long),
- model_intermediate_buffer=runtime_info,
- seq_token_counts=[3],
- )
- assert "rid-stream" in model._stream_vocoder_cache_by_req
-
- runtime_info[0]["stream_finished"] = torch.tensor(True)
- out = model.forward(
- input_ids=torch.tensor([0, 1, 2], dtype=torch.long),
- positions=torch.tensor([0, 1, 2], dtype=torch.long),
- model_intermediate_buffer=runtime_info,
- seq_token_counts=[3],
- )
- assert out.multimodal_outputs["audio"][0].tolist() == [7.0]
- assert "rid-stream" not in model._stream_vocoder_cache_by_req
-
-
-def test_sample_uses_ras_rejection_for_recent_repetition():
- model = _make_talker_model()
- metadata = _make_sampling_metadata(output_token_ids=[[1] * 10])
- logits = torch.tensor([[-1e9, 10.0, 0.0]], dtype=torch.float32)
-
- out = model.sample(logits, metadata)
-
- assert out is not None
- assert out.sampled_token_ids.tolist() == [[2]]
-
-
-def test_sample_tolerates_padded_rows_without_history():
- model = _make_talker_model()
- metadata = _make_sampling_metadata(output_token_ids=[[1] * 10])
- logits = torch.tensor(
- [
- [-1e9, 10.0, 0.0],
- [-1e9, 0.0, 10.0],
- ],
- dtype=torch.float32,
- )
-
- out = model.sample(logits, metadata)
-
- assert out is not None
- assert out.sampled_token_ids.shape == (2, 1)
-
-
-def test_gpu_ar_model_runner_prefers_model_sampler_when_opted_in():
- metadata = _make_sampling_metadata(output_token_ids=[[1, 2, 3]])
- expected = SamplerOutput(
- sampled_token_ids=torch.tensor([[7]], dtype=torch.int32),
- logprobs_tensors=None,
- )
- calls: list[torch.Tensor] = []
-
- class _DummyInputBatch:
- def __init__(self):
- self.sampling_metadata = metadata
- self.updated = False
-
- def update_async_output_token_ids(self):
- self.updated = True
-
- runner = object.__new__(GPUARModelRunner)
- runner.input_batch = _DummyInputBatch()
- runner.model = SimpleNamespace(
- prefer_model_sampler=True,
- sample=lambda logits, sampling_metadata: calls.append(logits.clone()) or expected,
- )
- runner.sampler = lambda **_: (_ for _ in ()).throw(AssertionError("fallback sampler should not be used"))
-
- out = runner._sample(torch.tensor([[0.1, 0.2]], dtype=torch.float32), spec_decode_metadata=None)
-
- assert out is expected
- assert runner.input_batch.updated is False
- assert len(calls) == 1
-
-
-def test_gpu_ar_model_runner_supplies_req_output_history_to_model_sampler():
- metadata = _make_sampling_metadata(output_token_ids=[])
- seen_histories: list[list[list[int]]] = []
-
- class _DummyInputBatch:
- def __init__(self):
- self.sampling_metadata = metadata
- self.req_output_token_ids = [[1, 2, 3]]
- self.req_ids = ["rid-1"]
- self.sampled_token_ids_cpu = None
- self.async_copy_ready_event = None
- self.prev_req_id_to_index = None
-
- def update_async_output_token_ids(self):
- raise AssertionError("fallback async repair should not run for model sampler path")
-
- runner = object.__new__(GPUARModelRunner)
- runner.input_batch = _DummyInputBatch()
- runner.model = SimpleNamespace(
- prefer_model_sampler=True,
- sample=lambda logits, sampling_metadata: seen_histories.append(
- [list(x) for x in sampling_metadata.output_token_ids]
- )
- or SamplerOutput(sampled_token_ids=torch.tensor([[7]], dtype=torch.int32), logprobs_tensors=None),
- )
- runner.sampler = lambda **_: (_ for _ in ()).throw(AssertionError("fallback sampler should not be used"))
-
- runner._sample(torch.tensor([[0.1, 0.2]], dtype=torch.float32), spec_decode_metadata=None)
-
- assert seen_histories == [[[1, 2, 3]]]
-
-
-def test_gpu_ar_model_runner_repairs_async_placeholders_for_model_sampler():
- metadata = _make_sampling_metadata(output_token_ids=[])
- seen_histories: list[list[list[int]]] = []
-
- class _ReadyEvent:
- def __init__(self):
- self.synced = False
-
- def synchronize(self):
- self.synced = True
-
- class _DummyInputBatch:
- def __init__(self):
- self.sampling_metadata = metadata
- self.req_output_token_ids = [[11, -1]]
- self.req_ids = ["rid-1"]
- self.sampled_token_ids_cpu = torch.tensor([[29]], dtype=torch.int32)
- self.async_copy_ready_event = _ReadyEvent()
- self.prev_req_id_to_index = {"rid-1": 0}
-
- def update_async_output_token_ids(self):
- raise AssertionError("fallback async repair should not run for model sampler path")
-
- runner = object.__new__(GPUARModelRunner)
- runner.input_batch = _DummyInputBatch()
- runner.model = SimpleNamespace(
- prefer_model_sampler=True,
- sample=lambda logits, sampling_metadata: seen_histories.append(
- [list(x) for x in sampling_metadata.output_token_ids]
- )
- or SamplerOutput(sampled_token_ids=torch.tensor([[7]], dtype=torch.int32), logprobs_tensors=None),
- )
- runner.sampler = lambda **_: (_ for _ in ()).throw(AssertionError("fallback sampler should not be used"))
-
- runner._sample(torch.tensor([[0.1, 0.2]], dtype=torch.float32), spec_decode_metadata=None)
-
- assert runner.input_batch.async_copy_ready_event.synced is True
- assert seen_histories == [[[11, 29]]]
diff --git a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_utils.py b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_utils.py
index 76428ed582d..828bb2b1473 100644
--- a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_utils.py
+++ b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_utils.py
@@ -5,8 +5,6 @@
import pytest
import torch
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
class TestMakePadMask:
"""Tests for make_pad_mask utility."""
diff --git a/tests/model_executor/models/glm_image/test_glm_image_ar.py b/tests/model_executor/models/glm_image/test_glm_image_ar.py
deleted file mode 100644
index 32a016b2a67..00000000000
--- a/tests/model_executor/models/glm_image/test_glm_image_ar.py
+++ /dev/null
@@ -1,352 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for GLM-Image AR model: DataParser, processor, and M-RoPE."""
-
-import importlib.util
-import os
-import sys
-import types
-from unittest.mock import MagicMock, patch
-
-import pytest
-import torch
-
-# ---------------------------------------------------------------------------
-# Load target classes via importlib to avoid requiring transformers.models.glm_image
-# (which may not exist in CI). This follows the same pattern as
-# tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py.
-# ---------------------------------------------------------------------------
-
-_BASE = os.path.join(
- os.path.dirname(__file__),
- os.pardir,
- os.pardir,
- os.pardir,
- os.pardir,
- "vllm_omni",
- "model_executor",
- "models",
- "glm_image",
-)
-
-
-def _load_module(name: str, filename: str):
- path = os.path.abspath(os.path.join(_BASE, filename))
- spec = importlib.util.spec_from_file_location(name, path)
- mod = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(mod)
- return mod
-
-
-def _build_mock_modules() -> dict[str, object]:
- """Build the dict of modules to inject into sys.modules."""
- # Stub transformers.models.glm_image submodules
- glm_image_mod = types.ModuleType("transformers.models.glm_image")
- glm_config_mod = types.ModuleType("transformers.models.glm_image.configuration_glm_image")
- glm_config_mod.GlmImageConfig = type("GlmImageConfig", (), {})
- glm_config_mod.GlmImageTextConfig = type("GlmImageTextConfig", (), {})
- glm_config_mod.GlmImageVisionConfig = type("GlmImageVisionConfig", (), {})
- glm_config_mod.GlmImageVQVAEConfig = type("GlmImageVQVAEConfig", (), {})
- glm_proc_mod = types.ModuleType("transformers.models.glm_image.processing_glm_image")
- glm_proc_mod.GlmImageProcessor = type("GlmImageProcessor", (), {})
-
- # vllm_omni submodules needed by the import chain
- vllm_omni_mod = MagicMock()
- vllm_omni_models = types.ModuleType("vllm_omni.model_executor.models")
- vllm_omni_glm_image_pkg = types.ModuleType("vllm_omni.model_executor.models.glm_image")
- vllm_omni_glm_image_pkg.__path__ = [os.path.abspath(_BASE)]
- vllm_omni_output = MagicMock()
-
- return {
- "transformers.models.glm_image": glm_image_mod,
- "transformers.models.glm_image.configuration_glm_image": glm_config_mod,
- "transformers.models.glm_image.processing_glm_image": glm_proc_mod,
- "vllm_omni": vllm_omni_mod,
- "vllm_omni.model_executor": types.ModuleType("vllm_omni.model_executor"),
- "vllm_omni.model_executor.models": vllm_omni_models,
- "vllm_omni.model_executor.models.glm_image": vllm_omni_glm_image_pkg,
- "vllm_omni.model_executor.models.output_templates": vllm_omni_output,
- }
-
-
-def _load_target_classes():
- """Load the glm_image_ar module with mocked dependencies."""
- mocks = _build_mock_modules()
- with patch.dict(sys.modules, mocks):
- mod = _load_module(
- "vllm_omni.model_executor.models.glm_image.glm_image_ar",
- "glm_image_ar.py",
- )
- sys.modules["vllm_omni.model_executor.models.glm_image.glm_image_ar"] = mod
- return mod
-
-
-_ar_mod = _load_target_classes()
-
-GlmImageDataParser = _ar_mod.GlmImageDataParser
-GlmImageMultiModalProcessor = _ar_mod.GlmImageMultiModalProcessor
-GlmImageForConditionalGeneration = _ar_mod.GlmImageForConditionalGeneration
-GlmImageRotaryEmbedding = _ar_mod.GlmImageRotaryEmbedding
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-# =============================================================================
-# Helper: Minimal config for testing
-# =============================================================================
-
-
-def _make_hf_config(**overrides):
- """Create a minimal GlmImageConfig-like object for testing."""
- defaults = {
- "image_token_id": 167855,
- "image_start_token_id": 16384,
- "image_end_token_id": 16385,
- "grid_bos_token_id": None,
- "grid_eos_token_id": None,
- }
- defaults.update(overrides)
- from types import SimpleNamespace
-
- return SimpleNamespace(**defaults)
-
-
-# =============================================================================
-# Tests for GlmImageDataParser
-# =============================================================================
-
-
-class TestGlmImageDataParser:
- """Test that img2img key is normalized to image in the data parser."""
-
- def test_img2img_normalized_to_image(self):
- parser = GlmImageDataParser.__new__(GlmImageDataParser)
- parser._expected_hidden_size = 4096
- # The _get_subparsers should include img2img
- subparsers = parser._get_subparsers()
- assert "img2img" in subparsers
- assert subparsers["img2img"] == parser._parse_image_data
-
- def test_parse_mm_data_normalizes_img2img(self):
- parser = GlmImageDataParser.__new__(GlmImageDataParser)
- parser._expected_hidden_size = 4096
- # Create a mock for the parent parse_mm_data
- original_parse = type(parser).parse_mm_data
-
- calls = []
-
- def mock_parse(mm_data, **kwargs):
- calls.append(mm_data)
- return MagicMock()
-
- # Monkey-patch temporarily
- type(parser).parse_mm_data = mock_parse
- try:
- parser.parse_mm_data({"img2img": "fake_image"})
- except Exception:
- pass # parse might fail on mock, we just check the normalization
- finally:
- type(parser).parse_mm_data = original_parse
-
- # Verify that "img2img" was normalized to "image"
- if calls:
- assert "image" in calls[0]
- assert "img2img" not in calls[0]
-
-
-# =============================================================================
-# Tests for _build_generation_grids
-# =============================================================================
-
-
-class TestBuildGenerationGrids:
- """Test M-RoPE grid construction for t2i mode."""
-
- @pytest.fixture
- def processor(self):
- """Create a minimal processor instance with mocked info."""
- proc = object.__new__(GlmImageMultiModalProcessor)
- proc.info = MagicMock()
- return proc
-
- def test_1024x1024(self, processor):
- kwargs = {"target_h": 1024, "target_w": 1024}
- grids = processor._build_generation_grids(kwargs)
- # token_h = 32, token_w = 32
- # ratio = 1.0, small_h = 16, small_w = 16
- assert grids.shape == (2, 3)
- assert grids[0].tolist() == [1, 32, 32] # large
- assert grids[1].tolist() == [1, 16, 16] # small
-
- def test_512x512(self, processor):
- kwargs = {"target_h": 512, "target_w": 512}
- grids = processor._build_generation_grids(kwargs)
- assert grids.shape == (2, 3)
- assert grids[0].tolist() == [1, 16, 16]
- # small: ratio=1.0, small_h=int(sqrt(1)*16)=16, small_w=16
- assert grids[1].tolist() == [1, 16, 16]
-
- def test_non_square(self, processor):
- kwargs = {"target_h": 1024, "target_w": 512}
- grids = processor._build_generation_grids(kwargs)
- # token_h = 32, token_w = 16, ratio = 2.0
- # small_h = int(sqrt(2)*16) = 22, small_w = int(sqrt(0.5)*16) = 11
- assert grids[0].tolist() == [1, 32, 16]
- assert grids[1].tolist() == [1, 22, 11]
-
- def test_defaults_to_1024_when_no_target(self, processor):
- kwargs = {}
- grids = processor._build_generation_grids(kwargs)
- assert grids[0].tolist() == [1, 32, 32]
-
- def test_height_width_fallback(self, processor):
- kwargs = {"height": 512, "width": 512}
- grids = processor._build_generation_grids(kwargs)
- assert grids[0].tolist() == [1, 16, 16]
-
- def test_aligned_to_factor(self, processor):
- # 1000 not aligned to 32, should be rounded down to 992
- kwargs = {"target_h": 1000, "target_w": 1000}
- grids = processor._build_generation_grids(kwargs)
- # 1000 // 32 = 31
- assert grids[0].tolist() == [1, 31, 31]
-
-
-# =============================================================================
-# Tests for get_mrope_input_positions
-# =============================================================================
-
-
-class TestGetMropeInputPositions:
- """Test M-RoPE position ID computation."""
-
- @pytest.fixture
- def model(self):
- """Create a minimal model instance for M-RoPE testing."""
- model = object.__new__(GlmImageForConditionalGeneration)
- model.config = _make_hf_config()
- return model
-
- def test_pure_text(self, model):
- """Pure text tokens: all 3 dimensions get same sequential positions."""
- input_tokens = [100, 101, 102, 103]
- positions, delta = model.get_mrope_input_positions(input_tokens)
- assert positions.shape == (3, 4)
- # All three dims should be [0, 1, 2, 3]
- for dim in range(3):
- assert positions[dim].tolist() == [0, 1, 2, 3]
- assert delta == 0 # max(3) + 1 - seq_len(4) = 0
-
- def test_t2i_with_target_size(self, model):
- """t2i with explicit target_h/target_w: grids built from them."""
- input_tokens = [100, 101, 102, 16384] # text +
- kwargs = {"target_h": 256, "target_w": 256}
-
- positions, delta = model.get_mrope_input_positions(input_tokens, **kwargs)
- # 256/32=8 -> grids = [[1,8,8], [1,16,16]] (small uses factor//2=16 base)
- # Decode order (reversed): grid[-1]=[1,16,16]=256, grid[-2]=[1,8,8]=64, EOS=1
- total_decode = 256 + 64 + 1 # 321
- assert positions.shape == (3, 4 + total_decode)
- # delta = max_position + 1 - seq_len
- # Positions advance by max(h,w) per grid: max(16,16)=16, max(8,8)=8
- # max_pos = seq_len(4) + 16 + 8 = 28, then EOS at 28
- # delta = 28 + 1 - 4 = 25
- assert delta == 25
-
- def test_t2i_1024_default_grids(self, model):
- """t2i with default 1024x1024 grids when no explicit target size."""
- # Prompt ending with image_start_token_id but no image_end_token_id
- input_tokens = [100, 101, 16384]
- # No target_h/target_w, no mrope_image_grid_thw
- # Falls back to token parsing then to default [[1,32,32], [1,16,16]]
- positions, delta = model.get_mrope_input_positions(input_tokens)
- assert positions.shape[0] == 3
-
- def test_i2i_with_mrope_grid(self, model):
- """i2i: mrope_image_grid_thw contains source + target grids."""
- # Source image tokens: [16384, 167855*4, 16385] + text + 16384(bos)
- source_grid = [1, 2, 2] # 2x2 = 4 image tokens
- target_grid = [1, 32, 32] # 32x32 = 1024 tokens
- mrope_grid = torch.tensor([source_grid, target_grid], dtype=torch.long)
-
- # input_tokens: text + + 4*image_token + +
- input_tokens = [100, 101, 16384] + [167855] * 4 + [16385, 16384]
-
- positions, delta = model.get_mrope_input_positions(input_tokens, mrope_image_grid_thw=mrope_grid)
-
- # 1 source image (num_complete_images=1), 1 target grid (num_decode_grids=1)
- # Prefill covers all input tokens
- # Decode covers: 32*32 + 1(EOS) = 1025 tokens
- assert positions.shape[0] == 3
-
- def test_position_delta_non_negative(self, model):
- """mrope_position_delta should be non-negative for valid inputs."""
- input_tokens = [100, 16384]
- kwargs = {"target_h": 64, "target_w": 64}
- positions, delta = model.get_mrope_input_positions(input_tokens, **kwargs)
- assert delta >= 0
-
-
-# =============================================================================
-# Tests for GlmImageRotaryEmbedding._apply_mrope
-# =============================================================================
-
-
-class TestGlmImageRotaryEmbedding:
- """Test M-RoPE section interleaving in the rotary embedding."""
-
- @pytest.fixture
- def rotary_emb(self):
- # mrope_section=[8,12,12] sums to 32, so rotary_dim//2 must be >= 32
- # -> head_dim=64 gives rotary_dim=64, rotary_dim//2=32
- return GlmImageRotaryEmbedding(head_dim=64, mrope_section=[8, 12, 12])
-
- def test_apply_mrope_shape(self, rotary_emb):
- """Output shape matches [num_tokens, rotary_dim // 2]."""
- freqs = torch.randn(3, 5, 32) # 3 dims, 5 tokens, rotary_dim//2=32
- result = rotary_emb._apply_mrope(freqs)
- assert result.shape == (5, 32)
-
- def test_apply_mrope_interleaving(self, rotary_emb):
- """Verify that M-RoPE correctly interleaves T/H/W sections."""
- # mrope_section = [8, 12, 12] splits dim 32 into 3 chunks: [8, 12, 12]
- # chunk 0 (size 8): dim 0 % 3 = 0 (temporal)
- # chunk 1 (size 12): dim 1 % 3 = 1 (height)
- # chunk 2 (size 12): dim 2 % 3 = 2 (width)
- freqs = torch.ones(3, 1, 32)
- freqs[0, :, :] = 1.0 # temporal
- freqs[1, :, :] = 2.0 # height
- freqs[2, :, :] = 3.0 # width
-
- result = rotary_emb._apply_mrope(freqs)
- assert result.shape == (1, 32)
- assert (result[0, :8] == 1.0).all() # chunk 0: temporal
- assert (result[0, 8:20] == 2.0).all() # chunk 1: height
- assert (result[0, 20:32] == 3.0).all() # chunk 2: width
-
- def test_forward_1d_positions(self, rotary_emb):
- """Forward with 1D positions (text-only) produces correct shapes."""
- positions = torch.arange(10) # [10]
- q = torch.randn(10, 64)
- k = torch.randn(10, 64)
- q_out, k_out = rotary_emb(positions, q, k)
- assert q_out.shape == (10, 64)
- assert k_out.shape == (10, 64)
-
- def test_forward_3d_positions(self, rotary_emb):
- """Forward with 3D M-RoPE positions produces correct shapes."""
- positions = torch.arange(30).reshape(3, 10) # [3, 10]
- q = torch.randn(10, 64)
- k = torch.randn(10, 64)
- q_out, k_out = rotary_emb(positions, q, k)
- assert q_out.shape == (10, 64)
- assert k_out.shape == (10, 64)
-
- def test_forward_preserves_dtype(self, rotary_emb):
- """Output dtype matches input dtype."""
- positions = torch.arange(5)
- q = torch.randn(5, 64, dtype=torch.float32)
- k = torch.randn(5, 64, dtype=torch.float32)
- q_out, k_out = rotary_emb(positions, q, k)
- assert q_out.dtype == torch.float32
- assert k_out.dtype == torch.float32
diff --git a/tests/model_executor/models/mimo_audio/test_mimo_audio_code2wav_batch_decode.py b/tests/model_executor/models/mimo_audio/test_mimo_audio_code2wav_batch_decode.py
index 8858d1f8f16..85c0e8b56e4 100644
--- a/tests/model_executor/models/mimo_audio/test_mimo_audio_code2wav_batch_decode.py
+++ b/tests/model_executor/models/mimo_audio/test_mimo_audio_code2wav_batch_decode.py
@@ -2,10 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from types import SimpleNamespace
+from unittest.mock import Mock
import pytest
import torch
-from pytest_mock import MockerFixture
from vllm_omni.model_executor.models.mimo_audio.config_mimo_audio import TALKER_CODEC_PAD_TOKEN_ID
from vllm_omni.model_executor.models.mimo_audio.mimo_audio_code2wav import (
@@ -51,7 +51,7 @@ def _make_invalid_flat_immediate_eostm(eostm_id: int = 666) -> torch.Tensor:
return g.reshape(-1)
-def _minimal_model(mocker: MockerFixture):
+def _minimal_model():
"""Avoid __init__ (HF tokenizer paths); only fields used by _batch_decode_waveforms."""
model = object.__new__(MiMoAudioToken2WavForConditionalGenerationVLLM)
model.device = torch.device("cpu")
@@ -59,7 +59,7 @@ def _minimal_model(mocker: MockerFixture):
model.streamer_config = AudioStreamerConfig(group_size=_GROUP, audio_channels=_AC)
model.codes = _codes_ns()
- decode_vq = mocker.Mock(
+ decode_vq = Mock(
side_effect=lambda audio_codes: torch.ones(
audio_codes.shape[1],
7,
@@ -67,7 +67,7 @@ def _minimal_model(mocker: MockerFixture):
device=audio_codes.device,
)
)
- decoder = mocker.Mock()
+ decoder = Mock()
audio_tok = SimpleNamespace(
encoder=SimpleNamespace(decode_vq=decode_vq),
@@ -78,9 +78,9 @@ def _minimal_model(mocker: MockerFixture):
return model, audio_tok
-def test_batch_decode_waveforms_empty_input_list(mocker: MockerFixture):
+def test_batch_decode_waveforms_empty_input_list():
"""Empty input list returns a single zero-length float32 tensor on model device."""
- model, _ = _minimal_model(mocker)
+ model, _ = _minimal_model()
out = MiMoAudioToken2WavForConditionalGenerationVLLM._batch_decode_waveforms(model, [])
assert len(out) == 1
assert out[0].dtype == torch.float32
@@ -88,9 +88,9 @@ def test_batch_decode_waveforms_empty_input_list(mocker: MockerFixture):
assert out[0].device == model.device
-def test_batch_decode_waveforms_single_vs_multiple_decoder_shapes(mocker: MockerFixture):
+def test_batch_decode_waveforms_single_vs_multiple_decoder_shapes():
"""Single and multi-request batches produce correctly shaped packed hidden states and trimmed waveforms."""
- model, audio_tok = _minimal_model(mocker)
+ model, audio_tok = _minimal_model()
decoder = audio_tok.decoder
# Single valid request: decoder output rank-3 for double squeeze path
@@ -118,9 +118,9 @@ def test_batch_decode_waveforms_single_vs_multiple_decoder_shapes(mocker: Mocker
assert out2[1].shape == (8 * _FTP,)
-def test_batch_decode_waveforms_mixed_valid_invalid_requests(mocker: MockerFixture):
+def test_batch_decode_waveforms_mixed_valid_invalid_requests():
"""Mixed valid and invalid requests: invalid slots get empty tensors, valid slots get decoded waveforms."""
- model, audio_tok = _minimal_model(mocker)
+ model, audio_tok = _minimal_model()
valid_a = _make_valid_flat_codes(1)
valid_b = _make_valid_flat_codes(1)
dummy = _make_dummy_code_tensor()
@@ -151,9 +151,9 @@ def test_batch_decode_waveforms_mixed_valid_invalid_requests(mocker: MockerFixtu
assert input_lengths.tolist() == [4, 4]
-def test_batch_decode_waveforms_all_invalid_returns_per_request_empty(mocker: MockerFixture):
+def test_batch_decode_waveforms_all_invalid_returns_per_request_empty():
"""All-invalid batch skips decoder entirely and returns empty tensors for every slot."""
- model, audio_tok = _minimal_model(mocker)
+ model, audio_tok = _minimal_model()
out = MiMoAudioToken2WavForConditionalGenerationVLLM._batch_decode_waveforms(
model,
[None, _make_dummy_code_tensor(), torch.tensor([], dtype=torch.long)],
@@ -163,9 +163,9 @@ def test_batch_decode_waveforms_all_invalid_returns_per_request_empty(mocker: Mo
audio_tok.decoder.assert_not_called()
-def test_batch_decode_waveforms_output_shape_trim_when_decoder_returns_extra_samples(mocker: MockerFixture):
+def test_batch_decode_waveforms_output_shape_trim_when_decoder_returns_extra_samples():
"""Decoder output longer than valid_len is trimmed to the exact expected waveform length."""
- model, audio_tok = _minimal_model(mocker)
+ model, audio_tok = _minimal_model()
flat = _make_valid_flat_codes(1)
# Longer than valid_len so branch wav = wav[:valid_len] runs
audio_tok.decoder.return_value = torch.ones(1, 1, 10_000, dtype=torch.float32)
@@ -175,9 +175,9 @@ def test_batch_decode_waveforms_output_shape_trim_when_decoder_returns_extra_sam
assert out[0].dtype == torch.float32
-def test_batch_decode_waveforms_multi_request_trims_each_row_when_decoder_returns_extra(mocker: MockerFixture):
+def test_batch_decode_waveforms_multi_request_trims_each_row_when_decoder_returns_extra():
"""Else-branch split: per-request wav[:valid_len] when decoder pads each batch row."""
- model, audio_tok = _minimal_model(mocker)
+ model, audio_tok = _minimal_model()
a = _make_valid_flat_codes(1)
b = _make_valid_flat_codes(2)
audio_tok.decoder.return_value = torch.ones(2, 1, 10_000, dtype=torch.float32)
@@ -189,9 +189,9 @@ def test_batch_decode_waveforms_multi_request_trims_each_row_when_decoder_return
assert out[1].dtype == torch.float32
-def test_batch_decode_waveforms_valid_only_at_edges_maps_to_correct_indices(mocker: MockerFixture):
+def test_batch_decode_waveforms_valid_only_at_edges_maps_to_correct_indices():
"""Tensor packing order must match valid_indices when invalid requests are in the middle."""
- model, audio_tok = _minimal_model(mocker)
+ model, audio_tok = _minimal_model()
first = _make_valid_flat_codes(1)
last = _make_valid_flat_codes(2)
inputs = [
@@ -212,9 +212,9 @@ def test_batch_decode_waveforms_valid_only_at_edges_maps_to_correct_indices(mock
assert input_lengths.tolist() == [4, 8]
-def test_batch_decode_waveforms_output_shapes_1d_float32_for_all_slots(mocker: MockerFixture):
+def test_batch_decode_waveforms_output_shapes_1d_float32_for_all_slots():
"""Every slot is a 1-D float32 vector (empty or waveform), matching downstream expectations."""
- model, audio_tok = _minimal_model(mocker)
+ model, audio_tok = _minimal_model()
inputs = [_make_valid_flat_codes(1), None, _make_valid_flat_codes(1)]
audio_tok.decoder.return_value = torch.ones(2, 1, 5000, dtype=torch.float32)
out = MiMoAudioToken2WavForConditionalGenerationVLLM._batch_decode_waveforms(model, inputs)
diff --git a/tests/model_executor/models/ming_flash_omni/test_talker_cfm.py b/tests/model_executor/models/ming_flash_omni/test_talker_cfm.py
deleted file mode 100644
index 419ce00dae1..00000000000
--- a/tests/model_executor/models/ming_flash_omni/test_talker_cfm.py
+++ /dev/null
@@ -1,146 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from __future__ import annotations
-
-from types import SimpleNamespace
-
-import pytest
-
-from vllm_omni.model_executor.models.ming_flash_omni.talker_module import (
- CFM,
- Aggregator,
- CFMGraphExecutor,
- CFMGraphExecutorPool,
- DiT,
-)
-
-torch = pytest.importorskip("torch")
-pytest.importorskip("x_transformers")
-
-pytestmark = [
- pytest.mark.core_model,
- pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA for graph capture"),
-]
-
-_LATENT_DIM = 8
-_PATCH_SIZE = 4
-_HIS_PATCH_SIZE = 8
-_LLM_HIDDEN = 16
-_DIT_HIDDEN = 32
-_AGG_HIDDEN = 32
-_NUM_HEADS = 4
-_DEPTH = 2
-_STEPS = 5
-_DTYPE = torch.float32
-
-
-def _warmup_pipeline(cfm: CFM, aggregator: Aggregator, stop_head: torch.nn.Linear, device: torch.device) -> None:
- llm_cond = torch.randn(1, 1, _LLM_HIDDEN, device=device, dtype=_DTYPE)
- lat_cond = torch.randn(1, _HIS_PATCH_SIZE, _LATENT_DIM, device=device, dtype=_DTYPE)
- y0 = torch.randn(1, _PATCH_SIZE, _LATENT_DIM, device=device, dtype=_DTYPE)
- t = torch.linspace(0.0, 1.0, _STEPS + 1, device=device, dtype=_DTYPE)
- sde_args = torch.tensor([2.0, 0.25, 0.0], device=device, dtype=_DTYPE)
- sde_rnd = torch.randn(_STEPS, 1, _PATCH_SIZE, _LATENT_DIM, device=device, dtype=_DTYPE)
-
- with torch.no_grad():
- gen_lat = cfm.sample(llm_cond, lat_cond, y0, t, sde_args, sde_rnd)
- aggregator(gen_lat)
- stop_head(llm_cond[:, -1, :]).softmax(dim=-1)
- torch.cuda.synchronize(device)
-
-
-def _build_pipeline():
- device = torch.device("cuda")
- dit = (
- DiT(
- in_channels=_LATENT_DIM,
- hidden_size=_DIT_HIDDEN,
- depth=_DEPTH,
- num_heads=_NUM_HEADS,
- mlp_ratio=2.0,
- llm_cond_dim=_LLM_HIDDEN,
- )
- .to(device=device, dtype=_DTYPE)
- .eval()
- )
- cfm = CFM(dit, steps=_STEPS, sway_sampling_coef=-1.0).to(device=device, dtype=_DTYPE).eval()
- aggregator = (
- Aggregator(
- in_channels=_LATENT_DIM,
- hidden_size=_AGG_HIDDEN,
- depth=_DEPTH,
- num_heads=_NUM_HEADS,
- mlp_ratio=2.0,
- llm_input_dim=_LLM_HIDDEN,
- )
- .to(device=device, dtype=_DTYPE)
- .eval()
- )
- stop_head = torch.nn.Linear(_LLM_HIDDEN, 2).to(device=device, dtype=_DTYPE).eval()
-
- config = SimpleNamespace(steps=_STEPS, patch_size=_PATCH_SIZE)
- _warmup_pipeline(cfm, aggregator, stop_head, device)
- return config, cfm, aggregator, stop_head, device
-
-
-class TestCFMGraphExecutor:
- """Capture once, replay twice: outputs must stay consistently-shaped."""
-
- def test_execute_shapes_and_replay(self) -> None:
- config, cfm, aggregator, stop_head, device = _build_pipeline()
- executor = CFMGraphExecutor(config, cfm, aggregator, stop_head)
-
- bsz = 1
- input_tensor = torch.randn(bsz, 1, _LLM_HIDDEN, device=device, dtype=_DTYPE)
- his_lat = torch.randn(bsz, _HIS_PATCH_SIZE, _LATENT_DIM, device=device, dtype=_DTYPE)
-
- gen_lat, inputs_embeds, stop_out = executor.execute(input_tensor, his_lat)
- torch.cuda.synchronize()
-
- assert gen_lat.shape == (bsz, _PATCH_SIZE, _LATENT_DIM)
- assert inputs_embeds.shape == (bsz, 1, _LLM_HIDDEN)
- assert stop_out.shape == (bsz, 2)
- assert torch.isfinite(gen_lat).all()
- assert torch.isfinite(inputs_embeds).all()
- # stop_head output is softmax-normalized across the last dim.
- assert torch.allclose(stop_out.sum(dim=-1), torch.ones(bsz, device=device, dtype=_DTYPE), atol=1e-4)
-
- # Replay the captured graph with fresh inputs — shapes must match.
- new_input = torch.randn_like(input_tensor)
- new_his = torch.randn_like(his_lat)
- gen_lat2, inputs_embeds2, stop_out2 = executor.execute(new_input, new_his)
- torch.cuda.synchronize()
- assert gen_lat2.shape == gen_lat.shape
- assert inputs_embeds2.shape == inputs_embeds.shape
- assert stop_out2.shape == stop_out.shape
- assert executor.initialized is True
-
- def test_execute_is_noninplace_on_inputs(self) -> None:
- config, cfm, aggregator, stop_head, device = _build_pipeline()
- executor = CFMGraphExecutor(config, cfm, aggregator, stop_head)
-
- input_tensor = torch.randn(1, 1, _LLM_HIDDEN, device=device, dtype=_DTYPE)
- his_lat = torch.randn(1, _HIS_PATCH_SIZE, _LATENT_DIM, device=device, dtype=_DTYPE)
- snapshot_input = input_tensor.clone()
- snapshot_his = his_lat.clone()
-
- executor.execute(input_tensor, his_lat)
- torch.cuda.synchronize()
- assert torch.equal(input_tensor, snapshot_input)
- assert torch.equal(his_lat, snapshot_his)
-
-
-class TestCFMGraphExecutorPool:
- def test_pool_acquires_and_releases(self) -> None:
- config, cfm, aggregator, stop_head, device = _build_pipeline()
- pool = CFMGraphExecutorPool(config, cfm, aggregator, stop_head, pool_size=2)
-
- input_tensor = torch.randn(1, 1, _LLM_HIDDEN, device=device, dtype=_DTYPE)
- his_lat = torch.randn(1, _HIS_PATCH_SIZE, _LATENT_DIM, device=device, dtype=_DTYPE)
-
- gen_lat, inputs_embeds, stop_out = pool.execute(input_tensor, his_lat)
- torch.cuda.synchronize()
- assert gen_lat.shape == (1, _PATCH_SIZE, _LATENT_DIM)
- assert inputs_embeds.shape == (1, 1, _LLM_HIDDEN)
- assert stop_out.shape == (1, 2)
- assert pool.pool.qsize() == 2
diff --git a/tests/model_executor/models/ming_flash_omni/test_talker_modules.py b/tests/model_executor/models/ming_flash_omni/test_talker_modules.py
deleted file mode 100644
index 4cbbc887a5e..00000000000
--- a/tests/model_executor/models/ming_flash_omni/test_talker_modules.py
+++ /dev/null
@@ -1,162 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from __future__ import annotations
-
-import pytest
-
-from vllm_omni.model_executor.models.ming_flash_omni.talker_module import CFM, Aggregator, DiT
-
-torch = pytest.importorskip("torch")
-pytest.importorskip("x_transformers")
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-_LATENT_DIM = 8
-_PATCH_SIZE = 4
-_HIS_PATCH_SIZE = 8
-_LLM_HIDDEN = 16
-_DIT_HIDDEN = 32
-_AGG_HIDDEN = 32
-_NUM_HEADS = 4
-_DEPTH = 2
-_STEPS = 5
-
-
-def _make_dit() -> DiT:
- return DiT(
- in_channels=_LATENT_DIM,
- hidden_size=_DIT_HIDDEN,
- depth=_DEPTH,
- num_heads=_NUM_HEADS,
- mlp_ratio=2.0,
- llm_cond_dim=_LLM_HIDDEN,
- )
-
-
-def _make_aggregator() -> Aggregator:
- return Aggregator(
- in_channels=_LATENT_DIM,
- hidden_size=_AGG_HIDDEN,
- depth=_DEPTH,
- num_heads=_NUM_HEADS,
- mlp_ratio=2.0,
- llm_input_dim=_LLM_HIDDEN,
- )
-
-
-class TestDiTDummyForward:
- """DiT with dummy weights runs forward + CFG-doubled forward."""
-
- def test_forward_shape(self) -> None:
- dit = _make_dit().eval()
- bsz = 2
- x = torch.randn(bsz, _PATCH_SIZE, _LATENT_DIM)
- t = torch.zeros(bsz)
- c = torch.randn(bsz, 1, _LLM_HIDDEN)
- latent_history = torch.randn(bsz, _HIS_PATCH_SIZE, _LATENT_DIM)
-
- with torch.no_grad():
- out = dit(x, t, c, latent_history)
-
- # Output preserves the concatenated (history + time/cond prefix + x)
- # token axis: history + 1 (time+cond) + patch.
- assert out.shape == (bsz, _HIS_PATCH_SIZE + 1 + _PATCH_SIZE, _LATENT_DIM)
-
- def test_forward_with_cfg_trims_to_patch(self) -> None:
- dit = _make_dit().eval()
- bsz = 1
- x = torch.randn(bsz, _PATCH_SIZE, _LATENT_DIM)
- t = torch.zeros(())
- c = torch.randn(bsz, 1, _LLM_HIDDEN)
- latent_history = torch.randn(bsz, _HIS_PATCH_SIZE, _LATENT_DIM)
-
- with torch.no_grad():
- out = dit.forward_with_cfg(x, t, c, latent_history)
-
- # CFG doubles the batch and trims the output to the patch window.
- assert out.shape == (2 * bsz, _PATCH_SIZE, _LATENT_DIM)
-
-
-class TestAggregatorDummyForward:
- """Aggregator with dummy weights maps latent patch -> LLM hidden."""
-
- def test_forward_shape(self) -> None:
- agg = _make_aggregator().eval()
- bsz = 3
- gen_lat = torch.randn(bsz, _PATCH_SIZE, _LATENT_DIM)
-
- with torch.no_grad():
- out = agg(gen_lat)
-
- assert out.shape == (bsz, 1, _LLM_HIDDEN)
-
- def test_forward_is_finite(self) -> None:
- agg = _make_aggregator().eval()
- gen_lat = torch.randn(1, _PATCH_SIZE, _LATENT_DIM)
- with torch.no_grad():
- out = agg(gen_lat)
- assert torch.isfinite(out).all()
-
-
-class TestCFMSampleDummy:
- """CFM.sample drives DiT.forward_with_cfg through the integration loop."""
-
- def test_sample_shape_and_finite(self) -> None:
- cfm = CFM(_make_dit(), steps=_STEPS, sway_sampling_coef=-1.0).eval()
- bsz = 1
- llm_cond = torch.randn(bsz, 1, _LLM_HIDDEN)
- lat_cond = torch.randn(bsz, _HIS_PATCH_SIZE, _LATENT_DIM)
- y0 = torch.randn(bsz, _PATCH_SIZE, _LATENT_DIM)
- # Grid used by the talker; must span [0, 1] inclusive.
- t = torch.linspace(0.0, 1.0, _STEPS + 1)
- sde_args = torch.tensor([2.0, 0.0, 0.0]) # cfg=2.0, sigma=0, temp=0
- sde_rnd = torch.zeros(_STEPS, bsz, _PATCH_SIZE, _LATENT_DIM)
-
- with torch.no_grad():
- out = cfm.sample(llm_cond, lat_cond, y0, t, sde_args, sde_rnd)
-
- assert out.shape == y0.shape
- assert torch.isfinite(out).all()
-
- def test_sample_zero_cfg_reduces_to_unguided(self) -> None:
- """With cfg=0 the guidance term drops, but output shape is still valid."""
- cfm = CFM(_make_dit(), steps=_STEPS, sway_sampling_coef=None).eval()
- bsz = 2
- llm_cond = torch.randn(bsz, 1, _LLM_HIDDEN)
- lat_cond = torch.randn(bsz, _HIS_PATCH_SIZE, _LATENT_DIM)
- y0 = torch.zeros(bsz, _PATCH_SIZE, _LATENT_DIM)
- t = torch.linspace(0.0, 1.0, _STEPS + 1)
- sde_args = torch.tensor([0.0, 0.0, 0.0])
- sde_rnd = torch.zeros(_STEPS, bsz, _PATCH_SIZE, _LATENT_DIM)
-
- with torch.no_grad():
- out = cfm.sample(llm_cond, lat_cond, y0, t, sde_args, sde_rnd)
-
- assert out.shape == (bsz, _PATCH_SIZE, _LATENT_DIM)
- assert torch.isfinite(out).all()
-
-
-class TestTalkerPipelineDummyWiring:
- """End-to-end wiring of DiT -> CFM.sample -> Aggregator with dummy weights."""
-
- def test_cfm_then_aggregator(self) -> None:
- dit = _make_dit().eval()
- cfm = CFM(dit, steps=_STEPS, sway_sampling_coef=-1.0).eval()
- agg = _make_aggregator().eval()
-
- bsz = 1
- llm_cond = torch.randn(bsz, 1, _LLM_HIDDEN)
- lat_cond = torch.randn(bsz, _HIS_PATCH_SIZE, _LATENT_DIM)
- y0 = torch.randn(bsz, _PATCH_SIZE, _LATENT_DIM)
- t = torch.linspace(0.0, 1.0, _STEPS + 1)
- sde_args = torch.tensor([2.0, 0.0, 0.0])
- sde_rnd = torch.zeros(_STEPS, bsz, _PATCH_SIZE, _LATENT_DIM)
-
- with torch.no_grad():
- gen_lat = cfm.sample(llm_cond, lat_cond, y0, t, sde_args, sde_rnd)
- agg_out = agg(gen_lat)
-
- assert gen_lat.shape == (bsz, _PATCH_SIZE, _LATENT_DIM)
- assert agg_out.shape == (bsz, 1, _LLM_HIDDEN)
- assert torch.isfinite(agg_out).all()
diff --git a/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py b/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py
index 587e7f7f8b1..8e04b04966b 100644
--- a/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py
+++ b/tests/model_executor/models/qwen2_5_omni/test_qwen2_5_omni_embed.py
@@ -10,9 +10,10 @@
- Interleaved (use_audio_in_video) should also work correctly.
"""
+from unittest.mock import Mock
+
import pytest
import torch
-from pytest_mock import MockerFixture
from vllm.model_executor.models.qwen2_5_omni_thinker import (
check_interleaved_audio_video,
merge_interleaved_embeddings,
@@ -106,7 +107,7 @@ def test_interleaved(self):
# ---------------------------------------------------------------------------
-def make_mock_model(mocker: MockerFixture, hidden: int = 8):
+def make_mock_model(hidden: int = 8):
"""
Return a minimal mock of Qwen2_5OmniThinkerForConditionalGeneration
that has enough structure to run embed_input_ids.
@@ -115,10 +116,10 @@ def make_mock_model(mocker: MockerFixture, hidden: int = 8):
Qwen2_5OmniThinkerForConditionalGeneration,
)
- model = mocker.Mock(spec=Qwen2_5OmniThinkerForConditionalGeneration)
+ model = Mock(spec=Qwen2_5OmniThinkerForConditionalGeneration)
# Config with token IDs
- cfg = mocker.Mock()
+ cfg = Mock()
cfg.video_token_index = VIDEO_TOKEN_ID
cfg.audio_token_index = AUDIO_TOKEN_ID
model.config = cfg
@@ -129,9 +130,9 @@ def fake_lm_embed(ids: torch.Tensor) -> torch.Tensor:
# view with shared memory, which masked_scatter_ cannot handle).
return ids.float().unsqueeze(-1).expand(-1, hidden).clone()
- lang_model = mocker.Mock()
+ lang_model = Mock()
lang_model.embed_input_ids = fake_lm_embed
- model.get_language_model = mocker.Mock(return_value=lang_model)
+ model.get_language_model = Mock(return_value=lang_model)
from vllm.model_executor.models.interfaces import SupportsMultiModal
@@ -168,7 +169,7 @@ def build_mm_embeds(audio_n, image_n, video_n, hidden, audio_val=10.0, image_val
class TestEmbedInputIds:
- def _run(self, mocker: MockerFixture, audio_n, image_n, video_n, hidden=8):
+ def _run(self, audio_n, image_n, video_n, hidden=8):
"""
Run embed_input_ids for a non-interleaved mixed-modality sequence.
Returns (result_embeds, input_ids, is_multimodal).
@@ -176,33 +177,33 @@ def _run(self, mocker: MockerFixture, audio_n, image_n, video_n, hidden=8):
input_ids, is_multimodal = make_token_seq(audio_n, image_n, video_n)
mm_embeds = build_mm_embeds(audio_n, image_n, video_n, hidden)
- model, _ = make_mock_model(mocker, hidden)
+ model, _ = make_mock_model(hidden)
result = model.embed_input_ids(input_ids, mm_embeds, is_multimodal=is_multimodal)
return result, input_ids, is_multimodal
- def test_audio_only(self, mocker: MockerFixture):
+ def test_audio_only(self):
"""Audio-only: audio positions get audio embeddings."""
audio_n, hidden = 5, 8
audio_val = 10.0
- result, input_ids, is_multimodal = self._run(mocker, audio_n, 0, 0, hidden)
+ result, input_ids, is_multimodal = self._run(audio_n, 0, 0, hidden)
audio_pos = (input_ids == AUDIO_TOKEN_ID).nonzero(as_tuple=True)[0]
assert result[audio_pos].allclose(torch.full((audio_n, hidden), audio_val)), (
"Audio positions should get audio embeddings"
)
- def test_video_only(self, mocker: MockerFixture):
+ def test_video_only(self):
"""Video-only: video positions get video embeddings."""
video_n, hidden = 6, 8
video_val = 30.0
- result, input_ids, is_multimodal = self._run(mocker, 0, 0, video_n, hidden)
+ result, input_ids, is_multimodal = self._run(0, 0, video_n, hidden)
video_pos = (input_ids == VIDEO_TOKEN_ID).nonzero(as_tuple=True)[0]
assert result[video_pos].allclose(torch.full((video_n, hidden), video_val)), (
"Video positions should get video embeddings"
)
- def test_mixed_modalities_audio_goes_to_audio_pos(self, mocker: MockerFixture):
+ def test_mixed_modalities_audio_goes_to_audio_pos(self):
"""
Regression test for GitHub issue #34506:
With audio + image + video (non-interleaved), audio positions must
@@ -211,7 +212,7 @@ def test_mixed_modalities_audio_goes_to_audio_pos(self, mocker: MockerFixture):
audio_n, image_n, video_n, hidden = 5, 4, 6, 8
audio_val, image_val, video_val = 10.0, 20.0, 30.0
- result, input_ids, is_multimodal = self._run(mocker, audio_n, image_n, video_n, hidden)
+ result, input_ids, is_multimodal = self._run(audio_n, image_n, video_n, hidden)
audio_pos = (input_ids == AUDIO_TOKEN_ID).nonzero(as_tuple=True)[0]
image_pos = (input_ids == IMAGE_TOKEN_ID).nonzero(as_tuple=True)[0]
@@ -232,10 +233,10 @@ def test_mixed_modalities_audio_goes_to_audio_pos(self, mocker: MockerFixture):
f"Video emb wrong: expected {video_val}, got mean={mean_v:.1f}"
)
- def test_text_positions_unchanged(self, mocker: MockerFixture):
+ def test_text_positions_unchanged(self):
"""Text positions should keep their text embeddings."""
audio_n, image_n, video_n, hidden = 3, 2, 4, 8
- result, input_ids, is_multimodal = self._run(mocker, audio_n, image_n, video_n, hidden)
+ result, input_ids, is_multimodal = self._run(audio_n, image_n, video_n, hidden)
text_pos = (~is_multimodal).nonzero(as_tuple=True)[0]
# Text tokens have value TEXT_TOKEN_ID=0, so embed -> 0.0
@@ -243,7 +244,7 @@ def test_text_positions_unchanged(self, mocker: MockerFixture):
"Text positions should keep text embeddings"
)
- def test_interleaved_use_audio_in_video(self, mocker: MockerFixture):
+ def test_interleaved_use_audio_in_video(self):
"""
Interleaved (use_audio_in_video): video chunks interleaved with audio.
Video embeddings must go to video positions, audio to audio positions.
@@ -262,7 +263,7 @@ def test_interleaved_use_audio_in_video(self, mocker: MockerFixture):
torch.full((audio_n, hidden), audio_val),
]
- model, _ = make_mock_model(mocker, hidden)
+ model, _ = make_mock_model(hidden)
result = model.embed_input_ids(input_ids, mm_embeds, is_multimodal=is_multimodal)
video_pos = (input_ids == VIDEO_TOKEN_ID).nonzero(as_tuple=True)[0]
diff --git a/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py b/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py
deleted file mode 100644
index 8798cb3ca9a..00000000000
--- a/tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py
+++ /dev/null
@@ -1,335 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""
-Tests for code predictor dtype alignment (fix for #2385).
-
-Verifies that the code predictor handles dtype mismatches between input
-tensors and model parameters without raising RuntimeError. This can happen
-when model weights are loaded in float16/bfloat16 but upstream modules
-produce float32 hidden states.
-"""
-
-from __future__ import annotations
-
-import importlib.util
-import os
-import sys
-import types
-
-import pytest
-import torch
-from pytest_mock import MockerFixture
-
-# Direct file import to avoid vllm_omni.__init__ patch dependencies.
-_MODELS = os.path.join(
- os.path.dirname(__file__),
- os.pardir,
- os.pardir,
- os.pardir,
- os.pardir,
- "vllm_omni",
- "model_executor",
- "models",
-)
-_BASE = os.path.join(_MODELS, "qwen3_tts")
-_COMMON = os.path.join(_MODELS, "common")
-
-
-def _load_module(name: str, filename: str):
- path = os.path.abspath(os.path.join(_BASE, filename))
- spec = importlib.util.spec_from_file_location(name, path)
- mod = importlib.util.module_from_spec(spec)
- sys.modules[name] = mod # register before exec (needed for dataclasses etc.)
- spec.loader.exec_module(mod)
- return mod
-
-
-def _build_mock_modules(mocker: MockerFixture) -> dict[str, object]:
- """Build the dict of modules to inject into sys.modules."""
- platforms_mock = mocker.MagicMock()
- platforms_mock.current_omni_platform.supports_torch_inductor.return_value = False
-
- logger_mock = mocker.MagicMock()
- logger_mock.init_logger = lambda name: mocker.MagicMock()
-
- vllm_config_mod = mocker.MagicMock()
- vllm_config_mod.set_current_vllm_config = lambda cfg: mocker.MagicMock(
- __enter__=mocker.MagicMock(),
- __exit__=mocker.MagicMock(),
- )
-
- weight_utils_mock = mocker.MagicMock()
- weight_utils_mock.default_weight_loader = lambda p, w: None
-
- tts_pkg = types.ModuleType("vllm_omni.model_executor.models.qwen3_tts")
- tts_pkg.__path__ = [os.path.abspath(_BASE)]
-
- common_pkg = types.ModuleType("vllm_omni.model_executor.models.common")
- common_pkg.__path__ = [os.path.abspath(_COMMON)]
-
- models_pkg = types.ModuleType("vllm_omni.model_executor.models")
- models_pkg.__path__ = [os.path.abspath(_MODELS)]
-
- vllm_parallel_mock = mocker.MagicMock()
- vllm_parallel_mock.VocabParallelEmbedding = torch.nn.Embedding
-
- return {
- "vllm_omni": mocker.MagicMock(),
- "vllm_omni.platforms": platforms_mock,
- "vllm.logger": logger_mock,
- "vllm.config": mocker.MagicMock(),
- "vllm.config.vllm": vllm_config_mod,
- "vllm.model_executor.model_loader.weight_utils": weight_utils_mock,
- "vllm.model_executor.layers.vocab_parallel_embedding": vllm_parallel_mock,
- "vllm_omni.model_executor": types.ModuleType("vllm_omni.model_executor"),
- "vllm_omni.model_executor.models": models_pkg,
- "vllm_omni.model_executor.models.common": common_pkg,
- "vllm_omni.model_executor.models.qwen3_tts": tts_pkg,
- }
-
-
-def _load_target_classes(mocker: MockerFixture):
- """Load config and code predictor modules with mocked dependencies.
-
- Uses mocker.patch.dict to ensure sys.modules is always restored, even on failure.
- """
- mocks = _build_mock_modules(mocker)
- mocker.patch.dict(sys.modules, mocks)
- config_mod = _load_module(
- "vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts",
- "configuration_qwen3_tts.py",
- )
- sys.modules["vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts"] = config_mod
-
- # Load the shared common module (thin wrappers import from it)
- common_cp_path = os.path.abspath(os.path.join(_COMMON, "qwen3_code_predictor.py"))
- common_spec = importlib.util.spec_from_file_location(
- "vllm_omni.model_executor.models.common.qwen3_code_predictor", common_cp_path
- )
- common_cp_mod = importlib.util.module_from_spec(common_spec)
- sys.modules["vllm_omni.model_executor.models.common.qwen3_code_predictor"] = common_cp_mod
- common_spec.loader.exec_module(common_cp_mod)
-
- cp_mod = _load_module(
- "vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code_predictor_vllm",
- "qwen3_tts_code_predictor_vllm.py",
- )
-
- return config_mod, cp_mod
-
-
-@pytest.fixture
-def loaded_target_classes(mocker: MockerFixture):
- config_mod, cp_mod = _load_target_classes(mocker)
- return (
- config_mod.Qwen3TTSTalkerCodePredictorConfig,
- config_mod.Qwen3TTSTalkerConfig,
- cp_mod.Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM,
- cp_mod.Qwen3TTSTalkerCodePredictorModelVLLM,
- cp_mod.CodePredictorWrapperConfig,
- )
-
-
-def _make_tiny_config(loaded_target_classes) -> tuple:
- """Create minimal configs for a tiny code predictor model."""
- (
- qwen3_tts_talker_code_predictor_config,
- qwen3_tts_talker_config,
- _,
- _,
- _,
- ) = loaded_target_classes
- cp_config = qwen3_tts_talker_code_predictor_config(
- vocab_size=64,
- hidden_size=32,
- intermediate_size=64,
- num_hidden_layers=1,
- num_attention_heads=4,
- num_key_value_heads=2,
- head_dim=8,
- num_code_groups=4,
- rms_norm_eps=1e-6,
- )
- talker_config = qwen3_tts_talker_config(
- hidden_size=32,
- num_code_groups=4,
- )
- return cp_config, talker_config
-
-
-def _make_vllm_config(mocker: MockerFixture, max_num_seqs: int = 4):
- """Create a mock VllmConfig with scheduler_config."""
- vllm_config = mocker.MagicMock()
- vllm_config.scheduler_config.max_num_seqs = max_num_seqs
- return vllm_config
-
-
-class TestCodePredictorDtypeAlignment:
- """Test that code predictor buffers match model parameter dtype."""
-
- def test_ensure_buffers_uses_given_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None:
- """_ensure_buffers should create proj_buf with the given dtype."""
- _, _, code_predictor_wrapper, _, _ = loaded_target_classes
- cp_config, talker_config = _make_tiny_config(loaded_target_classes)
- vllm_config = _make_vllm_config(mocker)
-
- predictor = code_predictor_wrapper(
- vllm_config=vllm_config,
- config=cp_config,
- talker_config=talker_config,
- )
-
- # Create buffer in float16
- predictor._ensure_buffers(torch.device("cpu"), torch.float16, 4)
- assert predictor._proj_buf is not None
- assert predictor._proj_buf.dtype == torch.float16
-
- # Re-create buffer in float32 (different dtype triggers re-allocation)
- predictor._ensure_buffers(torch.device("cpu"), torch.float32, 4)
- assert predictor._proj_buf.dtype == torch.float32
-
- def test_warmup_aligns_buffer_to_model_params(self, mocker: MockerFixture, loaded_target_classes) -> None:
- """_warmup_buckets should align proj_buf dtype to model parameters."""
- _, _, code_predictor_wrapper, _, _ = loaded_target_classes
- cp_config, talker_config = _make_tiny_config(loaded_target_classes)
- vllm_config = _make_vllm_config(mocker, max_num_seqs=2)
-
- predictor = code_predictor_wrapper(
- vllm_config=vllm_config,
- config=cp_config,
- talker_config=talker_config,
- )
-
- # Cast model to float16 (simulating vLLM loading weights in half precision)
- predictor = predictor.to(torch.float16)
-
- # Pre-create proj_buf with WRONG dtype (float32) — simulating the bug
- predictor._ensure_buffers(torch.device("cpu"), torch.float32, 2)
- assert predictor._proj_buf.dtype == torch.float32
-
- # Simulate _setup_compile having cached model dtype and compiled forward
- predictor._model_dtype = torch.float16
- predictor._compiled_model_fwd = predictor.model.forward
-
- # _warmup_buckets should fix the dtype mismatch
- predictor._warmup_buckets()
-
- assert predictor._proj_buf.dtype == torch.float16
-
- def test_setup_compile_caches_model_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None:
- """_setup_compile should cache model parameter dtype."""
- _, _, code_predictor_wrapper, _, _ = loaded_target_classes
- cp_config, talker_config = _make_tiny_config(loaded_target_classes)
- vllm_config = _make_vllm_config(mocker, max_num_seqs=2)
-
- predictor = code_predictor_wrapper(
- vllm_config=vllm_config,
- config=cp_config,
- talker_config=talker_config,
- )
- predictor = predictor.to(torch.float16)
-
- assert predictor._model_dtype is None
- predictor._setup_compile()
- assert predictor._model_dtype == torch.float16
-
- def test_forward_with_mismatched_input_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None:
- """forward() should not crash when inputs are float32 but model is float16."""
- _, _, code_predictor_wrapper, _, _ = loaded_target_classes
- cp_config, talker_config = _make_tiny_config(loaded_target_classes)
- vllm_config = _make_vllm_config(mocker, max_num_seqs=2)
-
- predictor = code_predictor_wrapper(
- vllm_config=vllm_config,
- config=cp_config,
- talker_config=talker_config,
- )
-
- # Model in float16
- predictor = predictor.to(torch.float16)
-
- bsz = 1
- num_groups = cp_config.num_code_groups
- hidden = talker_config.hidden_size
-
- # Inputs in float32 (simulating the dtype mismatch from #2385)
- layer0_code = torch.zeros(bsz, dtype=torch.long)
- layer0_embed = torch.randn(bsz, hidden, dtype=torch.float32)
- last_talker_hidden = torch.randn(bsz, hidden, dtype=torch.float32)
-
- # This should NOT raise RuntimeError about dtype mismatch
- result = predictor(
- layer0_code=layer0_code,
- layer0_embed=layer0_embed,
- last_talker_hidden=last_talker_hidden,
- do_sample=False,
- )
-
- assert result.shape == (bsz, num_groups)
- assert result.dtype == torch.long
-
-
-class TestCodePredictorModelDtype:
- """Test the inner model forward with different dtypes."""
-
- def test_model_forward_float16(self, loaded_target_classes) -> None:
- """Inner model forward should work in float16."""
- _, _, _, code_predictor_model, _ = loaded_target_classes
- cp_config, _ = _make_tiny_config(loaded_target_classes)
- model = code_predictor_model(cp_config, embedding_dim=32).to(torch.float16)
-
- bsz, seq_len = 1, 4
- inputs = torch.randn(bsz, seq_len, 32, dtype=torch.float16)
- pos_ids = torch.arange(seq_len).unsqueeze(0).expand(bsz, -1)
-
- output = model(inputs, pos_ids)
- assert output.dtype == torch.float16
- assert output.shape == (bsz, seq_len, 32)
-
- def test_model_forward_float32(self, loaded_target_classes) -> None:
- """Inner model forward should work in float32."""
- _, _, _, code_predictor_model, _ = loaded_target_classes
- cp_config, _ = _make_tiny_config(loaded_target_classes)
- model = code_predictor_model(cp_config, embedding_dim=32).to(torch.float32)
-
- bsz, seq_len = 1, 4
- inputs = torch.randn(bsz, seq_len, 32, dtype=torch.float32)
- pos_ids = torch.arange(seq_len).unsqueeze(0).expand(bsz, -1)
-
- output = model(inputs, pos_ids)
- assert output.dtype == torch.float32
- assert output.shape == (bsz, seq_len, 32)
-
-
-class TestCodePredictorWrapperConfig:
- """Test wrapper configuration for different models."""
-
- def test_omni_config(self, loaded_target_classes) -> None:
- """Qwen3-Omni uses correct wrapper config."""
- _, _, _, _, code_predictor_wrapper_config = loaded_target_classes
- config = code_predictor_wrapper_config(
- use_cuda_graphs=False,
- use_parallel_embedding=True,
- use_projection=False,
- return_proj_buf=True,
- sampling_mode="stored",
- )
- assert config.use_cuda_graphs is False
- assert config.use_parallel_embedding is True
- assert config.return_proj_buf is True
- assert config.sampling_mode == "stored"
-
- def test_tts_config(self, loaded_target_classes) -> None:
- """Qwen3-TTS uses correct wrapper config."""
- _, _, _, _, code_predictor_wrapper_config = loaded_target_classes
- config = code_predictor_wrapper_config(
- use_cuda_graphs=True,
- use_parallel_embedding=False,
- use_projection=True,
- return_proj_buf=False,
- sampling_mode="per_call",
- )
- assert config.use_cuda_graphs is True
- assert config.use_parallel_embedding is False
- assert config.return_proj_buf is False
- assert config.sampling_mode == "per_call"
diff --git a/tests/model_executor/models/qwen3_tts/test_qwen3_tts_code2wav.py b/tests/model_executor/models/qwen3_tts/test_qwen3_tts_code2wav.py
deleted file mode 100644
index d7559a72d9c..00000000000
--- a/tests/model_executor/models/qwen3_tts/test_qwen3_tts_code2wav.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from types import SimpleNamespace
-
-import pytest
-import torch
-import torch.nn as nn
-
-from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code2wav import Qwen3TTSCode2Wav
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-class _FakeDecoder(nn.Module):
- def __init__(self, total_upsample: int = 4):
- super().__init__()
- self.total_upsample = total_upsample
-
- def chunked_decode(self, codes: torch.Tensor) -> torch.Tensor:
- frames = codes.shape[-1]
- wav_len = frames * self.total_upsample + 6
- wav = torch.arange(wav_len, dtype=torch.float32)
- return wav.view(1, 1, -1)
-
-
-def _make_model() -> Qwen3TTSCode2Wav:
- model = Qwen3TTSCode2Wav(
- vllm_config=SimpleNamespace(
- model_config=SimpleNamespace(model="unused"),
- device_config=SimpleNamespace(device=torch.device("cpu")),
- )
- )
- model._decoder = _FakeDecoder()
- model._num_quantizers = 2
- model._output_sample_rate = 24000
- model._total_upsample = 4
- model._ensure_speech_tokenizer_loaded = lambda: None
- return model
-
-
-def test_forward_trims_context_on_exact_frame_boundaries():
- model = _make_model()
-
- out = model.forward(
- input_ids=torch.arange(12, dtype=torch.long),
- runtime_additional_information=[{"meta": {"left_context_size": 2}}],
- )
-
- audio = out.multimodal_outputs["model_outputs"][0]
- expected = torch.arange(8, 24, dtype=torch.float32)
- torch.testing.assert_close(audio, expected)
-
-
-def test_forward_trims_trailing_padding_without_context():
- model = _make_model()
-
- out = model.forward(
- input_ids=torch.arange(12, dtype=torch.long),
- runtime_additional_information=[{"meta": {"left_context_size": 0}}],
- )
-
- audio = out.multimodal_outputs["model_outputs"][0]
- expected = torch.arange(24, dtype=torch.float32)
- torch.testing.assert_close(audio, expected)
diff --git a/tests/model_executor/models/test_encoder_quant_config.py b/tests/model_executor/models/test_encoder_quant_config.py
deleted file mode 100644
index 80201849863..00000000000
--- a/tests/model_executor/models/test_encoder_quant_config.py
+++ /dev/null
@@ -1,77 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Regression test for #2686: pre-quantized methods must not apply
-quant config to vision / audio encoders.
-
-For modelopt FP8/FP4/MXFP8 checkpoints the Thinker LM is the only
-quantized component. Vision and audio encoder weights are BF16 with no
-FP8 scale tensors — passing quant_config to them causes FP8 kernels to
-run on BF16 weights, producing garbage embeddings.
-"""
-
-from __future__ import annotations
-
-from unittest.mock import MagicMock
-
-import pytest
-
-from vllm_omni.quantization.component_config import (
- PRE_QUANTIZED_METHODS,
- ComponentQuantizationConfig,
- resolve_encoder_quant_config,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-# ---------------------------------------------------------------------------
-# resolve_encoder_quant_config — the core routing logic for encoder quant
-# ---------------------------------------------------------------------------
-
-
-@pytest.mark.parametrize("method", sorted(PRE_QUANTIZED_METHODS))
-def test_pre_quantized_returns_none(method: str) -> None:
- """visual_quant_config and audio_quant_config must be None for
- pre-quantized methods (modelopt, modelopt_fp4, modelopt_mxfp8)."""
- mock_config = MagicMock()
- mock_config.get_name.return_value = method
-
- assert resolve_encoder_quant_config(mock_config) is None
-
-
-@pytest.mark.parametrize("method", ["fp8", "awq", "gptq", "bitsandbytes"])
-def test_non_pre_quantized_preserves_config(method: str) -> None:
- """Non-pre-quantized methods should pass through the original config."""
- mock_config = MagicMock()
- mock_config.get_name.return_value = method
-
- assert resolve_encoder_quant_config(mock_config) is mock_config
-
-
-def test_none_input_returns_none() -> None:
- """No quantization → None for encoders."""
- assert resolve_encoder_quant_config(None) is None
-
-
-def test_component_config_passed_through() -> None:
- """ComponentQuantizationConfig should be returned as-is so the caller
- can call .resolve() with the appropriate prefix."""
- inner = MagicMock()
- inner.get_name.return_value = "modelopt" # would be None if not Component
- component = ComponentQuantizationConfig(
- component_configs={"language_model": inner},
- default_config=None,
- )
-
- result = resolve_encoder_quant_config(component)
- assert result is component
-
-
-# ---------------------------------------------------------------------------
-# PRE_QUANTIZED_METHODS constant — exhaustiveness check
-# ---------------------------------------------------------------------------
-
-
-def test_pre_quantized_methods_contains_expected() -> None:
- """Guard against accidental removal of a known pre-quantized method."""
- expected = {"modelopt", "modelopt_fp4", "modelopt_mxfp8"}
- assert PRE_QUANTIZED_METHODS == expected
diff --git a/tests/model_executor/models/test_fish_speech_regressions.py b/tests/model_executor/models/test_fish_speech_regressions.py
index b8cad93106a..1f8c3cf71e8 100644
--- a/tests/model_executor/models/test_fish_speech_regressions.py
+++ b/tests/model_executor/models/test_fish_speech_regressions.py
@@ -45,7 +45,7 @@ def test_dac_decoder_mixed_batch_empty_request_does_not_misalign_indices():
out = decoder.forward(
input_ids=torch.arange(20, dtype=torch.long),
- runtime_additional_information=[{}, {"meta": {"left_context_size": 1}}],
+ runtime_additional_information=[{}, {"left_context_size": 1}],
)
audios = out.multimodal_outputs["model_outputs"]
@@ -80,6 +80,8 @@ def test_structured_voice_clone_prefill_adds_full_codebooks_with_decode_scale(mo
model.codebook_embeddings = codebook_embed
model._get_tokenizer = lambda: _FakeTokenizer({"<|audio_start|>": 10, "<|audio_end|>": 11})
+ monkeypatch.setattr(slow_ar_module.np, "load", lambda path: [0.0])
+ monkeypatch.setattr(slow_ar_module.os, "remove", lambda path: None)
monkeypatch.setattr(
slow_ar_module,
"encode_reference_audio_codes",
@@ -95,7 +97,7 @@ def test_structured_voice_clone_prefill_adds_full_codebooks_with_decode_scale(mo
{
"ref_text": "ref",
"text": "target",
- "ref_audio_wav": torch.tensor([0.0]),
+ "ref_audio_path": "unused.npy",
"ref_audio_sr": 16000,
}
)
diff --git a/tests/model_executor/models/test_fish_speech_voice_cache.py b/tests/model_executor/models/test_fish_speech_voice_cache.py
deleted file mode 100644
index fef4b551ab2..00000000000
--- a/tests/model_executor/models/test_fish_speech_voice_cache.py
+++ /dev/null
@@ -1,218 +0,0 @@
-"""Tests for Fish Speech DAC-code caching via VoiceEmbeddingCache.
-
-Covers:
- - Cache miss → DAC encode → store
- - Cache hit → skip DAC encode, reuse cached ref_codes_fq
- - Inline ref_audio (no voice name) → no caching, full encode path
- - Stale-cache protection via created_at
- - Temp file cleanup on cache hit
-"""
-
-import os
-import tempfile
-
-import numpy as np
-import pytest
-import torch
-from pytest_mock import MockerFixture
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def _make_info_dict(
- *,
- text: str = "Hello world",
- ref_text: str = "Reference transcript",
- ref_audio_sr: int = 44100,
- voice_name: str | None = None,
- voice_created_at: float | None = None,
- ref_audio_path: str | None = None,
-) -> dict:
- """Build a minimal info_dict for _build_structured_voice_clone_prefill_embeds."""
- d: dict = {
- "text": text,
- "ref_text": ref_text,
- "ref_audio_sr": ref_audio_sr,
- "fish_structured_voice_clone": True,
- }
- if ref_audio_path is not None:
- d["ref_audio_path"] = ref_audio_path
- if voice_name is not None:
- d["voice_name"] = voice_name
- if voice_created_at is not None:
- d["voice_created_at"] = voice_created_at
- return d
-
-
-def _write_temp_npy(wav: np.ndarray | None = None) -> str:
- """Write a temporary .npy file with dummy audio and return its path."""
- if wav is None:
- wav = np.random.randn(44100).astype(np.float32) # 1 second @ 44.1kHz
- with tempfile.NamedTemporaryFile(prefix="fish_test_", suffix=".npy", delete=False) as f:
- np.save(f, wav)
- return f.name
-
-
-# Fake ref_codes_fq: [frames, codebooks]
-_FAKE_REF_CODES = torch.randint(0, 1024, (10, 10), dtype=torch.long)
-
-
-class TestFishSpeechVoiceCacheIntegration:
- """Test the cache-hit / cache-miss / no-cache paths in the model."""
-
- @pytest.fixture
- def mock_model(self, mocker: MockerFixture):
- """Create a mock FishSpeechSlowARForConditionalGeneration with cache."""
- from vllm_omni.utils.voice_cache import VoiceEmbeddingCache
-
- model = mocker.MagicMock()
- model._voice_cache = VoiceEmbeddingCache(max_entries=4)
- model._semantic_begin_id = 151678
- model._num_codebooks = 10
- model._codebook_size = 4096
- model.model_path = "/fake/model"
- model.codebook_embeddings = mocker.MagicMock()
- model.codebook_embeddings.weight = mocker.MagicMock()
- model.codebook_embeddings.weight.device = torch.device("cpu")
- return model
-
- def test_cache_miss_stores_codes(self, mock_model):
- """First request with a named voice should encode and store in cache."""
- cache = mock_model._voice_cache
- voice_name = "alice"
- created_at = 1712345678.0
-
- # Verify cache starts empty.
- key = cache.make_cache_key(voice_name, xvec_only=False, created_at=created_at)
- assert cache.get(key) is None
-
- # Simulate a cache store (what the model does on miss).
- cache.put(key, {"ref_codes_fq": _FAKE_REF_CODES.detach().cpu()})
-
- # Verify it's now cached.
- cached = cache.get(key)
- assert cached is not None
- assert torch.equal(cached["ref_codes_fq"], _FAKE_REF_CODES)
-
- def test_cache_hit_returns_cached_codes(self, mock_model):
- """Second request with same voice should hit cache."""
- cache = mock_model._voice_cache
- voice_name = "alice"
- created_at = 1712345678.0
-
- key = cache.make_cache_key(voice_name, xvec_only=False, created_at=created_at)
- cache.put(key, {"ref_codes_fq": _FAKE_REF_CODES.detach().cpu()})
-
- # Hit.
- cached = cache.get(key)
- assert cached is not None
- ref_codes = cached["ref_codes_fq"].to(device=torch.device("cpu"), dtype=torch.long)
- assert torch.equal(ref_codes, _FAKE_REF_CODES)
- assert cache.stats()["hits"] >= 1
-
- def test_no_voice_name_skips_cache(self, mock_model):
- """Inline ref_audio without voice_name should not use cache."""
- cache = mock_model._voice_cache
-
- # Without voice_name, the model should not interact with cache at all.
- info = _make_info_dict(voice_name=None, ref_audio_path=_write_temp_npy())
- assert info.get("voice_name") is None
- # Cache should remain untouched.
- assert cache.stats()["hits"] == 0
- assert cache.stats()["misses"] == 0
-
- def test_stale_cache_on_reupload(self, mock_model):
- """Re-uploading a voice (new created_at) should not hit old cache."""
- cache = mock_model._voice_cache
- voice_name = "alice"
-
- key_old = cache.make_cache_key(voice_name, xvec_only=False, created_at=1000.0)
- cache.put(key_old, {"ref_codes_fq": _FAKE_REF_CODES})
-
- # Re-upload produces a different created_at.
- key_new = cache.make_cache_key(voice_name, xvec_only=False, created_at=2000.0)
- assert cache.get(key_new) is None # miss
- assert cache.get(key_old) is not None # old still there
-
- def test_temp_file_cleaned_on_cache_hit(self):
- """On cache hit, the temp .npy file written by the entrypoint should be deleted."""
- tmp_path = _write_temp_npy()
- assert os.path.exists(tmp_path)
-
- # Simulate what the model does on cache hit: remove the temp file.
- try:
- os.remove(tmp_path)
- except OSError:
- pass
- assert not os.path.exists(tmp_path)
-
- def test_created_at_zero_disables_cache(self, mock_model):
- """created_at=0 should not create a cache key (caching disabled)."""
- cache = mock_model._voice_cache
-
- info = _make_info_dict(
- voice_name="bob",
- voice_created_at=0.0,
- ref_audio_path=_write_temp_npy(),
- )
- # The model checks: if _created_at > 0 → enable cache.
- # With 0.0, no cache interaction should happen.
- _created_at = float(info.get("voice_created_at", 0))
- assert _created_at <= 0
- assert cache.stats()["hits"] == 0
- assert cache.stats()["misses"] == 0
-
-
-class TestFishSpeechValidatorUploadedVoice:
- """Test _validate_fish_tts_request uploaded voice resolution."""
-
- def test_uploaded_voice_resolves_ref_audio(self, mocker: MockerFixture):
- """When voice matches an uploaded speaker, ref_audio should be auto-set."""
- request = mocker.MagicMock()
- request.input = "Hello"
- request.voice = "alice"
- request.ref_audio = None
- request.ref_text = None
- request.max_new_tokens = None
-
- # Uploaded speaker with ref_text.
- uploaded_speakers = {
- "alice": {
- "file_path": "/tmp/fake_audio.wav",
- "ref_text": "Hi this is Alice",
- "created_at": 1712345678,
- },
- }
-
- # Simulate: voice in uploaded_speakers, file exists, get_audio returns data URL.
- mocker.patch("pathlib.Path.exists", return_value=True)
- voice_lower = request.voice.lower()
- assert voice_lower in uploaded_speakers
-
- speaker_info = uploaded_speakers[voice_lower]
- ref_text_from_upload = speaker_info.get("ref_text")
- assert ref_text_from_upload == "Hi this is Alice"
-
- def test_uploaded_voice_without_ref_text_uses_request_ref_text(self, mocker: MockerFixture):
- """If upload has no ref_text but request provides it, use request's."""
- request = mocker.MagicMock()
- request.input = "Hello"
- request.voice = "bob"
- request.ref_audio = None
- request.ref_text = "Request-level transcript"
- request.max_new_tokens = None
-
- uploaded_speakers = {
- "bob": {
- "file_path": "/tmp/fake_audio.wav",
- "ref_text": None,
- "created_at": 1712345678,
- },
- }
-
- voice_lower = request.voice.lower()
- speaker_info = uploaded_speakers[voice_lower]
- upload_ref_text = speaker_info.get("ref_text")
- # Upload has no ref_text, so request.ref_text should remain.
- assert upload_ref_text is None
- assert request.ref_text == "Request-level transcript"
diff --git a/tests/model_executor/models/voxcpm2/__init__.py b/tests/model_executor/models/voxcpm2/__init__.py
deleted file mode 100644
index 208f01a7cb5..00000000000
--- a/tests/model_executor/models/voxcpm2/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
diff --git a/tests/model_executor/models/voxcpm2/test_talker_state_eviction.py b/tests/model_executor/models/voxcpm2/test_talker_state_eviction.py
deleted file mode 100644
index 929e8a36adc..00000000000
--- a/tests/model_executor/models/voxcpm2/test_talker_state_eviction.py
+++ /dev/null
@@ -1,120 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Regression tests for VoxCPM2 talker per-request state lifecycle."""
-
-from __future__ import annotations
-
-import pytest
-
-torch = pytest.importorskip("torch")
-
-from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import ( # noqa: E402
- VoxCPM2TalkerForConditionalGeneration,
- _RequestState,
-)
-
-
-def _make_bare_talker() -> VoxCPM2TalkerForConditionalGeneration:
- talker = VoxCPM2TalkerForConditionalGeneration.__new__(VoxCPM2TalkerForConditionalGeneration)
- talker._active_states = {}
- talker._current_request_id = None
- talker._pending_requests = []
- talker._results_queue = []
- talker._audio_queue = []
- talker._deferred_cleanup_ids = set()
- talker._max_batch_size = 4
- talker._active_state_warn_threshold = 512
- talker._active_state_warned = False
- return talker
-
-
-def _seed_cached_decode(talker, req_id: str) -> _RequestState:
- state = _RequestState(request_id=req_id)
- state.prefill_completed = True
- state.decode_step_count = 5
- talker._active_states[req_id] = state
- return state
-
-
-class TestStateEvictionContract:
- def test_pending_requests_is_not_used_for_eviction(self) -> None:
- talker = _make_bare_talker()
-
- cached_ids = [f"req-{i}" for i in range(4)]
- for rid in cached_ids:
- _seed_cached_decode(talker, rid)
-
- walked_so_far = ["req-new", cached_ids[0], cached_ids[1]]
- talker._pending_requests = [(rid, False, None, 0) for rid in walked_so_far]
-
- for rid in cached_ids:
- assert rid in talker._active_states
- assert talker._active_states[rid].prefill_completed is True
-
- def test_on_requests_finished_defers_cleanup(self) -> None:
- talker = _make_bare_talker()
- _seed_cached_decode(talker, "req-A")
- _seed_cached_decode(talker, "req-B")
-
- talker.on_requests_finished({"req-A"})
-
- assert "req-A" in talker._active_states
- assert "req-A" in talker._deferred_cleanup_ids
-
- def test_flush_deferred_cleanup_removes_only_finished(self) -> None:
- talker = _make_bare_talker()
- _seed_cached_decode(talker, "req-A")
- _seed_cached_decode(talker, "req-B")
- talker.on_requests_finished(["req-A"])
-
- talker._flush_deferred_cleanup()
-
- assert "req-A" not in talker._active_states
- assert "req-B" in talker._active_states
- assert talker._deferred_cleanup_ids == set()
-
- def test_current_request_id_cleared_when_matching(self) -> None:
- talker = _make_bare_talker()
- _seed_cached_decode(talker, "req-A")
- talker._current_request_id = "req-A"
-
- talker.on_requests_finished({"req-A"})
- talker._flush_deferred_cleanup()
-
- assert talker._current_request_id is None
-
- def test_current_request_id_preserved_when_not_finished(self) -> None:
- talker = _make_bare_talker()
- _seed_cached_decode(talker, "req-A")
- _seed_cached_decode(talker, "req-B")
- talker._current_request_id = "req-B"
-
- talker.on_requests_finished({"req-A"})
- talker._flush_deferred_cleanup()
-
- assert talker._current_request_id == "req-B"
-
-
-class TestLeakWarnGuard:
- def test_warn_fires_once_over_threshold(self, monkeypatch) -> None:
- from vllm_omni.model_executor.models.voxcpm2 import voxcpm2_talker as tk
-
- calls: list[str] = []
-
- def _capture(msg, *args, **kwargs):
- calls.append(msg % args if args else msg)
-
- monkeypatch.setattr(tk.logger, "warning", _capture)
-
- talker = _make_bare_talker()
- talker._active_state_warn_threshold = 3
-
- for i in range(4):
- talker._active_states[f"seed-{i}"] = _RequestState(request_id=f"seed-{i}")
-
- talker._get_or_create_state("new-1")
- talker._get_or_create_state("new-2")
-
- leak_warnings = [m for m in calls if "cleanup path leak" in m]
- assert len(leak_warnings) == 1
- assert talker._active_state_warned is True
diff --git a/tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py b/tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py
index c7b023361a7..6f072944d9a 100644
--- a/tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py
+++ b/tests/model_executor/models/voxtral_tts/test_cuda_graph_acoustic_transformer.py
@@ -78,13 +78,6 @@
AudioSpecialTokens = _mod2.AudioSpecialTokens
-class SyntheticAcousticTransformerArgs:
- """Mimics AcousticTransformerArgs interface."""
-
- def __init__(self):
- self.n_decoding_steps = 7
-
-
class SyntheticModelArgs:
"""Mimics MultimodalAudioModelArgs interface."""
@@ -103,7 +96,6 @@ class SyntheticAcousticTransformer(nn.Module):
def __init__(self):
super().__init__()
self.model_args = SyntheticModelArgs()
- self.acoustic_transformer_args = SyntheticAcousticTransformerArgs()
self.acoustic_embeddings_levels = ACOUSTIC_EMBEDDINGS_LEVELS
# semantic_codebook_output: hidden_dim -> padded_codebook_size
@@ -137,7 +129,7 @@ def __init__(self):
def compute_mm_logits(
self,
hidden_states: torch.Tensor,
- cfg_alpha: torch.Tensor,
+ mm_sampling_tensors=None,
):
"""Eager fallback path: replicate what the wrapper does."""
at = self.acoustic_transformer
@@ -216,10 +208,6 @@ def _random_hidden(batch_size, device=DEVICE, dtype=torch.bfloat16):
return torch.randn(batch_size, HIDDEN_DIM, device=device, dtype=dtype)
-def _cfg_alpha(batch_size, value=1.2, device=DEVICE):
- return torch.full((batch_size,), value, device=device, dtype=torch.float32)
-
-
def _unpack_audio_codes(result):
"""Unpack (fake_eos, {"audio": [list of tensors]}) into (fake_eos, audio_codes)."""
fake_eos, mm_tokens = result
@@ -239,7 +227,7 @@ def test_exact_size_output_format(model, wrapper, batch_size):
"""Graph path returns correctly shaped and bounded outputs."""
hidden = _random_hidden(batch_size)
with torch.no_grad():
- graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden, cfg_alpha=_cfg_alpha(batch_size)))
+ graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden))
assert graph_eos.shape == (batch_size,)
assert graph_codes.shape == (batch_size, 1 + N_ACOUSTIC_CODEBOOK)
# fake_eos should be 0.0 or 1.0
@@ -252,12 +240,11 @@ def test_exact_size_output_format(model, wrapper, batch_size):
def test_exact_size_deterministic(model, wrapper, batch_size):
"""Same input + same RNG state produces identical CUDA graph output."""
hidden = _random_hidden(batch_size)
- cfg_alpha = _cfg_alpha(batch_size)
with torch.no_grad():
torch.manual_seed(42)
- eos1, codes1 = _unpack_audio_codes(wrapper(hidden, cfg_alpha=cfg_alpha))
+ eos1, codes1 = _unpack_audio_codes(wrapper(hidden))
torch.manual_seed(42)
- eos2, codes2 = _unpack_audio_codes(wrapper(hidden, cfg_alpha=cfg_alpha))
+ eos2, codes2 = _unpack_audio_codes(wrapper(hidden))
torch.testing.assert_close(eos1, eos2, atol=0, rtol=0)
torch.testing.assert_close(codes1, codes2, atol=0, rtol=0)
@@ -272,7 +259,7 @@ def test_padded_output_shape(model, wrapper, batch_size):
"""Padded decode must return output trimmed to actual batch size."""
hidden = _random_hidden(batch_size)
with torch.no_grad():
- graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden, cfg_alpha=_cfg_alpha(batch_size)))
+ graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden))
assert graph_eos.shape == (batch_size,)
assert graph_codes.shape == (batch_size, 1 + N_ACOUSTIC_CODEBOOK)
@@ -282,7 +269,7 @@ def test_padded_output_bounded(model, wrapper, batch_size):
"""Padded output audio codes should be non-negative integers."""
hidden = _random_hidden(batch_size)
with torch.no_grad():
- graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden, cfg_alpha=_cfg_alpha(batch_size)))
+ graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden))
# fake_eos should be 0.0 or 1.0
assert torch.all((graph_eos == 0.0) | (graph_eos == 1.0))
# Audio codes should be non-negative
@@ -298,12 +285,11 @@ def test_padded_output_bounded(model, wrapper, batch_size):
def test_fallback_eager_exact_match(model, wrapper, batch_size):
"""Cudagraph fallback to eager. Two eager runs should produce identical results."""
hidden = _random_hidden(batch_size)
- alpha = _cfg_alpha(batch_size)
with torch.no_grad():
torch.manual_seed(100)
- eager_eos, eager_codes = _unpack_audio_codes(model.compute_mm_logits(hidden, cfg_alpha=alpha))
+ eager_eos, eager_codes = _unpack_audio_codes(model.compute_mm_logits(hidden))
torch.manual_seed(100)
- graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden, cfg_alpha=alpha))
+ graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden))
torch.testing.assert_close(graph_eos, eager_eos, atol=0, rtol=0)
torch.testing.assert_close(graph_codes, eager_codes, atol=0, rtol=0)
@@ -316,13 +302,12 @@ def test_fallback_eager_exact_match(model, wrapper, batch_size):
def test_disabled_wrapper_matches_eager(model, wrapper):
"""Cudagraph fallback to eager. Two eager runs should produce identical results."""
hidden = _random_hidden(4)
- alpha = _cfg_alpha(4)
wrapper.enabled = False
with torch.no_grad():
torch.manual_seed(200)
- eager_eos, eager_codes = _unpack_audio_codes(model.compute_mm_logits(hidden, cfg_alpha=alpha))
+ eager_eos, eager_codes = _unpack_audio_codes(model.compute_mm_logits(hidden))
torch.manual_seed(200)
- graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden, cfg_alpha=alpha))
+ graph_eos, graph_codes = _unpack_audio_codes(wrapper(hidden))
wrapper.enabled = True
torch.testing.assert_close(graph_eos, eager_eos, atol=0, rtol=0)
torch.testing.assert_close(graph_codes, eager_codes, atol=0, rtol=0)
@@ -336,11 +321,10 @@ def test_disabled_wrapper_matches_eager(model, wrapper):
def test_deterministic_across_calls(model, wrapper):
"""Same input + same RNG state. Two cudagraph runs should produce identical results."""
hidden = _random_hidden(4)
- alpha = _cfg_alpha(4)
with torch.no_grad():
torch.manual_seed(300)
- eos1, codes1 = _unpack_audio_codes(wrapper(hidden, cfg_alpha=alpha))
+ eos1, codes1 = _unpack_audio_codes(wrapper(hidden))
torch.manual_seed(300)
- eos2, codes2 = _unpack_audio_codes(wrapper(hidden, cfg_alpha=alpha))
+ eos2, codes2 = _unpack_audio_codes(wrapper(hidden))
torch.testing.assert_close(eos1, eos2, atol=0, rtol=0)
torch.testing.assert_close(codes1, codes2, atol=0, rtol=0)
diff --git a/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py b/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py
deleted file mode 100644
index 9c236664616..00000000000
--- a/tests/model_executor/stage_input_processors/test_cosyvoice3_stage_input_processors.py
+++ /dev/null
@@ -1,267 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from collections import defaultdict
-from types import SimpleNamespace
-
-import torch
-
-from vllm_omni.model_executor.stage_input_processors.cosyvoice3 import talker2code2wav_async_chunk, text2flow
-
-
-def _source_output(request_id: str, prompt_ids: list[int], out_ids: list[int], mm: dict):
- return SimpleNamespace(
- request_id=request_id,
- prompt_token_ids=prompt_ids,
- outputs=[SimpleNamespace(token_ids=out_ids, cumulative_token_ids=out_ids, multimodal_output=mm)],
- )
-
-
-def _transfer_manager(
- *,
- chunk_frames: int = 2,
- pre_lookahead_frames: int = 0,
- stream_scale_factor: int = 1,
- max_chunk_frames: int | None = None,
-):
- if max_chunk_frames is None:
- max_chunk_frames = chunk_frames
- return SimpleNamespace(
- code_prompt_token_ids=defaultdict(list),
- request_payload={},
- connector=SimpleNamespace(
- config={
- "extra": {
- "codec_chunk_frames": chunk_frames,
- "codec_pre_lookahead_frames": pre_lookahead_frames,
- "codec_max_chunk_frames": max_chunk_frames,
- "codec_stream_scale_factor": stream_scale_factor,
- "codec_vocab_size": 6561,
- }
- }
- ),
- )
-
-
-def test_text2flow_supports_batched_source_outputs():
- stage_list = [
- SimpleNamespace(
- engine_outputs=[
- _source_output("req-0", [10, 11], [1, 2, 3], {"speech_token": torch.tensor([[1, 2]])}),
- _source_output("req-1", [20, 21], [4, 5], {"speech_token": torch.tensor([[3, 4]])}),
- ]
- )
- ]
-
- outputs = text2flow(stage_list=stage_list, engine_input_source=[0], prompt=None)
-
- assert len(outputs) == 2
- assert outputs[0]["prompt_token_ids"] == [1, 2, 3]
- assert outputs[1]["prompt_token_ids"] == [4, 5]
- assert outputs[0]["additional_information"]["ids"]["prompt"] == [10, 11]
- assert outputs[1]["additional_information"]["ids"]["prompt"] == [20, 21]
-
-
-def test_talker2code2wav_async_chunk_final_payload_uses_absolute_token_offset():
- transfer_manager = _transfer_manager()
- request = SimpleNamespace(
- external_req_id="rid-0",
- output_token_ids=[1, 2, 6562, 3],
- additional_information={
- "speech_token": [torch.tensor([[11, 12, 13]])],
- "speech_feat": [torch.tensor([[[0.1, 0.2], [0.3, 0.4]]])],
- "embedding": [torch.tensor([[0.5, 0.6]])],
- },
- is_finished=lambda: True,
- )
-
- payload = talker2code2wav_async_chunk(
- transfer_manager=transfer_manager,
- pooling_output=None,
- request=request,
- is_finished=True,
- )
-
- assert payload is not None
- assert payload["meta"]["finished"].item() is True
- assert payload["codes"]["audio"] == [1, 2, 3]
- assert payload["token_offset"] == 0
- assert payload["left_context_size"] == 0
- assert payload["req_id"] == ["rid-0"]
- assert payload["stream_finished"].item() is True
- assert "speech_token" in payload
- assert "speech_feat" in payload
- assert "embedding" in payload
-
-
-def test_talker2code2wav_async_chunk_emits_eof_when_finished_without_valid_codes():
- transfer_manager = _transfer_manager(chunk_frames=25)
- request = SimpleNamespace(
- external_req_id="rid-eof",
- output_token_ids=[6561, 6562], # all filtered out
- additional_information={},
- is_finished=lambda: True,
- )
-
- payload = talker2code2wav_async_chunk(
- transfer_manager=transfer_manager,
- pooling_output=None,
- request=request,
- is_finished=True,
- )
-
- assert payload is not None
- assert payload["codes"]["audio"] == []
- assert payload["meta"]["finished"].item() is True
-
-
-def test_talker2code2wav_async_chunk_does_not_reemit_without_new_tokens():
- transfer_manager = _transfer_manager()
- request = SimpleNamespace(
- external_req_id="rid-stable",
- output_token_ids=[1, 2],
- additional_information={},
- is_finished=lambda: False,
- )
-
- payload1 = talker2code2wav_async_chunk(
- transfer_manager=transfer_manager,
- pooling_output=None,
- request=request,
- is_finished=False,
- )
- payload2 = talker2code2wav_async_chunk(
- transfer_manager=transfer_manager,
- pooling_output=None,
- request=request,
- is_finished=False,
- )
-
- assert payload1 is not None
- assert payload1["codes"]["audio"] == [1, 2]
- assert payload1["token_offset"] == 0
- assert payload2 is None
-
-
-def test_talker2code2wav_async_chunk_waits_for_prelookahead_and_emits_cumulative_prefix():
- transfer_manager = _transfer_manager(pre_lookahead_frames=1)
- request = SimpleNamespace(
- external_req_id="rid-pre",
- output_token_ids=[1, 2],
- additional_information={},
- is_finished=lambda: False,
- )
-
- payload_pending = talker2code2wav_async_chunk(
- transfer_manager=transfer_manager,
- pooling_output=None,
- request=request,
- is_finished=False,
- )
- request.output_token_ids = [1, 2, 3]
- payload_ready = talker2code2wav_async_chunk(
- transfer_manager=transfer_manager,
- pooling_output=None,
- request=request,
- is_finished=False,
- )
-
- assert payload_pending is None
- assert payload_ready is not None
- assert payload_ready["codes"]["audio"] == [1, 2, 3]
- assert payload_ready["token_offset"] == 0
- assert payload_ready["meta"]["finished"].item() is False
-
-
-def test_talker2code2wav_async_chunk_final_flush_uses_previous_token_offset():
- transfer_manager = _transfer_manager(pre_lookahead_frames=1)
- request = SimpleNamespace(
- external_req_id="rid-tail",
- output_token_ids=[3, 4, 5],
- additional_information={},
- is_finished=lambda: False,
- )
-
- payload_stream = talker2code2wav_async_chunk(
- transfer_manager=transfer_manager,
- pooling_output=None,
- request=request,
- is_finished=False,
- )
- request.output_token_ids = [3, 4, 5, 6]
- payload_final = talker2code2wav_async_chunk(
- transfer_manager=transfer_manager,
- pooling_output=None,
- request=request,
- is_finished=True,
- )
-
- assert payload_stream is not None
- assert payload_stream["meta"]["finished"].item() is False
- assert payload_stream["codes"]["audio"] == [3, 4, 5]
- assert payload_stream["token_offset"] == 0
- assert payload_final is not None
- assert payload_final["meta"]["finished"].item() is True
- assert payload_final["codes"]["audio"] == [3, 4, 5, 6]
- assert payload_final["token_offset"] == 2
-
-
-def test_talker2code2wav_async_chunk_respects_prompt_token_pad_on_first_chunk():
- transfer_manager = _transfer_manager(pre_lookahead_frames=1)
- request = SimpleNamespace(
- external_req_id="rid-pad",
- output_token_ids=[8, 9, 10],
- additional_information={
- "speech_token": [torch.tensor([[1, 2, 3]])],
- },
- is_finished=lambda: False,
- )
-
- payload_pending = talker2code2wav_async_chunk(
- transfer_manager=transfer_manager,
- pooling_output=None,
- request=request,
- is_finished=False,
- )
- request.output_token_ids = [8, 9, 10, 11]
- payload_ready = talker2code2wav_async_chunk(
- transfer_manager=transfer_manager,
- pooling_output=None,
- request=request,
- is_finished=False,
- )
-
- assert payload_pending is None
- assert payload_ready is not None
- assert payload_ready["codes"]["audio"] == [8, 9, 10, 11]
- assert payload_ready["token_offset"] == 0
-
-
-def test_talker2code2wav_async_chunk_emits_terminal_eof_without_duplicate_audio():
- transfer_manager = _transfer_manager()
- request = SimpleNamespace(
- external_req_id="rid-eof-tail",
- output_token_ids=[3, 4],
- additional_information={},
- is_finished=lambda: False,
- )
-
- payload_stream = talker2code2wav_async_chunk(
- transfer_manager=transfer_manager,
- pooling_output=None,
- request=request,
- is_finished=False,
- )
- payload_final = talker2code2wav_async_chunk(
- transfer_manager=transfer_manager,
- pooling_output=None,
- request=request,
- is_finished=True,
- )
-
- assert payload_stream is not None
- assert payload_stream["meta"]["finished"].item() is False
- assert payload_stream["codes"]["audio"] == [3, 4]
- assert payload_final is not None
- assert payload_final["meta"]["finished"].item() is True
- assert payload_final["codes"]["audio"] == []
diff --git a/tests/model_executor/stage_input_processors/test_glm_image.py b/tests/model_executor/stage_input_processors/test_glm_image.py
deleted file mode 100644
index 542cdd5fba1..00000000000
--- a/tests/model_executor/stage_input_processors/test_glm_image.py
+++ /dev/null
@@ -1,403 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for GLM-Image stage input processor."""
-
-from types import SimpleNamespace
-
-import pytest
-import torch
-
-from vllm_omni.model_executor.stage_input_processors.glm_image import (
- _first_source_image,
- _has_source_image,
- _parse_generated_tokens,
- _upsample_token_ids,
- ar2diffusion,
- compute_max_tokens,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-# =============================================================================
-# Helpers
-# =============================================================================
-
-
-def _source_output(token_ids: list[int], mm_output: dict | None = None):
- """Create a minimal AR output mock."""
- return SimpleNamespace(
- outputs=[SimpleNamespace(token_ids=token_ids, cumulative_token_ids=token_ids)],
- multimodal_output=mm_output,
- )
-
-
-def _stage_with_outputs(outputs):
- """Create a stage list entry with engine_outputs."""
- return SimpleNamespace(engine_outputs=outputs)
-
-
-# =============================================================================
-# Tests for _has_source_image
-# =============================================================================
-
-
-class TestHasSourceImage:
- def test_none_input(self):
- assert _has_source_image(None) is False
-
- def test_non_dict_input(self):
- assert _has_source_image("not_a_dict") is False
-
- def test_empty_dict(self):
- assert _has_source_image({}) is False
-
- def test_image_key_present(self):
- from PIL import Image
-
- img = Image.new("RGB", (64, 64))
- assert _has_source_image({"image": img}) is True
-
- def test_image_key_none(self):
- assert _has_source_image({"image": None}) is False
-
- def test_img2img_key_present(self):
- from PIL import Image
-
- img = Image.new("RGB", (64, 64))
- assert _has_source_image({"img2img": img}) is True
-
- def test_images_key_list(self):
- from PIL import Image
-
- imgs = [Image.new("RGB", (64, 64))]
- assert _has_source_image({"images": imgs}) is True
-
- def test_images_key_empty_list(self):
- assert _has_source_image({"images": []}) is False
-
- def test_images_key_single(self):
- from PIL import Image
-
- img = Image.new("RGB", (64, 64))
- assert _has_source_image({"images": img}) is True
-
-
-# =============================================================================
-# Tests for _first_source_image
-# =============================================================================
-
-
-class TestFirstSourceImage:
- def test_none_input(self):
- assert _first_source_image(None) is None
-
- def test_non_dict_input(self):
- assert _first_source_image("not_a_dict") is None
-
- def test_image_key_single(self):
- from PIL import Image
-
- img = Image.new("RGB", (64, 64))
- assert _first_source_image({"image": img}) is img
-
- def test_image_key_list(self):
- from PIL import Image
-
- img = Image.new("RGB", (64, 64))
- assert _first_source_image({"image": [img]}) is img
-
- def test_image_key_empty_list(self):
- assert _first_source_image({"image": []}) is None
-
- def test_img2img_key_single(self):
- from PIL import Image
-
- img = Image.new("RGB", (64, 64))
- assert _first_source_image({"img2img": img}) is img
-
- def test_images_key_list(self):
- from PIL import Image
-
- imgs = [Image.new("RGB", (64, 64))]
- assert _first_source_image({"images": imgs}) is imgs[0]
-
- def test_images_key_empty_list(self):
- assert _first_source_image({"images": []}) is None
-
- def test_images_key_single_not_list(self):
- from PIL import Image
-
- img = Image.new("RGB", (64, 64))
- assert _first_source_image({"images": img}) is img
-
-
-# =============================================================================
-# Tests for compute_max_tokens
-# =============================================================================
-
-
-class TestComputeMaxTokens:
- def test_t2i_1024x1024(self):
- # t2i: small_tokens + large_tokens + 1 (EOS)
- # token_h = 1024/32 = 32, token_w = 1024/32 = 32
- # large = 32*32 = 1024
- # ratio = 1.0, small_h = sqrt(1)*16 = 16, small_w = sqrt(1)*16 = 16, small = 256
- # total = 256 + 1024 + 1 = 1281
- result = compute_max_tokens(1024, 1024, is_i2i=False)
- assert result == 1281
-
- def test_i2i_1024x1024(self):
- # i2i: large_tokens + 1 (EOS)
- # large = 32*32 = 1024, total = 1025
- result = compute_max_tokens(1024, 1024, is_i2i=True)
- assert result == 1025
-
- def test_t2i_512x512(self):
- # token_h = 16, token_w = 16, large = 256
- # ratio = 1.0, small_h = 16, small_w = 16, small = 256
- # total = 256 + 256 + 1 = 513
- result = compute_max_tokens(512, 512, is_i2i=False)
- assert result == 513
-
- def test_i2i_512x512(self):
- # large = 256, total = 257
- result = compute_max_tokens(512, 512, is_i2i=True)
- assert result == 257
-
- def test_non_square_t2i(self):
- # 1024x512: token_h=32, token_w=16, large=512
- # ratio = 32/16 = 2.0
- # small_h = max(1, int(sqrt(2)*16)) = 22, small_w = max(1, int(sqrt(0.5)*16)) = 11
- # small = 22*11 = 242
- # total = 242 + 512 + 1 = 755
- result = compute_max_tokens(1024, 512, is_i2i=False)
- assert result == 242 + 512 + 1
-
- def test_custom_factor(self):
- # factor=16, 512x512: token_h=32, token_w=32, large=1024
- # ratio=1.0, small_h=8, small_w=8, small=64
- # total = 64 + 1024 + 1 = 1089
- result = compute_max_tokens(512, 512, factor=16, is_i2i=False)
- assert result == 1089
-
- def test_i2i_smaller_than_t2i(self):
- t2i = compute_max_tokens(1024, 1024, is_i2i=False)
- i2i = compute_max_tokens(1024, 1024, is_i2i=True)
- assert i2i < t2i
-
-
-# =============================================================================
-# Tests for _upsample_token_ids
-# =============================================================================
-
-
-class TestUpsampleTokenIds:
- def test_2x2_to_4x4(self):
- tokens = torch.tensor([1, 2, 3, 4])
- result = _upsample_token_ids(tokens, 2, 2)
- assert result.shape == (16,) # 4 * 4 = 16 (2x each dim)
-
- def test_1x1_to_2x2(self):
- tokens = torch.tensor([7])
- result = _upsample_token_ids(tokens, 1, 1)
- assert result.shape == (4,) # 2 * 2
- assert (result == 7).all()
-
- def test_4x4_to_8x8(self):
- tokens = torch.arange(16, dtype=torch.long)
- result = _upsample_token_ids(tokens, 4, 4)
- assert result.shape == (64,)
-
- def test_preserves_dtype(self):
- tokens = torch.tensor([1, 2, 3, 4], dtype=torch.long)
- result = _upsample_token_ids(tokens, 2, 2)
- assert result.dtype == torch.long
-
-
-# =============================================================================
-# Tests for _parse_generated_tokens
-# =============================================================================
-
-
-class TestParseGeneratedTokens:
- def test_t2i_standard(self):
- # 1024x1024, t2i: small(256) + large(1024) + EOS
- # Generate 256 + 1024 + 1 = 1281 tokens, last is EOS (16385)
- large_tokens = list(range(1024))
- small_tokens = list(range(1000, 1256))
- eos = [16385]
- token_ids = small_tokens + large_tokens + eos
-
- prior, h, w = _parse_generated_tokens(token_ids, 1024, 1024, is_i2i=False)
- assert h == 1024
- assert w == 1024
- # Prior tokens should be upsampled: 1024 tokens -> 4*1024 = 4096
- assert prior.shape[0] == 1024 * 4
-
- def test_i2i_standard(self):
- # 1024x1024, i2i: large(1024) + EOS
- large_tokens = list(range(1024))
- eos = [16385]
- token_ids = large_tokens + eos
-
- prior, h, w = _parse_generated_tokens(token_ids, 1024, 1024, is_i2i=True)
- assert h == 1024
- assert w == 1024
- assert prior.shape[0] == 1024 * 4
-
- def test_i2i_without_eos(self):
- # i2i without EOS marker
- large_tokens = list(range(1024))
- prior, h, w = _parse_generated_tokens(large_tokens, 1024, 1024, is_i2i=True)
- assert h == 1024
- assert w == 1024
-
- def test_i2i_too_few_tokens_raises(self):
- with pytest.raises(ValueError, match="i2i token parse failed"):
- _parse_generated_tokens([1, 2, 3], 1024, 1024, is_i2i=True)
-
- def test_t2i_too_few_tokens_raises(self):
- # Only large tokens, no small preview
- large_tokens = list(range(1024))
- with pytest.raises(ValueError, match="t2i token parse failed"):
- _parse_generated_tokens(large_tokens, 1024, 1024, is_i2i=False)
-
- def test_i2i_t2i_style_layout_fallback(self):
- # i2i but got t2i-style (small + large) tokens
- small_tokens = list(range(256))
- large_tokens = list(range(1024))
- token_ids = small_tokens + large_tokens
-
- prior, h, w = _parse_generated_tokens(token_ids, 1024, 1024, is_i2i=True)
- # Should extract the large portion
- assert h == 1024
- assert w == 1024
-
-
-# =============================================================================
-# Tests for ar2diffusion
-# =============================================================================
-
-
-class TestAr2Diffusion:
- def test_basic_t2i(self):
- """Test basic text-to-image pipeline: AR -> Diffusion."""
- # 1024x1024 t2i: small(256) + large(1024) + EOS
- token_ids = list(range(256)) + list(range(1024)) + [16385]
- stage_list = [_stage_with_outputs([_source_output(token_ids)])]
-
- prompt = {"prompt": "a cat", "mm_processor_kwargs": {"target_h": 1024, "target_w": 1024}}
-
- result = ar2diffusion(stage_list, [0], prompt=[prompt])
- assert len(result) == 1
- assert result[0]["prompt"] == "a cat"
- assert result[0]["height"] == 1024
- assert result[0]["width"] == 1024
- assert "prior_token_ids" in result[0]["extra"]
-
- def test_i2i_with_mm_output(self):
- """Test image-to-image with prior_token_image_ids from AR model."""
- token_ids = list(range(1024)) + [16385]
- mm_output = {"ids": {"prior_image": torch.tensor([1, 2, 3])}}
- stage_list = [_stage_with_outputs([_source_output(token_ids, mm_output)])]
-
- from PIL import Image
-
- img = Image.new("RGB", (64, 64))
- prompt = {
- "prompt": "edit this",
- "mm_processor_kwargs": {"target_h": 1024, "target_w": 1024},
- "multi_modal_data": {"image": img},
- }
-
- result = ar2diffusion(stage_list, [0], prompt=[prompt])
- assert len(result) == 1
- assert result[0]["extra"]["prior_token_image_ids"] is not None
-
- def test_i2i_detected_via_modalities(self):
- """Test i2i mode detected via modalities field."""
- token_ids = list(range(1024)) + [16385]
- stage_list = [_stage_with_outputs([_source_output(token_ids)])]
-
- prompt = {
- "prompt": "edit this",
- "mm_processor_kwargs": {"target_h": 1024, "target_w": 1024},
- "modalities": ["img2img"],
- }
-
- result = ar2diffusion(stage_list, [0], prompt=[prompt])
- assert len(result) == 1
-
- def test_empty_engine_input_source_raises(self):
- with pytest.raises(ValueError, match="engine_input_source cannot be empty"):
- ar2diffusion([], [], prompt={})
-
- def test_invalid_stage_id_raises(self):
- with pytest.raises(IndexError, match="Invalid stage_id"):
- ar2diffusion([_stage_with_outputs(None)], [5], prompt={})
-
- def test_no_outputs_raises(self):
- with pytest.raises(RuntimeError, match="has no outputs yet"):
- ar2diffusion([SimpleNamespace(engine_outputs=None)], [0], prompt={})
-
- def test_default_dimensions(self):
- """When no height/width in prompt, defaults to 1024x1024."""
- token_ids = list(range(256)) + list(range(1024)) + [16385]
- stage_list = [_stage_with_outputs([_source_output(token_ids)])]
-
- prompt = {"prompt": "test"}
- result = ar2diffusion(stage_list, [0], prompt=[prompt])
- assert result[0]["height"] == 1024
- assert result[0]["width"] == 1024
-
- def test_requires_multimodal_data_with_pil_image(self):
- """Test that pil_image is included when requires_multimodal_data=True."""
- token_ids = list(range(256)) + list(range(1024)) + [16385]
- stage_list = [_stage_with_outputs([_source_output(token_ids)])]
-
- from PIL import Image
-
- img = Image.new("RGB", (64, 64))
- prompt = {
- "prompt": "test",
- "multi_modal_data": {"image": img},
- }
-
- result = ar2diffusion(stage_list, [0], prompt=[prompt], requires_multimodal_data=True)
- assert result[0]["pil_image"] is img
-
- def test_extra_params_passed_through(self):
- """Test that seed, num_inference_steps, guidance_scale, negative_prompt are passed."""
- token_ids = list(range(256)) + list(range(1024)) + [16385]
- stage_list = [_stage_with_outputs([_source_output(token_ids)])]
-
- prompt = {
- "prompt": "test",
- "seed": 42,
- "num_inference_steps": 50,
- "guidance_scale": 7.5,
- "negative_prompt": "blurry",
- }
-
- result = ar2diffusion(stage_list, [0], prompt=[prompt])
- assert result[0]["seed"] == 42
- assert result[0]["num_inference_steps"] == 50
- assert result[0]["guidance_scale"] == 7.5
- assert result[0]["negative_prompt"] == "blurry"
-
- def test_batch_requests(self):
- """Test processing multiple requests in a batch."""
- tokens1 = list(range(256)) + list(range(1024)) + [16385]
- tokens2 = list(range(256)) + list(range(1024)) + [16385]
- stage_list = [_stage_with_outputs([_source_output(tokens1), _source_output(tokens2)])]
-
- prompts = [
- {"prompt": "first", "mm_processor_kwargs": {"target_h": 1024, "target_w": 1024}},
- {"prompt": "second", "mm_processor_kwargs": {"target_h": 512, "target_w": 512}},
- ]
-
- result = ar2diffusion(stage_list, [0], prompt=prompts)
- assert len(result) == 2
- assert result[0]["prompt"] == "first"
- assert result[1]["prompt"] == "second"
diff --git a/tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py b/tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py
index b30da97800b..4807ac62744 100644
--- a/tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py
+++ b/tests/model_executor/stage_input_processors/test_mimo_audio_flush_remaining_codes.py
@@ -13,23 +13,23 @@
def _sentinel():
- return {"codes": {"audio": []}, "meta": {"finished": torch.tensor(True, dtype=torch.bool)}}
+ return {"code_predictor_codes": [], "finished": torch.tensor(True, dtype=torch.bool)}
def test_flush_remaining_codes_when_no_codes_accumulated_missing_request_id():
"""No entry for request_id: treat as empty, return finished sentinel with empty codes."""
tm = SimpleNamespace(code_prompt_token_ids={})
out = _flush_remaining_codes(tm, "missing", chunk_size=3, left_context_size=3)
- assert out["codes"]["audio"] == _sentinel()["codes"]["audio"]
- assert out["meta"]["finished"].item() is True
+ assert out["code_predictor_codes"] == _sentinel()["code_predictor_codes"]
+ assert out["finished"].equal(_sentinel()["finished"])
def test_flush_remaining_codes_when_no_codes_accumulated_empty_list():
"""Explicit empty accumulation list returns the same sentinel."""
tm = SimpleNamespace(code_prompt_token_ids={"r": []})
out = _flush_remaining_codes(tm, "r", chunk_size=3, left_context_size=3)
- assert out["codes"]["audio"] == []
- assert out["meta"]["finished"].item() is True
+ assert out["code_predictor_codes"] == []
+ assert out["finished"].item() is True
def test_flush_remaining_codes_partial_chunk_remaining():
@@ -41,8 +41,8 @@ def test_flush_remaining_codes_partial_chunk_remaining():
code_prompt_token_ids={"r": [[1], [2], [3], [4], [5], [6], [7]]},
)
out = _flush_remaining_codes(tm, "r", chunk_size=3, left_context_size=3)
- assert out["meta"]["finished"].item() is True
- assert out["codes"]["audio"] == [4, 5, 6, 7]
+ assert out["finished"].item() is True
+ assert out["code_predictor_codes"] == [4, 5, 6, 7]
def test_flush_remaining_codes_when_length_is_exact_multiple_of_chunk_size():
@@ -52,7 +52,7 @@ def test_flush_remaining_codes_when_length_is_exact_multiple_of_chunk_size():
)
out = _flush_remaining_codes(tm, "r", chunk_size=3, left_context_size=3)
# context_length = chunk_size = 3, end_index = min(6, 6) -> all 6
- assert out["codes"]["audio"] == [1, 2, 3, 4, 5, 6]
+ assert out["code_predictor_codes"] == [1, 2, 3, 4, 5, 6]
@pytest.mark.parametrize(
@@ -74,5 +74,5 @@ def test_flush_remaining_codes_context_window_end_index(
tm = SimpleNamespace(code_prompt_token_ids={"r": accumulated})
out = _flush_remaining_codes(tm, "r", chunk_size=chunk_size, left_context_size=left_context)
expected_flat = list(range(length - expected_end_index, length))
- assert out["codes"]["audio"] == expected_flat
- assert out["meta"]["finished"].item() is True
+ assert out["code_predictor_codes"] == expected_flat
+ assert out["finished"].item() is True
diff --git a/tests/model_executor/stage_input_processors/test_mimo_audio_llm2code2wav.py b/tests/model_executor/stage_input_processors/test_mimo_audio_llm2code2wav.py
deleted file mode 100644
index 1f0ec02f750..00000000000
--- a/tests/model_executor/stage_input_processors/test_mimo_audio_llm2code2wav.py
+++ /dev/null
@@ -1,72 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-
-import logging
-from types import SimpleNamespace
-
-import pytest
-import torch
-
-from vllm_omni.model_executor.stage_input_processors import mimo_audio as sip
-from vllm_omni.model_executor.stage_input_processors.mimo_audio import (
- MAX_CODE2WAV_TOKENS,
- llm2code2wav,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def _make_stage_list(codec_codes: torch.Tensor, request_id: str = "req-0"):
- """Build a minimal stage_list[0] with a single talker output carrying codec_codes."""
- output = SimpleNamespace(multimodal_output={"codes": {"audio": codec_codes}})
- talker_output = SimpleNamespace(outputs=[output], request_id=request_id)
- stage0 = SimpleNamespace(engine_outputs=[talker_output])
- return [stage0]
-
-
-def test_llm2code2wav_truncates_when_flat_exceeds_max(caplog):
- """Flat codec sequences longer than MAX_CODE2WAV_TOKENS must be truncated, not passed through."""
- # prepend_and_flatten_colmajor produces 36 ids per (8, 4) codec frame:
- # pad adds one row -> (9, 4) per frame, permuted and flattened.
- # Pick enough frames to comfortably exceed the cap.
- frames = (MAX_CODE2WAV_TOKENS // 36) + 100
- codec_codes = torch.ones(frames, 1, 8, 4, dtype=torch.long)
-
- stage_list = _make_stage_list(codec_codes, request_id="req-long")
-
- # Attach caplog's handler directly to the module logger so the warning is
- # captured regardless of propagation (vllm's logger configuration can
- # interact badly with caplog.at_level's default root-handler path).
- target_logger = logging.getLogger("vllm_omni.model_executor.stage_input_processors.mimo_audio")
- target_logger.addHandler(caplog.handler)
- prev_level = target_logger.level
- target_logger.setLevel(logging.WARNING)
- try:
- prompts = llm2code2wav(stage_list, engine_input_source=[0])
- finally:
- target_logger.removeHandler(caplog.handler)
- target_logger.setLevel(prev_level)
-
- assert len(prompts) == 1
- assert len(prompts[0]["prompt_token_ids"]) == MAX_CODE2WAV_TOKENS
- assert any("truncating" in rec.getMessage() for rec in caplog.records), (
- f"Expected a 'truncating' warning; captured records: {[r.getMessage() for r in caplog.records]}"
- )
-
-
-def test_llm2code2wav_short_sequence_unchanged():
- """Short codec sequences are returned without truncation."""
- codec_codes = torch.ones(4, 1, 8, 4, dtype=torch.long)
- stage_list = _make_stage_list(codec_codes, request_id="req-short")
-
- prompts = llm2code2wav(stage_list, engine_input_source=[0])
-
- assert len(prompts) == 1
- # 4 frames + 1 pad row, flattened col-major → well below the cap
- assert 0 < len(prompts[0]["prompt_token_ids"]) <= MAX_CODE2WAV_TOKENS
-
-
-def test_llm2code2wav_truncation_boundary_constant_matches_yaml():
- """MAX_CODE2WAV_TOKENS must match the stage-1 max_model_len in mimo_audio.yaml and end2end.py."""
- assert sip.MAX_CODE2WAV_TOKENS == 18192
diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py
deleted file mode 100644
index 18972c91d5d..00000000000
--- a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py
+++ /dev/null
@@ -1,81 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for Qwen3-Omni streaming thinker→talker / talker→codec helpers (PR #2581)."""
-
-from __future__ import annotations
-
-from types import SimpleNamespace
-
-import pytest
-
-import vllm_omni.model_executor.stage_input_processors.qwen3_omni as q3
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-@pytest.fixture(autouse=True)
-def _streaming_context() -> SimpleNamespace:
- return SimpleNamespace(bridge_states={})
-
-
-def test_get_streaming_talker_tokens_first_segment(_streaming_context: SimpleNamespace) -> None:
- inc_p, inc_o, merged, thinker_in = q3._get_streaming_talker_tokens(
- "r1",
- [1, 2],
- [10, 11],
- streaming_context=_streaming_context,
- )
- assert inc_p == [1, 2]
- assert inc_o == [10, 11]
- assert merged == [1, 2, 10, 11]
- assert thinker_in == [1, 2]
-
-
-def test_get_streaming_talker_tokens_second_segment_accumulates(_streaming_context: SimpleNamespace) -> None:
- q3._get_streaming_talker_tokens("r2", [1, 2], [10, 11], streaming_context=_streaming_context)
- inc_p, inc_o, merged, thinker_in = q3._get_streaming_talker_tokens(
- "r2",
- [1, 2, 3, 4],
- [10, 11, 12, 13],
- streaming_context=_streaming_context,
- )
- assert inc_p == [3, 4]
- assert inc_o == [12, 13]
- assert merged == [1, 2, 10, 3, 4, 12, 13]
- assert thinker_in == [1, 2, 10, 3, 4]
-
-
-def test_get_streaming_talker_tokens_new_prompt_len_snapshot_truncates(
- _streaming_context: SimpleNamespace,
-) -> None:
- inc_p, inc_o, merged, thinker_in = q3._get_streaming_talker_tokens(
- "r3",
- [1, 2, 3, 4, 5, 6],
- [10],
- new_prompt_len_snapshot=2,
- streaming_context=_streaming_context,
- )
- assert inc_p == [1, 2, 3, 4]
- assert inc_o == [10]
- assert merged == [1, 2, 3, 4, 10]
- assert thinker_in == [1, 2, 3, 4]
-
-
-def test_get_streaming_talker_tokens_clear_state(_streaming_context: SimpleNamespace) -> None:
- q3._get_streaming_talker_tokens("r4", [1], [2], streaming_context=_streaming_context, clear_state=True)
- state = q3._get_qwen3_streaming_state("r4", _streaming_context).thinker2talker
- assert state.last_prompt_len == 0
- assert state.last_output_len == 0
- assert state.merged_sequences == []
-
-
-def test_get_streaming_codec_delta_len_increments_and_finishes(_streaming_context: SimpleNamespace) -> None:
- d1 = q3._get_streaming_codec_delta_len(5, "c1", SimpleNamespace(finished=False), _streaming_context)
- assert d1 == 5
- d2 = q3._get_streaming_codec_delta_len(8, "c1", SimpleNamespace(finished=False), _streaming_context)
- assert d2 == 2
- # After d2, stored cursor is cur_seq_len + 1 == 9; next delta uses new cur_seq_len - 9.
- d3 = q3._get_streaming_codec_delta_len(10, "c1", SimpleNamespace(finished=True), _streaming_context)
- assert d3 == 1
- state = q3._get_qwen3_streaming_state("c1", _streaming_context)
- assert state.talker2code2wav_last_seq_len == 0
diff --git a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py
index 07e343bf030..95ee229298d 100644
--- a/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py
+++ b/tests/model_executor/stage_input_processors/test_qwen3_tts_async_chunk.py
@@ -55,7 +55,7 @@ def _call(tm, rid, *, n_frames, finished=False, req_ic=None):
tm.code_prompt_token_ids[rid] = [_FRAME[:] for _ in range(n_frames)]
return talker2code2wav_async_chunk(
transfer_manager=tm,
- pooling_output={"codes": {"audio": torch.zeros((0,))}},
+ pooling_output={"audio_codes": torch.zeros((0,))},
request=_req(rid, finished=finished, initial_codec_chunk_frames=req_ic),
is_finished=finished,
)
@@ -65,7 +65,7 @@ def test_empty_returns_none():
tm = _tm()
p = talker2code2wav_async_chunk(
transfer_manager=tm,
- pooling_output={"codes": {"audio": torch.zeros((0,))}},
+ pooling_output={"audio_codes": torch.zeros((0,))},
request=_req("r", finished=False),
)
assert p is None
@@ -79,8 +79,7 @@ def test_eof_marker_when_finished_empty():
request=_req("r", finished=True),
is_finished=True,
)
- assert p["codes"] == {"audio": []}
- assert p["meta"]["finished"].item() is True
+ assert p == {"code_predictor_codes": [], "finished": torch.tensor(True, dtype=torch.bool)}
def test_flush_on_finish():
@@ -93,8 +92,8 @@ def test_flush_on_finish():
is_finished=True,
)
assert p is not None
- assert p["meta"]["finished"].item() is True
- assert len(p["codes"]["audio"]) == _Q * 24
+ assert p["finished"] is True
+ assert len(p["code_predictor_codes"]) == _Q * 24
_CASES = [
@@ -159,8 +158,8 @@ def test_streaming_phases(config, n_frames, finished, expected):
else:
exp_ctx, exp_window = expected
assert payload is not None
- assert payload["meta"]["left_context_size"] == exp_ctx
- assert len(payload["codes"]["audio"]) == _Q * exp_window
+ assert payload["left_context_size"] == exp_ctx
+ assert len(payload["code_predictor_codes"]) == _Q * exp_window
def test_dynamic_ic_adapts_to_load():
@@ -170,14 +169,14 @@ def test_dynamic_ic_adapts_to_load():
# Low load (1/8) -> IC=2 -> emit at 2
p1 = _call(tm, "r", n_frames=2)
assert p1 is not None
- assert len(p1["codes"]["audio"]) == _Q * 2
+ assert len(p1["code_predictor_codes"]) == _Q * 2
# High load: add 4 others -> active=5/8 -> IC=8 -> emit at 8
for i in range(4):
tm.code_prompt_token_ids[f"other-{i}"] = [[0]]
p2 = _call(tm, "r", n_frames=8)
assert p2 is not None
- assert len(p2["codes"]["audio"]) == _Q * 8
+ assert len(p2["code_predictor_codes"]) == _Q * 8
# Requests past initial phase still count in load factor
tm2 = _tm(max_num_seqs=4)
@@ -186,7 +185,7 @@ def test_dynamic_ic_adapts_to_load():
# active=4/4=1.0 -> IC=16
p3 = _call(tm2, "new", n_frames=16)
assert p3 is not None
- assert len(p3["codes"]["audio"]) == _Q * 16
+ assert len(p3["code_predictor_codes"]) == _Q * 16
def test_ic_load_change_mid_request():
@@ -207,7 +206,6 @@ def test_ic_load_change_mid_request():
assert _call(tm, "r", n_frames=27) is None
p3 = _call(tm, "r", n_frames=49)
assert p3 is not None
- assert p3["meta"]["left_context_size"] == 24
# A *new* request under high load gets IC=16 (not IC=2).
# Frame 2 would emit under IC=2 but must hold under IC=16.
@@ -255,14 +253,14 @@ def test_first_streaming_chunk_prepends_ref_code_context():
payload = talker2code2wav_async_chunk(
transfer_manager=tm,
- pooling_output={"codes": {"audio": torch.zeros((0,)), "ref": ref_code}},
+ pooling_output={"audio_codes": torch.zeros((0,)), "ref_code": ref_code},
request=_req(rid, finished=False, initial_codec_chunk_frames=10),
is_finished=False,
)
assert payload is not None
- assert payload["meta"]["left_context_size"] == 2
- assert len(payload["codes"]["audio"]) == _Q * 12
+ assert payload["left_context_size"] == 2
+ assert len(payload["code_predictor_codes"]) == _Q * 12
def test_ref_code_context_applies_to_all_streaming_chunks():
@@ -276,15 +274,15 @@ def test_ref_code_context_applies_to_all_streaming_chunks():
payload = talker2code2wav_async_chunk(
transfer_manager=tm,
- pooling_output={"codes": {"audio": torch.zeros((0,)), "ref": ref_code}},
+ pooling_output={"audio_codes": torch.zeros((0,)), "ref_code": ref_code},
request=_req(rid, finished=False, initial_codec_chunk_frames=10),
is_finished=False,
)
assert payload is not None
# ref_code (2 frames) prepended as left context on second chunk too
- assert payload["meta"]["left_context_size"] == 10 + 2
- assert len(payload["codes"]["audio"]) == _Q * (20 + 2)
+ assert payload["left_context_size"] == 10 + 2
+ assert len(payload["code_predictor_codes"]) == _Q * (20 + 2)
def test_ref_code_context_can_be_buffered_before_first_emit():
@@ -294,7 +292,7 @@ def test_ref_code_context_can_be_buffered_before_first_emit():
first_payload = talker2code2wav_async_chunk(
transfer_manager=tm,
- pooling_output={"codes": {"audio": torch.tensor([[1, 2, 3, 4]]), "ref": ref_code}},
+ pooling_output={"audio_codes": torch.tensor([[1, 2, 3, 4]]), "ref_code": ref_code},
request=_req(rid, finished=False, initial_codec_chunk_frames=10),
is_finished=False,
)
@@ -304,22 +302,22 @@ def test_ref_code_context_can_be_buffered_before_first_emit():
for _ in range(8):
talker2code2wav_async_chunk(
transfer_manager=tm,
- pooling_output={"codes": {"audio": torch.tensor([[1, 2, 3, 4]])}},
+ pooling_output={"audio_codes": torch.tensor([[1, 2, 3, 4]])},
request=_req(rid, finished=False, initial_codec_chunk_frames=10),
is_finished=False,
)
payload = talker2code2wav_async_chunk(
transfer_manager=tm,
- pooling_output={"codes": {"audio": torch.tensor([[1, 2, 3, 4]])}},
+ pooling_output={"audio_codes": torch.tensor([[1, 2, 3, 4]])},
request=_req(rid, finished=False, initial_codec_chunk_frames=10),
is_finished=False,
)
assert payload is not None
# ref_code (2 frames) is kept (not popped) for subsequent chunks
- assert payload["meta"]["left_context_size"] == 2
- assert len(payload["codes"]["audio"]) == _Q * 12
+ assert payload["left_context_size"] == 2
+ assert len(payload["code_predictor_codes"]) == _Q * 12
assert rid in tm.request_payload
@@ -334,9 +332,8 @@ def test_non_async_processor_prepends_ref_code_and_sets_trim_context():
dtype=torch.long,
)
output = SimpleNamespace(
- multimodal_output={"codes": {"audio": audio_codes, "ref": ref_code}},
+ multimodal_output={"audio_codes": audio_codes, "ref_code": ref_code},
token_ids=list(range(3)),
- cumulative_token_ids=list(range(3)),
)
stage = SimpleNamespace(
engine_outputs=[SimpleNamespace(outputs=[output], finished=True)],
@@ -346,7 +343,7 @@ def test_non_async_processor_prepends_ref_code_and_sets_trim_context():
assert len(prompts) == 1
prompt = prompts[0]
- assert prompt["additional_information"] == {"meta": {"left_context_size": [2]}}
+ assert prompt["additional_information"] == {"left_context_size": [2]}
assert prompt["prompt_token_ids"] == [
9,
8,
@@ -380,9 +377,8 @@ def test_non_async_processor_filters_out_of_range_codec_values():
dtype=torch.long,
)
output = SimpleNamespace(
- multimodal_output={"codes": {"audio": audio_codes, "ref": ref_code}},
+ multimodal_output={"audio_codes": audio_codes, "ref_code": ref_code},
token_ids=list(range(4)),
- cumulative_token_ids=list(range(4)),
)
stage = SimpleNamespace(
engine_outputs=[SimpleNamespace(outputs=[output], finished=True)],
@@ -394,4 +390,4 @@ def test_non_async_processor_filters_out_of_range_codec_values():
prompt = prompts[0]
# Only ref_code (1 frame) + 2 valid frames = 3 frames * 4 quantizers = 12 codes
assert len(prompt["prompt_token_ids"]) == 4 * 3
- assert prompt["additional_information"] == {"meta": {"left_context_size": [1]}}
+ assert prompt["additional_information"] == {"left_context_size": [1]}
diff --git a/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py b/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py
deleted file mode 100644
index 7d6fc6e74c9..00000000000
--- a/tests/model_executor/stage_input_processors/test_voxcpm_async_chunk.py
+++ /dev/null
@@ -1,87 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""UTs for VoxCPM async-chunk stage input processing."""
-
-from types import SimpleNamespace
-
-import pytest
-import torch
-
-from vllm_omni.model_executor.stage_input_processors.voxcpm import (
- _VOXCPM_LATENT_MAGIC,
- _coerce_finished_flag,
- latent2vae_async_chunk,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def _request(*, finished):
- return SimpleNamespace(is_finished=lambda: finished)
-
-
-def _decode_serialized_latent(codes: list[int]) -> torch.Tensor:
- assert codes[0] == _VOXCPM_LATENT_MAGIC
- latent_dim = codes[1]
- time_dim = codes[2]
- payload = torch.tensor(codes[3:], dtype=torch.int32).to(torch.uint16)
- return payload.view(torch.bfloat16).to(torch.float32).reshape(1, latent_dim, time_dim)
-
-
-@pytest.mark.parametrize(
- ("value", "expected"),
- [
- (None, False),
- (False, False),
- (True, True),
- (torch.tensor(False), False),
- (torch.tensor(True), True),
- ([torch.tensor(True)], True),
- (([True],), True),
- ([], False),
- ],
-)
-def test_coerce_finished_flag(value, expected):
- assert _coerce_finished_flag(value) is expected
-
-
-def test_latent2vae_async_chunk_serializes_latent_payload():
- latent = torch.arange(6, dtype=torch.float32).reshape(2, 3)
-
- payload = latent2vae_async_chunk(
- transfer_manager=None,
- pooling_output={"latent_audio_feat": latent},
- request=_request(finished=False),
- is_finished=torch.tensor(False),
- )
-
- assert payload is not None
- assert torch.equal(payload["finished"], torch.tensor(False, dtype=torch.bool))
- recovered = _decode_serialized_latent(payload["code_predictor_codes"])
- torch.testing.assert_close(recovered, latent.to(torch.bfloat16).to(torch.float32).unsqueeze(0))
-
-
-def test_latent2vae_async_chunk_returns_terminal_marker_without_latent():
- payload = latent2vae_async_chunk(
- transfer_manager=None,
- pooling_output=None,
- request=_request(finished=[torch.tensor(True)]),
- is_finished=False,
- )
-
- assert payload == {
- "code_predictor_codes": [],
- "finished": torch.tensor(True, dtype=torch.bool),
- }
-
-
-def test_latent2vae_async_chunk_returns_none_for_nonterminal_empty_chunk():
- payload = latent2vae_async_chunk(
- transfer_manager=None,
- pooling_output={"latent_audio_feat": torch.zeros((0,), dtype=torch.float32)},
- request=_request(finished=False),
- is_finished=False,
- )
-
- assert payload is None
diff --git a/tests/model_executor/stage_input_processors/test_voxtral_tts_async_chunk.py b/tests/model_executor/stage_input_processors/test_voxtral_tts_async_chunk.py
index 1b78b103da8..45e54eb69d3 100644
--- a/tests/model_executor/stage_input_processors/test_voxtral_tts_async_chunk.py
+++ b/tests/model_executor/stage_input_processors/test_voxtral_tts_async_chunk.py
@@ -126,8 +126,8 @@ def test_flush_tail_when_finished():
)
assert payload is not None
- assert payload["meta"]["finished"].item() is True
- codes = payload["codes"]["audio"]
+ assert payload["finished"].item() is True
+ codes = payload["code_predictor_codes"]
# Format: [ctx_frames, context_length, ...flat_codes]
assert len(codes) >= 2 # At least ctx_frames + context_length header
ctx_frames = codes[0]
@@ -149,8 +149,10 @@ def test_eof_marker_when_finished_with_no_frames():
request=request,
)
- assert payload["codes"] == {"audio": []}
- assert payload["meta"]["finished"].item() is True
+ assert payload == {
+ "code_predictor_codes": [],
+ "finished": torch.tensor(True, dtype=torch.bool),
+ }
def test_normal_chunk_emission():
@@ -174,7 +176,7 @@ def test_normal_chunk_emission():
# A chunk should be emitted
assert payload is not None
- codes = payload["codes"]["audio"]
+ codes = payload["code_predictor_codes"]
ctx_frames = codes[0]
context_length = codes[1]
assert ctx_frames == 20 # 25 - 5(chunk_size_at_begin)
@@ -201,7 +203,7 @@ def test_small_initial_chunks():
)
assert payload is not None
- codes = payload["codes"]["audio"]
+ codes = payload["code_predictor_codes"]
ctx_frames = codes[0]
context_length = codes[1]
assert ctx_frames == 0
@@ -250,7 +252,7 @@ def test_context_handling_format():
)
assert payload is not None
- codes = payload["codes"]["audio"]
+ codes = payload["code_predictor_codes"]
# First two elements are ctx_frames and context_length
ctx_frames = codes[0]
context_length = codes[1]
diff --git a/tests/profile/test_omni_torch_profiler.py b/tests/profile/test_omni_torch_profiler.py
deleted file mode 100644
index 3920078af4d..00000000000
--- a/tests/profile/test_omni_torch_profiler.py
+++ /dev/null
@@ -1,582 +0,0 @@
-# tests/test_omni_torch_profiler.py
-from __future__ import annotations
-
-import gzip
-from dataclasses import dataclass
-from pathlib import Path
-from types import SimpleNamespace
-
-import pytest
-from openpyxl import load_workbook
-
-import vllm_omni.profiler.omni_torch_profiler as profiler_mod
-from vllm_omni.profiler.omni_torch_profiler import OmniTorchProfilerWrapper
-
-
-@pytest.fixture(autouse=True)
-def patch_worker_profiler_init(monkeypatch):
- def fake_init(self, profiler_config):
- self.profiler_config = profiler_config
-
- monkeypatch.setattr(
- profiler_mod.WorkerProfiler,
- "__init__",
- fake_init,
- )
-
-
-@dataclass
-class DummyProfilerConfig:
- torch_profiler_dir: str
- torch_profiler_use_gzip: bool = False
- torch_profiler_record_shapes: bool = True
- torch_profiler_with_memory: bool = True
- torch_profiler_with_stack: bool = True
- torch_profiler_with_flops: bool = False
- torch_profiler_dump_cuda_time_total: bool = False
-
-
-class FakeEvent:
- def __init__(
- self,
- *,
- name: str = "aten::mm",
- count: int = 1,
- input_shapes=None,
- stack=None,
- self_cpu_time_total: float = 10.0,
- cpu_time_total: float = 12.0,
- self_cuda_time_total: float = 20.0,
- cuda_time_total: float = 25.0,
- self_xpu_time_total: float = 0.0,
- xpu_time_total: float = 0.0,
- self_cpu_memory_usage: int = 128,
- cpu_memory_usage: int = 256,
- self_cuda_memory_usage: int = 1024,
- cuda_memory_usage: int = 2048,
- self_xpu_memory_usage: int = 0,
- xpu_memory_usage: int = 0,
- device_type: str = "CUDA",
- node_id: int = 0,
- overload_name: str = "",
- is_async: bool = False,
- is_legacy: bool = False,
- ):
- self.key = name
- self.name = name
- self.count = count
- self.input_shapes = input_shapes if input_shapes is not None else [[2, 2], [2, 2]]
- self.stack = stack if stack is not None else ["frame_a", "frame_b"]
- self.self_cpu_time_total = self_cpu_time_total
- self.cpu_time_total = cpu_time_total
- self.self_cuda_time_total = self_cuda_time_total
- self.cuda_time_total = cuda_time_total
- self.self_xpu_time_total = self_xpu_time_total
- self.xpu_time_total = xpu_time_total
- self.self_cpu_memory_usage = self_cpu_memory_usage
- self.cpu_memory_usage = cpu_memory_usage
- self.self_cuda_memory_usage = self_cuda_memory_usage
- self.cuda_memory_usage = cuda_memory_usage
- self.self_xpu_memory_usage = self_xpu_memory_usage
- self.xpu_memory_usage = xpu_memory_usage
- self.device_type = device_type
- self.node_id = node_id
- self.overload_name = overload_name
- self.is_async = is_async
- self.is_legacy = is_legacy
-
-
-class FakeEventList(list):
- def table(self, sort_by=None, row_limit=-1):
- return f"fake_table(sort_by={sort_by}, row_limit={row_limit}, len={len(self)})"
-
-
-class FakeTorchProfiler:
- def __init__(self, on_trace_ready=None):
- self.started = False
- self.stopped = False
- self.on_trace_ready = on_trace_ready
- self.exported_traces = []
- self.exported_stacks = []
-
- def start(self):
- self.started = True
-
- def stop(self):
- self.stopped = True
- if self.on_trace_ready is not None:
- self.on_trace_ready(self)
-
- def export_chrome_trace(self, path):
- Path(path).write_text('{"traceEvents": []}')
- self.exported_traces.append(path)
-
- def export_stacks(self, path, metric):
- Path(path).write_text(f"metric={metric}\nstack_line_1\nstack_line_2\n")
- self.exported_stacks.append((path, metric))
-
- def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
- if group_by_input_shape:
- return FakeEventList(
- [
- FakeEvent(
- name="aten::bmm",
- input_shapes=[[4, 8, 16], [4, 16, 32]],
- )
- ]
- )
- if group_by_stack_n:
- return FakeEventList(
- [
- FakeEvent(
- name="aten::all_reduce",
- stack=["python_a", "python_b", "python_c"],
- )
- ]
- )
- return FakeEventList(
- [
- FakeEvent(name="aten::mm"),
- FakeEvent(name="nccl:all_reduce"),
- ]
- )
-
-
-@pytest.fixture
-def fake_config(tmp_path):
- return DummyProfilerConfig(torch_profiler_dir=str(tmp_path))
-
-
-@pytest.fixture
-def fake_profiler_factory(monkeypatch):
- created = {}
-
- def fake_profile(*args, **kwargs):
- profiler = FakeTorchProfiler(on_trace_ready=kwargs.get("on_trace_ready"))
- created["profiler"] = profiler
- created["args"] = args
- created["kwargs"] = kwargs
- return profiler
-
- monkeypatch.setattr(profiler_mod.torch.profiler, "profile", fake_profile)
- return created
-
-
-@pytest.fixture
-def wrapper(fake_config, fake_profiler_factory):
- return OmniTorchProfilerWrapper(
- profiler_config=fake_config,
- worker_name="worker0",
- local_rank=0,
- activities=["CPU", "CUDA"],
- )
-
-
-def test_set_trace_filename_creates_timestamped_session_dir(wrapper, monkeypatch, tmp_path):
- class FixedDatetime:
- @classmethod
- def now(cls):
- class _Now:
- def strftime(self, fmt):
- return "20260403-034200"
-
- return _Now()
-
- monkeypatch.setattr(profiler_mod, "datetime", FixedDatetime)
-
- wrapper.set_trace_filename("stage_0_llm_1234567890")
-
- session_dir = Path(wrapper._session_dir)
- assert session_dir.exists()
- assert session_dir.parent == tmp_path
- assert session_dir.name == "20260403-034200_stage_0_llm_1234567890"
-
-
-def test_set_trace_filename_with_full_path_creates_timestamped_leaf(wrapper, monkeypatch, tmp_path):
- class FixedDatetime:
- @classmethod
- def now(cls):
- class _Now:
- def strftime(self, fmt):
- return "20260403-111111"
-
- return _Now()
-
- monkeypatch.setattr(profiler_mod, "datetime", FixedDatetime)
-
- target = tmp_path / "nested" / "stage_x"
- wrapper.set_trace_filename(str(target))
-
- session_dir = Path(wrapper._session_dir)
- assert session_dir.exists()
- assert session_dir.parent == target.parent
- assert session_dir.name == "20260403-111111_stage_x"
-
-
-def test_on_trace_ready_exports_trace_json(wrapper):
- wrapper.set_trace_filename("case_trace")
-
- wrapper._on_trace_ready(wrapper.profiler)
-
- trace_path = Path(wrapper._trace_path)
- assert trace_path.exists()
- assert trace_path.name == "trace_rank0.json"
- assert trace_path.read_text() == '{"traceEvents": []}'
-
-
-def test_on_trace_ready_exports_gzip_trace(fake_config, fake_profiler_factory, monkeypatch):
- fake_config.torch_profiler_use_gzip = True
-
- wrapper = OmniTorchProfilerWrapper(
- profiler_config=fake_config,
- worker_name="worker0",
- local_rank=0,
- activities=["CPU", "CUDA"],
- )
- wrapper.set_trace_filename("case_gzip")
-
- def fake_popen(cmd):
- assert cmd[:2] == ["gzip", "-f"]
- src = Path(cmd[2])
- gz_path = src.with_suffix(src.suffix + ".gz")
- gz_path.write_bytes(gzip.compress(src.read_bytes()))
- src.unlink()
-
- class DummyProc:
- pass
-
- return DummyProc()
-
- monkeypatch.setattr(profiler_mod.subprocess, "Popen", fake_popen)
-
- wrapper._on_trace_ready(wrapper.profiler)
-
- assert wrapper._trace_path.endswith(".json.gz")
- gz_path = Path(wrapper._trace_path)
- assert gz_path.exists()
- assert gzip.decompress(gz_path.read_bytes()) == b'{"traceEvents": []}'
-
-
-def test_start_enables_memory_history(wrapper, monkeypatch):
- calls = []
-
- monkeypatch.setattr(profiler_mod.torch.cuda, "is_available", lambda: True)
-
- def fake_record_memory_history(*args, **kwargs):
- calls.append((args, kwargs))
-
- monkeypatch.setattr(
- profiler_mod.torch.cuda.memory,
- "_record_memory_history",
- fake_record_memory_history,
- )
-
- wrapper.set_trace_filename("case_memory_start")
- wrapper._start()
-
- assert wrapper.profiler.started is True
- assert wrapper._memory_history_enabled is True
- assert len(calls) == 1
- assert calls[0][1]["enabled"] == "all"
- assert calls[0][1]["context"] == "all"
- assert calls[0][1]["stacks"] == "python"
- assert calls[0][1]["max_entries"] == 100000
- assert calls[0][1]["clear_history"] is True
-
-
-def test_start_skips_memory_history_when_memory_disabled(fake_config, fake_profiler_factory, monkeypatch):
- fake_config.torch_profiler_with_memory = False
-
- wrapper = OmniTorchProfilerWrapper(
- profiler_config=fake_config,
- worker_name="worker0",
- local_rank=0,
- activities=["CPU", "CUDA"],
- )
-
- called = {"n": 0}
-
- monkeypatch.setattr(profiler_mod.torch.cuda, "is_available", lambda: True)
-
- def fake_record_memory_history(*args, **kwargs):
- called["n"] += 1
-
- monkeypatch.setattr(
- profiler_mod.torch.cuda.memory,
- "_record_memory_history",
- fake_record_memory_history,
- )
-
- wrapper.set_trace_filename("case_skip_memory")
- wrapper._start()
-
- assert called["n"] == 0
- assert wrapper._memory_history_enabled is False
-
-
-def test_try_dump_memory_snapshot_writes_pickle(wrapper, monkeypatch):
- wrapper.set_trace_filename("case_snapshot")
- wrapper._memory_history_enabled = True
- wrapper._memory_history_backend = "CUDA"
- wrapper._memory_history_module = profiler_mod.torch.cuda.memory
-
- disable_calls = []
-
- def fake_record_memory_history(*args, **kwargs):
- disable_calls.append((args, kwargs))
-
- def fake_dump_snapshot(path):
- Path(path).write_bytes(b"fake pickle bytes")
-
- monkeypatch.setattr(
- profiler_mod.torch.cuda.memory,
- "_record_memory_history",
- fake_record_memory_history,
- )
- monkeypatch.setattr(
- profiler_mod.torch.cuda.memory,
- "_dump_snapshot",
- fake_dump_snapshot,
- )
-
- wrapper._try_dump_memory_snapshot()
-
- snapshot = Path(wrapper._artifact_paths["memory_snapshot"])
- assert snapshot.exists()
- assert snapshot.name == "memory_snapshot_rank0.pickle"
- assert snapshot.read_bytes() == b"fake pickle bytes"
- assert wrapper._memory_history_enabled is False
-
- assert disable_calls[-1][1]["enabled"] is None
-
-
-def test_stop_always_dumps_memory_snapshot_on_success_path(wrapper, monkeypatch):
- wrapper.set_trace_filename("case_stop")
-
- record_calls = []
- dump_calls = []
-
- monkeypatch.setattr(profiler_mod.torch.cuda, "is_available", lambda: True)
-
- def fake_record_memory_history(*args, **kwargs):
- record_calls.append((args, kwargs))
-
- def fake_dump_snapshot(path):
- dump_calls.append(path)
- Path(path).write_bytes(b"snapshot-bytes")
-
- monkeypatch.setattr(
- profiler_mod.torch.cuda.memory,
- "_record_memory_history",
- fake_record_memory_history,
- )
- monkeypatch.setattr(
- profiler_mod.torch.cuda.memory,
- "_dump_snapshot",
- fake_dump_snapshot,
- )
-
- wrapper._start()
- wrapper._stop()
-
- session_dir = Path(wrapper._session_dir)
-
- assert wrapper.profiler.started is True
- assert wrapper.profiler.stopped is True
- assert (session_dir / "memory_snapshot_rank0.pickle").exists()
- assert len(dump_calls) == 1
- assert record_calls[0][1]["enabled"] == "all"
- assert record_calls[-1][1]["enabled"] is None
-
-
-def test_on_stop_hook_generates_stack_and_excel_artifacts(wrapper):
- wrapper.set_trace_filename("case_artifacts")
- wrapper._on_stop_hook()
-
- session_dir = Path(wrapper._session_dir)
-
- assert not (session_dir / "ops_summary_rank0.txt").exists()
- assert not (session_dir / "ops_by_shape_rank0.txt").exists()
- assert not (session_dir / "ops_by_stack_rank0.txt").exists()
- assert (session_dir / "stacks_cpu_rank0.txt").exists()
- assert (session_dir / "stacks_cuda_rank0.txt").exists()
- assert (session_dir / "ops_rank0.xlsx").exists()
-
-
-def test_excel_contains_expected_sheets(wrapper):
- wrapper.set_trace_filename("case_excel")
- wrapper._on_stop_hook()
-
- xlsx_path = Path(wrapper._session_dir) / "ops_rank0.xlsx"
- wb = load_workbook(xlsx_path)
-
- assert "summary" in wb.sheetnames
- assert "by_shape" in wb.sheetnames
- assert "by_stack" in wb.sheetnames
-
-
-def test_excel_summary_has_expected_columns(wrapper):
- wrapper.set_trace_filename("case_excel_columns")
- wrapper._on_stop_hook()
-
- xlsx_path = Path(wrapper._session_dir) / "ops_rank0.xlsx"
- wb = load_workbook(xlsx_path)
- ws = wb["summary"]
-
- headers = [cell.value for cell in next(ws.iter_rows(min_row=1, max_row=1))]
- assert "name" in headers
- assert "count" in headers
- assert "self_cpu_time_total_us" in headers
- assert "self_cuda_time_total_us" in headers
- assert "self_cpu_memory_usage_bytes" in headers
- assert "self_cuda_memory_usage_bytes" in headers
- assert "input_shapes" in headers
- assert "stack" in headers
-
-
-def test_get_results_returns_all_artifact_paths(wrapper, monkeypatch):
- wrapper.set_trace_filename("case_results")
-
- monkeypatch.setattr(profiler_mod.torch.cuda, "is_available", lambda: True)
- monkeypatch.setattr(
- profiler_mod.torch.cuda.memory,
- "_record_memory_history",
- lambda *args, **kwargs: None,
- )
- monkeypatch.setattr(
- profiler_mod.torch.cuda.memory,
- "_dump_snapshot",
- lambda path: Path(path).write_bytes(b"snapshot"),
- )
-
- wrapper._start()
- wrapper._stop()
-
- results = wrapper.get_results()
-
- assert "trace" in results
- assert "table" in results
- assert "session_dir" in results
- assert "ops" in results
- assert "memory_snapshot" in results
- assert Path(results["session_dir"]).exists()
- assert Path(results["ops"]).exists()
- assert Path(results["table"]).exists()
- assert Path(results["table"]).name == "ops_rank0.xlsx"
- assert Path(results["memory_snapshot"]).exists()
-
-
-def test_start_uses_xpu_memory_history_when_available(wrapper, monkeypatch):
- calls = []
-
- def fake_record_memory_history(*args, **kwargs):
- calls.append((args, kwargs))
-
- fake_memory_module = SimpleNamespace(
- _record_memory_history=fake_record_memory_history,
- )
- monkeypatch.setattr(
- wrapper,
- "_resolve_memory_history_backend",
- lambda: ("XPU", fake_memory_module),
- )
-
- wrapper.set_trace_filename("case_xpu_memory_start")
- wrapper._start()
-
- assert wrapper._memory_history_enabled is True
- assert wrapper._memory_history_backend == "XPU"
- assert wrapper._memory_history_module is fake_memory_module
- assert calls[0][1]["enabled"] == "all"
-
-
-def test_start_uses_npu_memory_history_when_available(wrapper, monkeypatch):
- calls = []
-
- def fake_record_memory_history(*args, **kwargs):
- calls.append((args, kwargs))
-
- fake_memory_module = SimpleNamespace(
- _record_memory_history=fake_record_memory_history,
- )
- monkeypatch.setattr(
- wrapper,
- "_resolve_memory_history_backend",
- lambda: ("NPU", fake_memory_module),
- )
-
- wrapper.set_trace_filename("case_npu_memory_start")
- wrapper._start()
-
- assert wrapper._memory_history_enabled is True
- assert wrapper._memory_history_backend == "NPU"
- assert wrapper._memory_history_module is fake_memory_module
- assert calls[0][1]["enabled"] == "all"
-
-
-def test_start_skips_memory_history_when_backend_api_missing(wrapper, monkeypatch):
- fake_memory_module = SimpleNamespace()
- monkeypatch.setattr(
- wrapper,
- "_resolve_memory_history_backend",
- lambda: ("XPU", fake_memory_module),
- )
-
- wrapper.set_trace_filename("case_missing_memory_api")
- wrapper._start()
-
- assert wrapper._memory_history_enabled is False
- assert wrapper._memory_history_backend is None
- assert wrapper._memory_history_module is None
-
-
-def test_try_dump_memory_snapshot_uses_resolved_backend_module(wrapper):
- wrapper.set_trace_filename("case_xpu_snapshot")
- wrapper._memory_history_enabled = True
- wrapper._memory_history_backend = "XPU"
-
- calls = []
-
- def fake_record_memory_history(*args, **kwargs):
- calls.append((args, kwargs))
-
- def fake_dump_snapshot(path):
- Path(path).write_bytes(b"xpu snapshot bytes")
-
- wrapper._memory_history_module = SimpleNamespace(
- _record_memory_history=fake_record_memory_history,
- _dump_snapshot=fake_dump_snapshot,
- )
-
- wrapper._try_dump_memory_snapshot()
-
- snapshot = Path(wrapper._artifact_paths["memory_snapshot"])
- assert snapshot.exists()
- assert snapshot.read_bytes() == b"xpu snapshot bytes"
- assert calls[-1][1]["enabled"] is None
- assert wrapper._memory_history_enabled is False
- assert wrapper._memory_history_backend is None
- assert wrapper._memory_history_module is None
-
-
-def test_event_list_to_rows_contains_expected_fields(wrapper):
- rows = wrapper._event_list_to_rows(
- [
- FakeEvent(
- name="aten::linear",
- input_shapes=[[8, 16], [16, 32]],
- stack=["f1", "f2"],
- )
- ]
- )
-
- assert len(rows) == 1
- row = rows[0]
- assert row["name"] == "aten::linear"
- assert row["count"] == 1
- assert row["self_cpu_time_total_us"] == 10.0
- assert row["self_cuda_time_total_us"] == 20.0
- assert row["self_cpu_memory_usage_bytes"] == 128
- assert row["self_cuda_memory_usage_bytes"] == 1024
- assert "[[8, 16], [16, 32]]" == row["input_shapes"]
- assert row["stack"] == "f1\nf2"
diff --git a/tests/test_arg_utils.py b/tests/test_arg_utils.py
deleted file mode 100644
index ae640b2d861..00000000000
--- a/tests/test_arg_utils.py
+++ /dev/null
@@ -1,488 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Tests for vllm_omni.engine.arg_utils — invariants that must
-hold for the orchestrator/engine/server CLI flag partition."""
-
-from __future__ import annotations
-
-import logging
-from dataclasses import dataclass, fields
-
-import pytest
-
-from vllm_omni.engine.arg_utils import (
- SHARED_FIELDS,
- derive_server_dests_from_vllm_parser,
- internal_blacklist_keys,
- orchestrator_args_from_argparse,
- orchestrator_field_names,
- split_kwargs,
-)
-
-# ---------------------------------------------------------------------------
-# Fake engine class for unit testing — avoids pulling in the full vllm
-# EngineArgs and its heavy __post_init__ at test time.
-# ---------------------------------------------------------------------------
-
-
-@dataclass
-class _FakeEngineArgs:
- """Stand-in for OmniEngineArgs with a representative subset of fields."""
-
- model: str = ""
- stage_id: int = 0
- max_num_seqs: int = 64
- gpu_memory_utilization: float = 0.9
- async_chunk: bool = False # also in OrchestratorArgs → shared
- log_stats: bool = False # also in OrchestratorArgs → shared
- stage_configs_path: str | None = None
-
-
-# ============================================================================
-# Invariant 1 — OrchestratorArgs and engine must not ambiguously overlap.
-# ============================================================================
-
-
-def test_no_ambiguous_overlap_with_fake_engine():
- """OrchestratorArgs ∩ engine fields must be ⊆ SHARED_FIELDS."""
- orch = orchestrator_field_names()
- engine = {f.name for f in fields(_FakeEngineArgs)}
- overlap = orch & engine
- unexpected = overlap - SHARED_FIELDS
- assert not unexpected, (
- f"Fields declared in both OrchestratorArgs and the engine class "
- f"but not in SHARED_FIELDS: {sorted(unexpected)}. These cause "
- f"double-routing — either remove the duplicate declaration or add "
- f"to SHARED_FIELDS if sharing is intentional."
- )
-
-
-def test_no_ambiguous_overlap_with_real_engine():
- """Same check, but against the real OmniEngineArgs."""
- try:
- from vllm_omni.engine.arg_utils import OmniEngineArgs
- except Exception as exc:
- pytest.skip(f"OmniEngineArgs not importable: {exc}")
-
- orch = orchestrator_field_names()
- engine = {f.name for f in fields(OmniEngineArgs)}
- overlap = orch & engine
- unexpected = overlap - SHARED_FIELDS
- assert not unexpected, (
- f"Real OmniEngineArgs has ambiguous overlap with OrchestratorArgs: "
- f"{sorted(unexpected)}. Update SHARED_FIELDS or remove duplication."
- )
-
-
-# ============================================================================
-# Invariant 2 — split_kwargs partitions correctly.
-# ============================================================================
-
-
-def test_split_orchestrator_only():
- """Pure orchestrator fields go to OrchestratorArgs, not engine_kwargs."""
- raw = {"stage_init_timeout": 500, "worker_backend": "ray"}
- orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
- assert orch.stage_init_timeout == 500
- assert orch.worker_backend == "ray"
- assert "stage_init_timeout" not in engine
- assert "worker_backend" not in engine
-
-
-def test_split_engine_only():
- """Pure engine fields go to engine_kwargs, not OrchestratorArgs."""
- raw = {"max_num_seqs": 128, "gpu_memory_utilization": 0.85}
- orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
- assert engine["max_num_seqs"] == 128
- assert engine["gpu_memory_utilization"] == 0.85
- # These fields don't exist on OrchestratorArgs at all.
-
-
-def test_split_shared_fields_go_to_both():
- """Fields in SHARED_FIELDS are copied to both buckets."""
- raw = {"model": "Qwen/Qwen2.5-Omni-7B", "log_stats": True}
- orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
- assert orch.log_stats is True
- assert engine["model"] == "Qwen/Qwen2.5-Omni-7B"
- assert engine["log_stats"] is True
-
-
-def test_split_drops_unclassified():
- """Unclassified fields (uvicorn/server) are dropped silently."""
- raw = {
- "max_num_seqs": 64, # engine
- "host": "0.0.0.0", # unclassified (server)
- "port": 8091, # unclassified (server)
- "ssl_keyfile": "key.pem", # unclassified (server)
- }
- orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
- assert engine == {"max_num_seqs": 64}
- assert "host" not in engine
- assert "port" not in engine
- assert "ssl_keyfile" not in engine
-
-
-def test_split_mixed_real_world():
- """End-to-end: raw CLI kwargs with all three classes present."""
- raw = {
- # orchestrator
- "stage_init_timeout": 400,
- "deploy_config": "/tmp/deploy.yaml",
- "worker_backend": "multi_process",
- "async_chunk": True,
- # engine
- "max_num_seqs": 32,
- "gpu_memory_utilization": 0.8,
- # shared
- "model": "Qwen/Qwen3-Omni",
- "log_stats": False,
- # server / unclassified
- "host": "0.0.0.0",
- "port": 8091,
- "api_key": "secret",
- # None values
- "ray_address": None,
- }
- orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
-
- # Orchestrator side
- assert orch.stage_init_timeout == 400
- assert orch.deploy_config == "/tmp/deploy.yaml"
- assert orch.worker_backend == "multi_process"
- assert orch.async_chunk is True
- assert orch.log_stats is False # shared, read from raw
- assert orch.ray_address is None # default preserved
-
- # Engine side
- assert engine["max_num_seqs"] == 32
- assert engine["gpu_memory_utilization"] == 0.8
- assert engine["model"] == "Qwen/Qwen3-Omni"
- assert engine["log_stats"] is False
- assert "host" not in engine
- assert "port" not in engine
- assert "api_key" not in engine
- # orchestrator-only keys never reach engine
- assert "stage_init_timeout" not in engine
- assert "deploy_config" not in engine
- assert "async_chunk" not in engine
-
-
-# ============================================================================
-# Invariant 3 — user-typed unclassifiable flags warn (don't fail silently).
-# ============================================================================
-
-
-def test_user_typed_unclassified_warns(caplog):
- """If the user types a flag we can't route, warn — don't silently drop."""
- raw = {"bogus_flag": "value", "max_num_seqs": 64}
- with caplog.at_level(logging.WARNING, logger="vllm_omni.engine.arg_utils"):
- split_kwargs(raw, engine_cls=_FakeEngineArgs, user_typed={"bogus_flag"})
- assert any("bogus_flag" in rec.message for rec in caplog.records), (
- f"Expected warning mentioning 'bogus_flag', got: {[rec.message for rec in caplog.records]}"
- )
-
-
-def test_unclassified_without_user_typed_silent(caplog):
- """Without user_typed, unclassified keys drop silently (argparse defaults
- for server flags shouldn't spam logs on every launch)."""
- raw = {"host": "0.0.0.0", "port": 8091, "max_num_seqs": 64}
- with caplog.at_level(logging.WARNING, logger="vllm_omni.engine.arg_utils"):
- split_kwargs(raw, engine_cls=_FakeEngineArgs, user_typed=None)
- # No warnings because we don't know these were user-typed.
- assert not any("host" in rec.message or "port" in rec.message for rec in caplog.records)
-
-
-# ============================================================================
-# Invariant 4 — CLI flag classification completeness.
-# Catches new flags added without updating OrchestratorArgs or OmniEngineArgs.
-# ============================================================================
-
-
-def test_all_omni_cli_flags_classified():
- """Every vllm-omni-added CLI flag must be classifiable.
-
- Runs ``OmniServeCommand.subparser_init`` and checks that every new
- argument (compared to vllm's base parser) is either:
- - a field on OrchestratorArgs, OR
- - a field on OmniEngineArgs, OR
- - in SHARED_FIELDS
- """
- try:
- from vllm.utils.argparse_utils import FlexibleArgumentParser
-
- from vllm_omni.engine.arg_utils import OmniEngineArgs
- from vllm_omni.entrypoints.cli.serve import OmniServeCommand
- except Exception as exc:
- pytest.skip(f"Cannot build parser in this environment: {exc}")
-
- # Build the serve parser
- root = FlexibleArgumentParser()
- subparsers = root.add_subparsers()
- cmd = OmniServeCommand()
- try:
- parser = cmd.subparser_init(subparsers)
- except Exception as exc:
- pytest.skip(f"subparser_init failed (dev env issue): {exc}")
-
- all_dests = {a.dest for a in parser._actions if a.dest and a.dest not in {"help", "model_tag"}}
-
- orch = orchestrator_field_names()
- engine = {f.name for f in fields(OmniEngineArgs)}
- server_derived = derive_server_dests_from_vllm_parser()
-
- unclassified = all_dests - orch - engine - SHARED_FIELDS - server_derived
- # Some argparse-internal dests (suppressed, private) may not match —
- # filter those out.
- unclassified = {d for d in unclassified if not d.startswith("_")}
-
- assert not unclassified, (
- f"These CLI flags are not classified as "
- f"orchestrator/engine/shared/server: {sorted(unclassified)}. "
- f"Add them to OrchestratorArgs (if consumed by orchestrator), "
- f"OmniEngineArgs (if consumed by per-stage engine), or the known-server "
- f"allowlist (if they're vllm frontend flags). "
- f"If intentional (e.g. a new CLI-only flag that doesn't map to either "
- f"dataclass), add it to a KNOWN_UNROUTED allowlist."
- )
-
-
-# ============================================================================
-# argparse interop (Phase 3).
-# ============================================================================
-
-
-def test_orchestrator_args_from_argparse():
- """Can build OrchestratorArgs from an argparse.Namespace."""
- import argparse
-
- ns = argparse.Namespace(
- stage_init_timeout=500,
- deploy_config="/tmp/x.yaml",
- max_num_seqs=64, # engine field — ignored
- host="0.0.0.0", # server field — ignored
- )
- orch = orchestrator_args_from_argparse(ns)
- assert orch.stage_init_timeout == 500
- assert orch.deploy_config == "/tmp/x.yaml"
- assert orch.worker_backend == "multi_process" # default
-
-
-def test_derive_server_dests_returns_frozenset():
- """Server-dest derivation returns a frozenset (possibly empty)."""
- result = derive_server_dests_from_vllm_parser()
- assert isinstance(result, frozenset)
-
-
-# ============================================================================
-# internal_blacklist_keys — single source of truth for per-stage forwarding.
-# ============================================================================
-
-
-def test_internal_blacklist_keys_derived_from_orchestrator():
- """Blacklist is exactly OrchestratorArgs fields minus SHARED_FIELDS.
-
- This function replaces the old hardcoded INTERNAL_STAGE_OVERRIDE_KEYS
- frozenset. Asserts the contract so future changes to OrchestratorArgs
- automatically propagate to the blacklist.
- """
- blacklist = internal_blacklist_keys()
- assert blacklist == orchestrator_field_names() - SHARED_FIELDS
- # Spot-check expected entries
- assert "stage_init_timeout" in blacklist
- assert "deploy_config" in blacklist
- assert "async_chunk" in blacklist
- # Shared fields must NOT appear — they flow to both orchestrator and engine
- assert "model" not in blacklist
- assert "log_stats" not in blacklist
-
-
-# ============================================================================
-# Boundary value analysis — edge cases around split_kwargs.
-# ============================================================================
-
-
-def test_split_empty_kwargs():
- """Empty kwargs yields default OrchestratorArgs and empty engine dict."""
- orch, engine = split_kwargs({}, engine_cls=_FakeEngineArgs)
- assert orch.stage_init_timeout == 300 # dataclass default
- assert orch.worker_backend == "multi_process" # dataclass default
- assert engine == {}
-
-
-def test_split_all_none_values_preserved_on_orchestrator():
- """None values for orchestrator fields are kept (represents 'not set')."""
- raw = {"ray_address": None, "deploy_config": None, "max_num_seqs": None}
- orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
- assert orch.ray_address is None
- assert orch.deploy_config is None
- # Engine-side None still passes through; caller decides semantics downstream.
- assert engine.get("max_num_seqs") is None
-
-
-def test_split_user_typed_with_empty_kwargs_no_warn(caplog):
- """user_typed non-empty but kwargs empty — no warnings emitted."""
- with caplog.at_level(logging.WARNING, logger="vllm_omni.engine.arg_utils"):
- split_kwargs({}, engine_cls=_FakeEngineArgs, user_typed={"nothing"})
- assert not caplog.records
-
-
-def test_ambiguous_field_strict_raises():
- """strict=True raises ValueError on overlap outside SHARED_FIELDS."""
-
- # deploy_config is on OrchestratorArgs; declaring it on the engine class
- # too (without adding to SHARED_FIELDS) creates an ambiguous route.
- @dataclass
- class _AmbiguousEngine:
- deploy_config: str | None = None
-
- with pytest.raises(ValueError, match="both OrchestratorArgs and"):
- split_kwargs({"deploy_config": "x"}, engine_cls=_AmbiguousEngine, strict=True)
-
-
-def test_ambiguous_field_non_strict_routes_to_orchestrator(caplog):
- """strict=False logs ERROR but routes the ambiguous field to orchestrator."""
-
- @dataclass
- class _AmbiguousEngine:
- deploy_config: str | None = None
-
- with caplog.at_level(logging.ERROR, logger="vllm_omni.engine.arg_utils"):
- orch, engine = split_kwargs({"deploy_config": "x"}, engine_cls=_AmbiguousEngine, strict=False)
- assert orch.deploy_config == "x"
- assert "deploy_config" not in engine
- assert any("both OrchestratorArgs" in r.message for r in caplog.records)
-
-
-# Sentinel-default precedence invariants (#3035)
-
-
-def _build_full_serve_parser():
- from vllm.utils.argparse_utils import FlexibleArgumentParser
-
- try:
- from vllm.entrypoints.openai.cli_args import make_arg_parser
- except ImportError:
- pytest.skip("vllm parser not importable")
- return make_arg_parser(FlexibleArgumentParser())
-
-
-def test_nullify_stage_engine_defaults_resets_inherited_defaults():
- import argparse
-
- from vllm_omni.engine.arg_utils import (
- deploy_override_field_names,
- nullify_stage_engine_defaults,
- )
-
- parser = _build_full_serve_parser()
- nullify_stage_engine_defaults(parser)
-
- override_dests = deploy_override_field_names()
- offenders = [
- (a.dest, a.default)
- for a in parser._actions
- if a.dest not in ("help", "version")
- and a.option_strings
- and a.dest in override_dests
- and a.default is not None
- and a.default is not argparse.SUPPRESS
- ]
- assert not offenders, f"Stage flags with non-None defaults after nullify: {offenders}"
-
-
-def test_non_override_flags_keep_real_defaults_after_nullify():
- import argparse
-
- from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
-
- parser = argparse.ArgumentParser()
- parser.add_argument("--hsdp-shard-size", type=int, default=-1, help="HSDP shard size.")
- parser.add_argument("--max-num-seqs", type=int, default=64, help="Max num seqs.")
- nullify_stage_engine_defaults(parser)
-
- hsdp = next(a for a in parser._actions if a.dest == "hsdp_shard_size")
- max_num_seqs = next(a for a in parser._actions if a.dest == "max_num_seqs")
- assert hsdp.default == -1
- assert max_num_seqs.default is None
-
-
-def test_help_text_preserves_default_after_nullify():
- # Real defaults must stay visible in --help even though parser stores None.
- import argparse
-
- from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
-
- parser = argparse.ArgumentParser()
- parser.add_argument("--max-num-seqs", type=int, default=42, help="Example knob.")
- nullify_stage_engine_defaults(parser)
-
- action = next(a for a in parser._actions if a.dest == "max_num_seqs")
- assert action.default is None
- assert "(default: 42)" in action.help
-
-
-_OMNIENGINEARGS_USER_INPUT_FIELDS = frozenset(
- {
- "model_stage",
- "model_arch",
- "engine_output_type",
- "hf_config_name",
- "custom_process_next_stage_input_func",
- "subtalker_sampling_params",
- "async_chunk",
- "omni_kv_config",
- "quantization_config",
- "worker_type",
- "task_type",
- "worker_cls",
- "enable_sleep_mode",
- "omni_master_address",
- "omni_master_port",
- "stage_configs_path",
- "output_modalities",
- "log_stats",
- "custom_pipeline_args",
- }
-)
-
-
-def test_omniengineargs_user_input_fields_default_to_none():
- try:
- from vllm_omni.engine.arg_utils import OmniEngineArgs
- except Exception as exc:
- pytest.skip(f"OmniEngineArgs not importable: {exc}")
-
- offenders = [
- (f.name, f.default)
- for f in fields(OmniEngineArgs)
- if f.name in _OMNIENGINEARGS_USER_INPUT_FIELDS
- and f.default is not dataclasses.MISSING
- and f.default is not None
- ]
- assert not offenders, f"User-input fields with non-None defaults: {offenders}"
-
-
-def test_omniengineargs_create_tracks_explicit_fields():
- try:
- from vllm_omni.engine.arg_utils import OmniEngineArgs
- except Exception as exc:
- pytest.skip(f"OmniEngineArgs not importable: {exc}")
-
- ea = OmniEngineArgs.create(model="x", gpu_memory_utilization=0.5)
- assert ea._explicit_fields == frozenset({"model", "gpu_memory_utilization"})
- assert ea.explicit_kwargs() == {"model": "x", "gpu_memory_utilization": 0.5}
-
-
-def test_omniengineargs_bare_constructor_has_no_explicit_tracking():
- try:
- from vllm_omni.engine.arg_utils import OmniEngineArgs
- except Exception as exc:
- pytest.skip(f"OmniEngineArgs not importable: {exc}")
-
- ea = OmniEngineArgs(model="x")
- assert not hasattr(ea, "_explicit_fields")
- assert "model" in ea.explicit_kwargs()
-
-
-# dataclasses already imported via ``from dataclasses import dataclass, fields``
-import dataclasses # noqa: E402 -- needed for MISSING sentinel above
diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py
index 16d49034fa1..e284de48d0b 100644
--- a/tests/test_config_factory.py
+++ b/tests/test_config_factory.py
@@ -4,27 +4,12 @@
Unit tests for StageConfigFactory and related classes.
"""
-import warnings
-from dataclasses import dataclass
-from pathlib import Path
-
-import pytest
-
from vllm_omni.config.stage_config import (
- _EXECUTION_TYPE_TO_SCHEDULER,
- _PIPELINE_REGISTRY,
ModelPipeline,
- PipelineConfig,
StageConfig,
StageConfigFactory,
- StageExecutionType,
- StagePipelineConfig,
StageType,
- build_stage_runtime_overrides,
- register_pipeline,
- strip_parent_engine_args,
)
-from vllm_omni.engine.arg_utils import SHARED_FIELDS, internal_blacklist_keys
class TestStageType:
@@ -256,9 +241,8 @@ def test_default_diffusion_no_yaml(self):
def test_default_diffusion_with_parallel_config(self):
"""Test diffusion config calculates devices from parallel_config."""
- @dataclass
class MockParallelConfig:
- world_size: int = 4
+ world_size = 4
kwargs = {
"parallel_config": MockParallelConfig(),
@@ -286,7 +270,7 @@ def test_cli_override_forwards_engine_registered_args(self):
stage = StageConfig(stage_id=0, model_stage="thinker", input_sources=[])
cli_overrides = {
"gpu_memory_utilization": 0.9, # Well-known param
- "custom_engine_flag": True, # Not orchestrator-owned, so forwarded
+ "custom_engine_flag": True, # Not in _INTERNAL_KEYS, so forwarded
}
overrides = StageConfigFactory._merge_cli_overrides(stage, cli_overrides)
@@ -327,81 +311,6 @@ def test_per_stage_override_excludes_internal_keys(self):
assert "batch_timeout" not in overrides
-class TestStageResolutionHelpers:
- """Tests for shared stage override / filtering helpers."""
-
- def test_build_stage_runtime_overrides_ignores_other_stage_and_internal_keys(self):
- # Pass the same filter set the function uses by default
- # (orchestrator-only fields plus SHARED_FIELDS so ``model`` is
- # treated as not-per-stage-overridable).
- overrides = build_stage_runtime_overrides(
- 0,
- {
- "gpu_memory_utilization": 0.5,
- "stage_0_gpu_memory_utilization": 0.9,
- "stage_1_gpu_memory_utilization": 0.1,
- "stage_0_model": "should_be_ignored",
- "parallel_config": {"world_size": 2},
- },
- internal_keys=internal_blacklist_keys() | SHARED_FIELDS,
- )
-
- assert overrides["gpu_memory_utilization"] == 0.9
- assert "model" not in overrides
- assert "parallel_config" not in overrides
-
- def test_strip_parent_engine_args_reports_only_surprising_parent_overrides(self):
- from dataclasses import fields as dc_fields
-
- from vllm.engine.arg_utils import EngineArgs
-
- parent_fields = {f.name: f for f in dc_fields(EngineArgs)}
- filtered, overridden = strip_parent_engine_args(
- {
- "model": "some/model",
- "stage_configs_path": "/tmp/stages.yaml",
- "tensor_parallel_size": 4,
- "worker_extension_cls": "some.Extension",
- "custom_pipeline_args": {"pipeline_class": "demo.Pipeline"},
- },
- parent_fields=parent_fields,
- keep_keys={"worker_extension_cls"},
- strip_keys={"stage_configs_path"},
- no_warn_keys={"model"},
- )
-
- assert filtered == {
- "worker_extension_cls": "some.Extension",
- "custom_pipeline_args": {"pipeline_class": "demo.Pipeline"},
- }
- assert overridden == ["tensor_parallel_size"]
-
- def test_strip_parent_engine_args_keeps_allowed_media_access_controls(self):
- from dataclasses import fields as dc_fields
-
- from vllm.engine.arg_utils import EngineArgs
-
- parent_fields = {f.name: f for f in dc_fields(EngineArgs)}
- filtered, overridden = strip_parent_engine_args(
- {
- "model": "some/model",
- "stage_configs_path": "/tmp/stages.yaml",
- "allowed_local_media_path": "/data/qwentts",
- "allowed_media_domains": ["example.com"],
- },
- parent_fields=parent_fields,
- keep_keys={"allowed_local_media_path", "allowed_media_domains"},
- strip_keys={"stage_configs_path"},
- no_warn_keys={"model"},
- )
-
- assert filtered == {
- "allowed_local_media_path": "/data/qwentts",
- "allowed_media_domains": ["example.com"],
- }
- assert overridden == []
-
-
class TestPipelineYamlParsing:
"""Tests for pipeline YAML file parsing (@ZJY0516)."""
@@ -700,637 +609,16 @@ def test_parse_missing_async_chunk_defaults_false(self, tmp_path):
assert pipeline.async_chunk is False
-class TestPipelineDiscovery:
- """Tests for the central pipeline registry (``pipeline_registry._OMNI_PIPELINES``)."""
-
- def test_registry_has_known_models(self):
- """Built-in pipelines are lazy-loaded from the central declaration
- on first access; no eager import or discovery walk needed."""
- # ``in`` triggers the lazy-map lookup without forcing a load.
- assert "qwen2_5_omni" in _PIPELINE_REGISTRY
- assert "qwen3_omni_moe" in _PIPELINE_REGISTRY
- assert "qwen3_tts" in _PIPELINE_REGISTRY
-
- def test_registry_loads_pipeline_on_getitem(self):
- """Looking up a registered model_type returns the matching PipelineConfig."""
- pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
- assert pipeline.model_type == "qwen3_omni_moe"
- assert len(pipeline.stages) == 3 # thinker + talker + code2wav
-
- def test_registry_returns_none_for_unknown(self):
- """Unknown model_types aren't found; ``get()`` returns None."""
- assert "definitely_not_a_real_model" not in _PIPELINE_REGISTRY
- assert _PIPELINE_REGISTRY.get("definitely_not_a_real_model") is None
-
- def test_pipeline_config_supports_hf_architectures(self):
- """PipelineConfig accepts hf_architectures for HF-arch fallback
- (replaces the old _ARCHITECTURE_MODELS dict)."""
- p = PipelineConfig(
- model_type="custom_collide",
- hf_architectures=("SomeCollidingArch",),
- )
- assert p.hf_architectures == ("SomeCollidingArch",)
-
-
-class TestStagePipelineConfig:
- def test_frozen(self):
- s = StagePipelineConfig(stage_id=0, model_stage="a")
- with pytest.raises(AttributeError):
- s.model_stage = "changed"
-
- def test_defaults(self):
- s = StagePipelineConfig(stage_id=0, model_stage="a")
- assert s.execution_type == StageExecutionType.LLM_AR
- assert s.input_sources == ()
- assert s.final_output is False
- assert s.sampling_constraints == {}
- assert s.engine_output_type is None
-
-
-class TestPipelineConfigNew:
- def test_frozen(self):
- p = PipelineConfig(model_type="t", model_arch="A")
- with pytest.raises(AttributeError):
- p.model_type = "changed"
-
- def test_validate_valid(self):
- p = PipelineConfig(
- model_type="t",
- model_arch="A",
- stages=(
- StagePipelineConfig(stage_id=0, model_stage="a"),
- StagePipelineConfig(stage_id=1, model_stage="b", input_sources=(0,)),
- ),
- )
- assert p.validate() == []
-
- def test_validate_no_stages(self):
- p = PipelineConfig(model_type="t", model_arch="A")
- assert any("no stages" in e.lower() for e in p.validate())
-
- def test_get_scheduler_cls(self):
- p = PipelineConfig(
- model_type="t",
- model_arch="A",
- stages=(
- StagePipelineConfig(stage_id=0, model_stage="a", execution_type=StageExecutionType.LLM_AR),
- StagePipelineConfig(
- stage_id=1, model_stage="b", execution_type=StageExecutionType.LLM_GENERATION, input_sources=(0,)
- ),
- ),
- )
- assert "OmniARScheduler" in p.get_scheduler_cls(0)
- assert "OmniGenerationScheduler" in p.get_scheduler_cls(1)
-
-
-class TestExecutionTypeToScheduler:
- def test_all_types_mapped(self):
- for et in StageExecutionType:
- assert et in _EXECUTION_TYPE_TO_SCHEDULER
-
-
-class TestPipelineRegistry:
- def test_register_and_lookup(self):
- p = PipelineConfig(
- model_type="__test_only__",
- model_arch="A",
- stages=(StagePipelineConfig(stage_id=0, model_stage="a"),),
- )
- register_pipeline(p)
- assert _PIPELINE_REGISTRY["__test_only__"] is p
- del _PIPELINE_REGISTRY["__test_only__"]
-
-
-class TestDeployConfigLoading:
- def test_load_deploy_config(self):
- from pathlib import Path
-
- from vllm_omni.config.stage_config import load_deploy_config
-
- deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
- if not deploy_path.exists():
- pytest.skip("Deploy config not found")
-
- deploy = load_deploy_config(deploy_path)
- assert len(deploy.stages) == 3
- assert deploy.async_chunk is True
- assert deploy.connectors is not None
- assert deploy.platforms is not None
-
- def test_merge_pipeline_deploy(self):
- from pathlib import Path
-
- import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
- from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy
-
- pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
- deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
- if not deploy_path.exists():
- pytest.skip("Deploy config not found")
-
- deploy = load_deploy_config(deploy_path)
- stages = merge_pipeline_deploy(pipeline, deploy)
-
- assert len(stages) == 3
- s0 = stages[0]
- assert s0.model_stage == "thinker"
- assert s0.yaml_engine_args["model_arch"] == "Qwen3OmniMoeForConditionalGeneration"
- assert s0.yaml_engine_args["engine_output_type"] == "latent"
- assert s0.yaml_extras["default_sampling_params"]["detokenize"] is True
-
-
-class TestQwen3OmniPipeline:
- def test_registered(self):
- import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
-
- p = _PIPELINE_REGISTRY.get("qwen3_omni_moe")
- assert p is not None
- assert p.model_arch == "Qwen3OmniMoeForConditionalGeneration"
- assert len(p.stages) == 3
- assert p.validate() == []
-
- def test_thinker(self):
- import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
-
- s = _PIPELINE_REGISTRY["qwen3_omni_moe"].get_stage(0)
- assert s.model_stage == "thinker"
- assert s.execution_type == StageExecutionType.LLM_AR
- assert s.owns_tokenizer is True
- assert s.engine_output_type == "latent"
- assert s.sampling_constraints["detokenize"] is True
-
- def test_talker(self):
- import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
-
- s = _PIPELINE_REGISTRY["qwen3_omni_moe"].get_stage(1)
- assert s.input_sources == (0,)
- assert s.sampling_constraints["stop_token_ids"] == [2150]
- assert s.custom_process_input_func is not None
- assert s.custom_process_next_stage_input_func is not None
-
- def test_code2wav(self):
- import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
-
- s = _PIPELINE_REGISTRY["qwen3_omni_moe"].get_stage(2)
- assert s.execution_type == StageExecutionType.LLM_GENERATION
- assert s.final_output_type == "audio"
- assert s.custom_process_input_func is not None
-
-
-class TestQwen2_5OmniPipeline:
- def test_registered(self):
- import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401
-
- p = _PIPELINE_REGISTRY.get("qwen2_5_omni")
- assert p is not None
- assert p.model_arch == "Qwen2_5OmniForConditionalGeneration"
- assert len(p.stages) == 3
- assert p.validate() == []
-
- def test_thinker(self):
- import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401
-
- s = _PIPELINE_REGISTRY["qwen2_5_omni"].get_stage(0)
- assert s.model_stage == "thinker"
- assert s.execution_type == StageExecutionType.LLM_AR
- assert s.owns_tokenizer is True
- assert s.engine_output_type == "latent"
- assert s.requires_multimodal_data is True
-
- def test_talker(self):
- import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401
-
- s = _PIPELINE_REGISTRY["qwen2_5_omni"].get_stage(1)
- assert s.input_sources == (0,)
- assert s.sampling_constraints["stop_token_ids"] == [8294]
- assert s.custom_process_input_func is not None
-
- def test_code2wav(self):
- import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401
-
- s = _PIPELINE_REGISTRY["qwen2_5_omni"].get_stage(2)
- assert s.execution_type == StageExecutionType.LLM_GENERATION
- assert s.final_output_type == "audio"
- assert s.engine_output_type == "audio"
-
-
-class TestQwen3TTSPipeline:
- def test_registered(self):
- import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401
-
- p = _PIPELINE_REGISTRY.get("qwen3_tts")
- assert p is not None
- assert p.model_arch == "Qwen3TTSTalkerForConditionalGeneration"
- assert len(p.stages) == 2
- assert p.validate() == []
-
- def test_talker_stage(self):
- import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401
-
- s = _PIPELINE_REGISTRY["qwen3_tts"].get_stage(0)
- assert s.model_stage == "qwen3_tts"
- assert s.execution_type == StageExecutionType.LLM_AR
- assert s.owns_tokenizer is True
- assert s.engine_output_type == "latent"
- assert s.sampling_constraints["stop_token_ids"] == [2150]
- # Stage 0 inherits the pipeline-level model_arch
- assert s.model_arch is None
-
- def test_code2wav_stage_has_per_stage_model_arch(self):
- import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401
-
- s = _PIPELINE_REGISTRY["qwen3_tts"].get_stage(1)
- assert s.execution_type == StageExecutionType.LLM_GENERATION
- assert s.final_output_type == "audio"
- assert s.engine_output_type == "audio"
- # Per-stage model_arch override (different from pipeline-level talker)
- assert s.model_arch == "Qwen3TTSCode2Wav"
- # tts_args is passed through via extras
- assert s.extras["tts_args"]["max_instructions_length"] == 500
-
- def test_per_stage_model_arch_flows_through_merge(self, tmp_path):
- """Verify the new ps.model_arch override survives merge_pipeline_deploy."""
- import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401
- from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy
-
- deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_tts.yaml"
- if not deploy_path.exists():
- pytest.skip("qwen3_tts deploy yaml not found")
-
- deploy = load_deploy_config(deploy_path)
- pipeline = _PIPELINE_REGISTRY["qwen3_tts"]
- stages = merge_pipeline_deploy(pipeline, deploy)
-
- # Stage 0 inherits pipeline-level model_arch
- assert stages[0].yaml_engine_args["model_arch"] == "Qwen3TTSTalkerForConditionalGeneration"
- # Stage 1 uses its per-stage override
- assert stages[1].yaml_engine_args["model_arch"] == "Qwen3TTSCode2Wav"
-
- def test_subtalker_sampling_params_deep_merge_preserves_base_keys(self):
- """Verify subtalker sampling params participate in stage deep-merge."""
- from vllm_omni.config.stage_config import _deep_merge_stage
-
- base = {
- "stage_id": 0,
- "subtalker_sampling_params": {
- "do_sample": True,
- "temperature": 0.9,
- "top_k": 50,
- "top_p": 1.0,
- },
- }
- overlay = {
- "stage_id": 0,
- "subtalker_sampling_params": {
- "temperature": 0.7,
- "top_k": 32,
- },
- }
-
- merged = _deep_merge_stage(base, overlay)
-
- assert merged["subtalker_sampling_params"] == {
- "do_sample": True,
- "temperature": 0.7,
- "top_k": 32,
- "top_p": 1.0,
- }
-
-
-class TestBaseConfigInheritance:
- """Test deploy YAML base_config inheritance."""
-
- def test_ci_inherits_from_main(self):
- from tests.helpers.stage_config import get_deploy_config_path
- from vllm_omni.config.stage_config import load_deploy_config
-
- ci_path = Path(get_deploy_config_path("ci/qwen3_omni_moe.yaml"))
- if not ci_path.exists():
- pytest.skip("CI deploy config not found")
-
- deploy = load_deploy_config(ci_path)
- assert len(deploy.stages) == 3
- # CI overrides
- assert deploy.stages[0].engine_extras.get("load_format") == "dummy"
- assert deploy.stages[0].max_num_seqs == 5
- # Inherited from base
- assert deploy.stages[0].gpu_memory_utilization == 0.9
- assert deploy.connectors is not None
- assert "connector_of_shared_memory" in deploy.connectors
- # CI overlay explicitly sets async_chunk: False (see
- # tests.helpers.stage_config._CI_OVERLAYS and PR #2383 discussion). Overlay
- # bool overrides base even when the base yaml has async_chunk: true.
- assert deploy.async_chunk is False
-
- def test_ci_sampling_merge(self):
- from tests.helpers.stage_config import get_deploy_config_path
- from vllm_omni.config.stage_config import load_deploy_config
-
- ci_path = Path(get_deploy_config_path("ci/qwen3_omni_moe.yaml"))
- if not ci_path.exists():
- pytest.skip("CI deploy config not found")
-
- deploy = load_deploy_config(ci_path)
- s0 = deploy.stages[0].default_sampling_params
- # CI overrides max_tokens
- assert s0["max_tokens"] == 150
- # Inherited from base
- assert s0["temperature"] == 0.4
- assert s0["seed"] == 42
-
- def test_pure_inheritance_overlay(self, tmp_path):
- """An overlay with only ``base_config`` inherits everything."""
- from vllm_omni.config.stage_config import load_deploy_config
-
- base = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
- if not base.exists():
- pytest.skip("Base deploy config not found")
-
- overlay = tmp_path / "overlay.yaml"
- overlay.write_text(f"base_config: {base}\n")
-
- deploy = load_deploy_config(overlay)
- assert len(deploy.stages) == 3
- assert deploy.stages[0].gpu_memory_utilization == 0.9
-
- def test_single_field_overlay(self, tmp_path):
- """An overlay overriding one stage field merges with the base."""
- from vllm_omni.config.stage_config import load_deploy_config
-
- base = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
- if not base.exists():
- pytest.skip("Base deploy config not found")
-
- overlay = tmp_path / "overlay.yaml"
- overlay.write_text(f"base_config: {base}\nstages:\n - stage_id: 2\n max_num_batched_tokens: 1000000\n")
-
- deploy = load_deploy_config(overlay)
- assert deploy.stages[2].max_num_batched_tokens == 1000000
- # Rest inherited
- assert deploy.stages[0].gpu_memory_utilization == 0.9
-
-
-class TestPlatformOverrides:
- """Test platform-specific deploy config overrides."""
-
- def test_npu_overrides(self):
- from pathlib import Path
-
- from vllm_omni.config.stage_config import _apply_platform_overrides, load_deploy_config
-
- deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
- if not deploy_path.exists():
- pytest.skip("Deploy config not found")
-
- deploy = load_deploy_config(deploy_path)
- deploy = _apply_platform_overrides(deploy, platform="npu")
-
- assert deploy.stages[0].gpu_memory_utilization == 0.6
- assert deploy.stages[0].tensor_parallel_size == 2
- assert deploy.stages[0].devices == "0,1"
- # Stage 2 unaffected fields stay at base
- assert deploy.stages[2].enforce_eager is True
-
- def test_xpu_overrides(self):
- from pathlib import Path
-
- from vllm_omni.config.stage_config import _apply_platform_overrides, load_deploy_config
-
- deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
- if not deploy_path.exists():
- pytest.skip("Deploy config not found")
-
- deploy = load_deploy_config(deploy_path)
- deploy = _apply_platform_overrides(deploy, platform="xpu")
-
- assert deploy.stages[0].tensor_parallel_size == 4
- assert deploy.stages[0].devices == "0,1,2,3"
- assert deploy.stages[0].engine_extras.get("max_cudagraph_capture_size") == 0
-
- def test_unknown_platform_noop(self):
- from pathlib import Path
-
- from vllm_omni.config.stage_config import _apply_platform_overrides, load_deploy_config
-
- deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
- if not deploy_path.exists():
- pytest.skip("Deploy config not found")
-
- deploy = load_deploy_config(deploy_path)
- original_mem = deploy.stages[0].gpu_memory_utilization
- deploy = _apply_platform_overrides(deploy, platform="unknown_hw")
- assert deploy.stages[0].gpu_memory_utilization == original_mem
-
- def test_platforms_deep_merge_inheritance(self, tmp_path):
- """Overlay's platforms: block layers onto base's, per-stage."""
- from vllm_omni.config.stage_config import _apply_platform_overrides, load_deploy_config
-
- base = tmp_path / "base.yaml"
- base.write_text(
- "stages:\n"
- " - stage_id: 0\n"
- " gpu_memory_utilization: 0.9\n"
- "platforms:\n"
- " rocm:\n"
- " stages:\n"
- " - stage_id: 0\n"
- " enforce_eager: true\n"
- )
- overlay = tmp_path / "overlay.yaml"
- overlay.write_text(
- f"base_config: {base.name}\n"
- "platforms:\n"
- " rocm:\n"
- " stages:\n"
- " - stage_id: 0\n"
- " max_num_seqs: 1\n"
- )
-
- deploy = load_deploy_config(overlay)
- deploy = _apply_platform_overrides(deploy, platform="rocm")
- # Both base's enforce_eager and overlay's max_num_seqs should apply.
- assert deploy.stages[0].enforce_eager is True
- assert deploy.stages[0].max_num_seqs == 1
- # Inherited stage default not touched by overlay platforms section.
- assert deploy.stages[0].gpu_memory_utilization == 0.9
-
-
-class TestCLIOverrideFlow:
- """Test --stage-overrides JSON merge into StageConfig."""
-
- def test_stage_overrides_merge(self):
- from pathlib import Path
-
- import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
- from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy
-
- pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
- deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
- if not deploy_path.exists():
- pytest.skip("Deploy config not found")
-
- deploy = load_deploy_config(deploy_path)
- stages = merge_pipeline_deploy(pipeline, deploy)
-
- # Simulate --stage-overrides '{"0": {"gpu_memory_utilization": 0.5}}'
- overrides = {"stage_0_gpu_memory_utilization": 0.5}
- stages[0].runtime_overrides = StageConfigFactory._merge_cli_overrides(stages[0], overrides)
- assert stages[0].runtime_overrides["gpu_memory_utilization"] == 0.5
-
- def test_global_override_applies_to_all(self):
- from pathlib import Path
-
- import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
- from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy
-
- pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
- deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
- if not deploy_path.exists():
- pytest.skip("Deploy config not found")
-
- deploy = load_deploy_config(deploy_path)
- stages = merge_pipeline_deploy(pipeline, deploy)
-
- overrides = {"enforce_eager": True}
- for s in stages:
- s.runtime_overrides = StageConfigFactory._merge_cli_overrides(s, overrides)
- assert s.runtime_overrides["enforce_eager"] is True
-
-
-class TestSentinelDefaultPrecedence:
- """Caller-typed (non-None) values win over YAML; None values fall through
- to YAML / dataclass defaults (#3035)."""
-
- def _stages(self, cli_overrides):
- import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
-
- return StageConfigFactory._create_from_registry(
- "qwen3_omni_moe",
- cli_overrides=cli_overrides,
- )
-
- def test_typed_kwarg_overrides_yaml(self):
- stages = self._stages({"max_num_seqs": 999})
- assert stages[2].runtime_overrides.get("max_num_seqs") == 999
-
- def test_none_value_skipped_yaml_wins(self):
- stages = self._stages({"max_num_seqs": None})
- assert stages[2].runtime_overrides.get("max_num_seqs") is None
- assert stages[2].yaml_engine_args.get("max_num_seqs") == 1
-
- def test_empty_kwargs_yaml_only(self):
- stages = self._stages({})
- for stage in stages:
- assert stage.runtime_overrides == {}
-
- def test_typed_kwarg_equal_to_dataclass_default_still_overrides(self):
- # Caller intent honored regardless of value coincidence (no heuristic).
- stages = self._stages({"gpu_memory_utilization": 0.9})
- assert stages[2].runtime_overrides.get("gpu_memory_utilization") == 0.9
-
- def test_per_stage_kwarg_routed_to_correct_stage(self):
- stages = self._stages({"stage_0_gpu_memory_utilization": 0.42})
- assert stages[0].runtime_overrides.get("gpu_memory_utilization") == 0.42
- assert stages[2].runtime_overrides.get("gpu_memory_utilization") is None
-
- def test_async_chunk_false_overrides_yaml_true(self):
- stages = self._stages({"async_chunk": False})
- for stage in stages:
- assert stage.yaml_engine_args.get("async_chunk") is not True
-
- def test_async_chunk_none_keeps_yaml_true(self):
- stages = self._stages({"async_chunk": None})
- for stage in stages:
- assert stage.yaml_engine_args.get("async_chunk") is True
-
- def test_enable_prefix_caching_typed_overrides_yaml(self):
- stages = self._stages({"enable_prefix_caching": True})
- for stage in stages:
- assert stage.runtime_overrides.get("enable_prefix_caching") is True
-
- def test_omni_with_vars_args_anti_pattern_is_safe(self):
- # Omni(**vars(args)) with mostly-None namespace must not clobber YAML.
- simulated_vars_args = {
- "gpu_memory_utilization": None,
- "max_num_seqs": None,
- "async_chunk": None,
- "enable_prefix_caching": None,
- "dtype": None,
- }
- stages = self._stages(simulated_vars_args)
- for stage in stages:
- assert stage.runtime_overrides == {}
-
- def test_create_from_registry_no_cli_explicit_keys_param(self):
- import inspect
-
- sig = inspect.signature(StageConfigFactory._create_from_registry)
- named = [p for p in sig.parameters.values() if p.kind != p.VAR_KEYWORD]
- assert "cli_explicit_keys" not in {p.name for p in named}
-
- def test_cli_explicit_keys_kwarg_emits_deprecation(self):
- import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
-
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter("always")
- StageConfigFactory._create_from_registry(
- "qwen3_omni_moe",
- cli_overrides={},
- cli_explicit_keys={"max_num_seqs"},
- )
- assert any(issubclass(x.category, DeprecationWarning) for x in w)
-
- def test_async_chunk_dispatches_processors(self):
- """A single ``qwen3_tts`` pipeline picks per-chunk vs end-to-end
- processors based on ``deploy.async_chunk``, without needing a
- separate variant pipeline registration."""
- import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401
- from vllm_omni.config.stage_config import (
- _PIPELINE_REGISTRY,
- DeployConfig,
- merge_pipeline_deploy,
- )
-
- pipeline = _PIPELINE_REGISTRY["qwen3_tts"]
-
- # async_chunk=True → stage 0's per-chunk processor wires up, stage 1
- # has no sync input processor.
- async_stages = merge_pipeline_deploy(pipeline, DeployConfig(async_chunk=True))
- assert (
- async_stages[0]
- .yaml_engine_args.get("custom_process_next_stage_input_func", "")
- .endswith("talker2code2wav_async_chunk")
- )
- assert async_stages[1].custom_process_input_func is None
-
- # async_chunk=False → stage 0 has no streaming processor, stage 1's
- # batch-end processor wires up.
- sync_stages = merge_pipeline_deploy(pipeline, DeployConfig(async_chunk=False))
- assert "custom_process_next_stage_input_func" not in sync_stages[0].yaml_engine_args
- assert sync_stages[1].custom_process_input_func is not None
- assert sync_stages[1].custom_process_input_func.endswith("talker2code2wav")
-
-
-class TestSamplingConstraintsPrecedence:
- """Test that pipeline sampling_constraints override deploy defaults."""
-
- def test_constraints_win(self):
- from pathlib import Path
-
- import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
- from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy
-
- pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
- deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
- if not deploy_path.exists():
- pytest.skip("Deploy config not found")
+class TestArchitectureFallback:
+ """Tests for architecture-based model detection fallback."""
- deploy = load_deploy_config(deploy_path)
- stages = merge_pipeline_deploy(pipeline, deploy)
+ def test_architecture_models_mapping_exists(self):
+ """Test that _ARCHITECTURE_MODELS contains expected entries."""
+ assert "MiMoAudioForConditionalGeneration" in StageConfigFactory._ARCHITECTURE_MODELS
+ assert StageConfigFactory._ARCHITECTURE_MODELS["MiMoAudioForConditionalGeneration"] == "mimo_audio"
+ assert "HunyuanImage3ForCausalMM" in StageConfigFactory._ARCHITECTURE_MODELS
+ assert StageConfigFactory._ARCHITECTURE_MODELS["HunyuanImage3ForCausalMM"] == "hunyuan_image3"
- # Pipeline says detokenize=True for thinker, deploy can't override
- assert stages[0].yaml_extras["default_sampling_params"]["detokenize"] is True
- # Pipeline says stop_token_ids=[2150] for talker
- assert stages[1].yaml_extras["default_sampling_params"]["stop_token_ids"] == [2150]
- # Deploy temperature still flows through
- assert stages[0].yaml_extras["default_sampling_params"]["temperature"] == 0.4
+ def test_mimo_audio_in_pipeline_models(self):
+ """Test that mimo_audio is registered in PIPELINE_MODELS."""
+ assert "mimo_audio" in StageConfigFactory.PIPELINE_MODELS
diff --git a/tests/test_data_entry_keys.py b/tests/test_data_entry_keys.py
deleted file mode 100644
index f4a50e677de..00000000000
--- a/tests/test_data_entry_keys.py
+++ /dev/null
@@ -1,252 +0,0 @@
-"""Tests for data_entry_keys: TypedDict payload structure, flatten/unflatten, serialize/deserialize."""
-
-import torch
-
-from vllm_omni.data_entry_keys import (
- OmniPayload,
- deserialize_payload,
- flatten_payload,
- serialize_payload,
- unflatten_payload,
-)
-from vllm_omni.engine import AdditionalInformationPayload
-
-
-class TestOmniPayload:
- def test_nested_payload_structure(self):
- """Verify OmniPayload can be constructed with nested dicts."""
- payload: OmniPayload = {
- "hidden_states": {"output": torch.tensor([1.0])},
- "embed": {"prefill": torch.tensor([2.0])},
- "codes": {"audio": torch.tensor([3.0])},
- "ids": {"all": [1, 2, 3]},
- "meta": {"finished": torch.tensor(True, dtype=torch.bool)},
- }
- assert torch.equal(payload["hidden_states"]["output"], torch.tensor([1.0]))
- assert torch.equal(payload["embed"]["prefill"], torch.tensor([2.0]))
- assert torch.equal(payload["codes"]["audio"], torch.tensor([3.0]))
- assert payload["ids"]["all"] == [1, 2, 3]
- assert payload["meta"]["finished"].item() is True
-
- def test_partial_payload(self):
- """OmniPayload fields are all optional (total=False)."""
- payload: OmniPayload = {"meta": {"finished": torch.tensor(False, dtype=torch.bool)}}
- assert payload["meta"]["finished"].item() is False
-
- def test_empty_payload(self):
- payload: OmniPayload = {}
- assert len(payload) == 0
-
-
-class TestFlattenPayload:
- def test_basic_nested_to_dotted(self):
- nested = {
- "codes": {"audio": torch.tensor([1.0])},
- "meta": {"finished": torch.tensor(True, dtype=torch.bool), "left_context_size": 5},
- }
- flat = flatten_payload(nested)
- assert torch.equal(flat["codes.audio"], torch.tensor([1.0]))
- assert flat["meta.finished"].item() is True
- assert flat["meta.left_context_size"] == 5
- assert "codes" not in flat
- assert "meta" not in flat
-
- def test_top_level_keys_preserved(self):
- nested = {
- "latent": torch.tensor([9.0]),
- "generated_len": 42,
- }
- flat = flatten_payload(nested)
- assert torch.equal(flat["latent"], torch.tensor([9.0]))
- assert flat["generated_len"] == 42
-
- def test_hidden_states_layers_expanded(self):
- nested = {
- "hidden_states": {
- "output": torch.tensor([1.0]),
- "layers": {
- 0: torch.tensor([2.0]),
- 24: torch.tensor([3.0]),
- },
- },
- }
- flat = flatten_payload(nested)
- assert torch.equal(flat["hidden_states.output"], torch.tensor([1.0]))
- assert torch.equal(flat["hidden_states.layer_0"], torch.tensor([2.0]))
- assert torch.equal(flat["hidden_states.layer_24"], torch.tensor([3.0]))
- assert "hidden_states.layers" not in flat
-
- def test_empty_payload(self):
- assert flatten_payload({}) == {}
-
- def test_mixed_nested_and_top_level(self):
- nested: OmniPayload = {
- "codes": {"audio": torch.tensor([1.0])},
- "latent": torch.tensor([2.0]),
- "meta": {"finished": torch.tensor(False, dtype=torch.bool)},
- }
- flat = flatten_payload(nested)
- assert set(flat.keys()) == {"codes.audio", "latent", "meta.finished"}
-
-
-class TestUnflattenPayload:
- def test_basic_dotted_to_nested(self):
- flat = {
- "codes.audio": torch.tensor([1.0]),
- "meta.finished": torch.tensor(True, dtype=torch.bool),
- "meta.left_context_size": 5,
- }
- nested = unflatten_payload(flat)
- assert torch.equal(nested["codes"]["audio"], torch.tensor([1.0]))
- assert nested["meta"]["finished"].item() is True
- assert nested["meta"]["left_context_size"] == 5
-
- def test_top_level_keys_preserved(self):
- flat = {"latent": torch.tensor([9.0]), "generated_len": 42}
- nested = unflatten_payload(flat)
- assert torch.equal(nested["latent"], torch.tensor([9.0]))
- assert nested["generated_len"] == 42
-
- def test_hidden_states_layers_collected(self):
- flat = {
- "hidden_states.output": torch.tensor([1.0]),
- "hidden_states.layer_0": torch.tensor([2.0]),
- "hidden_states.layer_24": torch.tensor([3.0]),
- }
- nested = unflatten_payload(flat)
- assert torch.equal(nested["hidden_states"]["output"], torch.tensor([1.0]))
- assert torch.equal(nested["hidden_states"]["layers"][0], torch.tensor([2.0]))
- assert torch.equal(nested["hidden_states"]["layers"][24], torch.tensor([3.0]))
-
- def test_empty_payload(self):
- assert unflatten_payload({}) == {}
-
-
-class TestFlattenUnflattenRoundTrip:
- def test_round_trip_simple(self):
- original: OmniPayload = {
- "codes": {"audio": torch.tensor([1.0, 2.0])},
- "meta": {"finished": torch.tensor(True, dtype=torch.bool), "left_context_size": 10},
- "ids": {"prompt": [1, 2, 3]},
- "latent": torch.tensor([5.0]),
- }
- restored = unflatten_payload(flatten_payload(original))
- assert torch.equal(restored["codes"]["audio"], original["codes"]["audio"])
- assert restored["meta"]["finished"].item() is True
- assert restored["meta"]["left_context_size"] == 10
- assert restored["ids"]["prompt"] == [1, 2, 3]
- assert torch.equal(restored["latent"], original["latent"])
-
- def test_round_trip_with_layers(self):
- original = {
- "hidden_states": {
- "output": torch.tensor([1.0]),
- "layers": {0: torch.tensor([2.0]), 24: torch.tensor([3.0])},
- },
- }
- restored = unflatten_payload(flatten_payload(original))
- assert torch.equal(restored["hidden_states"]["output"], torch.tensor([1.0]))
- assert torch.equal(restored["hidden_states"]["layers"][0], torch.tensor([2.0]))
- assert torch.equal(restored["hidden_states"]["layers"][24], torch.tensor([3.0]))
-
- def test_round_trip_all_categories(self):
- original: OmniPayload = {
- "hidden_states": {"output": torch.tensor([1.0]), "last": torch.tensor([2.0])},
- "embed": {"prefill": torch.tensor([3.0]), "tts_bos": torch.tensor([4.0])},
- "codes": {"audio": torch.tensor([5.0]), "ref": torch.tensor([6.0])},
- "ids": {"all": [1, 2], "prompt": [3, 4]},
- "meta": {"finished": torch.tensor(False, dtype=torch.bool), "ar_width": 8},
- }
- restored = unflatten_payload(flatten_payload(original))
- assert torch.equal(restored["hidden_states"]["output"], torch.tensor([1.0]))
- assert torch.equal(restored["hidden_states"]["last"], torch.tensor([2.0]))
- assert torch.equal(restored["embed"]["prefill"], torch.tensor([3.0]))
- assert torch.equal(restored["embed"]["tts_bos"], torch.tensor([4.0]))
- assert torch.equal(restored["codes"]["audio"], torch.tensor([5.0]))
- assert torch.equal(restored["codes"]["ref"], torch.tensor([6.0]))
- assert restored["ids"]["all"] == [1, 2]
- assert restored["ids"]["prompt"] == [3, 4]
- assert restored["meta"]["finished"].item() is False
- assert restored["meta"]["ar_width"] == 8
-
-
-class TestSerializeDeserializePayload:
- def test_tensor_round_trip(self):
- original: OmniPayload = {
- "hidden_states": {"output": torch.tensor([[1.0, 2.0], [3.0, 4.0]])},
- }
- wire = serialize_payload(original)
- assert isinstance(wire, AdditionalInformationPayload)
- restored = deserialize_payload(wire)
- assert torch.equal(restored["hidden_states"]["output"], original["hidden_states"]["output"])
-
- def test_list_round_trip(self):
- original: OmniPayload = {
- "ids": {"prompt": [10, 20, 30]},
- }
- wire = serialize_payload(original)
- restored = deserialize_payload(wire)
- assert restored["ids"]["prompt"] == [10, 20, 30]
-
- def test_finished_tensor_round_trip(self):
- original: OmniPayload = {
- "meta": {"finished": torch.tensor(True, dtype=torch.bool), "left_context_size": 5},
- }
- wire = serialize_payload(original)
- restored = deserialize_payload(wire)
- assert isinstance(restored["meta"]["finished"], torch.Tensor)
- assert restored["meta"]["finished"].dtype == torch.bool
- assert restored["meta"]["finished"].item() is True
- assert restored["meta"]["left_context_size"] == 5
-
- def test_mixed_types_round_trip(self):
- original: OmniPayload = {
- "hidden_states": {"output": torch.tensor([1.0, 2.0])},
- "ids": {"all": [1, 2, 3]},
- "meta": {"finished": torch.tensor(False, dtype=torch.bool), "ar_width": 4},
- "codes": {"audio": torch.tensor([3.0])},
- }
- wire = serialize_payload(original)
- restored = deserialize_payload(wire)
- assert torch.equal(restored["hidden_states"]["output"], original["hidden_states"]["output"])
- assert restored["ids"]["all"] == [1, 2, 3]
- assert restored["meta"]["finished"].item() is False
- assert restored["meta"]["ar_width"] == 4
- assert torch.equal(restored["codes"]["audio"], original["codes"]["audio"])
-
- def test_hidden_states_layers_round_trip(self):
- original = {
- "hidden_states": {
- "output": torch.tensor([1.0]),
- "layers": {0: torch.tensor([2.0]), 24: torch.tensor([3.0])},
- },
- }
- wire = serialize_payload(original)
- restored = deserialize_payload(wire)
- assert torch.equal(restored["hidden_states"]["output"], torch.tensor([1.0]))
- assert torch.equal(restored["hidden_states"]["layers"][0], torch.tensor([2.0]))
- assert torch.equal(restored["hidden_states"]["layers"][24], torch.tensor([3.0]))
-
- def test_tensor_dtype_preserved(self):
- # bfloat16 excluded: numpy() doesn't support it; callers must cast before serializing.
- for dtype in [torch.float16, torch.float32, torch.int64, torch.int32, torch.bool]:
- original: OmniPayload = {"codes": {"audio": torch.tensor([1], dtype=dtype)}}
- wire = serialize_payload(original)
- restored = deserialize_payload(wire)
- assert restored["codes"]["audio"].dtype == dtype, f"dtype mismatch for {dtype}"
-
- def test_tensor_shape_preserved(self):
- t = torch.randn(3, 4, 5)
- original: OmniPayload = {"hidden_states": {"output": t}}
- wire = serialize_payload(original)
- restored = deserialize_payload(wire)
- assert restored["hidden_states"]["output"].shape == (3, 4, 5)
- assert torch.allclose(restored["hidden_states"]["output"], t)
-
- def test_empty_payload_returns_none(self):
- assert serialize_payload({}) is None
-
- def test_none_values_skipped(self):
- original: OmniPayload = {"meta": {"finished": None}}
- wire = serialize_payload(original)
- assert wire is None
diff --git a/tests/test_diffusion_config_fields.py b/tests/test_diffusion_config_fields.py
deleted file mode 100644
index b87ceec1df6..00000000000
--- a/tests/test_diffusion_config_fields.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Ensure diffusion stage YAML configs only use valid OmniDiffusionConfig fields.
-
-Regression test for https://github.com/vllm-project/vllm-omni/issues/2563
-"""
-
-from dataclasses import fields
-from pathlib import Path
-
-import pytest
-import yaml
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-try:
- from vllm_omni.diffusion.data import OmniDiffusionConfig
-except Exception:
- OmniDiffusionConfig = None
-
-
-@pytest.mark.skipif(
- OmniDiffusionConfig is None,
- reason="OmniDiffusionConfig could not be imported (missing torch?)",
-)
-def test_diffusion_stage_configs_only_contain_valid_fields():
- """Diffusion stage engine_args must only contain OmniDiffusionConfig fields.
-
- Regression test for https://github.com/vllm-project/vllm-omni/issues/2563
- """
- # Scan both main configs and test configs
- repo_root = Path(__file__).parent.parent
- config_dirs = [
- repo_root / "vllm_omni" / "model_executor" / "stage_configs",
- ]
- # Also scan test directories recursively
- test_dir = repo_root / "tests"
-
- yaml_paths: list[Path] = []
- for config_dir in config_dirs:
- yaml_paths.extend(sorted(config_dir.glob("*.yaml")))
- yaml_paths.extend(sorted(test_dir.rglob("*.yaml")))
-
- valid_fields = {f.name for f in fields(OmniDiffusionConfig)}
- # model_stage is consumed by the stage init layer, not OmniDiffusionConfig
- valid_fields.add("model_stage")
- # model_arch is consumed by the stage init layer for diffusion model class resolution
- valid_fields.add("model_arch")
- # "quantization" is mapped to "quantization_config" by from_kwargs() backwards-compat
- valid_fields.add("quantization")
-
- invalid_entries: list[tuple[str, set[str]]] = []
- for yaml_path in yaml_paths:
- with open(yaml_path) as fh:
- config = yaml.safe_load(fh)
-
- stages = config.get("stage_args", config.get("stages", []))
- for stage in stages:
- if stage.get("stage_type") != "diffusion":
- continue
- engine_args = stage.get("engine_args", {})
- invalid = set(engine_args.keys()) - valid_fields
- if invalid:
- invalid_entries.append((yaml_path.relative_to(repo_root), invalid))
-
- assert not invalid_entries, "Diffusion stage configs contain fields not in OmniDiffusionConfig:\n" + "\n".join(
- f" {name}: {sorted(bad)}" for name, bad in invalid_entries
- )
diff --git a/tests/test_diffusion_config_propagation.py b/tests/test_diffusion_config_propagation.py
index eeb3505efe9..58eb6097cad 100644
--- a/tests/test_diffusion_config_propagation.py
+++ b/tests/test_diffusion_config_propagation.py
@@ -7,7 +7,6 @@
from collections.abc import Mapping
-import pytest
import torch
from vllm_omni.config.stage_config import StageConfigFactory
@@ -15,9 +14,6 @@
DiffusionParallelConfig,
OmniDiffusionConfig,
)
-from vllm_omni.diffusion.model_metadata import QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
def _roundtrip_diffusion_config(**kwargs) -> OmniDiffusionConfig:
@@ -110,12 +106,3 @@ def test_extra_kwargs_forwarded(self):
ea = stages[0]["engine_args"]
assert ea["enforce_eager"] is True
assert ea["lora_path"] == "/tmp/lora"
-
-
-def test_qwen_image_edit_plus_sets_generic_multimodal_limit():
- od_config = OmniDiffusionConfig(model="Qwen/Qwen-Image-Edit-2511", model_class_name="QwenImageEditPlusPipeline")
-
- od_config.update_multimodal_support()
-
- assert od_config.supports_multimodal_inputs is True
- assert od_config.max_multimodal_image_inputs == QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES
diff --git a/tests/test_fish_speech_voice_cache.py b/tests/test_fish_speech_voice_cache.py
deleted file mode 100644
index 1c299d80142..00000000000
--- a/tests/test_fish_speech_voice_cache.py
+++ /dev/null
@@ -1,227 +0,0 @@
-"""Tests for Fish Speech DAC-code caching via VoiceEmbeddingCache.
-
-Covers:
- - Cache miss → DAC encode → store
- - Cache hit → skip DAC encode, reuse cached ref_codes_fq
- - Inline ref_audio (no voice name) → no caching, full encode path
- - Stale-cache protection via created_at
- - Temp file cleanup on cache hit
-"""
-
-import os
-import tempfile
-from pathlib import Path
-
-import numpy as np
-import pytest
-import torch
-from pytest_mock import MockerFixture
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def _make_info_dict(
- *,
- text: str = "Hello world",
- ref_text: str = "Reference transcript",
- ref_audio_sr: int = 44100,
- voice_name: str | None = None,
- voice_created_at: float | None = None,
- ref_audio_path: str | None = None,
-) -> dict:
- """Build a minimal info_dict for _build_structured_voice_clone_prefill_embeds."""
- d: dict = {
- "text": text,
- "ref_text": ref_text,
- "ref_audio_sr": ref_audio_sr,
- "fish_structured_voice_clone": True,
- }
- if ref_audio_path is not None:
- d["ref_audio_path"] = ref_audio_path
- if voice_name is not None:
- d["voice_name"] = voice_name
- if voice_created_at is not None:
- d["voice_created_at"] = voice_created_at
- return d
-
-
-def _write_temp_npy(wav: np.ndarray | None = None) -> str:
- """Write a temporary .npy file with dummy audio and return its path."""
- if wav is None:
- wav = np.random.randn(44100).astype(np.float32) # 1 second @ 44.1kHz
- with tempfile.NamedTemporaryFile(prefix="fish_test_", suffix=".npy", delete=False) as f:
- np.save(f, wav)
- return f.name
-
-
-# Fake ref_codes_fq: [frames, codebooks]
-_FAKE_REF_CODES = torch.randint(0, 1024, (10, 10), dtype=torch.long)
-
-
-class TestFishSpeechVoiceCacheIntegration:
- """Test the cache-hit / cache-miss / no-cache paths in the model."""
-
- @pytest.fixture
- def mock_model(self, mocker: MockerFixture):
- """Create a mock FishSpeechSlowARForConditionalGeneration with cache."""
- from vllm_omni.utils.voice_cache import VoiceEmbeddingCache
-
- model = mocker.MagicMock()
- model._voice_cache = VoiceEmbeddingCache(max_entries=4)
- model._semantic_begin_id = 151678
- model._num_codebooks = 10
- model._codebook_size = 4096
- model.model_path = "/fake/model"
- model.codebook_embeddings = mocker.MagicMock()
- model.codebook_embeddings.weight = mocker.MagicMock()
- model.codebook_embeddings.weight.device = torch.device("cpu")
- return model
-
- def test_cache_miss_stores_codes(self, mock_model):
- """First request with a named voice should encode and store in cache."""
- cache = mock_model._voice_cache
- voice_name = "alice"
- created_at = 1712345678.0
-
- # Verify cache starts empty.
- key = cache.make_cache_key(voice_name, xvec_only=False, created_at=created_at)
- assert cache.get(key) is None
-
- # Simulate a cache store (what the model does on miss).
- cache.put(key, {"ref_codes_fq": _FAKE_REF_CODES.detach().cpu()})
-
- # Verify it's now cached.
- cached = cache.get(key)
- assert cached is not None
- assert torch.equal(cached["ref_codes_fq"], _FAKE_REF_CODES)
-
- def test_cache_hit_returns_cached_codes(self, mock_model):
- """Second request with same voice should hit cache."""
- cache = mock_model._voice_cache
- voice_name = "alice"
- created_at = 1712345678.0
-
- key = cache.make_cache_key(voice_name, xvec_only=False, created_at=created_at)
- cache.put(key, {"ref_codes_fq": _FAKE_REF_CODES.detach().cpu()})
-
- # Hit.
- cached = cache.get(key)
- assert cached is not None
- ref_codes = cached["ref_codes_fq"].to(device=torch.device("cpu"), dtype=torch.long)
- assert torch.equal(ref_codes, _FAKE_REF_CODES)
- assert cache.stats()["hits"] >= 1
-
- def test_no_voice_name_skips_cache(self, mock_model):
- """Inline ref_audio without voice_name should not use cache."""
- cache = mock_model._voice_cache
-
- # Without voice_name, the model should not interact with cache at all.
- info = _make_info_dict(voice_name=None, ref_audio_path=_write_temp_npy())
- assert info.get("voice_name") is None
- # Cache should remain untouched.
- assert cache.stats()["hits"] == 0
- assert cache.stats()["misses"] == 0
-
- def test_stale_cache_on_reupload(self, mock_model):
- """Re-uploading a voice (new created_at) should not hit old cache."""
- cache = mock_model._voice_cache
- voice_name = "alice"
-
- key_old = cache.make_cache_key(voice_name, xvec_only=False, created_at=1000.0)
- cache.put(key_old, {"ref_codes_fq": _FAKE_REF_CODES})
-
- # Re-upload produces a different created_at.
- key_new = cache.make_cache_key(voice_name, xvec_only=False, created_at=2000.0)
- assert cache.get(key_new) is None # miss
- assert cache.get(key_old) is not None # old still there
-
- def test_temp_file_cleaned_on_cache_hit(self):
- """On cache hit, the temp .npy file written by the entrypoint should be deleted."""
- tmp_path = _write_temp_npy()
- assert os.path.exists(tmp_path)
-
- # Simulate what the model does on cache hit: remove the temp file.
- try:
- os.remove(tmp_path)
- except OSError:
- pass
- assert not os.path.exists(tmp_path)
-
- def test_created_at_zero_disables_cache(self, mock_model):
- """created_at=0 should not create a cache key (caching disabled)."""
- cache = mock_model._voice_cache
-
- info = _make_info_dict(
- voice_name="bob",
- voice_created_at=0.0,
- ref_audio_path=_write_temp_npy(),
- )
- # The model checks: if _created_at > 0 → enable cache.
- # With 0.0, no cache interaction should happen.
- _created_at = float(info.get("voice_created_at", 0))
- assert _created_at <= 0
- assert cache.stats()["hits"] == 0
- assert cache.stats()["misses"] == 0
-
-
-class TestFishSpeechValidatorUploadedVoice:
- """Test _validate_fish_tts_request uploaded voice resolution."""
-
- def test_uploaded_voice_resolves_ref_audio(
- self,
- monkeypatch: pytest.MonkeyPatch,
- mocker: MockerFixture,
- ):
- """When voice matches an uploaded speaker, ref_audio should be auto-set."""
- request = mocker.MagicMock()
- request.input = "Hello"
- request.voice = "alice"
- request.ref_audio = None
- request.ref_text = None
- request.max_new_tokens = None
-
- # Uploaded speaker with ref_text.
- uploaded_speakers = {
- "alice": {
- "file_path": "/tmp/fake_audio.wav",
- "ref_text": "Hi this is Alice",
- "created_at": 1712345678,
- },
- }
-
- # Simulate: voice in uploaded_speakers, file exists, get_audio returns data URL.
- monkeypatch.setattr(Path, "exists", lambda self: True)
-
- voice_lower = request.voice.lower()
- assert voice_lower in uploaded_speakers
-
- speaker_info = uploaded_speakers[voice_lower]
- ref_text_from_upload = speaker_info.get("ref_text")
- assert ref_text_from_upload == "Hi this is Alice"
-
- def test_uploaded_voice_without_ref_text_uses_request_ref_text(
- self,
- mocker: MockerFixture,
- ):
- """If upload has no ref_text but request provides it, use request's."""
- request = mocker.MagicMock()
- request.input = "Hello"
- request.voice = "bob"
- request.ref_audio = None
- request.ref_text = "Request-level transcript"
- request.max_new_tokens = None
-
- uploaded_speakers = {
- "bob": {
- "file_path": "/tmp/fake_audio.wav",
- "ref_text": None,
- "created_at": 1712345678,
- },
- }
-
- voice_lower = request.voice.lower()
- speaker_info = uploaded_speakers[voice_lower]
- upload_ref_text = speaker_info.get("ref_text")
- # Upload has no ref_text, so request.ref_text should remain.
- assert upload_ref_text is None
- assert request.ref_text == "Request-level transcript"
diff --git a/tests/test_generate_nightly_perf_excel.py b/tests/test_generate_nightly_perf_excel.py
deleted file mode 100644
index 9b05d6de0fd..00000000000
--- a/tests/test_generate_nightly_perf_excel.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import importlib.util
-import json
-import sys
-from pathlib import Path
-
-import pytest
-from openpyxl import load_workbook
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def _load_excel_module():
- repo_root = Path(__file__).resolve().parents[1]
- module_path = repo_root / "tools" / "nightly" / "generate_nightly_perf_excel.py"
- spec = importlib.util.spec_from_file_location("generate_nightly_perf_excel", module_path)
- module = importlib.util.module_from_spec(spec)
- sys.modules[spec.name] = module
- spec.loader.exec_module(module)
- return module
-
-
-def _cell_value_by_header(ws, header_name: str, row_idx: int = 2):
- headers = [c.value for c in ws[1]]
- col_idx = headers.index(header_name) + 1
- return ws.cell(row=row_idx, column=col_idx).value
-
-
-def test_generate_excel_report_with_perf_templates(tmp_path: Path):
- module = _load_excel_module()
- repo_root = Path(__file__).resolve().parents[1]
- perf_scripts_dir = repo_root / "tests" / "dfx" / "perf" / "scripts"
-
- omni_template_path = perf_scripts_dir / "result_omni_template.json"
- diffusion_template_path = perf_scripts_dir / "diffusion_result_template.json"
-
- omni_record = json.loads(omni_template_path.read_text(encoding="utf-8"))
- diffusion_records = json.loads(diffusion_template_path.read_text(encoding="utf-8"))
-
- input_dir = tmp_path / "input"
- diffusion_input_dir = tmp_path / "diffusion_input"
- input_dir.mkdir()
- diffusion_input_dir.mkdir()
-
- # Keep file names compatible with parser conventions in generate_nightly_perf_excel.py
- omni_result_file = input_dir / "result_test_perf_random_1_4_in2500_out900_20260415-185642.json"
- diffusion_result_file = diffusion_input_dir / "diffusion_result_qwen_image_edit_20260415-193200.json"
- omni_result_file.write_text(json.dumps(omni_record, ensure_ascii=False, indent=2), encoding="utf-8")
- diffusion_result_file.write_text(json.dumps(diffusion_records, ensure_ascii=False, indent=2), encoding="utf-8")
-
- output_file = tmp_path / "nightly_perf.xlsx"
- module.generate_excel_report(
- input_dir=str(input_dir),
- diffusion_input_dir=str(diffusion_input_dir),
- output_file=str(output_file),
- commit_sha="test_commit_sha",
- build_id="test_build_id",
- build_url="https://example.com/build/123",
- )
-
- assert output_file.exists()
-
- wb = load_workbook(output_file)
- assert set(wb.sheetnames) >= {"omni_summary", "diffusion_summary", "omni_raw", "diffusion_raw"}
-
- ws_omni_raw = wb["omni_raw"]
- baseline_cell = _cell_value_by_header(ws_omni_raw, "baseline")
- assert baseline_cell == json.dumps(omni_record["baseline"], ensure_ascii=False, sort_keys=True)
-
- ws_omni_summary = wb["omni_summary"]
- assert _cell_value_by_header(ws_omni_summary, "commit_sha") == "test_commit_sha"
- assert _cell_value_by_header(ws_omni_summary, "build_id") == "test_build_id"
diff --git a/tests/test_generate_nightly_perf_html.py b/tests/test_generate_nightly_perf_html.py
deleted file mode 100644
index 4e77eb3adfd..00000000000
--- a/tests/test_generate_nightly_perf_html.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import importlib.util
-import json
-import sys
-from pathlib import Path
-
-import pytest
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-def _load_html_module():
- repo_root = Path(__file__).resolve().parents[1]
- module_path = repo_root / "tools" / "nightly" / "generate_nightly_perf_html.py"
- spec = importlib.util.spec_from_file_location("generate_nightly_perf_html", module_path)
- module = importlib.util.module_from_spec(spec)
- sys.modules[spec.name] = module
- spec.loader.exec_module(module)
- return module
-
-
-def test_generate_html_report_with_perf_templates(tmp_path: Path):
- module = _load_html_module()
- repo_root = Path(__file__).resolve().parents[1]
- perf_scripts_dir = repo_root / "tests" / "dfx" / "perf" / "scripts"
-
- omni_template_path = perf_scripts_dir / "result_omni_template.json"
- diffusion_template_path = perf_scripts_dir / "diffusion_result_template.json"
-
- omni_record = json.loads(omni_template_path.read_text(encoding="utf-8"))
- diffusion_records = json.loads(diffusion_template_path.read_text(encoding="utf-8"))
-
- input_dir = tmp_path / "input"
- diffusion_input_dir = tmp_path / "diffusion_input"
- input_dir.mkdir()
- diffusion_input_dir.mkdir()
-
- omni_result_file = input_dir / "result_test_perf_random_1_4_in2500_out900_20260415-185642.json"
- diffusion_result_file = diffusion_input_dir / "diffusion_result_qwen_image_edit_20260415-193200.json"
- omni_result_file.write_text(json.dumps(omni_record, ensure_ascii=False, indent=2), encoding="utf-8")
- diffusion_result_file.write_text(json.dumps(diffusion_records, ensure_ascii=False, indent=2), encoding="utf-8")
-
- output_file = tmp_path / "nightly_perf_v2.html"
- module.generate_html_report(
- input_dir=str(input_dir),
- diffusion_input_dir=str(diffusion_input_dir),
- output_file=str(output_file),
- )
-
- assert output_file.exists()
- html = output_file.read_text(encoding="utf-8")
- assert "Nightly Performance Report" in html
- assert "Omni records 1 " in html
- assert f"Diffusion records {len(diffusion_records)} " in html
- assert "const DIFF_DATA =" in html
diff --git a/tests/test_version.py b/tests/test_version.py
deleted file mode 100644
index 07e622a7d15..00000000000
--- a/tests/test_version.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for version compatibility warnings."""
-
-import warnings
-from unittest import mock
-
-import pytest
-
-from vllm_omni.version import warn_if_misaligned_vllm_version
-
-
-@mock.patch("vllm_omni.version.__version_tuple__", (0, 19, 0))
-@mock.patch("vllm_omni.version.__version__", "0.19.0")
-@mock.patch("vllm.__version_tuple__", (0, 18, 0))
-@mock.patch("vllm.__version__", "0.18.0")
-def test_version_mismatch_warning():
- """Ensure that we warn when vLLM and vLLM-Omni major/minor versions differ."""
- with pytest.warns(RuntimeWarning, match="mismatched major/minor versions"):
- warn_if_misaligned_vllm_version()
-
-
-@pytest.mark.parametrize(
- "vllm_ver,vllm_tuple,omni_ver,omni_tuple",
- [
- ("0.19.0", (0, 19, 0), "0.19.5", (0, 19, 5)), # Patch differs
- ("0.18.0", (0, 18, 0), "dev", (0, 0, "dev")), # Omni dev
- ("dev", (0, 0, "dev"), "0.19.0", (0, 19, 0)), # vLLM dev
- # Ensure local identifies don't matter for the warning
- ("0.19.0+foo", (0, 19, 0, "foo"), "0.19.5", (0, 19, 0)),
- ("0.19.0", (0, 19, 0), "0.19.5+bar", (0, 19, 0, "bar")),
- ("0.19.0+foo", (0, 19, 0, "foo"), "0.19.5+bar", (0, 19, 0, "bar")),
- ],
-)
-def test_no_warning_cases(vllm_ver, vllm_tuple, omni_ver, omni_tuple):
- """Ensure we don't warn when minor versions match or either is a dev build."""
- with (
- mock.patch.multiple("vllm", __version__=vllm_ver, __version_tuple__=vllm_tuple),
- mock.patch.multiple("vllm_omni.version", __version__=omni_ver, __version_tuple__=omni_tuple),
- ):
- with warnings.catch_warnings():
- warnings.simplefilter("error")
- warn_if_misaligned_vllm_version()
-
-
-@mock.patch("vllm_omni.version.__version_tuple__", (0, 19, 0))
-@mock.patch("vllm_omni.version.__version__", "0.19.0rc2.dev21")
-@mock.patch("vllm.__version_tuple__", (0, 18, 0))
-@mock.patch("vllm.__version__", "0.18.0")
-def test_warning_contains_version_strings():
- """Ensure that the warning contains the full version strings."""
- with pytest.warns(RuntimeWarning) as record:
- warn_if_misaligned_vllm_version()
-
- assert len(record) == 1
- msg = str(record[0].message)
- assert "0.19.0rc2.dev21" in msg
- assert "0.18.0" in msg
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 00000000000..84edbbf3d11
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,621 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# Some functions are copied from vllm/tests/utils.py
+import functools
+import os
+import signal
+import subprocess
+import sys
+import tempfile
+import threading
+import time
+from collections.abc import Callable
+from contextlib import ExitStack, contextmanager, suppress
+from typing import Any, Literal
+
+import cloudpickle
+import pytest
+import torch
+from typing_extensions import ParamSpec
+from vllm.platforms import current_platform
+from vllm.utils.torch_utils import cuda_device_count_stateless
+
+from vllm_omni.platforms import current_omni_platform
+
+_P = ParamSpec("_P")
+
+if current_platform.is_rocm():
+ from amdsmi import (
+ amdsmi_get_gpu_vram_usage,
+ amdsmi_get_processor_handles,
+ amdsmi_init,
+ amdsmi_shut_down,
+ )
+
+ @contextmanager
+ def _nvml():
+ try:
+ amdsmi_init()
+ yield
+ finally:
+ amdsmi_shut_down()
+elif current_platform.is_cuda():
+ from vllm.third_party.pynvml import (
+ nvmlDeviceGetHandleByIndex,
+ nvmlDeviceGetMemoryInfo,
+ nvmlInit,
+ nvmlShutdown,
+ )
+
+ @contextmanager
+ def _nvml():
+ try:
+ nvmlInit()
+ yield
+ finally:
+ nvmlShutdown()
+else:
+
+ @contextmanager
+ def _nvml():
+ yield
+
+
+def get_physical_device_indices(devices):
+ visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
+ if visible_devices is None:
+ return devices
+
+ visible_indices = [int(x) for x in visible_devices.split(",")]
+ index_mapping = {i: physical for i, physical in enumerate(visible_indices)}
+ return [index_mapping[i] for i in devices if i in index_mapping]
+
+
+@_nvml()
+def wait_for_gpu_memory_to_clear(
+ *,
+ devices: list[int],
+ threshold_bytes: int | None = None,
+ threshold_ratio: float | None = None,
+ timeout_s: float = 120,
+) -> None:
+ import gc
+
+ assert threshold_bytes is not None or threshold_ratio is not None
+ # Use nvml instead of pytorch to reduce measurement error from torch cuda
+ # context.
+ devices = get_physical_device_indices(devices)
+ start_time = time.time()
+
+ # Print waiting start information
+ device_list = ", ".join(str(d) for d in devices)
+ if threshold_bytes is not None:
+ threshold_str = f"{threshold_bytes / 2**30:.2f} GiB"
+ condition_str = f"Memory usage ≤ {threshold_str}"
+ else:
+ threshold_percent = threshold_ratio * 100
+ threshold_str = f"{threshold_percent:.1f}%"
+ condition_str = f"Memory usage ratio ≤ {threshold_str}"
+
+ print(f"[GPU Memory Monitor] Waiting for GPU {device_list} to free memory, Condition: {condition_str}")
+
+ # Define the is_free function based on threshold type
+ if threshold_bytes is not None:
+
+ def is_free(used, total):
+ return used <= threshold_bytes / 2**30
+ else:
+
+ def is_free(used, total):
+ return used / total <= threshold_ratio
+
+ while True:
+ output: dict[int, str] = {}
+ output_raw: dict[int, tuple[float, float]] = {}
+ for device in devices:
+ if current_platform.is_rocm():
+ dev_handle = amdsmi_get_processor_handles()[device]
+ mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
+ gb_used = mem_info["vram_used"] / 2**10
+ gb_total = mem_info["vram_total"] / 2**10
+ else:
+ dev_handle = nvmlDeviceGetHandleByIndex(device)
+ mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
+ gb_used = mem_info.used / 2**30
+ gb_total = mem_info.total / 2**30
+ output_raw[device] = (gb_used, gb_total)
+ # Format to more readable form
+ usage_percent = (gb_used / gb_total) * 100 if gb_total > 0 else 0
+ output[device] = f"{gb_used:.1f}GiB/{gb_total:.1f}GiB ({usage_percent:.1f}%)"
+
+ # Optimized GPU memory status print
+ print("[GPU Memory Status] Current usage:")
+ for device_id, mem_info in output.items():
+ print(f" GPU {device_id}: {mem_info}")
+
+ # Calculate waiting duration
+ dur_s = time.time() - start_time
+ elapsed_minutes = dur_s / 60
+
+ # Check if all devices meet the condition
+ if all(is_free(used, total) for used, total in output_raw.values()):
+ # Optimized completion message
+ print(f"[GPU Memory Freed] Devices {device_list} meet memory condition")
+ print(f" Condition: {condition_str}")
+ print(f" Wait time: {dur_s:.1f} seconds ({elapsed_minutes:.1f} minutes)")
+ print(" Final status:")
+ for device_id, mem_info in output.items():
+ print(f" GPU {device_id}: {mem_info}")
+ break
+
+ # Check timeout
+ if dur_s >= timeout_s:
+ raise ValueError(
+ f"[GPU Memory Timeout] Devices {device_list} still don't meet memory condition after {dur_s:.1f} seconds\n"
+ f"Condition: {condition_str}\n"
+ f"Current status:\n" + "\n".join(f" GPU {device}: {output[device]}" for device in devices)
+ )
+
+ # Add waiting hint (optional)
+ if dur_s > 10 and int(dur_s) % 10 == 0: # Show hint every 10 seconds
+ print(f"Waiting... Already waited {dur_s:.1f} seconds ({elapsed_minutes:.1f} minutes)")
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ time.sleep(5)
+
+
+def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]:
+ """Decorator to fork a new process for each test function.
+ See https://github.com/vllm-project/vllm/issues/7053 for more details.
+ """
+
+ @functools.wraps(func)
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
+ # Make the process the leader of its own process group
+ # to avoid sending SIGTERM to the parent process
+ os.setpgrp()
+ from _pytest.outcomes import Skipped
+
+ # Create a unique temporary file to store exception info from child
+ # process. Use test function name and process ID to avoid collisions.
+ with (
+ tempfile.NamedTemporaryFile(
+ delete=False, mode="w+b", prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", suffix=".exc"
+ ) as exc_file,
+ ExitStack() as delete_after,
+ ):
+ exc_file_path = exc_file.name
+ delete_after.callback(os.remove, exc_file_path)
+
+ pid = os.fork()
+ print(f"Fork a new process to run a test {pid}")
+ if pid == 0:
+ # Parent process responsible for deleting, don't delete
+ # in child.
+ delete_after.pop_all()
+ try:
+ func(*args, **kwargs)
+ except Skipped as e:
+ # convert Skipped to exit code 0
+ print(str(e))
+ os._exit(0)
+ except Exception as e:
+ import traceback
+
+ tb_string = traceback.format_exc()
+
+ # Try to serialize the exception object first
+ exc_to_serialize: dict[str, Any]
+ try:
+ # First, try to pickle the actual exception with
+ # its traceback.
+ exc_to_serialize = {"pickled_exception": e}
+ # Test if it can be pickled
+ cloudpickle.dumps(exc_to_serialize)
+ except (Exception, KeyboardInterrupt):
+ # Fall back to string-based approach.
+ exc_to_serialize = {
+ "exception_type": type(e).__name__,
+ "exception_msg": str(e),
+ "traceback": tb_string,
+ }
+ try:
+ with open(exc_file_path, "wb") as f:
+ cloudpickle.dump(exc_to_serialize, f)
+ except Exception:
+ # Fallback: just print the traceback.
+ print(tb_string)
+ os._exit(1)
+ else:
+ os._exit(0)
+ else:
+ pgid = os.getpgid(pid)
+ _pid, _exitcode = os.waitpid(pid, 0)
+ # ignore SIGTERM signal itself
+ old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
+ # kill all child processes
+ os.killpg(pgid, signal.SIGTERM)
+ # restore the signal handler
+ signal.signal(signal.SIGTERM, old_signal_handler)
+ if _exitcode != 0:
+ # Try to read the exception from the child process
+ exc_info = {}
+ if os.path.exists(exc_file_path):
+ with suppress(Exception), open(exc_file_path, "rb") as f:
+ exc_info = cloudpickle.load(f)
+
+ if (original_exception := exc_info.get("pickled_exception")) is not None:
+ # Re-raise the actual exception object if it was
+ # successfully pickled.
+ assert isinstance(original_exception, Exception)
+ raise original_exception
+
+ if (original_tb := exc_info.get("traceback")) is not None:
+ # Use string-based traceback for fallback case
+ raise AssertionError(
+ f"Test {func.__name__} failed when called with"
+ f" args {args} and kwargs {kwargs}"
+ f" (exit code: {_exitcode}):\n{original_tb}"
+ ) from None
+
+ # Fallback to the original generic error
+ raise AssertionError(
+ f"function {func.__name__} failed when called with"
+ f" args {args} and kwargs {kwargs}"
+ f" (exit code: {_exitcode})"
+ ) from None
+
+ return wrapper
+
+
+def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]:
+ """Decorator to spawn a new process for each test function."""
+
+ @functools.wraps(f)
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
+ # Check if we're already in a subprocess
+ if os.environ.get("RUNNING_IN_SUBPROCESS") == "1":
+ # If we are, just run the function directly
+ return f(*args, **kwargs)
+
+ import torch.multiprocessing as mp
+
+ with suppress(RuntimeError):
+ mp.set_start_method("spawn")
+
+ # Get the module
+ module_name = f.__module__
+
+ # Create a process with environment variable set
+ env = os.environ.copy()
+ env["RUNNING_IN_SUBPROCESS"] = "1"
+
+ with tempfile.TemporaryDirectory() as tempdir:
+ output_filepath = os.path.join(tempdir, "new_process.tmp")
+
+ # `cloudpickle` allows pickling complex functions directly
+ input_bytes = cloudpickle.dumps((f, output_filepath))
+
+ cmd = [sys.executable, "-m", f"{module_name}"]
+
+ returned = subprocess.run(cmd, input=input_bytes, capture_output=True, env=env)
+
+ # check if the subprocess is successful
+ try:
+ returned.check_returncode()
+ except Exception as e:
+ # wrap raised exception to provide more information
+ raise RuntimeError(f"Error raised in subprocess:\n{returned.stderr.decode()}") from e
+
+ return wrapper
+
+
+def create_new_process_for_each_test(
+ method: Literal["spawn", "fork"] | None = None,
+) -> Callable[[Callable[_P, None]], Callable[_P, None]]:
+ """Creates a decorator that runs each test function in a new process.
+
+ Args:
+ method: The process creation method. Can be either "spawn" or "fork".
+ If not specified, it defaults to "spawn" on ROCm and XPU
+ platforms and "fork" otherwise.
+
+ Returns:
+ A decorator to run test functions in separate processes.
+ """
+ if method is None:
+ # TODO: Spawn is not working correctly on ROCm
+ # The test content will not run and tests passed immediately.
+ # For now, using `fork` for ROCm as it can run with `fork`
+ # and tests are running correctly.
+ use_spawn = current_platform.is_xpu()
+ method = "spawn" if use_spawn else "fork"
+
+ assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'"
+
+ if method == "fork":
+ return fork_new_process_for_each_test
+
+ return spawn_new_process_for_each_test
+
+
+def cuda_marks(*, res: str, num_cards: int):
+ """
+ Get a collection of pytest marks to apply for `@cuda_test`.
+
+ Args:
+ res: Resource type, e.g., "L4" or "H100".
+ num_cards: Number of GPU cards required.
+
+ Returns:
+ List of pytest marks to apply.
+ """
+ test_platform_detail = pytest.mark.cuda
+
+ if res == "L4":
+ test_resource = pytest.mark.L4
+ elif res == "H100":
+ test_resource = pytest.mark.H100
+ else:
+ raise ValueError(f"Invalid CUDA resource type: {res}. Supported: L4, H100")
+
+ marks = [test_resource, test_platform_detail]
+
+ if num_cards == 1:
+ return marks
+ else:
+ test_distributed = pytest.mark.distributed_cuda(num_cards=num_cards)
+ test_skipif = pytest.mark.skipif_cuda(
+ cuda_device_count_stateless() < num_cards,
+ reason=f"Need at least {num_cards} CUDA GPUs to run the test.",
+ )
+ return marks + [test_distributed, test_skipif]
+
+
+def rocm_marks(*, res: str, num_cards: int):
+ """
+ Get a collection of pytest marks to apply for `@rocm_test`.
+
+ Args:
+ res: Resource type, e.g., "MI325".
+ num_cards: Number of GPU cards required.
+
+ Returns:
+ List of pytest marks to apply.
+ """
+ test_platform_detail = pytest.mark.rocm
+
+ if res == "MI325":
+ test_resource = pytest.mark.MI325
+ else:
+ raise ValueError(f"Invalid ROCm resource type: {res}. Supported: MI325")
+
+ marks = [test_resource, test_platform_detail]
+
+ if num_cards == 1:
+ return marks
+ else:
+ test_distributed = pytest.mark.distributed_rocm(num_cards=num_cards)
+ # TODO: add ROCm support for `skipif_rocm` marker
+ return marks + [test_distributed]
+
+
+def xpu_marks(*, res: str, num_cards: int):
+ """
+ Get a collection of pytest marks to apply for `@xpu_test`.
+
+ Args:
+ res: Resource type, e.g., "B60".
+ num_cards: Number of GPU cards required.
+
+ Returns:
+ List of pytest marks to apply.
+ """
+ test_platform_detail = pytest.mark.xpu
+
+ if res == "B60":
+ test_resource = pytest.mark.B60
+ else:
+ raise ValueError(f"Invalid XPU resource type: {res}. Supported: B60")
+
+ marks = [test_resource, test_platform_detail]
+
+ if num_cards == 1:
+ return marks
+ else:
+ test_distributed = pytest.mark.distributed_rocm(num_cards=num_cards)
+ # TODO: add XPU support for `skipif_xpu` marker
+ return marks + [test_distributed]
+
+
+def musa_marks(*, res: str, num_cards: int):
+ """
+ Get a collection of pytest marks to apply for `@musa_test`.
+
+ Args:
+ res: Resource type, e.g., "S5000".
+ num_cards: Number of GPU cards required.
+
+ Returns:
+ List of pytest marks to apply.
+ """
+ test_platform_detail = pytest.mark.musa
+
+ if res == "S5000":
+ test_resource = pytest.mark.S5000
+ else:
+ raise ValueError(f"Invalid MUSA resource type: {res}. Supported: S5000")
+
+ marks = [test_resource, test_platform_detail]
+
+ if num_cards == 1:
+ return marks
+ else:
+ test_distributed = pytest.mark.distributed_musa(num_cards=num_cards)
+ # TODO: add MUSA support for `skipif_musa` marker
+ return marks + [test_distributed]
+
+
+def gpu_marks(*, res: str, num_cards: int):
+ """
+ Get a collection of pytest marks to apply for `@gpu_test`.
+ Platform is automatically determined based on resource type.
+
+ Args:
+ res: Resource type, e.g., "L4", "H100" for CUDA, or "MI325" for ROCm, or "B60" for XPU, or "S5000" for MUSA.
+ num_cards: Number of GPU cards required.
+
+ Returns:
+ List of pytest marks to apply.
+ """
+ test_platform = pytest.mark.gpu
+ if res in ("L4", "H100"):
+ return [test_platform] + cuda_marks(res=res, num_cards=num_cards)
+ if res == "MI325":
+ return [test_platform] + rocm_marks(res=res, num_cards=num_cards)
+ if res == "B60":
+ return [test_platform] + xpu_marks(res=res, num_cards=num_cards)
+ if res == "S5000":
+ return [test_platform] + musa_marks(res=res, num_cards=num_cards)
+ raise ValueError(f"Invalid resource type: {res}. Supported: L4, H100, MI325, B60, S5000")
+
+
+def npu_marks(*, res: str, num_cards: int):
+ """Get a collection of pytest marks to apply for `@npu_test`."""
+ test_platform = pytest.mark.npu
+ if res == "A2":
+ test_resource = pytest.mark.A2
+ elif res == "A3":
+ test_resource = pytest.mark.A3
+ else:
+ # TODO: Currently we don't have various NPU card types defined
+ # Use None to skip resource-specific marking for unknown types
+ test_resource = None
+
+ if num_cards == 1:
+ return [mark for mark in [test_platform, test_resource] if mark is not None]
+ else:
+ # Multiple cards scenario needs distributed_npu mark
+ test_distributed = pytest.mark.distributed_npu(num_cards=num_cards)
+ # TODO: add NPU support for `skipif_npu` marker
+ return [mark for mark in [test_platform, test_resource, test_distributed] if mark is not None]
+
+
+def hardware_marks(*, res: dict[str, str], num_cards: int | dict[str, int] = 1):
+ """
+ Get a collection of pytest marks to apply for `@hardware_test`,
+ including CUDA, ROCm, XPU, NPU, and MUSA,
+ based on the specified platforms and resources.
+ """
+ # Validate platforms
+ # Don't validate platform details in this decorator
+ for platform, _ in res.items():
+ if platform not in ("cuda", "rocm", "xpu", "npu", "musa"):
+ raise ValueError(f"Unsupported platform: {platform}")
+
+ # Normalize num_cards
+ if isinstance(num_cards, int):
+ num_cards_dict = {platform: num_cards for platform in res.keys()}
+ else:
+ num_cards_dict = num_cards
+ for platform in num_cards_dict.keys():
+ if platform not in res:
+ raise ValueError(
+ f"Platform '{platform}' in num_cards but not in res. Available platforms: {list(res.keys())}"
+ )
+ for platform in res.keys():
+ if platform not in num_cards_dict:
+ num_cards_dict[platform] = 1
+
+ # Collect marks from all platforms
+ all_marks: list[pytest.MarkDecorator] = []
+ for platform, resource in res.items():
+ cards = num_cards_dict[platform]
+ if platform == "cuda" or platform == "rocm" or platform == "xpu":
+ marks = gpu_marks(res=resource, num_cards=cards)
+ elif platform == "musa":
+ marks = musa_marks(res=resource, num_cards=cards)
+ elif platform == "npu":
+ marks = npu_marks(res=resource, num_cards=cards)
+ else:
+ raise ValueError(f"Unsupported platform: {platform}")
+ all_marks.extend(marks)
+ return all_marks
+
+
+def hardware_test(*, res: dict[str, str], num_cards: int | dict[str, int] = 1):
+ """
+ Decorate a test for multiple hardware platforms with a single call.
+ Automatically wraps the test with @create_new_process_for_each_test() for distributed tests.
+
+ Args:
+ res: Mapping from platform to resource type. Supported platforms/resources:
+ - cuda: L4, H100
+ - rocm: MI325
+ - xpu: B60
+ - npu: A2, A3
+ - musa: S5000
+ num_cards: Number of cards required. Can be:
+ - int: same card count for all platforms (default: 1)
+ - dict: per-platform card count, e.g., {"cuda": 2, "rocm": 2}
+
+ Example:
+ @hardware_test(
+ res={"cuda": "L4", "rocm": "MI325", "npu": "A2", "musa": "S5000"},
+ num_cards={"cuda": 2, "rocm": 2, "npu": 2, "musa": 2},
+ )
+ def test_multi_platform():
+ ...
+ """
+ all_marks = hardware_marks(res=res, num_cards=num_cards)
+
+ def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
+ func = f
+ for mark in reversed(all_marks):
+ func = mark(func)
+ return func
+
+ return wrapper
+
+
+class DeviceMemoryMonitor:
+ """Poll global device memory usage."""
+
+ def __init__(self, device_index: int, interval: float = 0.05):
+ self.device_index = device_index
+ self.interval = interval
+ self._peak_used_mb = 0.0
+ self._stop_event = threading.Event()
+ self._thread: threading.Thread | None = None
+
+ def start(self) -> None:
+ def monitor_loop() -> None:
+ while not self._stop_event.is_set():
+ try:
+ with current_omni_platform.device(self.device_index):
+ free_bytes, total_bytes = current_omni_platform.mem_get_info()
+ used_mb = (total_bytes - free_bytes) / (1024**2)
+ self._peak_used_mb = max(self._peak_used_mb, used_mb)
+ except Exception:
+ pass
+ time.sleep(self.interval)
+
+ self._thread = threading.Thread(target=monitor_loop, daemon=False)
+ self._thread.start()
+
+ def stop(self) -> None:
+ if self._thread is None:
+ return
+ self._stop_event.set()
+ self._thread.join(timeout=2.0)
+
+ @property
+ def peak_used_mb(self) -> float:
+ fallback_alloc = current_omni_platform.max_memory_allocated(device=self.device_index) / (1024**2)
+ fallback_reserved = current_omni_platform.max_memory_reserved(device=self.device_index) / (1024**2)
+ return max(self._peak_used_mb, fallback_alloc, fallback_reserved)
+
+ def __del__(self):
+ self.stop()
diff --git a/tests/utils/test_audio.py b/tests/utils/test_audio.py
deleted file mode 100644
index 0e483e64685..00000000000
--- a/tests/utils/test_audio.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""Unit tests for vllm_omni.utils.audio."""
-
-import numpy as np
-import pytest
-import torch
-
-from vllm_omni.utils.audio import mel_filter_bank, peak_normalize
-
-# Parameter combinations used across the codebase.
-_PARAM_SETS = [
- # Qwen3-TTS talker / speaker encoder (sr=24000)
- dict(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000),
- # CosyVoice3 whisper encoder, Qwen3-TTS 25Hz tokenizer (sr=16000, 80 mels)
- dict(sr=16000, n_fft=400, n_mels=80),
- # CosyVoice3 whisper encoder (sr=16000, 128 mels)
- dict(sr=16000, n_fft=400, n_mels=128),
-]
-
-_parametrize_params = pytest.mark.parametrize(
- "params", _PARAM_SETS, ids=lambda p: f"{p['sr']}_{p['n_fft']}_{p['n_mels']}"
-)
-
-
-class TestMelFilterBank:
- @_parametrize_params
- def test_output_shape(self, params):
- fb = mel_filter_bank(**params)
- n_freqs = params["n_fft"] // 2 + 1
- assert fb.shape == (params["n_mels"], n_freqs)
-
- @_parametrize_params
- def test_non_negative(self, params):
- fb = mel_filter_bank(**params)
- assert (fb >= 0).all()
-
- def test_dtype_is_float(self):
- fb = mel_filter_bank(sr=16000, n_fft=400, n_mels=80)
- assert fb.dtype == torch.float32
-
- def test_fmax_defaults_to_nyquist(self):
- """When fmax is omitted it should equal sr / 2."""
- fb_default = mel_filter_bank(sr=16000, n_fft=400, n_mels=80)
- fb_explicit = mel_filter_bank(sr=16000, n_fft=400, n_mels=80, fmax=8000.0)
- torch.testing.assert_close(fb_default, fb_explicit)
-
- def test_each_mel_band_has_nonzero_energy(self):
- """Every mel band should have at least one nonzero frequency bin."""
- fb = mel_filter_bank(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000)
- for i in range(fb.shape[0]):
- assert fb[i].sum() > 0, f"mel band {i} is all zeros"
-
- def test_higher_fmax_extends_coverage(self):
- """A higher fmax should produce nonzero weights at higher frequency bins."""
- fb_low = mel_filter_bank(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=6000)
- fb_high = mel_filter_bank(sr=24000, n_fft=1024, n_mels=128, fmin=0, fmax=12000)
- # The highest nonzero column should be larger for fb_high.
- last_nonzero_low = (fb_low.sum(dim=0) > 0).nonzero()[-1].item()
- last_nonzero_high = (fb_high.sum(dim=0) > 0).nonzero()[-1].item()
- assert last_nonzero_high > last_nonzero_low
-
-
-class TestPeakNormalize:
- def test_silence_unchanged(self):
- """All-zero input should remain all-zero."""
- audio = np.zeros(1600, dtype=np.float32)
- result = peak_normalize(audio, db_level=-6.0)
- np.testing.assert_array_equal(result, audio)
-
- def test_peak_reaches_target(self):
- """After normalization, peak amplitude should be at target dB."""
- rng = np.random.default_rng(7)
- audio = rng.uniform(-0.4, 0.4, size=16000).astype(np.float32)
-
- result = peak_normalize(audio, db_level=-6.0)
- peak_db = 20 * np.log10(np.abs(result).max())
- np.testing.assert_allclose(peak_db, -6.0, atol=1e-4)
diff --git a/tests/worker/test_cudagraph_wrapper_perf.py b/tests/worker/test_cudagraph_wrapper_perf.py
new file mode 100644
index 00000000000..d73fe46c903
--- /dev/null
+++ b/tests/worker/test_cudagraph_wrapper_perf.py
@@ -0,0 +1,185 @@
+"""Tests for CUDAGraphWrapper.__getattr__ performance optimization.
+
+This module tests that the patched CUDAGraphWrapper avoids expensive __repr__
+calls when hasattr() is used for non-existent attributes. The original vLLM
+implementation includes {self.runnable} in the AttributeError message, which
+triggers model tree traversal and can take ~6ms on large models.
+"""
+
+import time
+
+import pytest
+import torch
+import torch.nn as nn
+
+from vllm_omni.worker.gpu_model_runner import CUDAGraphWrapper
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+class SlowReprModel(nn.Module):
+ """A mock model with artificially slow __repr__ to detect unwanted calls."""
+
+ def __init__(self, repr_delay_ms: float = 10.0):
+ super().__init__()
+ self.linear = nn.Linear(16, 16)
+ self.repr_delay_ms = repr_delay_ms
+ self.repr_call_count = 0
+
+ def forward(self, x):
+ return self.linear(x)
+
+ def __repr__(self):
+ self.repr_call_count += 1
+ # Simulate expensive repr by sleeping
+ time.sleep(self.repr_delay_ms / 1000.0)
+ return f"SlowReprModel(delay={self.repr_delay_ms}ms)"
+
+
+class MockCUDAGraphWrapper:
+ """A minimal mock that mimics CUDAGraphWrapper structure for CPU testing."""
+
+ def __init__(self, runnable):
+ # Store in __dict__ directly to avoid triggering __getattr__
+ object.__setattr__(self, "runnable", runnable)
+
+ def __getattr__(self, key: str):
+ # This is the optimized implementation we're testing
+ runnable = object.__getattribute__(self, "runnable")
+ if hasattr(runnable, key):
+ return getattr(runnable, key)
+ # Key optimization: DO NOT include {self.runnable} in error message
+ # as it triggers expensive __repr__ on large models
+ raise AttributeError(f"Attribute {key} not exists in the runnable of cudagraph wrapper")
+
+
+def test_hasattr_nonexistent_does_not_trigger_repr():
+ """Verify that hasattr for non-existent attributes doesn't call __repr__."""
+ model = SlowReprModel(repr_delay_ms=100.0) # Very slow repr
+ wrapper = MockCUDAGraphWrapper(model)
+
+ # Reset counter
+ model.repr_call_count = 0
+
+ # Call hasattr for non-existent attribute multiple times
+ for _ in range(10):
+ result = hasattr(wrapper, "nonexistent_attribute_xyz")
+ assert result is False
+
+ # __repr__ should never have been called
+ assert model.repr_call_count == 0, (
+ f"__repr__ was called {model.repr_call_count} times when checking "
+ "for non-existent attributes. This indicates the AttributeError "
+ "message contains {self.runnable} which triggers expensive repr."
+ )
+
+
+def test_hasattr_nonexistent_is_fast():
+ """Verify that hasattr for non-existent attributes is fast (<1ms per call)."""
+ model = SlowReprModel(repr_delay_ms=100.0)
+ wrapper = MockCUDAGraphWrapper(model)
+
+ num_iterations = 100
+ start = time.perf_counter()
+ for _ in range(num_iterations):
+ hasattr(wrapper, "nonexistent_attribute_xyz")
+ elapsed_ms = (time.perf_counter() - start) * 1000
+
+ avg_ms = elapsed_ms / num_iterations
+ # If __repr__ were being called, each would take ~100ms
+ # We expect <1ms per call with the fix
+ assert avg_ms < 1.0, (
+ f"hasattr for non-existent attribute took {avg_ms:.2f}ms on average. "
+ "Expected <1ms. This suggests __repr__ is being triggered."
+ )
+
+
+def test_hasattr_existing_attribute_works():
+ """Verify that hasattr for existing attributes returns True and works correctly."""
+ model = SlowReprModel()
+ wrapper = MockCUDAGraphWrapper(model)
+
+ # 'forward' exists on nn.Module
+ assert hasattr(wrapper, "forward") is True
+
+ # 'linear' exists on our model
+ assert hasattr(wrapper, "linear") is True
+
+ # Can actually access the attribute
+ linear = wrapper.linear
+ assert isinstance(linear, nn.Linear)
+
+
+def test_getattr_existing_attribute_returns_value():
+ """Verify that getattr for existing attributes returns the correct value."""
+ model = SlowReprModel()
+ wrapper = MockCUDAGraphWrapper(model)
+
+ # Access forward method
+ forward_method = wrapper.forward
+ assert callable(forward_method)
+
+ # Access linear layer
+ linear = wrapper.linear
+ assert isinstance(linear, nn.Linear)
+ assert linear.in_features == 16
+ assert linear.out_features == 16
+
+
+def test_getattr_nonexistent_raises_attribute_error():
+ """Verify that getattr for non-existent attributes raises AttributeError."""
+ model = SlowReprModel()
+ wrapper = MockCUDAGraphWrapper(model)
+
+ with pytest.raises(AttributeError) as exc_info:
+ _ = wrapper.nonexistent_attribute
+
+ # Verify error message format (should NOT contain model repr)
+ error_msg = str(exc_info.value)
+ assert "nonexistent_attribute" in error_msg
+ assert "cudagraph wrapper" in error_msg
+ # Should NOT contain the slow repr output
+ assert "SlowReprModel(delay=" not in error_msg
+
+
+def test_attribute_error_message_does_not_contain_runnable_repr():
+ """Explicitly verify the error message doesn't trigger runnable repr."""
+ model = SlowReprModel(repr_delay_ms=100.0)
+ wrapper = MockCUDAGraphWrapper(model)
+ model.repr_call_count = 0
+
+ try:
+ _ = wrapper.nonexistent_attr
+ except AttributeError:
+ pass
+
+ # __repr__ should not have been called during error construction
+ assert model.repr_call_count == 0, (
+ "AttributeError message construction triggered __repr__. The error message should not include {self.runnable}."
+ )
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
+def test_real_cudagraph_wrapper_hasattr_performance():
+ """Test the actual CUDAGraphWrapper from vllm_omni (requires CUDA)."""
+ from vllm.config import CUDAGraphMode
+
+ model = SlowReprModel(repr_delay_ms=50.0).cuda()
+ model.repr_call_count = 0
+
+ # Create actual CUDAGraphWrapper
+ try:
+ wrapper = CUDAGraphWrapper(model, runtime_mode=CUDAGraphMode.NONE)
+ except Exception:
+ pytest.skip("Could not create CUDAGraphWrapper")
+
+ # Test hasattr performance
+ num_iterations = 50
+ start = time.perf_counter()
+ for _ in range(num_iterations):
+ hasattr(wrapper, "nonexistent_xyz")
+ elapsed_ms = (time.perf_counter() - start) * 1000
+
+ avg_ms = elapsed_ms / num_iterations
+ assert avg_ms < 1.0, f"Real CUDAGraphWrapper hasattr took {avg_ms:.2f}ms avg. Expected <1ms with the optimization."
+ assert model.repr_call_count == 0, f"__repr__ called {model.repr_call_count} times"
diff --git a/tests/worker/test_omni_connector_mixin.py b/tests/worker/test_omni_connector_mixin.py
deleted file mode 100644
index 0e4539cff19..00000000000
--- a/tests/worker/test_omni_connector_mixin.py
+++ /dev/null
@@ -1,1417 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unit tests for OmniConnectorModelRunnerMixin.
-
-These tests use a mock connector (in-memory dict store) and do not require
-GPU or vLLM runtime.
-"""
-
-from __future__ import annotations
-
-import time
-import unittest
-from types import SimpleNamespace
-from typing import Any
-from unittest.mock import MagicMock, patch
-
-import pytest
-import torch
-
-from vllm_omni.outputs import OmniConnectorOutput
-from vllm_omni.worker.omni_connector_model_runner_mixin import (
- OmniConnectorModelRunnerMixin,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-# ------------------------------------------------------------------ #
-# Mock helpers
-# ------------------------------------------------------------------ #
-
-
-class MockConnector:
- """In-memory connector for testing (mimics OmniConnectorBase)."""
-
- def __init__(self, stage_id: int = 0):
- self.stage_id = stage_id
- self._store: dict[str, Any] = {}
-
- def put(self, from_stage, to_stage, put_key, data):
- key = f"{from_stage}_{to_stage}_{put_key}"
- self._store[key] = data
- return True, len(str(data)), None
-
- def get(self, from_stage, to_stage, get_key, metadata=None):
- key = f"{from_stage}_{to_stage}_{get_key}"
- data = self._store.pop(key, None)
- if data is None:
- return None
- return data, len(str(data))
-
- def close(self):
- pass
-
-
-def _make_model_config(
- stage_id: int = 0,
- async_chunk: bool = False,
- worker_type: str = "ar",
- custom_func: str | None = None,
-) -> SimpleNamespace:
- return SimpleNamespace(
- stage_connector_config=None,
- async_chunk=async_chunk,
- worker_type=worker_type,
- custom_process_next_stage_input_func=custom_func,
- )
-
-
-def _make_request(req_id: str, external_req_id: str | None = None):
- r = SimpleNamespace(
- request_id=req_id,
- external_req_id=external_req_id or req_id,
- additional_information=None,
- prompt_token_ids=[],
- num_computed_tokens=0,
- )
- return r
-
-
-class MixinHost(OmniConnectorModelRunnerMixin):
- """Minimal class that mixes in the mixin for testing."""
-
- pass
-
-
-class _FakeTPGroup:
- def __init__(self, *, world_size: int, rank_in_group: int, follower_result: Any = None):
- self.world_size = world_size
- self.rank_in_group = rank_in_group
- self.follower_result = follower_result
- self.broadcast_inputs: list[Any] = []
-
- def broadcast_object(self, obj: Any | None = None, src: int = 0):
- self.broadcast_inputs.append(obj)
- if self.rank_in_group == src:
- return obj
- return self.follower_result
-
-
-# ------------------------------------------------------------------ #
-# Test cases
-# ------------------------------------------------------------------ #
-
-
-class TestMixinAsyncChunkSendRecv(unittest.TestCase):
- """Test 2: Async chunk send/recv + bg threads."""
-
- def test_send_chunk_passes_is_finished_and_connector(self):
- connector = MockConnector(stage_id=0)
-
- sender = MixinHost()
- sender.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(stage_id=0, async_chunk=True),
- )
- sender._omni_connector = connector
- sender._stage_id = 0
- sender._async_chunk = True
-
- seen = {}
-
- def mock_process(transfer_manager, pooling_output, request, is_finished=False):
- seen["connector"] = transfer_manager.connector
- seen["is_finished"] = is_finished
- return {"data": pooling_output, "finished": is_finished}
-
- sender._custom_process_func = mock_process
-
- request = _make_request("req-1", "ext-req-1")
- request.is_finished = lambda: True
- sender._send_single_request(
- {
- "stage_id": 0,
- "next_stage_id": 1,
- "request_id": "ext-req-1",
- "request": request,
- "pooling_output": {"value": 42},
- }
- )
- self.assertIs(seen["connector"], connector)
- self.assertTrue(seen["is_finished"])
-
- sender.shutdown_omni_connectors()
-
- def test_send_chunk_does_not_retry_real_type_error(self):
- connector = MockConnector(stage_id=0)
-
- sender = MixinHost()
- sender.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(stage_id=0, async_chunk=True),
- )
- sender._omni_connector = connector
- sender._stage_id = 0
- sender._async_chunk = True
-
- seen = {"calls": 0}
-
- def broken_process(transfer_manager, pooling_output, request, is_finished=""):
- seen["calls"] += 1
- return {"data": is_finished + "tail"}
-
- sender._custom_process_func = broken_process
-
- request = _make_request("req-1", "ext-req-1")
- request.is_finished = lambda: True
- ok = sender.send_chunk(request, pooling_output={"value": 42})
- self.assertFalse(ok)
- self.assertEqual(seen["calls"], 1)
-
- sender.shutdown_omni_connectors()
-
-
-class TestMixinKVCacheTransfer(unittest.TestCase):
- """Test 3: KV cache delegation to OmniKVTransferManager."""
-
- def test_send_kv_delegates(self):
- mock_kvm = MagicMock()
- mock_kvm.handle_finished_requests_kv_transfer.return_value = ["req-1"]
-
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(),
- kv_transfer_manager=mock_kvm,
- )
-
- result = host.send_kv_cache(
- finished_reqs={"req-1": {"seq_len": 10, "block_ids": [0]}},
- kv_caches=[],
- block_size=16,
- cache_dtype="float16",
- )
- self.assertEqual(result, ["req-1"])
- mock_kvm.handle_finished_requests_kv_transfer.assert_called_once()
-
- host.shutdown_omni_connectors()
-
- def test_recv_kv_delegates(self):
- mock_kvm = MagicMock()
- mock_kvm.receive_kv_cache_for_request.return_value = ({"layer_blocks": {}}, 100)
-
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(),
- kv_transfer_manager=mock_kvm,
- )
-
- data, size = host.recv_kv_cache("req-1")
- self.assertIsNotNone(data)
- self.assertEqual(size, 100)
- mock_kvm.receive_kv_cache_for_request.assert_called_once()
-
- host.shutdown_omni_connectors()
-
- def test_receive_multi_kv_fetches_companions_via_mixin(self):
- mock_kvm = MagicMock()
-
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(),
- kv_transfer_manager=mock_kvm,
- )
-
- host.recv_kv_cache = MagicMock(
- side_effect=[({"layer_blocks": {"k": [1]}}, 64), ({"layer_blocks": {"k": [2]}}, 32)]
- )
- seen = {}
-
- def collect_cfg(request_id, cfg_role_payloads):
- seen["request_id"] = request_id
- seen["cfg_role_payloads"] = cfg_role_payloads
- return {"cfg_text_kv_metadata": {"seq_len": 3}}
-
- req = SimpleNamespace(
- request_id="req-1",
- sampling_params=SimpleNamespace(cfg_kv_request_ids={"cfg_text": "req-1__cfg_text"}),
- )
- ok = host.receive_multi_kv_cache(req, cfg_kv_collect_func=collect_cfg)
- self.assertTrue(ok)
- host.recv_kv_cache.assert_any_call("req-1", target_device=None)
- host.recv_kv_cache.assert_any_call("req-1__cfg_text", target_device=None)
- mock_kvm.apply_kv_cache_to_request.assert_called_once_with(req, {"layer_blocks": {"k": [1]}})
- self.assertEqual(seen["request_id"], "req-1")
- self.assertEqual(
- seen["cfg_role_payloads"],
- {"cfg_text": ({"layer_blocks": {"k": [2]}}, 32)},
- )
- self.assertEqual(req.sampling_params.cfg_text_kv_metadata, {"seq_len": 3})
-
- host.shutdown_omni_connectors()
-
- def test_receive_multi_kv_skips_inactive_request(self):
- mock_kvm = MagicMock()
-
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(),
- kv_transfer_manager=mock_kvm,
- )
-
- host.requests = {}
- host.recv_kv_cache = MagicMock(return_value=({"layer_blocks": {"k": [1]}}, 64))
- req = SimpleNamespace(request_id="req-1", sampling_params=None)
-
- ok = host.receive_multi_kv_cache(req)
-
- self.assertFalse(ok)
- host.recv_kv_cache.assert_not_called()
- mock_kvm.apply_kv_cache_to_request.assert_not_called()
-
- host.shutdown_omni_connectors()
-
-
-class TestOmniConnectorOutput(unittest.TestCase):
- """Test 4: Output aggregation across transfer modes."""
-
- def test_output_aggregation(self):
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(),
- )
-
- host._chunk_ready_req_ids.add("req-1")
- host._chunk_finished_req_ids.add("req-2")
- host._local_request_metadata["req-1"] = {"next_stage_prompt_len": 10}
- host._stage_recv_req_ids.add("req-3")
-
- output = host.get_omni_connector_output()
- self.assertIsInstance(output, OmniConnectorOutput)
- self.assertEqual(output.chunk_ready_req_ids, {"req-1"})
- self.assertEqual(output.chunk_finished_req_ids, {"req-2"})
- self.assertEqual(output.request_metadata, {"req-1": {"next_stage_prompt_len": 10}})
- self.assertEqual(output.stage_recv_req_ids, {"req-3"})
-
- output2 = host.get_omni_connector_output()
- self.assertEqual(output2.chunk_ready_req_ids, set())
- self.assertEqual(output2.request_metadata, {})
-
- host.shutdown_omni_connectors()
-
-
-class TestMixinNoConnector(unittest.TestCase):
- """Edge case: mixin works gracefully without a connector."""
-
- def test_no_connector(self):
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(),
- )
- self.assertIsNone(host._omni_connector)
-
- results = host.recv_full_payload_inputs(scheduler_output=None)
- self.assertIsNone(results)
-
- sent = host.send_full_payload_outputs(None, {"req-1": {}})
- self.assertEqual(sent, [])
-
- ok = host.send_chunk(_make_request("req-1"), pooling_output={})
- self.assertFalse(ok)
-
- output = host.get_omni_connector_output()
- self.assertIsInstance(output, OmniConnectorOutput)
-
- host.shutdown_omni_connectors()
-
-
-class TestFinishedLoadReqsDrain(unittest.TestCase):
- """Test A1 fix: get_omni_connector_output drains _finished_load_reqs."""
-
- def test_finished_load_reqs_flow_to_chunk_ready(self):
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(),
- )
-
- host._finished_load_reqs.add("req-1")
- host._finished_load_reqs.add("req-2")
-
- output = host.get_omni_connector_output()
- self.assertIn("req-1", output.chunk_ready_req_ids)
- self.assertIn("req-2", output.chunk_ready_req_ids)
-
- self.assertEqual(len(host._finished_load_reqs), 0)
- self.assertEqual(len(host._chunk_ready_req_ids), 0)
-
- host.shutdown_omni_connectors()
-
-
-class TestLoadCustomFuncSelection(unittest.TestCase):
- def test_skips_legacy_stage_list_processors_for_full_payload_mode(self):
- legacy_paths = [
- "vllm_omni.model_executor.stage_input_processors.mimo_audio.llm2code2wav",
- "vllm_omni.model_executor.stage_input_processors.mammoth_moda2.ar2dit",
- "vllm_omni.model_executor.stage_input_processors.cosyvoice3.text2flow",
- "vllm_omni.model_executor.stage_input_processors.glm_image.ar2diffusion",
- ]
-
- for func_path in legacy_paths:
- selected_path, func = MixinHost._load_custom_func(
- SimpleNamespace(
- async_chunk=False,
- custom_process_input_func=func_path,
- custom_process_next_stage_input_func=None,
- )
- )
- assert selected_path != func_path
- assert func is None or MixinHost._is_connector_payload_builder(func)
-
-
-class TestFullPayloadSendWithCustomFunc(unittest.TestCase):
- """Test B4: send_full_payload_outputs with full_payload_mode custom process func."""
-
- def test_full_payload_send_passes_is_finished_and_connector(self):
- seen = {}
-
- def full_payload_func(transfer_manager, pooling_output, request, is_finished=False):
- seen["connector"] = transfer_manager.connector
- seen["is_finished"] = is_finished
- seen["data"] = pooling_output
- seen["rid"] = request.request_id if request else None
- return {"processed": True, "finished": is_finished}
-
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(),
- )
- host._omni_connector = MockConnector(stage_id=0)
- host._stage_id = 0
- host._custom_process_func = full_payload_func
-
- req = _make_request("req-1")
- req.is_finished = lambda: True
- sent = host.send_full_payload_outputs(
- scheduler_output=None,
- outputs={"req-1": ({"raw": 100}, req)},
- )
- self.assertEqual(sent, ["req-1"])
- self.assertEqual(
- seen,
- {
- "connector": host._omni_connector,
- "is_finished": True,
- "data": {"raw": 100},
- "rid": "req-1",
- },
- )
-
- host.shutdown_omni_connectors()
-
- def test_accumulate_and_flush(self):
- call_log = []
-
- def full_payload_func(transfer_manager, pooling_output, request):
- call_log.append(request.request_id if request else None)
- return {"processed": True}
-
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(),
- )
- host._omni_connector = MockConnector(stage_id=0)
- host._stage_id = 0
- host._custom_process_func = full_payload_func
-
- req = _make_request("req-1")
- host.accumulate_full_payload_output("req-1", {"raw": 42}, req)
- self.assertEqual(len(host._pending_full_payload_send), 1)
-
- host.flush_full_payload_outputs({"req-1"})
- self.assertEqual(len(host._pending_full_payload_send), 0)
- self.assertEqual(len(call_log), 1)
- self.assertEqual(call_log[0], "req-1")
-
- time.sleep(0.1)
- host.shutdown_omni_connectors()
-
-
-class TestKVSentReqIdsAccumulation(unittest.TestCase):
- """Test that kv_sent_req_ids accumulates results from send_kv_cache."""
-
- def test_kv_sent_accumulation(self):
- mock_kvm = MagicMock()
- mock_kvm.handle_finished_requests_kv_transfer.return_value = ["req-1", "req-2"]
-
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(),
- kv_transfer_manager=mock_kvm,
- )
-
- host.send_kv_cache(
- finished_reqs={"req-1": {}, "req-2": {}},
- kv_caches=[],
- block_size=16,
- cache_dtype="float16",
- )
-
- output = host.get_omni_connector_output()
- self.assertIn("req-1", output.kv_sent_req_ids)
- self.assertIn("req-2", output.kv_sent_req_ids)
-
- output2 = host.get_omni_connector_output()
- self.assertEqual(output2.kv_sent_req_ids, [])
-
- host.shutdown_omni_connectors()
-
-
-class TestChunkStreamCompletedGuard(unittest.TestCase):
- """Test that register_chunk_recv is skipped after finish sentinel.
-
- This validates the fix for the race condition where the scheduling
- coordinator re-registers a request for chunk polling after its
- upstream chunk stream has already finished (is_finished sentinel
- received), causing the bg recv thread to poll for a non-existent
- shared-memory segment (e.g. ``_0_7`` when only 7 chunks 0–6 exist).
- """
-
- def _make_host(self, stage_id: int = 1) -> MixinHost:
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(stage_id=stage_id, async_chunk=True),
- )
- host._omni_connector = MockConnector(stage_id=stage_id)
- host._stage_id = stage_id
- host._async_chunk = True
- return host
-
- def test_register_blocked_after_finish_sentinel(self):
- """register_chunk_recv must be a no-op after the finish sentinel."""
- host = self._make_host(stage_id=1)
-
- req = _make_request("req-1", "ext-req-1")
-
- # Simulate the bg thread having received the finish sentinel:
- with host._lock:
- host._chunk_stream_completed.add("req-1")
-
- # Now try to re-register — this mimics the coordinator asking
- # the model runner to poll for the next (non-existent) chunk.
- host.register_chunk_recv(req)
-
- # The request must NOT appear in _pending_load_reqs
- self.assertNotIn(
- "req-1",
- host._pending_load_reqs,
- "register_chunk_recv should skip requests whose chunk stream is already complete",
- )
-
- host.shutdown_omni_connectors()
-
- def test_register_allowed_before_finish(self):
- """register_chunk_recv works normally before finish sentinel."""
- host = self._make_host(stage_id=1)
- req = _make_request("req-1", "ext-req-1")
-
- host.register_chunk_recv(req)
- self.assertIn(
- "req-1",
- host._pending_load_reqs,
- "register_chunk_recv should add request to pending when stream is not yet complete",
- )
-
- host.shutdown_omni_connectors()
-
- def test_finish_sentinel_populates_completed_set(self):
- """Receiving is_finished=True adds to _chunk_stream_completed."""
- host = self._make_host(stage_id=1)
-
- # Simulate _poll_single_request receiving is_finished=True
- req_id = "req-1"
- with host._lock:
- host._chunk_finished_req_ids.add(req_id)
- host._chunk_stream_completed.add(req_id)
- host._local_stage_payload_cache[req_id] = {"finished": True}
- host._local_request_metadata[req_id] = {}
- host._finished_load_reqs.add(req_id)
- host._pending_load_reqs.pop(req_id, None)
-
- self.assertIn(req_id, host._chunk_stream_completed)
-
- # Subsequent register_chunk_recv should be blocked
- req = _make_request(req_id, f"ext-{req_id}")
- host.register_chunk_recv(req)
- self.assertNotIn(req_id, host._pending_load_reqs)
-
- host.shutdown_omni_connectors()
-
- def test_stage_0_always_skipped(self):
- """Stage-0 has no upstream, register_chunk_recv is always no-op."""
- host = self._make_host(stage_id=0)
- host._stage_id = 0
-
- req = _make_request("req-1")
- host.register_chunk_recv(req)
- self.assertNotIn("req-1", host._pending_load_reqs)
-
- host.shutdown_omni_connectors()
-
- def test_full_payload_recv_guard_still_works(self):
- """Pre-existing guard: staged full-payload results prevent registration."""
- host = self._make_host(stage_id=1)
-
- with host._lock:
- host._stage_recv_req_ids.add("req-1")
-
- req = _make_request("req-1", "ext-req-1")
- host.register_chunk_recv(req)
- self.assertNotIn("req-1", host._pending_load_reqs)
-
- host.shutdown_omni_connectors()
-
-
-class TestCleanupFinishedRequest(unittest.TestCase):
- """Test cleanup_finished_request frees per-request mixin state."""
-
- def _make_host(self, stage_id: int = 1) -> MixinHost:
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(stage_id=stage_id, async_chunk=True),
- )
- host._omni_connector = MockConnector(stage_id=stage_id)
- host._stage_id = stage_id
- host._async_chunk = True
- return host
-
- def test_cleanup_removes_all_state(self):
- """cleanup_finished_request removes all tracking dicts/sets."""
- host = self._make_host(stage_id=1)
- req_id = "req-1"
- ext_id = "ext-req-1"
-
- # Simulate state accumulated during a request's lifetime
- host._request_ids_mapping[req_id] = ext_id
- host._put_req_chunk[ext_id] = 5
- host._get_req_chunk[req_id] = 3
- host._send_side_request_payload[ext_id] = {"some": "data"}
- host._code_prompt_token_ids[ext_id] = [[1, 2, 3]]
- host._chunk_stream_completed.add(req_id)
- host._stage_recv_req_ids.add(req_id)
- host._local_stage_payload_cache[req_id] = {"engine_inputs": {}}
- host._local_request_metadata[req_id] = {"prompt_len": 10}
-
- # Cleanup
- host.cleanup_finished_request(req_id)
-
- # All state should be gone
- self.assertNotIn(req_id, host._request_ids_mapping)
- self.assertNotIn(ext_id, host._put_req_chunk)
- self.assertNotIn(req_id, host._get_req_chunk)
- self.assertNotIn(ext_id, host._send_side_request_payload)
- self.assertNotIn(ext_id, host._code_prompt_token_ids)
- self.assertNotIn(req_id, host._chunk_stream_completed)
- self.assertNotIn(req_id, host._stage_recv_req_ids)
- self.assertNotIn(req_id, host._local_stage_payload_cache)
- self.assertNotIn(req_id, host._local_request_metadata)
-
- host.shutdown_omni_connectors()
-
- def test_cleanup_removes_per_cycle_ready_state(self):
- """cleanup_finished_request clears ready/finished carry-over for req-id reuse."""
- host = self._make_host(stage_id=1)
- req_id = "req-1"
-
- host._pending_load_reqs[req_id] = _make_request(req_id, "ext-req-1")
- host._finished_load_reqs.add(req_id)
- host._chunk_ready_req_ids.add(req_id)
- host._chunk_finished_req_ids.add(req_id)
-
- host.cleanup_finished_request(req_id)
-
- self.assertNotIn(req_id, host._pending_load_reqs)
- self.assertNotIn(req_id, host._finished_load_reqs)
- self.assertNotIn(req_id, host._chunk_ready_req_ids)
- self.assertNotIn(req_id, host._chunk_finished_req_ids)
-
- host.shutdown_omni_connectors()
-
- def test_cleanup_without_mapping(self):
- """cleanup works for Stage-0 where _request_ids_mapping isn't set."""
- host = self._make_host(stage_id=0)
- host._stage_id = 0
- req_id = "req-1"
-
- # Stage-0 uses req_id directly (no ext_id mapping)
- host._put_req_chunk[req_id] = 3
- host._get_req_chunk[req_id] = 0
-
- host.cleanup_finished_request(req_id)
-
- self.assertNotIn(req_id, host._put_req_chunk)
- self.assertNotIn(req_id, host._get_req_chunk)
-
- host.shutdown_omni_connectors()
-
- def test_prune_inactive_requests_cleans_stale_state_but_keeps_active(self):
- """Inactive request IDs should be pruned without touching active ones."""
- host = self._make_host(stage_id=1)
- active_req_id = "req-active"
- stale_req_id = "req-stale"
- stale_ext_id = "ext-stale"
-
- host._request_ids_mapping[active_req_id] = "ext-active"
- host._request_ids_mapping[stale_req_id] = stale_ext_id
- host._put_req_chunk[stale_ext_id] = 2
- host._get_req_chunk[stale_req_id] = 1
- host._finished_load_reqs.add(stale_req_id)
- host._chunk_ready_req_ids.update({active_req_id, stale_req_id})
- host._chunk_finished_req_ids.add(stale_req_id)
- host._chunk_stream_completed.add(stale_req_id)
- host._stage_recv_req_ids.add(active_req_id)
- host._send_side_request_payload[stale_ext_id] = {"stale": True}
- host._code_prompt_token_ids[stale_ext_id] = [[1, 2, 3]]
-
- pruned = host.prune_inactive_requests({active_req_id})
-
- self.assertEqual(pruned, {stale_req_id})
- self.assertIn(active_req_id, host._request_ids_mapping)
- self.assertIn(active_req_id, host._chunk_ready_req_ids)
- self.assertIn(active_req_id, host._stage_recv_req_ids)
- self.assertNotIn(stale_req_id, host._request_ids_mapping)
- self.assertNotIn(stale_ext_id, host._put_req_chunk)
- self.assertNotIn(stale_req_id, host._get_req_chunk)
- self.assertNotIn(stale_req_id, host._pending_load_reqs)
- self.assertNotIn(stale_req_id, host._finished_load_reqs)
- self.assertNotIn(stale_req_id, host._chunk_ready_req_ids)
- self.assertNotIn(stale_req_id, host._chunk_finished_req_ids)
- self.assertNotIn(stale_req_id, host._chunk_stream_completed)
- self.assertNotIn(stale_req_id, host._stage_recv_req_ids)
- self.assertNotIn(stale_ext_id, host._send_side_request_payload)
- self.assertNotIn(stale_ext_id, host._code_prompt_token_ids)
-
- host.shutdown_omni_connectors()
-
- def test_prune_inactive_requests_keeps_recently_received_full_payload_state(self):
- """Late bg-thread receives must survive until the scheduler catches up."""
- host = self._make_host(stage_id=1)
- req_id = "req-recv-race"
- ext_id = "ext-recv-race"
-
- host._request_ids_mapping[req_id] = ext_id
- host._put_req_chunk[ext_id] = 1
- host._local_stage_payload_cache[req_id] = {"engine_inputs": {"ids": [1, 2, 3]}}
- host._local_request_metadata[req_id] = {"next_stage_prompt_len": 3}
- host._stage_recv_req_ids.add(req_id)
-
- pruned = host.prune_inactive_requests(set())
-
- self.assertEqual(pruned, set())
- self.assertIn(req_id, host._request_ids_mapping)
- self.assertIn(req_id, host._local_stage_payload_cache)
- self.assertIn(req_id, host._local_request_metadata)
- self.assertIn(req_id, host._stage_recv_req_ids)
- self.assertIn(ext_id, host._put_req_chunk)
-
- # Once the scheduler has consumed the wake-up and the request really
- # disappears from all protected sets, prune should clean it up.
- host._stage_recv_req_ids.clear()
- host._local_stage_payload_cache.clear()
- host._local_request_metadata.clear()
-
- pruned = host.prune_inactive_requests(set())
-
- self.assertEqual(pruned, {req_id})
- self.assertNotIn(req_id, host._request_ids_mapping)
- self.assertNotIn(ext_id, host._put_req_chunk)
-
- host.shutdown_omni_connectors()
-
-
-class TestSendChunkCachesMapping(unittest.TestCase):
- """Test that send_chunk caches internal→external req ID mapping."""
-
- def test_send_chunk_populates_request_ids_mapping(self):
- """send_chunk should cache the internal→external mapping."""
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(stage_id=0, async_chunk=True),
- )
- host._omni_connector = MockConnector(stage_id=0)
- host._stage_id = 0
- host._async_chunk = True
-
- def mock_process(transfer_manager, pooling_output, request):
- return {"data": "test", "finished": False}
-
- host._custom_process_func = mock_process
-
- request = _make_request("internal-1", "external-1")
- host.send_chunk(request, pooling_output={"v": 1})
-
- # The mapping should be cached
- self.assertEqual(
- host._request_ids_mapping.get("internal-1"),
- "external-1",
- )
-
- time.sleep(0.1)
- host.shutdown_omni_connectors()
-
-
-class TestLocalPayloadCacheLifecycle(unittest.TestCase):
- """Unit tests for the local payload cache API (RFC §2.4)."""
-
- def _make_host(self) -> MixinHost:
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(stage_id=0),
- )
- host._omni_connector = MockConnector(stage_id=0)
- host._stage_id = 0
- return host
-
- def test_put_get_pop(self):
- host = self._make_host()
- payload = {"engine_inputs": {"ids": [1, 2, 3]}}
- host.put_local_stage_payload("r1", payload)
-
- self.assertEqual(host.get_local_stage_payload("r1"), payload)
- popped = host.pop_local_stage_payload("r1")
- self.assertEqual(popped, payload)
- self.assertIsNone(host.get_local_stage_payload("r1"))
- host.shutdown_omni_connectors()
-
- def test_recv_full_payload_inputs_populates_local_cache(self):
- host = self._make_host()
- host._omni_connector = MockConnector(stage_id=0)
- host._stage_id = 0
-
- # Simulate a full payload already staged by the bg recv path
- with host._lock:
- host._local_stage_payload_cache["r1"] = {"tok": [10]}
- host._stage_recv_req_ids.add("r1")
-
- host.recv_full_payload_inputs(scheduler_output=None)
- self.assertEqual(host.get_local_stage_payload("r1"), {"tok": [10]})
- host.shutdown_omni_connectors()
-
- def test_rank0_only_polls_connector_for_tp_full_payload(self):
- host = self._make_host()
- host._omni_connector = MagicMock()
- host._stage_id = 2
- host._local_rank = 0
- host._request_ids_mapping["r1"] = "ext-r1"
- host._get_req_chunk["r1"] = 0
- payload = {"tok": [10], "finished": torch.tensor(True)}
- connector_result = (payload, 123)
- host._omni_connector.get.return_value = connector_result
- tp_group = _FakeTPGroup(world_size=2, rank_in_group=0)
-
- with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
- made_progress = host._poll_single_request("r1")
-
- self.assertTrue(made_progress)
- host._omni_connector.get.assert_called_once_with("1", "2", "ext-r1_1_0")
- self.assertEqual(tp_group.broadcast_inputs, [])
- self.assertEqual(host.get_local_stage_payload("r1"), payload)
- self.assertIn("r1", host._full_payload_pending_broadcast_req_ids)
- self.assertNotIn("r1", host._stage_recv_req_ids)
- self.assertIsNone(host.get_local_request_metadata("r1"))
- host.shutdown_omni_connectors()
-
- def test_tp_follower_skips_connector_poll_for_full_payload(self):
- host = self._make_host()
- host._omni_connector = MagicMock()
- host._stage_id = 2
- host._local_rank = 1
- host._request_ids_mapping["r1"] = "ext-r1"
- host._get_req_chunk["r1"] = 0
- tp_group = _FakeTPGroup(world_size=2, rank_in_group=1)
-
- with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
- made_progress = host._poll_single_request("r1")
-
- self.assertFalse(made_progress)
- host._omni_connector.get.assert_not_called()
- self.assertEqual(tp_group.broadcast_inputs, [])
- self.assertNotIn("r1", host._local_stage_payload_cache)
- host.shutdown_omni_connectors()
-
- def test_recv_full_payload_inputs_broadcasts_tp_leader_results_to_followers(self):
- host = self._make_host()
- host._omni_connector = MagicMock()
- host._stage_id = 2
- host._local_rank = 1
- host._pending_load_reqs["r1"] = object()
- payload = {"tok": [10], "finished": torch.tensor(True)}
- tp_group = _FakeTPGroup(world_size=2, rank_in_group=1, follower_result={"r1": payload})
-
- with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
- results = host.recv_full_payload_inputs(scheduler_output=None)
-
- self.assertEqual(results, {"r1": payload})
- self.assertEqual(host.get_local_stage_payload("r1"), payload)
- self.assertEqual(host.get_local_request_metadata("r1"), {})
- self.assertEqual(host._stage_recv_req_ids, {"r1"})
- self.assertNotIn("r1", host._pending_load_reqs)
- self.assertEqual(tp_group.broadcast_inputs, [None])
- host.shutdown_omni_connectors()
-
-
-class TestTPAsyncChunkFanout(unittest.TestCase):
- def _make_host(self, rank: int) -> MixinHost:
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(stage_id=2, async_chunk=True, worker_type="gen"),
- )
- host._omni_connector = MagicMock()
- host._stage_id = 2
- host._async_chunk = True
- host._model_mode = "gen"
- host._local_rank = rank
- host._request_ids_mapping["r1"] = "ext-r1"
- host._get_req_chunk["r1"] = 0
- return host
-
- def test_rank0_only_polls_connector_for_tp_async_chunk(self):
- host = self._make_host(rank=0)
- payload = {
- "codes": {"audio": [10, 11]},
- "meta": {"left_context_size": 0, "finished": torch.tensor(False)},
- }
- host._omni_connector.get.return_value = (payload, 123)
- tp_group = _FakeTPGroup(world_size=2, rank_in_group=0)
-
- with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
- made_progress = host._poll_single_request("r1")
-
- self.assertTrue(made_progress)
- host._omni_connector.get.assert_called_once_with("1", "2", "ext-r1_1_0")
- self.assertEqual(host.get_local_stage_payload("r1"), payload)
- self.assertIn("r1", host._finished_load_reqs)
- self.assertIn("r1", host._async_chunk_updated_req_ids)
- self.assertEqual(tp_group.broadcast_inputs, [])
- host.shutdown_omni_connectors()
-
- def test_tp_follower_skips_connector_poll_for_async_chunk(self):
- host = self._make_host(rank=1)
- tp_group = _FakeTPGroup(world_size=2, rank_in_group=1)
-
- with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
- made_progress = host._poll_single_request("r1")
-
- self.assertFalse(made_progress)
- host._omni_connector.get.assert_not_called()
- self.assertIsNone(host.get_local_stage_payload("r1"))
- self.assertEqual(tp_group.broadcast_inputs, [])
- host.shutdown_omni_connectors()
-
- def test_get_output_broadcasts_tp_async_chunk_payloads_to_followers(self):
- host = self._make_host(rank=1)
- host._pending_load_reqs["r1"] = object()
- payload = {
- "code_predictor_codes": [10, 11],
- "left_context_size": 0,
- "finished": torch.tensor(True),
- }
- packet = {
- "staged_payloads": {"r1": payload},
- "request_metadata": {"r1": {"code_predictor_codes": [10, 11], "left_context_size": 0}},
- "newly_finished": {"r1"},
- "chunk_finished": {"r1"},
- }
- tp_group = _FakeTPGroup(world_size=2, rank_in_group=1, follower_result=packet)
-
- with patch("vllm_omni.worker.omni_connector_model_runner_mixin.get_tp_group", return_value=tp_group):
- output = host.get_omni_connector_output()
-
- self.assertEqual(output.chunk_ready_req_ids, {"r1"})
- self.assertEqual(output.chunk_finished_req_ids, {"r1"})
- self.assertEqual(
- output.request_metadata,
- {"r1": {"code_predictor_codes": [10, 11], "left_context_size": 0}},
- )
- self.assertEqual(host.get_local_stage_payload("r1"), payload)
- self.assertNotIn("r1", host._pending_load_reqs)
- self.assertIn("r1", host._chunk_stream_completed)
- self.assertEqual(tp_group.broadcast_inputs, [None])
- host.shutdown_omni_connectors()
-
-
-class TestKVTransferLifecycle(unittest.TestCase):
- """Unit tests for KV transfer lifecycle methods."""
-
- def _make_host(self) -> MixinHost:
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(stage_id=0),
- )
- return host
-
- def test_mark_drain_ack_complete(self):
- host = self._make_host()
- self.assertFalse(host.has_pending_kv_work())
-
- host.mark_kv_transfer("r1", seq_len=100, block_ids=[0, 1, 2])
- self.assertTrue(host.has_pending_kv_work())
- self.assertTrue(host.is_kv_transfer_triggered("r1"))
-
- # Drain moves pending → active
- pending = host.drain_pending_kv_transfers()
- self.assertEqual(pending, {"r1": {"seq_len": 100, "block_ids": [0, 1, 2]}})
- self.assertIn("r1", host._kv_active_transfers)
- self.assertTrue(host.has_pending_kv_work())
-
- # Ack moves active → completed
- host.ack_kv_transfers(["r1"])
- self.assertNotIn("r1", host._kv_active_transfers)
- self.assertIn("r1", host._kv_completed_transfers)
-
- # Drain completed
- completed = host.drain_completed_kv_transfers()
- self.assertEqual(completed, {"r1"})
- self.assertFalse(host.has_pending_kv_work())
- host.shutdown_omni_connectors()
-
- def test_mark_dedup(self):
- host = self._make_host()
- host.mark_kv_transfer("r1", seq_len=100, block_ids=[0])
- host.mark_kv_transfer("r1", seq_len=200, block_ids=[0, 1])
- # Second mark is a no-op
- self.assertEqual(host._kv_pending_transfers["r1"]["seq_len"], 100)
- host.shutdown_omni_connectors()
-
- def test_cleanup_removes_kv_state(self):
- host = self._make_host()
- host.mark_kv_transfer("r1", seq_len=50, block_ids=[0])
- host.drain_pending_kv_transfers()
- host.cleanup_finished_request("r1")
- self.assertFalse(host.is_kv_transfer_triggered("r1"))
- self.assertNotIn("r1", host._kv_active_transfers)
- self.assertFalse(host.has_pending_kv_work())
- host.shutdown_omni_connectors()
-
-
-class TestAsyncPayloadLifecycle(unittest.TestCase):
- """Regression tests for async payload delivery lifecycle."""
-
- def test_send_side_request_payload_not_cleared_before_payload_is_consumable(self):
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"),
- )
- host._request_ids_mapping["r1"] = "r1"
- payload = {
- "thinker_decode_embeddings": torch.ones(1, 2),
- "thinker_output_token_ids": [1],
- "override_keys": ["thinker_decode_embeddings", "thinker_output_token_ids"],
- "finished": torch.tensor(False),
- }
-
- host._accumulate_payload("r1", dict(payload))
- with host._lock:
- host._finished_load_reqs.add("r1")
-
- host.get_omni_connector_output()
- self.assertIn("r1", host._send_side_request_payload)
- host.shutdown_omni_connectors()
-
- def test_payload_consumable_ignores_token_horizon_only_updates(self):
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"),
- )
- payload = {
- "thinker_output_token_ids": [1, 2, 3],
- "finished": torch.tensor(False),
- "override_keys": [
- "thinker_output_token_ids",
- "thinker_decode_embeddings_token_start",
- "thinker_decode_embeddings_token_end",
- ],
- "thinker_decode_embeddings_token_start": 2,
- "thinker_decode_embeddings_token_end": 3,
- }
- self.assertFalse(host._payload_is_consumable(payload))
- host.shutdown_omni_connectors()
-
- def test_payload_consumable_accepts_decode_embeddings(self):
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"),
- )
- payload = {
- "thinker_output_token_ids": [1, 2, 3],
- "thinker_decode_embeddings": torch.ones(1, 2),
- "finished": torch.tensor(False),
- }
- self.assertTrue(host._payload_is_consumable(payload))
- host.shutdown_omni_connectors()
-
- def test_ar_metadata_only_followup_chunk_does_not_rewake_request(self):
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(stage_id=1, async_chunk=True, worker_type="ar"),
- )
- host._omni_connector = MagicMock()
- host._stage_id = 1
- host._async_chunk = True
- host._model_mode = "ar"
- host._request_ids_mapping["r1"] = "ext-r1"
- host._get_req_chunk["r1"] = 0
-
- host._omni_connector.get.side_effect = [
- (
- {
- "thinker_decode_embeddings": torch.ones(1, 2),
- "finished": torch.tensor(False),
- },
- 1,
- ),
- (
- {
- "next_stage_prompt_len": 7,
- "finished": torch.tensor(False),
- },
- 1,
- ),
- ]
-
- host._poll_single_request("r1")
- output1 = host.get_omni_connector_output()
- self.assertEqual(output1.chunk_ready_req_ids, {"r1"})
-
- host._poll_single_request("r1")
- output2 = host.get_omni_connector_output()
- self.assertEqual(output2.chunk_ready_req_ids, set())
- self.assertEqual(output2.request_metadata, {"r1": {"next_stage_prompt_len": 7}})
-
- host.shutdown_omni_connectors()
-
- def test_non_ar_recv_does_not_overwrite_unconsumed_staged_chunk(self):
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(stage_id=2, async_chunk=True, worker_type="gen"),
- )
- host._omni_connector = MagicMock()
- host._stage_id = 2
- host._async_chunk = True
- host._model_mode = "gen"
- host._request_ids_mapping["r1"] = "ext-r1"
- host._get_req_chunk["r1"] = 1
- host._local_stage_payload_cache["r1"] = {
- "code_predictor_codes": [1, 2, 3],
- "left_context_size": 0,
- "finished": torch.tensor(False),
- }
-
- made_progress = host._poll_single_request("r1")
-
- self.assertFalse(made_progress)
- host._omni_connector.get.assert_not_called()
- self.assertEqual(host._get_req_chunk["r1"], 1)
-
- host.shutdown_omni_connectors()
-
- def test_non_ar_recv_waits_for_scheduler_handoff_before_fetching_next_chunk(self):
- host = MixinHost()
- host.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(stage_id=2, async_chunk=True, worker_type="gen"),
- )
- host._omni_connector = MagicMock()
- host._stage_id = 2
- host._async_chunk = True
- host._model_mode = "gen"
- host._request_ids_mapping["r1"] = "ext-r1"
- host._get_req_chunk["r1"] = 1
- host._local_request_metadata["r1"] = {
- "code_predictor_codes": [10, 11, 12],
- "left_context_size": 0,
- }
- host._finished_load_reqs.add("r1")
-
- made_progress = host._poll_single_request("r1")
-
- self.assertFalse(made_progress)
- host._omni_connector.get.assert_not_called()
- self.assertEqual(host._get_req_chunk["r1"], 1)
-
- output = host.get_omni_connector_output()
- self.assertEqual(output.request_metadata["r1"]["code_predictor_codes"], [10, 11, 12])
- self.assertEqual(output.chunk_ready_req_ids, {"r1"})
-
- host._omni_connector.get.return_value = (
- {
- "codes": {"audio": [20, 21, 22]},
- "meta": {"left_context_size": 0, "finished": torch.tensor(False)},
- },
- 1,
- )
- made_progress = host._poll_single_request("r1")
-
- self.assertTrue(made_progress)
- host._omni_connector.get.assert_called_once()
- self.assertEqual(host._get_req_chunk["r1"], 2)
-
- host.shutdown_omni_connectors()
-
-
-class TestRankAwareKVRouting(unittest.TestCase):
- def _make_host(self, *, from_tp: int, to_tp: int, local_rank: int) -> MixinHost:
- host = MixinHost()
- host.init_omni_connectors(vllm_config=None, model_config=_make_model_config(stage_id=1))
- host._from_tp = from_tp
- host._to_tp = to_tp
- host._local_rank = local_rank
- return host
-
- def test_recv_keys_use_remote_rank_as_from_rank(self):
- host = self._make_host(from_tp=4, to_tp=2, local_rank=1)
- self.assertEqual(
- host.get_rank_aware_kv_keys("req", from_stage=0),
- ["req_0_0_2_1", "req_0_0_3_1"],
- )
- host.shutdown_omni_connectors()
-
- def test_send_keys_route_from_rank_gt_to_rank(self):
- host = self._make_host(from_tp=4, to_tp=2, local_rank=3)
- self.assertEqual(host.get_rank_aware_kv_send_keys("req", from_stage=0), ["req_0_0_3_1"])
- host.shutdown_omni_connectors()
-
- def test_invalid_recv_rank_mapping_raises(self):
- host = self._make_host(from_tp=3, to_tp=2, local_rank=1)
- with self.assertRaises(ValueError):
- host.get_rank_aware_kv_keys("req", from_stage=0)
- host.shutdown_omni_connectors()
-
- def test_invalid_send_rank_mapping_raises(self):
- host = self._make_host(from_tp=3, to_tp=2, local_rank=1)
- with self.assertRaises(ValueError):
- host.get_rank_aware_kv_send_keys("req", from_stage=0)
- host.shutdown_omni_connectors()
-
- def test_merge_rank_sharded_payloads_concatenates_head_dimension(self):
- host = self._make_host(from_tp=4, to_tp=2, local_rank=0)
- payloads = [
- {"layer_blocks": {"key_cache": [torch.ones(2, 1, 3)], "value_cache": [torch.ones(2, 1, 3)]}},
- {"layer_blocks": {"key_cache": [torch.full((2, 1, 3), 2.0)], "value_cache": [torch.full((2, 1, 3), 2.0)]}},
- ]
- merged = host._merge_rank_sharded_kv_payloads(payloads)
- self.assertEqual(tuple(merged["layer_blocks"]["key_cache"][0].shape), (2, 2, 3))
- self.assertTrue(torch.equal(merged["layer_blocks"]["key_cache"][0][:, 0], torch.ones(2, 3)))
- self.assertTrue(torch.equal(merged["layer_blocks"]["key_cache"][0][:, 1], torch.full((2, 3), 2.0)))
- host.shutdown_omni_connectors()
-
- def test_slice_rank_sharded_payload_splits_head_dimension(self):
- host = self._make_host(from_tp=2, to_tp=4, local_rank=1)
- payload = {
- "layer_blocks": {
- "key_cache": [torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)],
- "value_cache": [torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)],
- },
- "metadata": {},
- }
- sliced = host._slice_rank_sharded_kv_payload(payload)
- self.assertEqual(tuple(sliced["layer_blocks"]["key_cache"][0].shape), (2, 2, 3))
- expected = torch.arange(24, dtype=torch.float32).reshape(2, 4, 3)[:, 2:4, :]
- self.assertTrue(torch.equal(sliced["layer_blocks"]["key_cache"][0], expected))
- host.shutdown_omni_connectors()
-
-
-class TestAttachOmniConnectorOutput(unittest.TestCase):
- def test_wraps_empty_model_runner_output_when_signals_exist(self):
- from vllm.v1.worker.gpu_model_runner import EMPTY_MODEL_RUNNER_OUTPUT
-
- host = MixinHost()
- host.get_omni_connector_output = lambda: OmniConnectorOutput(chunk_ready_req_ids={"req-1"})
-
- wrapped = host.attach_omni_connector_output(EMPTY_MODEL_RUNNER_OUTPUT)
-
- self.assertIsNot(wrapped, EMPTY_MODEL_RUNNER_OUTPUT)
- self.assertEqual(wrapped.omni_connector_output.chunk_ready_req_ids, {"req-1"})
-
-
-class TestConnectorConfigValidation(unittest.TestCase):
- def test_invalid_connector_name_raises(self):
- host = MixinHost()
- model_config = _make_model_config(stage_id=1)
- model_config.stage_connector_config = {"name": " "}
-
- with self.assertRaisesRegex(RuntimeError, "missing connector name"):
- host.init_omni_connectors(vllm_config=None, model_config=model_config)
-
-
-class _FailingConnector:
- """Connector whose put() fails a configurable number of times."""
-
- def __init__(self, fail_count: int = 1, raise_on_fail: bool = False):
- self._fail_count = fail_count
- self._raise_on_fail = raise_on_fail
- self.attempt = 0
-
- def put(self, from_stage, to_stage, put_key, data):
- self.attempt += 1
- if self.attempt <= self._fail_count:
- if self._raise_on_fail:
- raise ConnectionError("transient connector error")
- return False, 0, None
- return True, len(str(data)), None
-
- def get(self, *a, **kw):
- return None
-
- def close(self):
- pass
-
-
-class TestSendRetry(unittest.TestCase):
- """Tests for P1-2: failed connector sends must be retried."""
-
- def _make_sender(self, connector):
- sender = MixinHost()
- sender.init_omni_connectors(
- vllm_config=None,
- model_config=_make_model_config(stage_id=0, async_chunk=True),
- )
- sender._omni_connector = connector
- sender._stage_id = 0
- sender._async_chunk = True
- return sender
-
- def _make_task(self, req_id="r1"):
- return {
- "stage_id": 0,
- "next_stage_id": 1,
- "request_id": req_id,
- "data": {"payload": "test"},
- }
-
- def test_send_single_request_returns_false_on_put_failure(self):
- connector = _FailingConnector(fail_count=999)
- sender = self._make_sender(connector)
-
- result = sender._send_single_request(self._make_task())
- self.assertFalse(result)
- sender.shutdown_omni_connectors()
-
- def test_send_single_request_does_not_decrement_on_failure(self):
- connector = _FailingConnector(fail_count=999)
- sender = self._make_sender(connector)
- sender._pending_save_counts["r1"] = 1
-
- sender._send_single_request(self._make_task())
- self.assertEqual(sender._pending_save_counts.get("r1"), 1, "pending count must NOT be decremented on failure")
- sender.shutdown_omni_connectors()
-
- def test_send_single_request_decrements_on_success(self):
- connector = MockConnector(stage_id=0)
- sender = self._make_sender(connector)
- sender._pending_save_counts["r1"] = 1
-
- result = sender._send_single_request(self._make_task())
- self.assertTrue(result)
- self.assertNotIn("r1", sender._pending_save_counts, "pending count should be zero/removed on success")
- sender.shutdown_omni_connectors()
-
- def test_requeue_or_drop_requeues_on_first_failure(self):
- sender = self._make_sender(MockConnector(stage_id=0))
- task = self._make_task()
-
- sender._requeue_or_drop_failed_send(task)
-
- self.assertEqual(task.get("_retry_count"), 1)
- with sender._lock:
- dq = sender._pending_save_reqs.get("r1")
- self.assertIsNotNone(dq)
- self.assertEqual(len(dq), 1)
- sender.shutdown_omni_connectors()
-
- def test_requeue_or_drop_drops_after_max_retries(self):
- sender = self._make_sender(MockConnector(stage_id=0))
- sender._pending_save_counts["r1"] = 1
- task = self._make_task()
- task["_retry_count"] = sender._MAX_SEND_RETRIES # already at max
-
- sender._requeue_or_drop_failed_send(task)
-
- with sender._lock:
- dq = sender._pending_save_reqs.get("r1")
- self.assertTrue(dq is None or len(dq) == 0, "task should NOT be re-enqueued after max retries")
- self.assertNotIn("r1", sender._pending_save_counts, "pending count should be cleaned up on final drop")
- sender.shutdown_omni_connectors()
-
- def test_save_loop_retries_on_exception(self):
- """Integration: _save_loop retries a task when put() raises."""
- from collections import deque
-
- connector = _FailingConnector(fail_count=1, raise_on_fail=True)
- sender = self._make_sender(connector)
- task = self._make_task()
-
- with sender._lock:
- sender._pending_save_reqs["r1"] = deque([task])
- sender._pending_save_counts["r1"] = 1
-
- sender._stop_event.clear()
-
- def run_one_loop():
- sender._save_loop()
-
- sender._stop_event.set() # will exit after one iteration
- # Run manually instead of threading
- # Simulate: pop task, send fails, requeue
- popped_task = None
- with sender._lock:
- dq = sender._pending_save_reqs.get("r1")
- if dq:
- popped_task = dq.popleft()
- if not dq:
- del sender._pending_save_reqs["r1"]
-
- if popped_task is not None:
- success = False
- try:
- success = sender._send_single_request(popped_task)
- except Exception:
- pass
- if not success:
- sender._requeue_or_drop_failed_send(popped_task)
-
- # After first failure, task should be re-enqueued
- with sender._lock:
- dq = sender._pending_save_reqs.get("r1")
- self.assertIsNotNone(dq)
- self.assertEqual(len(dq), 1)
- requeued = dq[0]
- self.assertEqual(requeued.get("_retry_count"), 1)
-
- # Second attempt should succeed (connector now returns True)
- success = sender._send_single_request(requeued)
- self.assertTrue(success)
- sender.shutdown_omni_connectors()
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py
index 9f908a7adc5..b2d61931558 100644
--- a/tests/worker/test_omni_gpu_model_runner.py
+++ b/tests/worker/test_omni_gpu_model_runner.py
@@ -41,17 +41,7 @@ def __init__(self):
class DummyTalkerMTP(torch.nn.Module):
"""A fake talker_mtp module for deterministic CPU testing."""
- def forward(
- self,
- req_input_ids,
- req_embeds,
- last_talker_hidden,
- text_step,
- do_sample=None,
- temperature=None,
- top_k=None,
- top_p=None,
- ):
+ def forward(self, req_input_ids, req_embeds, last_talker_hidden, text_step):
# Deterministic behavior:
# - output embeds = input embeds + 1
# - output codes = [[0], [1], ...]
@@ -61,36 +51,6 @@ def forward(
return new_embeds, codes
-class CaptureTalkerMTP(torch.nn.Module):
- """A fake talker_mtp module that records sampling kwargs."""
-
- def __init__(self):
- super().__init__()
- self.calls = []
-
- def forward(
- self,
- req_input_ids,
- req_embeds,
- last_talker_hidden,
- text_step,
- do_sample=None,
- temperature=None,
- top_k=None,
- top_p=None,
- ):
- self.calls.append(
- {
- "do_sample": do_sample,
- "temperature": temperature,
- "top_k": top_k,
- "top_p": top_p,
- }
- )
- codes = torch.zeros((req_embeds.shape[0], 1), dtype=torch.int64)
- return req_embeds, codes
-
-
@contextmanager
def _noop_forward_context(*args, **kwargs):
"""A no-op context manager to replace vLLM forward context in CPU tests."""
@@ -119,8 +79,8 @@ def _make_runner(req_ids=("r1", "r2"), hidden_size=4):
runner.text_step = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32))
runner.talker_mtp = DummyTalkerMTP()
- runner.model = SimpleNamespace(talker_mtp_output_key=("codes", "audio"))
- runner.vllm_config = SimpleNamespace(model_config=SimpleNamespace())
+ runner.model = SimpleNamespace(talker_mtp_output_key="code_predictor_codes")
+ runner.vllm_config = object()
# Provide a minimal implementation that returns the expected 4-tuple.
def _determine_batch_execution_and_padding(**kwargs):
@@ -188,8 +148,8 @@ def fake_determine(self, num_tokens, num_reqs, num_scheduled_tokens_np, max_num_
# Validate per-request additional_information_cpu was updated
info_r1 = runner.requests["r1"].additional_information_cpu
info_r2 = runner.requests["r2"].additional_information_cpu
- assert int(info_r1["codes"]["audio"][0, 0]) == 0
- assert int(info_r2["codes"]["audio"][0, 0]) == 1
+ assert int(info_r1["code_predictor_codes"][0, 0]) == 0
+ assert int(info_r2["code_predictor_codes"][0, 0]) == 1
def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch):
@@ -208,43 +168,6 @@ def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch):
assert torch.allclose(inputs_embeds, before)
-def test_talker_mtp_forward_passes_qwen3_tts_subtalker_sampling_params_to_talker(monkeypatch):
- import vllm_omni.worker.gpu_model_runner as mod
-
- monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context)
-
- runner = _make_runner(req_ids=("r1",), hidden_size=4)
- runner.talker_mtp = CaptureTalkerMTP()
- runner.vllm_config = SimpleNamespace(
- model_config=SimpleNamespace(
- subtalker_sampling_params={
- "do_sample": False,
- "temperature": 0.2,
- "top_k": 9,
- "top_p": 0.55,
- }
- )
- )
-
- def fake_determine(self, num_tokens, num_reqs, num_scheduled_tokens_np, max_num_scheduled_tokens, use_cascade_attn):
- batch_desc = SimpleNamespace(num_tokens=int(num_tokens))
- return (False, batch_desc, None, None, None)
-
- monkeypatch.setattr(runner, "_determine_batch_execution_and_padding", fake_determine.__get__(runner, type(runner)))
-
- inputs_embeds = torch.zeros((2, 4), dtype=torch.float32)
- OmniGPUModelRunner._talker_mtp_forward(runner, ["r1"], inputs_embeds)
-
- assert runner.talker_mtp.calls == [
- {
- "do_sample": False,
- "temperature": 0.2,
- "top_k": 9,
- "top_p": 0.55,
- }
- ]
-
-
def test_update_intermediate_buffer_writes_to_buffer_and_setattr(monkeypatch):
"""Validate that _update_intermediate_buffer writes to model_intermediate_buffer
(forward path) and mirrors to additional_information_cpu setattr (backward compat)."""
diff --git a/tools/nightly/generate_nightly_perf_excel.py b/tools/nightly/generate_nightly_perf_excel.py
index 6ba1d1eef0e..817f37f664e 100644
--- a/tools/nightly/generate_nightly_perf_excel.py
+++ b/tools/nightly/generate_nightly_perf_excel.py
@@ -23,22 +23,6 @@
GREY_BLOCK_FILL = PatternFill(start_color="D3D3D3", fill_type="solid")
# Diffusion sheet columns (Qwen-Image diffusion benchmark).
-# Per-stage latency metrics. Unpack from stage_durations_mean/p50/p99 dicts
-DIFFUSION_STAGE_LATENCY_COLUMNS: tuple[str, ...] = (
- # "vae.encode_mean",
- # "vae.encode_p50",
- # "vae.encode_p99",
- "vae.decode_mean",
- "vae.decode_p50",
- "vae.decode_p99",
- "diffuse_mean",
- "diffuse_p50",
- "diffuse_p99",
- "text_encoder.forward_mean",
- "text_encoder.forward_p50",
- "text_encoder.forward_p99",
-)
-
DIFFUSION_BENCHMARK_COLUMNS: tuple[str, ...] = (
"duration",
"completed_requests",
@@ -52,7 +36,7 @@
"peak_memory_mb_mean",
"peak_memory_mb_median",
"slo_attainment_rate",
-) + DIFFUSION_STAGE_LATENCY_COLUMNS
+)
DIFFUSION_NUMERIC_FORMAT_COLUMNS: tuple[str, ...] = DIFFUSION_BENCHMARK_COLUMNS
@@ -79,7 +63,7 @@
"build_id",
"build_url",
"source_file",
-) + DIFFUSION_STAGE_LATENCY_COLUMNS
+)
# Benchmark metric columns: grey the latest row's cell when value changed vs previous date.
BENCHMARK_COLUMNS: tuple[str, ...] = (
@@ -122,7 +106,7 @@
_COLUMNS_FILENAME = "nightly_perf_summary_columns.txt"
_RESULT_JSON_PREFIX = "result_test_"
-_DIFFUSION_RESULT_PREFIX = "diffusion_result_"
+_DIFFUSION_JSON_PREFIX = "diffusion_perf_"
DEFAULT_INPUT_DIR = os.getenv("DEFAULT_INPUT_DIR") if os.getenv("DEFAULT_INPUT_DIR") else "tests"
DEFAULT_OUTPUT_DIR = os.getenv("DEFAULT_OUTPUT_DIR") if os.getenv("DEFAULT_OUTPUT_DIR") else "tests"
DEFAULT_DIFFUSION_INPUT_DIR = os.getenv("DIFFUSION_BENCHMARK_DIR")
@@ -268,7 +252,7 @@ def parse_args() -> argparse.Namespace:
type=str,
default=None,
help=(
- "Directory containing diffusion_result_*.json files; default is "
+ "Directory containing diffusion_perf_*.json files; default is "
"DIFFUSION_BENCHMARK_DIR, fallback to --input-dir."
),
)
@@ -302,7 +286,7 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args()
-def _load_json_file(path: str) -> dict[str, Any] | list[Any] | None:
+def _load_json_file(path: str) -> dict[str, Any] | None:
"""Safely load a single JSON file; return None and log a warning on failure."""
try:
with open(path, encoding="utf-8") as f:
@@ -311,18 +295,18 @@ def _load_json_file(path: str) -> dict[str, Any] | list[Any] | None:
LOGGER.warning("failed to load json '%s': %s", path, exc)
return None
- if not isinstance(data, (dict, list)):
- LOGGER.warning("json root in '%s' is not a dict or list, skip", path)
+ if not isinstance(data, dict):
+ LOGGER.warning("json root in '%s' is not an object, skip", path)
return None
return data
def _parse_from_filename(filename: str) -> dict[str, Any]:
- """Parse test-related metadata from a ``result_test_*.json`` filename.
+ """Parse test-related metadata from a result JSON filename.
- Matches ``tests/dfx/perf/scripts/run_benchmark.py`` naming, including optional
- ``_in{X}_out{Y}_`` before the timestamp (``na`` when unset).
+ Expected pattern (after prefix/suffix stripped):
+ ____
"""
name, ext = os.path.splitext(filename)
if ext != ".json" or not name.startswith(_RESULT_JSON_PREFIX):
@@ -331,42 +315,22 @@ def _parse_from_filename(filename: str) -> dict[str, Any]:
core = name[len(_RESULT_JSON_PREFIX) :]
parts = core.split("_")
if len(parts) < 5:
- LOGGER.warning(
- "filename '%s' does not match expected pattern (need >= 5 segments), skip parsing",
- filename,
- )
+ LOGGER.warning("filename '%s' does not match expected pattern, skip parsing test metadata", filename)
return {}
- idx = len(parts) - 1
- timestamp = parts[idx]
- idx -= 1
+ timestamp = parts[-1]
+ num_prompts_str = parts[-2]
+ max_concurrency_str = parts[-3]
+ dataset_name = parts[-4]
+ test_name = "_".join(parts[:-4]) if parts[:-4] else ""
parsed: dict[str, Any] = {}
+
if len(timestamp) >= 15:
parsed["date"] = timestamp
- if idx >= 0 and parts[idx].startswith("out"):
- parsed["random_output_len"] = parts[idx][3:]
- idx -= 1
- if idx >= 0 and parts[idx].startswith("in"):
- parsed["random_input_len"] = parts[idx][2:]
- idx -= 1
-
- if idx < 3:
- LOGGER.warning(
- "filename '%s' has too few segments after timestamp / optional in-out (idx=%s)",
- filename,
- idx,
- )
- return parsed
-
- num_prompts_str = parts[idx]
- idx -= 1
- flow_str = parts[idx]
- idx -= 1
- dataset_name = parts[idx]
- idx -= 1
- test_name = "_".join(parts[: idx + 1]) if idx >= 0 else ""
+ if dataset_name in DATASET_NAME_ALLOWED:
+ parsed["dataset_name"] = dataset_name
try:
parsed["num_prompts"] = int(num_prompts_str)
@@ -374,16 +338,13 @@ def _parse_from_filename(filename: str) -> dict[str, Any]:
pass
try:
- parsed["max_concurrency"] = int(flow_str)
+ parsed["max_concurrency"] = int(max_concurrency_str)
except (TypeError, ValueError):
pass
if test_name:
parsed["test_name"] = test_name
- if dataset_name in DATASET_NAME_ALLOWED:
- parsed["dataset_name"] = dataset_name
-
return parsed
@@ -435,29 +396,27 @@ def _iter_omni_json_records(input_dir: str) -> Iterable[dict[str, Any]]:
yield record
-def _parse_diffusion_result_from_filename(filename: str) -> dict[str, Any]:
- """Parse test_name/date from filename: diffusion_result__.json"""
+def _parse_diffusion_from_filename(filename: str) -> dict[str, Any]:
+ """Parse diffusion test_name/date from filename: diffusion_perf__.json"""
name, ext = os.path.splitext(filename)
- if ext != ".json" or not name.startswith(_DIFFUSION_RESULT_PREFIX):
+ if ext != ".json" or not name.startswith(_DIFFUSION_JSON_PREFIX):
return {}
- core = name[len(_DIFFUSION_RESULT_PREFIX) :]
+ core = name[len(_DIFFUSION_JSON_PREFIX) :]
parts = core.split("_")
if len(parts) < 2:
return {}
timestamp = parts[-1]
+ test_name = "_".join(parts[:-1]) if parts[:-1] else ""
parsed: dict[str, Any] = {}
if len(timestamp) >= 15:
parsed["date"] = timestamp
+ if test_name:
+ parsed["test_name"] = test_name
return parsed
-def _iter_diffusion_records(input_dir: str) -> Iterable[dict[str, Any]]:
- """Iterate over diffusion_result_*.json files and yield normalized records.
-
- Unlike omni format where each JSON file contains one test case, diffusion format
- produces a single JSON file containing a list of all test case records.
- Test params (feature toggles) are NOT embedded in the filename.
- """
+def _iter_diffusion_json_records(input_dir: str) -> Iterable[dict[str, Any]]:
+ """Iterate over diffusion_perf_*.json files and yield normalized diffusion records."""
if not os.path.isdir(input_dir):
LOGGER.warning("diffusion input dir '%s' does not exist or is not a directory", input_dir)
return
@@ -465,7 +424,7 @@ def _iter_diffusion_records(input_dir: str) -> Iterable[dict[str, Any]]:
for entry in sorted(os.listdir(input_dir)):
if not entry.endswith(".json"):
continue
- if not entry.startswith(_DIFFUSION_RESULT_PREFIX):
+ if not entry.startswith(_DIFFUSION_JSON_PREFIX):
continue
full_path = os.path.join(input_dir, entry)
if not os.path.isfile(full_path):
@@ -475,63 +434,23 @@ def _iter_diffusion_records(input_dir: str) -> Iterable[dict[str, Any]]:
if data is None:
continue
- filename_meta = _parse_diffusion_result_from_filename(os.path.basename(full_path))
-
- if not isinstance(data, list):
- LOGGER.warning("diffusion result file '%s' root is not a list, skip", full_path)
- continue
-
- for record in data:
- if not isinstance(record, dict):
- continue
- record = dict(record)
- if "date" not in record or not record.get("date"):
- record["date"] = filename_meta.get("date") or datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
- record["source_file"] = os.path.basename(full_path)
- yield record
+ record: dict[str, Any] = dict(data)
+ filename_meta = _parse_diffusion_from_filename(os.path.basename(full_path))
+ if "date" not in record or not record.get("date"):
+ record["date"] = filename_meta.get("date") or datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
+ if "test_name" not in record or not record.get("test_name"):
+ if "test_name" in filename_meta:
+ record["test_name"] = filename_meta["test_name"]
+ record["source_file"] = os.path.basename(full_path)
+ yield record
-def _collect_omni_records(input_dir: str) -> list[dict[str, Any]]:
+def _collect_records(input_dir: str) -> list[dict[str, Any]]:
return list(_iter_omni_json_records(input_dir))
def _collect_diffusion_records(diffusion_input_dir: str) -> list[dict[str, Any]]:
- """Collect diffusion records from diffusion_result_*.json files.
- Their format is different from omni JSON files.
- """
- return [_process_diffusion_record(r) for r in _iter_diffusion_records(diffusion_input_dir)]
-
-
-def _flatten_stage_durations(record: dict[str, Any]) -> dict[str, Any]:
- """Flatten stage_durations dict into individual columns matching DIFFUSION_STAGE_LATENCY_COLUMNS."""
- result = dict(record)
-
- for prefix in ("stage_durations_mean", "stage_durations_p50", "stage_durations_p99"):
- durations = result.pop(prefix, None)
- if not isinstance(durations, dict):
- continue
-
- suffix = prefix.replace("stage_durations_", "") # "mean", "p50", "p99"
-
- for stage_key, value in durations.items(): # e.g., "SomePipeline.vae.decode_mean": 100.0
- stage_key = stage_key.split(".", 1)[-1] # "decode_mean"
- col_name = f"{stage_key}_{suffix}"
- if col_name not in DIFFUSION_STAGE_LATENCY_COLUMNS:
- print(f"skipping stage_key: {col_name}")
- continue
- result[col_name] = value
-
- return result
-
-
-def _process_diffusion_record(record: dict[str, Any]) -> dict[str, Any]:
- """Normalize a diffusion record by merging `result` and flattening stage metrics."""
- flat = record.copy()
- flat.update(flat.pop("result", {}))
- flat = _flatten_stage_durations(flat)
- flat.pop("benchmark_params", None)
- flat.pop("server_params", None)
- return flat
+ return list(_iter_diffusion_json_records(diffusion_input_dir))
def _apply_build_metadata_to_latest_only(
@@ -574,7 +493,7 @@ def _apply_build_metadata_to_latest_only(
def _sort_records_for_summary(records: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Sort so that same test configuration is grouped, newest date first within each group."""
- by_date_desc = sorted(records, key=lambda r: r.get("date") or "", reverse=True)
+ by_date_desc = sorted(records, key=lambda r: (r.get("date") or ""), reverse=True)
return sorted(
by_date_desc,
key=_omni_group_key,
@@ -582,7 +501,7 @@ def _sort_records_for_summary(records: list[dict[str, Any]]) -> list[dict[str, A
def _sort_diffusion_records_for_summary(records: list[dict[str, Any]]) -> list[dict[str, Any]]:
- by_date_desc = sorted(records, key=lambda r: r.get("date") or "", reverse=True)
+ by_date_desc = sorted(records, key=lambda r: (r.get("date") or ""), reverse=True)
return sorted(by_date_desc, key=_diffusion_group_key)
@@ -665,21 +584,6 @@ def _to_float_if_numeric(value: Any) -> Any:
return value
-def _to_excel_compatible(value: Any) -> Any:
- """Convert non-scalar objects to JSON string for openpyxl cell values."""
- if isinstance(value, (dict, list, tuple)):
- try:
- return json.dumps(value, ensure_ascii=False, sort_keys=True)
- except (TypeError, ValueError):
- return str(value)
- if isinstance(value, set):
- try:
- return json.dumps(sorted(value), ensure_ascii=False)
- except (TypeError, ValueError):
- return str(value)
- return value
-
-
def _write_sheet(
ws,
columns: Sequence[str],
@@ -695,7 +599,6 @@ def _write_sheet(
v = record.get(col)
if col in numeric_set:
v = _to_float_if_numeric(v)
- v = _to_excel_compatible(v)
row_values.append(v)
ws.append(row_values)
@@ -775,7 +678,7 @@ def generate_excel_report(
script_dir = os.path.dirname(os.path.abspath(__file__))
omni_summary_columns = _ensure_omni_summary_columns(_load_summary_columns(script_dir))
- omni_records = _collect_omni_records(input_dir)
+ omni_records = _collect_records(input_dir)
diffusion_records = _collect_diffusion_records(diffusion_input_dir)
if not omni_records:
diff --git a/tools/nightly/generate_nightly_perf_html.py b/tools/nightly/generate_nightly_perf_html.py
index a50a462550a..05dc48d717c 100644
--- a/tools/nightly/generate_nightly_perf_html.py
+++ b/tools/nightly/generate_nightly_perf_html.py
@@ -17,7 +17,7 @@
LOGGER = logging.getLogger(__name__)
_RESULT_JSON_PREFIX = "result_test_"
-_DIFFUSION_JSON_PREFIXES = ("diffusion_perf_", "diffusion_result_")
+_DIFFUSION_JSON_PREFIX = "diffusion_perf_"
DEFAULT_INPUT_DIR = os.getenv("DEFAULT_INPUT_DIR") or "tests"
DEFAULT_OUTPUT_DIR = os.getenv("DEFAULT_OUTPUT_DIR") or "tests"
DEFAULT_DIFFUSION_INPUT_DIR = os.getenv("DIFFUSION_BENCHMARK_DIR")
@@ -51,7 +51,7 @@ def _default_diffusion_input_dir(input_dir: str) -> str:
return DEFAULT_DIFFUSION_INPUT_DIR if DEFAULT_DIFFUSION_INPUT_DIR else input_dir
-def _load_json_file(path: str) -> dict[str, Any] | list[Any] | None:
+def _load_json_file(path: str) -> dict[str, Any] | None:
try:
with open(path, encoding="utf-8") as f:
data = json.load(f)
@@ -59,15 +59,14 @@ def _load_json_file(path: str) -> dict[str, Any] | list[Any] | None:
LOGGER.warning("failed to load json '%s': %s", path, exc)
return None
- if not isinstance(data, (dict, list)):
- LOGGER.warning("json root in '%s' is not an object or list, skip", path)
+ if not isinstance(data, dict):
+ LOGGER.warning("json root in '%s' is not an object, skip", path)
return None
return data
def _parse_from_filename(filename: str) -> dict[str, Any]:
- """Parse ``result_test_*.json`` filenames; same rules as ``generate_nightly_perf_excel``."""
name, ext = os.path.splitext(filename)
if ext != ".json" or not name.startswith(_RESULT_JSON_PREFIX):
return {}
@@ -76,58 +75,32 @@ def _parse_from_filename(filename: str) -> dict[str, Any]:
parts = core.split("_")
if len(parts) < 5:
LOGGER.warning(
- "filename '%s' does not match expected pattern (need >= 5 segments), skip parsing",
+ "filename '%s' does not match expected pattern, skip parsing test metadata",
filename,
)
return {}
- idx = len(parts) - 1
- timestamp = parts[idx]
- idx -= 1
+ timestamp = parts[-1]
+ num_prompts_str = parts[-2]
+ max_concurrency_str = parts[-3]
+ dataset_name = parts[-4]
+ test_name = "_".join(parts[:-4]) if parts[:-4] else ""
parsed: dict[str, Any] = {}
if len(timestamp) >= 15:
parsed["date"] = timestamp
-
- if idx >= 0 and parts[idx].startswith("out"):
- parsed["random_output_len"] = parts[idx][3:]
- idx -= 1
- if idx >= 0 and parts[idx].startswith("in"):
- parsed["random_input_len"] = parts[idx][2:]
- idx -= 1
-
- if idx < 3:
- LOGGER.warning(
- "filename '%s' has too few segments after timestamp / optional in-out (idx=%s)",
- filename,
- idx,
- )
- return parsed
-
- num_prompts_str = parts[idx]
- idx -= 1
- flow_str = parts[idx]
- idx -= 1
- dataset_name = parts[idx]
- idx -= 1
- test_name = "_".join(parts[: idx + 1]) if idx >= 0 else ""
-
+ if dataset_name in ("random", "random-mm"):
+ parsed["dataset_name"] = dataset_name
try:
parsed["num_prompts"] = int(num_prompts_str)
except (TypeError, ValueError):
pass
-
try:
- parsed["max_concurrency"] = int(flow_str)
+ parsed["max_concurrency"] = int(max_concurrency_str)
except (TypeError, ValueError):
pass
-
if test_name:
parsed["test_name"] = test_name
-
- if dataset_name in ("random", "random-mm"):
- parsed["dataset_name"] = dataset_name
-
return parsed
@@ -170,10 +143,9 @@ def _iter_omni_json_records(input_dir: str) -> Iterable[dict[str, Any]]:
def _parse_diffusion_from_filename(filename: str) -> dict[str, Any]:
name, ext = os.path.splitext(filename)
- if ext != ".json" or not any(name.startswith(prefix) for prefix in _DIFFUSION_JSON_PREFIXES):
+ if ext != ".json" or not name.startswith(_DIFFUSION_JSON_PREFIX):
return {}
- matched_prefix = next(prefix for prefix in _DIFFUSION_JSON_PREFIXES if name.startswith(prefix))
- core = name[len(matched_prefix) :]
+ core = name[len(_DIFFUSION_JSON_PREFIX) :]
parts = core.split("_")
if len(parts) < 2:
return {}
@@ -196,7 +168,7 @@ def _iter_diffusion_json_records(input_dir: str) -> Iterable[dict[str, Any]]:
return
for entry in sorted(os.listdir(input_dir)):
- if not entry.endswith(".json") or not any(entry.startswith(prefix) for prefix in _DIFFUSION_JSON_PREFIXES):
+ if not entry.endswith(".json") or not entry.startswith(_DIFFUSION_JSON_PREFIX):
continue
full_path = os.path.join(input_dir, entry)
if not os.path.isfile(full_path):
@@ -205,32 +177,17 @@ def _iter_diffusion_json_records(input_dir: str) -> Iterable[dict[str, Any]]:
if data is None:
continue
+ record: dict[str, Any] = dict(data)
filename_meta = _parse_diffusion_from_filename(os.path.basename(full_path))
- if isinstance(data, dict):
- records = [data]
- elif isinstance(data, list):
- records = [r for r in data if isinstance(r, dict)]
- else:
- records = []
-
- if not records:
- LOGGER.warning("diffusion json '%s' has no valid records, skip", full_path)
- continue
-
- for record in records:
- flat: dict[str, Any] = dict(record)
- result = flat.pop("result", None)
- if isinstance(result, dict):
- flat.update(result)
- if "date" not in flat or not flat.get("date"):
- flat["date"] = filename_meta.get("date") or datetime.now(
- timezone.utc,
- ).strftime("%Y%m%d-%H%M%S")
- if "test_name" not in flat or not flat.get("test_name"):
- if "test_name" in filename_meta:
- flat["test_name"] = filename_meta["test_name"]
- flat["source_file"] = os.path.basename(full_path)
- yield flat
+ if "date" not in record or not record.get("date"):
+ record["date"] = filename_meta.get("date") or datetime.now(
+ timezone.utc,
+ ).strftime("%Y%m%d-%H%M%S")
+ if "test_name" not in record or not record.get("test_name"):
+ if "test_name" in filename_meta:
+ record["test_name"] = filename_meta["test_name"]
+ record["source_file"] = os.path.basename(full_path)
+ yield record
def _collect_omni_records(input_dir: str) -> list[dict[str, Any]]:
diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py
index 819a7c8c3dd..1c08a1543d2 100644
--- a/tools/pre_commit/check_pickle_imports.py
+++ b/tools/pre_commit/check_pickle_imports.py
@@ -16,7 +16,8 @@
# alternatives like msgpack or pydantic that are already in use in vLLM. Only
# add to this list if absolutely necessary and after careful security review.
ALLOWED_FILES = {
- "tests/helpers/process.py",
+ "tests/e2e/offline_inference/utils.py",
+ "tests/utils.py",
"vllm_omni/diffusion/distributed/group_coordinator.py",
"tests/diffusion/attention/test_attention_sp.py",
}
diff --git a/tools/wan22/assemble_wan22_i2v_diffusers.py b/tools/wan22/assemble_wan22_i2v_diffusers.py
deleted file mode 100644
index 8e14ca3c26d..00000000000
--- a/tools/wan22/assemble_wan22_i2v_diffusers.py
+++ /dev/null
@@ -1,385 +0,0 @@
-#!/usr/bin/env python3
-"""
-Assemble a Wan2.2-I2V-A14B-Diffusers-style model directory using a Diffusers
-skeleton and optional replacement transformer checkpoints.
-
-This tool does NOT run any external conversion step. You can use it in two
-ways:
-- keep the original weights from the Diffusers skeleton
-- replace transformer/transformer_2 with converted checkpoints such as
- LightX2V outputs
-- use legacy LightX2V arg names (--high-noise-weight/--low-noise-weight),
- which are accepted as aliases
-
-Typical use:
- python tools/wan22/assemble_wan22_i2v_diffusers.py \
- --diffusers-skeleton /path/to/Wan2.2-I2V-A14B-Diffusers \
- --transformer-weight /path/to/high_noise_out/diffusion_pytorch_model.safetensors \
- --transformer-2-weight /path/to/low_noise_out/diffusion_pytorch_model.safetensors \
- --output-dir /path/to/Wan2.2-I2V-A14B-Custom-Diffusers
-"""
-
-from __future__ import annotations
-
-import argparse
-import json
-import shutil
-import sys
-from dataclasses import dataclass
-from pathlib import Path
-
-WEIGHT_CANDIDATES = (
- "diffusion_pytorch_model.safetensors",
- "diffusion_pytorch_model.bin",
- "diffusion_pytorch_model.pt",
- "model.safetensors",
- "pytorch_model.bin",
- "model.pt",
-)
-WEIGHT_INDEX_CANDIDATES = (
- "diffusion_pytorch_model.safetensors.index.json",
- "model.safetensors.index.json",
- "pytorch_model.bin.index.json",
-)
-
-ROOT_REQUIRED_FILES = ("model_index.json",)
-ROOT_REQUIRED_DIRS = ("tokenizer", "text_encoder", "vae", "transformer", "transformer_2")
-OPTIONAL_DIRS = ("image_encoder", "image_processor", "scheduler", "feature_extractor")
-
-
-class AssembleError(RuntimeError):
- pass
-
-
-@dataclass(frozen=True)
-class WeightSpec:
- kind: str # "single" | "sharded"
- single_file: Path | None = None
- index_file: Path | None = None
- shard_files: tuple[Path, ...] = ()
-
-
-def _load_shard_files_from_index(index_file: Path, role: str) -> tuple[Path, ...]:
- try:
- with index_file.open(encoding="utf-8") as f:
- payload = json.load(f)
- except Exception as exc:
- raise AssembleError(f"Failed to parse {role} index file: {index_file}. error={exc}") from exc
-
- weight_map = payload.get("weight_map")
- if not isinstance(weight_map, dict) or not weight_map:
- raise AssembleError(f"Invalid {role} index file (missing/empty weight_map): {index_file}")
-
- shard_names = sorted({str(v) for v in weight_map.values()})
- shard_paths: list[Path] = []
- missing: list[str] = []
- for shard_name in shard_names:
- shard_path = index_file.parent / shard_name
- if not shard_path.is_file():
- missing.append(str(shard_path))
- else:
- shard_paths.append(shard_path)
-
- if missing:
- raise AssembleError(f"{role} index references missing shard file(s): " + ", ".join(missing))
-
- if not shard_paths:
- raise AssembleError(f"No shard files referenced by {role} index: {index_file}")
-
- return tuple(shard_paths)
-
-
-def _resolve_weight_spec(path: Path, role: str) -> WeightSpec:
- if path.is_file():
- return WeightSpec(kind="single", single_file=path)
-
- if path.is_dir():
- for name in WEIGHT_CANDIDATES:
- candidate = path / name
- if candidate.is_file():
- return WeightSpec(kind="single", single_file=candidate)
-
- for index_name in WEIGHT_INDEX_CANDIDATES:
- index_file = path / index_name
- if not index_file.is_file():
- continue
- shard_files = _load_shard_files_from_index(index_file, role=role)
- return WeightSpec(
- kind="sharded",
- index_file=index_file,
- shard_files=shard_files,
- )
-
- shard_candidates = sorted(path.glob("diffusion_pytorch_model-*.safetensors"))
- if shard_candidates:
- raise AssembleError(
- f"Detected sharded {role} files under {path}, but index json is missing. "
- f"Expected one of: {', '.join(WEIGHT_INDEX_CANDIDATES)}"
- )
-
- raise AssembleError(
- f"Cannot find {role} weight under directory: {path}. "
- f"Expected one of single files [{', '.join(WEIGHT_CANDIDATES)}] "
- f"or sharded index files [{', '.join(WEIGHT_INDEX_CANDIDATES)}]."
- )
-
- raise AssembleError(f"{role} path does not exist: {path}")
-
-
-def _canonical_weight_name(weight_file: Path) -> str:
- suffix = weight_file.suffix.lower()
- if suffix == ".safetensors":
- return "diffusion_pytorch_model.safetensors"
- if suffix == ".bin":
- return "diffusion_pytorch_model.bin"
- if suffix == ".pt":
- return "diffusion_pytorch_model.pt"
- return weight_file.name
-
-
-def _validate_skeleton(skeleton: Path) -> None:
- if not skeleton.is_dir():
- raise AssembleError(f"--diffusers-skeleton is not a directory: {skeleton}")
-
- for file_name in ROOT_REQUIRED_FILES:
- if not (skeleton / file_name).is_file():
- raise AssembleError(f"Missing required file in skeleton: {skeleton / file_name}")
-
- for dir_name in ROOT_REQUIRED_DIRS:
- if not (skeleton / dir_name).is_dir():
- raise AssembleError(f"Missing required directory in skeleton: {skeleton / dir_name}")
-
- if not (skeleton / "transformer" / "config.json").is_file():
- raise AssembleError(f"Missing transformer config: {skeleton / 'transformer/config.json'}")
-
- if not (skeleton / "transformer_2" / "config.json").is_file():
- raise AssembleError(f"Missing transformer_2 config: {skeleton / 'transformer_2/config.json'}")
-
-
-def _ensure_clean_output(output_dir: Path, overwrite: bool) -> None:
- if output_dir.exists():
- if not overwrite:
- raise AssembleError(
- f"Output directory already exists: {output_dir}. Use --overwrite to remove and recreate it."
- )
- shutil.rmtree(output_dir)
- output_dir.mkdir(parents=True, exist_ok=False)
-
-
-def _copy_or_link_dir(src: Path, dst: Path, asset_mode: str) -> None:
- if asset_mode == "copy":
- shutil.copytree(src, dst)
- elif asset_mode == "symlink":
- dst.symlink_to(src, target_is_directory=True)
- else:
- raise AssembleError(f"Unknown asset mode: {asset_mode}")
-
-
-def _materialize_weight(weight: WeightSpec, dst_dir: Path, role: str) -> tuple[Path, ...]:
- if weight.kind == "single":
- assert weight.single_file is not None
- dst = dst_dir / _canonical_weight_name(weight.single_file)
- shutil.copy2(weight.single_file, dst)
- return (dst,)
-
- if weight.kind == "sharded":
- assert weight.index_file is not None
- copied: list[Path] = []
- index_dst = dst_dir / weight.index_file.name
- shutil.copy2(weight.index_file, index_dst)
- copied.append(index_dst)
- for shard_file in weight.shard_files:
- shard_dst = dst_dir / shard_file.name
- shutil.copy2(shard_file, shard_dst)
- copied.append(shard_dst)
- return tuple(copied)
-
- raise AssembleError(f"Unknown {role} weight kind: {weight.kind}")
-
-
-def _assemble(
- skeleton: Path,
- output_dir: Path,
- transformer_weight: WeightSpec,
- transformer_2_weight: WeightSpec,
- asset_mode: str,
-) -> tuple[tuple[Path, ...], tuple[Path, ...]]:
- shutil.copy2(skeleton / "model_index.json", output_dir / "model_index.json")
-
- for dir_name in ROOT_REQUIRED_DIRS:
- if dir_name in ("transformer", "transformer_2"):
- continue
- _copy_or_link_dir(skeleton / dir_name, output_dir / dir_name, asset_mode)
-
- for dir_name in OPTIONAL_DIRS:
- src_dir = skeleton / dir_name
- if src_dir.is_dir():
- _copy_or_link_dir(src_dir, output_dir / dir_name, asset_mode)
-
- (output_dir / "transformer").mkdir(parents=True, exist_ok=True)
- (output_dir / "transformer_2").mkdir(parents=True, exist_ok=True)
-
- shutil.copy2(skeleton / "transformer" / "config.json", output_dir / "transformer" / "config.json")
- shutil.copy2(skeleton / "transformer_2" / "config.json", output_dir / "transformer_2" / "config.json")
-
- transformer_copied = _materialize_weight(transformer_weight, output_dir / "transformer", role="transformer")
- transformer_2_copied = _materialize_weight(
- transformer_2_weight,
- output_dir / "transformer_2",
- role="transformer_2",
- )
-
- return transformer_copied, transformer_2_copied
-
-
-def _validate_output(
- output_dir: Path,
- transformer_copied: tuple[Path, ...],
- transformer_2_copied: tuple[Path, ...],
-) -> None:
- if not (output_dir / "model_index.json").is_file():
- raise AssembleError("Output validation failed: model_index.json missing")
-
- required_paths = (
- output_dir / "tokenizer",
- output_dir / "text_encoder",
- output_dir / "vae",
- output_dir / "transformer" / "config.json",
- output_dir / "transformer_2" / "config.json",
- *transformer_copied,
- *transformer_2_copied,
- )
- missing = [str(p) for p in required_paths if not p.exists()]
- if missing:
- raise AssembleError("Output validation failed, missing: " + ", ".join(missing))
-
-
-def parse_args() -> argparse.Namespace:
- parser = argparse.ArgumentParser(
- description=(
- "Assemble a Wan2.2-I2V-A14B-Diffusers directory while optionally "
- "replacing transformer and transformer_2 weights."
- )
- )
- parser.add_argument(
- "--diffusers-skeleton",
- type=Path,
- required=True,
- help="Path to a local Wan-AI/Wan2.2-I2V-A14B-Diffusers directory.",
- )
- parser.add_argument(
- "--transformer-weight",
- type=Path,
- help=(
- "Optional checkpoint file, or directory containing either a single-file "
- "weight or sharded index+shards for transformer/. If omitted, keep the "
- "skeleton's original transformer weights."
- ),
- )
- parser.add_argument(
- "--transformer-2-weight",
- type=Path,
- help=(
- "Optional checkpoint file, or directory containing either a single-file "
- "weight or sharded index+shards for transformer_2/. If omitted, keep the "
- "skeleton's original transformer_2 weights."
- ),
- )
- parser.add_argument(
- "--high-noise-weight",
- type=Path,
- help=argparse.SUPPRESS,
- )
- parser.add_argument(
- "--low-noise-weight",
- type=Path,
- help=argparse.SUPPRESS,
- )
- parser.add_argument(
- "--output-dir",
- type=Path,
- required=True,
- help="Output directory for the assembled model.",
- )
- parser.add_argument(
- "--asset-mode",
- choices=("symlink", "copy"),
- default="symlink",
- help=(
- "How to materialize non-transformer assets (tokenizer/text_encoder/vae/optional dirs). "
- "symlink saves disk and is default."
- ),
- )
- parser.add_argument(
- "--overwrite",
- action="store_true",
- help="Overwrite output-dir if it exists.",
- )
- return parser.parse_args()
-
-
-def main() -> int:
- args = parse_args()
-
- skeleton = args.diffusers_skeleton.resolve()
- output_dir = args.output_dir.resolve()
-
- if args.transformer_weight is not None and args.high_noise_weight is not None:
- print(
- "[ERROR] --transformer-weight and --high-noise-weight are aliases; please provide only one.",
- file=sys.stderr,
- )
- return 2
- if args.transformer_2_weight is not None and args.low_noise_weight is not None:
- print(
- "[ERROR] --transformer-2-weight and --low-noise-weight are aliases; please provide only one.",
- file=sys.stderr,
- )
- return 2
-
- transformer_weight_arg = args.transformer_weight if args.transformer_weight is not None else args.high_noise_weight
- transformer_2_weight_arg = (
- args.transformer_2_weight if args.transformer_2_weight is not None else args.low_noise_weight
- )
-
- transformer_input = (
- transformer_weight_arg.resolve() if transformer_weight_arg is not None else skeleton / "transformer"
- )
- transformer_2_input = (
- transformer_2_weight_arg.resolve() if transformer_2_weight_arg is not None else skeleton / "transformer_2"
- )
-
- try:
- _validate_skeleton(skeleton)
- transformer_weight = _resolve_weight_spec(transformer_input, role="transformer")
- transformer_2_weight = _resolve_weight_spec(transformer_2_input, role="transformer_2")
-
- _ensure_clean_output(output_dir, overwrite=args.overwrite)
- transformer_copied, transformer_2_copied = _assemble(
- skeleton=skeleton,
- output_dir=output_dir,
- transformer_weight=transformer_weight,
- transformer_2_weight=transformer_2_weight,
- asset_mode=args.asset_mode,
- )
- _validate_output(output_dir, transformer_copied, transformer_2_copied)
- except AssembleError as exc:
- print(f"[ERROR] {exc}", file=sys.stderr)
- return 2
-
- def _weight_summary(copied: tuple[Path, ...]) -> str:
- if len(copied) == 1:
- return copied[0].name
- return f"{copied[0].name} + {len(copied) - 1} shard files"
-
- print("[OK] Assembled Wan2.2 I2V Diffusers directory:")
- print(f" output_dir: {output_dir}")
- print(f" transformer weight: {_weight_summary(transformer_copied)}")
- print(f" transformer_2 weight: {_weight_summary(transformer_2_copied)}")
- print("\nUse it with vLLM-Omni, for example:")
- print(f" vllm serve {output_dir} --omni --port 8091")
- return 0
-
-
-if __name__ == "__main__":
- raise SystemExit(main())
diff --git a/vllm_omni/__init__.py b/vllm_omni/__init__.py
index f3c9d9afd4e..b093272d2f4 100644
--- a/vllm_omni/__init__.py
+++ b/vllm_omni/__init__.py
@@ -12,12 +12,6 @@
processing
"""
-# We import version early, because it will warn if vLLM / vLLM Omni
-# are not using the same major + minor version (if vLLM is installed).
-# We should do this before applying patch, because vLLM imports might
-# throw in patch if the versions differ.
-from .version import __version__, __version_tuple__ # isort:skip # noqa: F401
-
try:
from . import patch # noqa: F401
except ModuleNotFoundError as exc: # pragma: no cover - optional dependency
@@ -28,26 +22,11 @@
# Register custom configs (AutoConfig, AutoTokenizer) as early as possible.
from vllm_omni.transformers_utils import configs as _configs # noqa: F401, E402
-from vllm_omni.transformers_utils import parsers as _parsers # noqa: F401, E402
from .config import OmniModelConfig
+from .entrypoints import AsyncOmni, Omni
-
-def __getattr__(name: str):
- # Lazy import for AsyncOmni and Omni to avoid pulling in heavy
- # dependencies (vllm model_loader → fused_moe → pynvml) at package
- # import time. This prevents crashes in lightweight subprocesses
- # (e.g. model-architecture inspection) that lack a CUDA context.
- # See: https://github.com/vllm-project/vllm-omni/issues/1793
- if name == "AsyncOmni":
- from .entrypoints.async_omni import AsyncOmni
-
- return AsyncOmni
- if name == "Omni":
- from .entrypoints.omni import Omni
-
- return Omni
- raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
+from .version import __version__, __version_tuple__ # isort:skip
__all__ = [
diff --git a/vllm_omni/assets/video.py b/vllm_omni/assets/video.py
index 6a5f3204a91..98b1f7e4e29 100644
--- a/vllm_omni/assets/video.py
+++ b/vllm_omni/assets/video.py
@@ -1,6 +1,6 @@
+import librosa
import numpy as np
from vllm.assets.video import VideoAsset
-from vllm.multimodal.media.audio import load_audio
def extract_video_audio(path: str = None, sampling_rate: int = 16000) -> np.ndarray:
@@ -12,5 +12,5 @@ def extract_video_audio(path: str = None, sampling_rate: int = 16000) -> np.ndar
"""
if not path:
path = VideoAsset(name="baby_reading").video_path
- audio_signal, sr = load_audio(path, sr=sampling_rate)
+ audio_signal, sr = librosa.load(path, sr=sampling_rate)
return audio_signal
diff --git a/vllm_omni/benchmarks/data_modules/daily_omni_dataset.py b/vllm_omni/benchmarks/data_modules/daily_omni_dataset.py
deleted file mode 100644
index 65918414f45..00000000000
--- a/vllm_omni/benchmarks/data_modules/daily_omni_dataset.py
+++ /dev/null
@@ -1,1013 +0,0 @@
-"""Daily-Omni Dataset loader for benchmark.
-
-Daily-Omni is an audio-visual reasoning benchmark with 684 videos
-and 1,197 multiple-choice QA pairs across 6 major task types.
-
-Dataset source: https://huggingface.co/datasets/liarliar/Daily-Omni
-
-Supports loading QA metadata from:
-- Local JSON file (``qa_json_path``): recommended for offline/air-gapped environments
-- HuggingFace datasets (``dataset_path``): legacy online mode
-
-Video/audio files normally come from extracted ``Videos.tar``. When ``--daily-omni-video-dir``
-is not set, the first request that needs on-disk media downloads that archive from the Hugging Face
-dataset repo (``huggingface_hub``) and caches it under ``HF_HOME``.
-
-Why ``BenchmarkDataset`` instead of ``HuggingFaceDataset``?
- vLLM's ``HuggingFaceDataset`` is a thin wrapper whose ``__init__`` always ends by calling
- ``load_data()`` → ``datasets.load_dataset(...)`` with a required Hub id and split. That
- contract fits "Hub-only" benches, but Daily-Omni also needs **offline QA metadata** from a
- local ``qa.json`` without touching the network. Subclassing ``HuggingFaceDataset`` would
- mean fighting the parent constructor (fake ``dataset_path``, reordering ``load_data``, or
- duplicating half the parent) and would still imply ``datasets`` is always relevant.
-
- This class therefore inherits only ``BenchmarkDataset`` (minimal: ``dataset_path``,
- ``random_seed``, ``self.data``) and implements **two explicit loaders**:
- ``_load_from_local_json`` (default path for air-gapped runs) and ``_load_from_huggingface``
- (optional legacy path for users who prefer ``datasets`` + Hub cache). The latter is **not**
- inheritance; it is the same Hub rows as before, factored into a helper so one class can
- serve both deployment modes without mandatory ``datasets`` when using ``qa_json_path``.
-
-Usage:
- from vllm_omni.benchmarks.data_modules.daily_omni_dataset import DailyOmniDataset
-
- # Local JSON mode (recommended)
- dataset = DailyOmniDataset(
- qa_json_path="/path/to/qa.json",
- video_dir="/path/to/Videos",
- random_seed=42,
- )
-
- # HuggingFace mode (legacy, requires network)
- dataset = DailyOmniDataset(
- dataset_path="liarliar/Daily-Omni",
- dataset_split="train",
- random_seed=42,
- )
- requests = dataset.sample(
- tokenizer=tokenizer,
- num_requests=100,
- output_len=256,
- )
-"""
-
-import base64
-import json
-import logging
-import os
-import shutil
-import tarfile
-from dataclasses import dataclass
-from pathlib import Path
-from typing import Any, Literal
-
-try:
- from vllm.benchmarks.datasets import BenchmarkDataset, SampleRequest
-except ImportError:
- # Fallback: if BenchmarkDataset not available, use base class from same module
- from vllm.benchmarks.datasets import HuggingFaceDataset as BenchmarkDataset
- from vllm.benchmarks.datasets import SampleRequest
-from vllm.tokenizers import TokenizerLike
-from vllm.tokenizers.hf import get_cached_tokenizer
-
-try:
- from datasets import load_dataset
-except ImportError:
- load_dataset = None
-
-logger = logging.getLogger(__name__)
-
-
-def _daily_omni_hf_cache_root() -> Path:
- return Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface")).expanduser().resolve()
-
-
-def _daily_omni_tar_fingerprint(tar_path: Path) -> str:
- st = tar_path.stat()
- return f"v1:{st.st_size}:{int(st.st_mtime_ns)}"
-
-
-def _daily_omni_find_videos_root_in_extract(tmp: Path) -> Path:
- """Return directory whose children are ``video_id`` folders with ``*_video.mp4``."""
- videos = tmp / "Videos"
- if videos.is_dir():
- return videos
- for child in sorted(tmp.iterdir()):
- if child.is_dir() and not child.name.startswith("."):
- probe = child / f"{child.name}_video.mp4"
- if probe.is_file():
- return tmp
- raise RuntimeError(
- f"Unrecognized layout after extracting Daily-Omni Videos.tar under {tmp} "
- "(expected top-level 'Videos/' or per-video_id subdirs)."
- )
-
-
-def ensure_daily_omni_hub_videos_dir(repo_id: str) -> Path:
- """Download ``Videos.tar`` from the Hugging Face dataset repo and return the ``Videos`` root.
-
- The returned path matches ``--daily-omni-video-dir`` (directory containing ``{{video_id}}/``).
-
- Cached under ``HF_HOME`` / ``vllm_omni/daily_omni_media/``. Reuses extraction when the
- tarball fingerprint matches.
-
- Raises:
- ImportError: if ``huggingface_hub`` is not installed.
- FileNotFoundError / RuntimeError: if the archive is missing or malformed.
- """
- rid = (repo_id or "").strip()
- if not rid:
- raise ValueError("repo_id is required to download Daily-Omni Videos.tar")
-
- try:
- from huggingface_hub import hf_hub_download
- except ImportError as e:
- raise ImportError(
- "Daily-Omni Hub media download requires huggingface_hub. "
- "Install it (e.g. with vLLM) or provide --daily-omni-video-dir with a local extract."
- ) from e
-
- safe = rid.replace("/", "__").replace("\\", "_")
- staging_root = _daily_omni_hf_cache_root() / "vllm_omni" / "daily_omni_media" / safe
- videos_dir = staging_root / "Videos"
- marker = staging_root / ".videos_extracted"
-
- tar_path: Path | None = None
- for fname in ("Videos.tar", "videos.tar"):
- try:
- tar_path = Path(hf_hub_download(repo_id=rid, filename=fname, repo_type="dataset"))
- break
- except Exception:
- continue
- if tar_path is None or not tar_path.is_file():
- raise FileNotFoundError(
- f"Could not download Videos.tar from Hugging Face dataset {rid!r} (tried Videos.tar / videos.tar)."
- )
-
- fp = _daily_omni_tar_fingerprint(tar_path)
- if marker.is_file() and videos_dir.is_dir():
- try:
- if marker.read_text(encoding="utf-8").strip() == fp:
- next(videos_dir.iterdir())
- logger.info("Reusing cached Daily-Omni Videos extract at %s", videos_dir)
- return videos_dir
- except (OSError, StopIteration):
- shutil.rmtree(videos_dir, ignore_errors=True)
- marker.unlink(missing_ok=True)
-
- staging_root.mkdir(parents=True, exist_ok=True)
- work = staging_root / "_extract_work"
- shutil.rmtree(work, ignore_errors=True)
- work.mkdir(parents=True)
- try:
- logger.info("Extracting Daily-Omni Videos.tar from %s (repo=%s)", tar_path, rid)
- with tarfile.open(tar_path, "r:*") as tf:
- tf.extractall(path=work, filter="data")
- found = _daily_omni_find_videos_root_in_extract(work)
- if videos_dir.exists():
- shutil.rmtree(videos_dir, ignore_errors=True)
- shutil.move(str(found), str(videos_dir))
- finally:
- shutil.rmtree(work, ignore_errors=True)
-
- marker.write_text(fp, encoding="utf-8")
- logger.info("Daily-Omni Hub media ready at %s", videos_dir)
- return videos_dir
-
-
-class _ListDatasetIterator:
- """Simple iterator wrapper around a list to mimic HuggingFace streaming dataset behavior."""
-
- def __init__(self, data: list[dict[str, Any]]) -> None:
- self._data = data
- self._index = 0
-
- def __iter__(self):
- self._index = 0
- return self
-
- def __next__(self) -> dict[str, Any]:
- if self._index >= len(self._data):
- raise StopIteration
- item = self._data[self._index]
- self._index += 1
- return item
-
- def __len__(self) -> int:
- return len(self._data)
-
- def __getitem__(self, idx: int | slice) -> dict[str, Any] | list[dict[str, Any]]:
- return self._data[idx]
-
-
-# Aligns with Lliar-liar/Daily-Omni CLI ``--input_mode`` (test_model/*/testmodel.py).
-DailyOmniInputMode = Literal["all", "visual", "audio"]
-
-# ``build_conversation()`` in Daily-Omni ``test_model/Qwen2.5-Omni/testmodel.py`` (verbatim).
-DAILY_OMNI_SYSTEM_TEXT = (
- "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
- "capable of perceiving auditory and visual inputs, as well as generating text and speech."
-)
-
-
-@dataclass
-class DailyOmniSampleRequest(SampleRequest):
- """``SampleRequest`` with Daily-Omni gold labels for post-run accuracy scoring."""
-
- daily_omni_gold_answer: str = ""
- daily_omni_video_id: str = ""
- daily_omni_task_type: str = ""
- #: Official qa.json ``video_duration`` (e.g. ``30s``, ``60s``) for leaderboard-style breakdown.
- daily_omni_video_duration: str = ""
- #: Official ``video_category`` (YouTube-style category string) for per-category accuracy.
- daily_omni_video_category: str = ""
- #: Extra JSON fields merged into chat-completions ``extra_body`` (e.g. ``mm_processor_kwargs``).
- omni_extra_body: dict[str, Any] | None = None
- #: Full OpenAI ``messages`` (system + user) mirroring upstream Daily-Omni conversation.
- omni_chat_messages: list[dict[str, Any]] | None = None
- #: Used only when ``omni_chat_messages`` is None (non-Daily-Omni-style requests).
- omni_chat_mm_position: Literal["first", "last"] = "last"
-
-
-class DailyOmniDataset(BenchmarkDataset):
- """Daily-Omni audio-visual QA dataset for benchmarking.
-
- Inherits ``BenchmarkDataset`` only (not ``HuggingFaceDataset``): see module docstring for why
- Hub loading lives in ``_load_from_huggingface`` instead of subclassing the HF base class.
-
- The dataset includes:
- - 684 videos from daily life scenarios (available in Videos.tar)
- - 1,197 multiple-choice QA pairs in qa.json
- - 6 major task categories
-
- QA metadata can be loaded from:
- - Local JSON file (``qa_json_path``): recommended for offline/air-gapped environments
- - HuggingFace datasets (``dataset_path``): legacy online mode
-
- Video/audio files normally come from extracted ``Videos.tar``. When ``video_dir`` is not set,
- the first sample that needs on-disk media downloads that archive from the Hugging Face dataset
- repo (env ``VLLM_DAILY_OMNI_MEDIA_REPO`` overrides the repo id; else ``dataset_path`` or
- :data:`DEFAULT_HF_DATASET_ID`).
-
- Args:
- qa_json_path: Path to local qa.json file (offline mode, preferred). When provided,
- ``dataset_path`` and ``dataset_split`` are ignored.
- dataset_path: HuggingFace dataset path (e.g., "liarliar/Daily-Omni"). Used only if
- ``qa_json_path`` is not provided (legacy online mode).
- dataset_split: Dataset split to use (default: "train"). Used only in online mode.
- random_seed: Random seed for shuffling
- video_dir: Directory containing extracted video files (default: None; may be filled lazily
- from Hub — see above).
- input_mode: Which modalities to send, matching upstream Daily-Omni ``--input_mode``:
- ``all`` — video + WAV (default; official audio-visual protocol);
- ``visual`` — video only;
- ``audio`` — extracted WAV only (requires ``{video_id}/{video_id}_audio.wav`` under ``video_dir``).
- max_duration_seconds: Reserved for future ffprobe-based filtering; currently **not applied**
- when building requests (metadata ``video_duration`` is still passed through for eval).
- dataset_subset: Optional HuggingFace subset name (``load_dataset(..., name=...)``); used by bench
- ``--hf-subset`` / patch.
- no_stream: If True, load the Hub split non-streaming (matches bench ``--no-stream``).
- inline_local_video: If True, embed local MP4 as ``data:video/mp4;base64,...`` in requests so
- the API server does not need ``--allowed-local-media-path`` (large JSON; use for small runs).
- When ``input_mode`` is ``audio`` or ``all``, local WAV is embedded the same way
- (``data:audio/wav;base64,...``).
- trust_remote_code: Whether to trust remote code when loading HuggingFace dataset
- (online mode only).
- """
-
- SUPPORTED_DATASET_PATHS: set[str] = {
- "liarliar/Daily-Omni",
- }
- #: Default Hub id for synthetic video URLs when ``qa_json_path`` is used (``dataset_path`` None).
- DEFAULT_HF_DATASET_ID = "liarliar/Daily-Omni"
- IS_MULTIMODAL = True
- DEFAULT_OUTPUT_LEN = 256
-
- def __init__(
- self,
- qa_json_path: str | None = None,
- dataset_path: str | None = None,
- dataset_split: str = "train",
- random_seed: int = 0,
- video_dir: str | None = None,
- input_mode: DailyOmniInputMode = "all",
- inline_local_video: bool = False,
- trust_remote_code: bool = False,
- max_duration_seconds: float | None = None,
- dataset_subset: str | None = None,
- no_stream: bool = False,
- **kwargs,
- ) -> None:
- if input_mode not in ("all", "visual", "audio"):
- raise ValueError(f"input_mode must be 'all', 'visual', or 'audio', got {input_mode!r}")
-
- # Validate arguments: need either local JSON or HF path
- if qa_json_path is None and dataset_path is None:
- raise ValueError(
- "Either 'qa_json_path' (local JSON) or 'dataset_path' (HuggingFace) must be provided. "
- "For offline/air-gapped environments, download qa.json and use qa_json_path."
- )
-
- # Store configuration
- self.qa_json_path = Path(qa_json_path) if qa_json_path else None
- self.dataset_path = dataset_path
- self.dataset_split = dataset_split
- self.dataset_subset = dataset_subset
- #: Match vLLM ``HuggingFaceDataset`` / bench CLI ``--no-stream``.
- self._hf_streaming = not no_stream
- self.video_dir = Path(video_dir) if video_dir else None
- self.inline_local_video = inline_local_video
- self.input_mode: DailyOmniInputMode = input_mode
- self.max_duration_seconds = max_duration_seconds
- self.trust_remote_code = trust_remote_code
-
- #: In-process cache of ffprobe durations only (no disk persistence).
- self._video_durations: dict[str, float] = {}
-
- # Initialize parent BenchmarkDataset
- super().__init__(
- dataset_path=dataset_path if qa_json_path is None else None,
- random_seed=random_seed,
- **kwargs,
- )
-
- # Load data based on mode
- self.load_data()
-
- # Verify dataset info
- logger.info(
- "Loaded Daily-Omni dataset: mode=%s, source=%s, random_seed=%d, input_mode=%s, max_duration=%s",
- "local_json" if self.qa_json_path else "huggingface",
- str(self.qa_json_path) if self.qa_json_path else f"{dataset_path}/{dataset_split}",
- random_seed,
- input_mode,
- f"{max_duration_seconds}s" if max_duration_seconds else "unlimited",
- )
-
- def load_data(self) -> None:
- """Populate ``self.data`` from either local JSON or the Hub.
-
- See module docstring: we do not subclass ``HuggingFaceDataset`` because Daily-Omni needs
- a first-class offline path; Hub loading is an optional branch implemented below.
- """
- if self.qa_json_path is not None:
- self._load_from_local_json()
- else:
- self._load_from_huggingface()
-
- def _load_from_local_json(self) -> None:
- """Load QA data from local JSON file."""
- if not self.qa_json_path.exists():
- raise FileNotFoundError(f"QA JSON file not found: {self.qa_json_path}")
-
- with open(self.qa_json_path, encoding="utf-8") as f:
- data = json.load(f)
-
- # Support both list format and dict with "train"/"test" splits
- if isinstance(data, dict):
- # Try to get the requested split, fallback to first available
- split_data = data.get(self.dataset_split)
- if split_data is None:
- available = list(data.keys())
- if available:
- logger.warning(
- "Split '%s' not found in %s, using '%s' instead",
- self.dataset_split,
- self.qa_json_path,
- available[0],
- )
- split_data = data[available[0]]
- else:
- split_data = []
- data = split_data
-
- if not isinstance(data, list):
- raise ValueError(f"Expected list of QA items in JSON, got {type(data).__name__}")
-
- # Shuffle if requested
- if not getattr(self, "disable_shuffle", False) and self.random_seed is not None:
- import random
-
- rng = random.Random(self.random_seed)
- shuffled = data[:]
- rng.shuffle(shuffled)
- data = shuffled
-
- # Create an iterator-like wrapper for compatibility
- self.data = _ListDatasetIterator(data)
-
- def _load_from_huggingface(self) -> None:
- """Load QA rows via ``datasets.load_dataset`` (legacy / convenience path).
-
- Kept for backward compatibility: callers can still pass ``dataset_path=liarliar/Daily-Omni``
- and get the same parquet-backed rows as the Hub dataset card, with streaming (or
- non-streaming if ``no_stream=True``) and shuffle.
-
- This is intentionally **not** implemented by subclassing ``HuggingFaceDataset``: that base
- always runs Hub ``load_dataset`` from its constructor and expects a Hub id as the primary
- API; Daily-Omni instead chooses the source in ``load_data()`` (JSON vs Hub) while sharing
- one ``sample()`` / request-building implementation for both.
- """
- if load_dataset is None:
- raise ImportError(
- "datasets library is required for HuggingFace mode. "
- "Install with: pip install datasets, or use local JSON mode instead."
- )
-
- load_kw: dict[str, Any] = {
- "split": self.dataset_split,
- "streaming": self._hf_streaming,
- "trust_remote_code": self.trust_remote_code,
- }
- if self.dataset_subset is not None:
- load_kw["name"] = self.dataset_subset
- ds = load_dataset(self.dataset_path, **load_kw)
- if not getattr(self, "disable_shuffle", False):
- ds = ds.shuffle(seed=self.random_seed)
- self.data = ds
-
- def get_task_statistics(self) -> dict[str, int]:
- """Get distribution of task types in the dataset.
-
- Returns:
- Dict mapping task type to count
- """
- stats: dict[str, int] = {}
- for item in self.data:
- row = self._coerce_row(item)
- fields = self._normalize_qa_fields(row)
- task_type = fields["task_type"] or "unknown"
- stats[task_type] = stats.get(task_type, 0) + 1
- return stats
-
- @staticmethod
- def _coerce_row(item: Any) -> dict[str, Any]:
- """Turn a dataset row into a plain dict (Arrow / Mapping)."""
- if isinstance(item, dict):
- return item
- if hasattr(item, "as_py"):
- return dict(item.as_py()) # pyarrow Row
- try:
- return dict(item)
- except (TypeError, ValueError):
- return {k: item[k] for k in item} # type: ignore[misc]
-
- @staticmethod
- def _normalize_qa_fields(row: dict[str, Any]) -> dict[str, Any]:
- """Map official Daily-Omni qa.json / Hub schema to internal fields.
-
- Official fields (see liarliar/Daily-Omni ``qa.json``): ``Question``, ``Choice`` (list),
- ``Answer``, ``video_id``, ``Type``, ``video_duration`` (``30s`` / ``60s``), ``video_category``,
- plus other category columns. Legacy aliases (lowercase / older loaders) are still accepted.
- """
- out: dict[str, Any] = {}
-
- out["question"] = str(row.get("Question") or row.get("question") or "").strip()
- vid = row.get("video_id") if row.get("video_id") is not None else row.get("video")
- out["video_id"] = str(vid).strip() if vid is not None else ""
- out["task_type"] = str(row.get("Type") or row.get("task_type") or row.get("type") or "").strip()
- vc = row.get("video_category") if row.get("video_category") is not None else row.get("videoCategory")
- out["video_category"] = str(vc).strip() if vc is not None else ""
- vd = row.get("video_duration") if row.get("video_duration") is not None else row.get("videoDuration")
- out["video_duration"] = str(vd).strip() if vd is not None else ""
- out["answer"] = str(row.get("Answer") or row.get("answer") or "").strip()
- vu = row.get("video_url") if row.get("video_url") is not None else row.get("Video_URL")
- out["video_url"] = str(vu).strip() if vu is not None and str(vu).strip() else None
-
- choice = row.get("Choice")
- if choice is None:
- choice = row.get("options") or row.get("choice")
- out["choice"] = choice
-
- return out
-
- def sample(
- self,
- tokenizer: TokenizerLike,
- num_requests: int,
- output_len: int | None = None,
- request_id_prefix: str = "",
- no_oversample: bool = False,
- **kwargs,
- ) -> list[SampleRequest]:
- """Sample requests from Daily-Omni dataset.
-
- Args:
- tokenizer: Tokenizer for computing prompt length
- num_requests: Number of requests to sample
- output_len: Target output length in tokens (default: 256)
- request_id_prefix: Prefix for request IDs
- no_oversample: If True, do not oversample if fewer examples available
- **kwargs: Additional arguments (ignored)
-
- Returns:
- List of SampleRequest objects with video URLs and prompts
- """
- if output_len is None:
- output_len = self.DEFAULT_OUTPUT_LEN
-
- sampled_requests: list[SampleRequest] = []
- ind = 0
- cached_tokenizer = get_cached_tokenizer(tokenizer)
-
- # Iterate over shuffled dataset
- for item in self.data:
- if len(sampled_requests) >= num_requests:
- break
-
- request = self._create_sample_request(
- self._coerce_row(item), cached_tokenizer, output_len, request_id_prefix, ind
- )
- if request:
- sampled_requests.append(request)
- ind += 1
-
- logger.info("Created %d sample requests from Daily-Omni dataset", len(sampled_requests))
-
- # Handle oversampling if needed
- self.maybe_oversample_requests(sampled_requests, num_requests, request_id_prefix, no_oversample)
-
- return sampled_requests
-
- def _create_sample_request(
- self,
- qa_item: dict[str, Any],
- tokenizer: TokenizerLike,
- output_len: int,
- request_id_prefix: str,
- index: int,
- ) -> SampleRequest | None:
- """Create a SampleRequest from a QA item.
-
- Args:
- qa_item: QA pair from the dataset
- tokenizer: Tokenizer
- output_len: Target output length
- request_id_prefix: Prefix for request ID
- index: Request index
-
- Returns:
- SampleRequest or None if invalid
- """
- fields = self._normalize_qa_fields(qa_item)
- video_id = fields["video_id"]
- question = fields["question"]
- choice = fields["choice"]
- task_type = fields["task_type"]
- video_url = fields["video_url"]
- video_duration = fields.get("video_duration") or ""
- video_category = fields.get("video_category") or ""
-
- if not video_id and not video_url:
- logger.warning("Skipping item: no video_id / video_url")
- return None
-
- if not question:
- logger.warning("Skipping item: no question found")
- return None
-
- # Official layout after extracting Videos.tar (see Lliar-liar/Daily-Omni test_model):
- # {video_base_dir}/{video_id}/{video_id}_video.mp4
- mm_payload, omni_extra, mm_pos = self._compose_daily_omni_multimodal(video_id, video_url)
- if not mm_payload:
- return None
-
- messages = self._build_daily_omni_openai_messages(mm_payload, question, choice)
- user_text = self._official_daily_omni_user_prompt(question, choice)
- # Text-only length estimate (same as before: no MM token count in bench).
- prompt_len = len(tokenizer.encode(f"{DAILY_OMNI_SYSTEM_TEXT}\n{user_text}"))
-
- return DailyOmniSampleRequest(
- prompt=user_text,
- prompt_len=prompt_len,
- expected_output_len=output_len,
- multi_modal_data=None,
- request_id=f"{request_id_prefix}{index}",
- daily_omni_gold_answer=fields["answer"],
- daily_omni_video_id=video_id,
- daily_omni_task_type=task_type,
- daily_omni_video_duration=video_duration,
- daily_omni_video_category=video_category,
- omni_extra_body=omni_extra,
- omni_chat_messages=messages,
- omni_chat_mm_position=mm_pos,
- )
-
- @staticmethod
- def _official_video_relpath(video_id: str) -> str:
- """Relative path inside extracted ``Videos/`` per upstream Daily-Omni scripts."""
- return f"{video_id}/{video_id}_video.mp4"
-
- @staticmethod
- def _official_audio_relpath(video_id: str) -> str:
- """Relative path for extracted WAV per upstream ``get_audio_path``."""
- return f"{video_id}/{video_id}_audio.wav"
-
- def _resolve_local_video_path(self, video_id: str) -> Path | None:
- """Pick an existing file under ``video_dir`` (official layout + flat fallback)."""
- if not self.video_dir or not video_id:
- return None
-
- candidates = [
- self.video_dir / self._official_video_relpath(video_id),
- self.video_dir / f"{video_id}.mp4", # flat layout (custom mirrors / outdated docs)
- ]
- seen: set[Path] = set()
- for p in candidates:
- rp = p.resolve()
- if rp in seen:
- continue
- seen.add(rp)
- if p.exists():
- return p
- return None
-
- def _resolve_local_audio_path(self, video_id: str) -> Path | None:
- """Pick an existing WAV under ``video_dir`` (official layout + flat fallback)."""
- if not self.video_dir or not video_id:
- return None
- candidates = [
- self.video_dir / self._official_audio_relpath(video_id),
- self.video_dir / f"{video_id}.wav",
- ]
- seen: set[Path] = set()
- for p in candidates:
- rp = p.resolve()
- if rp in seen:
- continue
- seen.add(rp)
- if p.exists():
- return p
- return None
-
- def _local_file_to_video_url_payload(self, video_path: Path) -> dict[str, Any]:
- """Build OpenAI-style video_url part for a resolved local file.
-
- vLLM rejects ``file://`` unless the server was started with
- ``--allowed-local-media-path`` set to a directory that **contains** the file
- (typically the extracted ``Videos`` root). Use ``inline_local_video=True`` to
- send base64 data URLs instead (no server path allowlist; larger requests).
- """
- path = video_path.expanduser().resolve()
- if self.inline_local_video:
- raw = path.read_bytes()
- b64 = base64.b64encode(raw).decode("ascii")
- return {
- "type": "video_url",
- "video_url": {"url": f"data:video/mp4;base64,{b64}"},
- }
- return {
- "type": "video_url",
- "video_url": {"url": path.as_uri()},
- }
-
- def _local_file_to_audio_url_payload(self, audio_path: Path) -> dict[str, Any]:
- """Build OpenAI-style ``audio_url`` part for a resolved local WAV file."""
- path = audio_path.expanduser().resolve()
- if self.inline_local_video:
- raw = path.read_bytes()
- b64 = base64.b64encode(raw).decode("ascii")
- return {
- "type": "audio_url",
- "audio_url": {"url": f"data:audio/wav;base64,{b64}"},
- }
- return {
- "type": "audio_url",
- "audio_url": {"url": path.as_uri()},
- }
-
- def _lazy_ensure_hub_media_dir(self) -> None:
- """If ``video_dir`` was not configured, download and extract ``Videos.tar`` once from HF."""
- if self.video_dir is not None:
- return
- repo = os.environ.get("VLLM_DAILY_OMNI_MEDIA_REPO", "").strip()
- if not repo:
- repo = (self.dataset_path or "").strip()
- if not repo:
- repo = self.DEFAULT_HF_DATASET_ID
- self.video_dir = ensure_daily_omni_hub_videos_dir(repo)
-
- def _get_video_content(
- self,
- video_id: str,
- video_url: str | None,
- ) -> dict[str, Any] | None:
- """Resolve video for OpenAI-style ``video_url`` content.
-
- Upstream uses ``get_video_path(video_id, base) -> base/video_id/video_id_video.mp4``.
- The Hub repo only publishes ``Videos.tar``; use ``--daily-omni-video-dir`` pointing
- at the extracted ``Videos`` folder (parent of per-``video_id`` subdirs).
-
- For ``file://`` URLs, start ``vllm serve`` with e.g.
- ``--allowed-local-media-path /same/path/as/daily-omni-video-dir``.
- """
- if video_url:
- url = video_url
- if not url.startswith(("http://", "https://", "file://")):
- url = f"https://{url.lstrip('/')}"
- return {"type": "video_url", "video_url": {"url": url}}
-
- self._lazy_ensure_hub_media_dir()
-
- if self.video_dir and video_id:
- video_path = self._resolve_local_video_path(video_id)
- if video_path is not None:
- return self._local_file_to_video_url_payload(video_path)
- logger.warning(
- "Video not found under video_dir=%s for video_id=%r (expected %s or %s)",
- self.video_dir,
- video_id,
- self._official_video_relpath(video_id),
- f"{video_id}.mp4",
- )
-
- if video_id:
- repo = self.dataset_path or self.DEFAULT_HF_DATASET_ID
- rel = self._official_video_relpath(video_id)
- hf_video_url = f"https://huggingface.co/datasets/{repo}/resolve/main/Videos/{rel}"
- logger.debug(
- "Using HF video URL (likely 404 — Hub ships Videos.tar only): %s",
- hf_video_url,
- )
- return {"type": "video_url", "video_url": {"url": hf_video_url}}
-
- logger.error("Could not determine video source for video_id=%r", video_id)
- return None
-
- def _get_audio_content(self, video_id: str) -> dict[str, Any] | None:
- """Resolve extracted WAV for OpenAI-style ``audio_url`` (local files under ``video_dir``).
-
- Uses the same tree as video (``{video_id}/{video_id}_audio.wav``), including after lazy
- Hub ``Videos.tar`` extraction when ``video_dir`` was unset.
- """
- self._lazy_ensure_hub_media_dir()
-
- if not self.video_dir or not video_id:
- logger.warning(
- "Daily-Omni input_mode %r requires --daily-omni-video-dir with %s",
- self.input_mode,
- self._official_audio_relpath(video_id),
- )
- return None
- audio_path = self._resolve_local_audio_path(video_id)
- if audio_path is not None:
- return self._local_file_to_audio_url_payload(audio_path)
- logger.warning(
- "Audio not found under video_dir=%s for video_id=%r (expected %s or %s)",
- self.video_dir,
- video_id,
- self._official_audio_relpath(video_id),
- f"{video_id}.wav",
- )
- return None
-
- def _compose_daily_omni_multimodal(
- self,
- video_id: str,
- video_url: str | None,
- ) -> tuple[dict[str, Any] | list[dict[str, Any]] | None, dict[str, Any] | None, Literal["first", "last"]]:
- """Build ``multi_modal_data`` and request extras for the active ``input_mode``.
-
- Mirrors upstream Daily-Omni: separate video + WAV with ``use_audio_in_video=False``.
- """
- extra: dict[str, Any] = {"mm_processor_kwargs": {"use_audio_in_video": False}}
- mode = self.input_mode
-
- if mode == "visual":
- v = self._get_video_content(video_id, video_url)
- return v, extra, "last"
-
- if mode == "audio":
- a = self._get_audio_content(video_id)
- return a, extra, "first"
-
- v = self._get_video_content(video_id, video_url)
- a = self._get_audio_content(video_id)
- if not v or not a:
- return None, None, "first"
- return [v, a], extra, "first"
-
- @staticmethod
- def _media_desc_for_official_prompt(mode: DailyOmniInputMode) -> str:
- """``media_desc`` in upstream ``build_conversation``."""
- if mode == "audio":
- return "given audio"
- if mode == "all":
- return "given video and audio together"
- return "given video"
-
- @staticmethod
- def _choices_repr_for_official_prompt(choice: Any) -> str:
- """Format ``Choice`` from qa.json for the model (one option per line when possible).
-
- Using ``str(list)`` embeds Python list brackets and quotes, which is poor for MCQ
- reading; lists/tuples are joined with newlines instead. Other shapes fall back to
- ``str(choice)`` for parity with exotic upstream payloads.
- """
- if choice is None:
- return ""
- if isinstance(choice, (list, tuple)):
- lines = [str(x).strip() for x in choice if str(x).strip()]
- return "\n".join(lines)
- if isinstance(choice, dict):
- return "\n".join(f"{k}. {v}" for k, v in choice.items())
- return str(choice)
-
- def _official_daily_omni_user_prompt(self, question: str, choice: Any) -> str:
- """User text block from Daily-Omni ``build_conversation`` (after media parts)."""
- task_prompt = self._media_desc_for_official_prompt(self.input_mode)
- choices = self._choices_repr_for_official_prompt(choice)
- # Single f-string with explicit newlines avoids accidental implicit concatenation
- # gluing sentences (e.g. ``...media_desc.Select...``) when editing.
- return (
- "Your task is to accurately answer multiple-choice questions "
- f"based on the {task_prompt}.\n"
- "Select the single most accurate answer from the given choices.\n"
- f"Question: {question}\n"
- f"Choices: {choices}\n"
- "Your answer should be a capital letter representing your choice: "
- "A, B, C, or D. Don't generate any other text.\n"
- )
-
- def _build_daily_omni_openai_messages(
- self,
- mm_payload: dict[str, Any] | list[dict[str, Any]],
- question: str,
- choice: Any,
- ) -> list[dict[str, Any]]:
- """Map upstream conversation to OpenAI Chat Completions ``messages`` (video_url / audio_url parts)."""
- user_text = self._official_daily_omni_user_prompt(question, choice)
- mm_list: list[dict[str, Any]] = mm_payload if isinstance(mm_payload, list) else [mm_payload]
- user_content: list[dict[str, Any]] = [*mm_list, {"type": "text", "text": user_text}]
- return [
- {"role": "system", "content": [{"type": "text", "text": DAILY_OMNI_SYSTEM_TEXT}]},
- {"role": "user", "content": user_content},
- ]
-
- def sample_by_task_type(
- self,
- tokenizer: TokenizerLike,
- task_type: str,
- num_samples: int,
- output_len: int | None = None,
- request_id_prefix: str = "",
- **kwargs,
- ) -> list[SampleRequest]:
- """Sample requests filtered by task type.
-
- Args:
- tokenizer: Tokenizer
- task_type: Task type to filter by
- num_samples: Number of samples
- output_len: Target output length
- request_id_prefix: Prefix for request IDs
- **kwargs: Additional sampling arguments
-
- Returns:
- List of SampleRequest objects matching the task type
- """
- if output_len is None:
- output_len = self.DEFAULT_OUTPUT_LEN
-
- filtered = [
- item for item in self.data if self._normalize_qa_fields(self._coerce_row(item))["task_type"] == task_type
- ]
-
- available = len(filtered)
- if available < num_samples:
- logger.warning(
- "Only %d samples available for task type '%s', requested %d",
- available,
- task_type,
- num_samples,
- )
- num_samples = available
-
- sampled_requests: list[SampleRequest] = []
- cached_tokenizer = get_cached_tokenizer(tokenizer)
-
- for i, item in enumerate(filtered[:num_samples]):
- request = self._create_sample_request(item, cached_tokenizer, output_len, request_id_prefix, i)
- if request:
- sampled_requests.append(request)
-
- return sampled_requests
-
- def __repr__(self) -> str:
- return (
- f"DailyOmniDataset("
- f"dataset_path={self.dataset_path!r}, "
- f"dataset_split={self.dataset_split!r}, "
- f"video_dir={self.video_dir!r}, "
- f"input_mode={self.input_mode!r}, "
- f"inline_local_video={self.inline_local_video!r}, "
- f"max_duration_seconds={self.max_duration_seconds}, "
- f"random_seed={self.random_seed}"
- f")"
- )
-
-
-def load_daily_omni_dataset(
- qa_json_path: str | None = None,
- dataset_path: str | None = None,
- dataset_split: str = "train",
- random_seed: int = 0,
- video_dir: str | None = None,
- input_mode: DailyOmniInputMode = "all",
- max_duration_seconds: float | None = None,
- dataset_subset: str | None = None,
- no_stream: bool = False,
- **kwargs,
-) -> DailyOmniDataset:
- """Convenience function to load Daily-Omni dataset.
-
- Args:
- qa_json_path: Path to local qa.json file (recommended for offline/air-gapped environments).
- When provided, ``dataset_path`` is ignored.
- dataset_path: HuggingFace dataset path (default: liarliar/Daily-Omni). Used only if
- ``qa_json_path`` is not provided (legacy online mode).
- dataset_split: Dataset split to use (default: "train")
- random_seed: Random seed for shuffling
- video_dir: Directory containing extracted ``Videos/`` tree (MP4 and, for ``all``/``audio``, WAV)
- input_mode: ``visual`` | ``audio`` | ``all`` (same semantics as upstream Daily-Omni)
- max_duration_seconds: Maximum video duration in seconds (e.g., 30 for 30s subset, 60 for 60s subset);
- uses ffprobe on local files under ``video_dir`` (in-memory cache only for this process).
- **kwargs: Additional arguments passed to DailyOmniDataset
-
- Returns:
- DailyOmniDataset instance
-
- Example:
- >>> from vllm_omni.benchmarks.data_modules.daily_omni_dataset import load_daily_omni_dataset
-
- # Local JSON mode (recommended for offline)
- >>> dataset = load_daily_omni_dataset(
- ... qa_json_path="/path/to/qa.json",
- ... video_dir="/path/to/Daily-Omni/Videos",
- ... random_seed=42,
- ... max_duration_seconds=30,
- ... )
-
- # HuggingFace mode (legacy online)
- >>> dataset = load_daily_omni_dataset(
- ... dataset_path="liarliar/Daily-Omni",
- ... video_dir="/path/to/Daily-Omni/Videos",
- ... random_seed=42,
- ... )
- >>> requests = dataset.sample(tokenizer, num_requests=100)
- """
- return DailyOmniDataset(
- qa_json_path=qa_json_path,
- dataset_path=dataset_path,
- dataset_split=dataset_split,
- random_seed=random_seed,
- video_dir=video_dir,
- input_mode=input_mode,
- max_duration_seconds=max_duration_seconds,
- dataset_subset=dataset_subset,
- no_stream=no_stream,
- **kwargs,
- )
-
-
-def get_daily_omni_statistics(
- qa_json_path: str | None = None,
- dataset_path: str | None = DailyOmniDataset.DEFAULT_HF_DATASET_ID,
- dataset_split: str = "train",
-) -> dict[str, Any]:
- """Get statistics about the Daily-Omni dataset.
-
- Args:
- qa_json_path: Path to local qa.json file (recommended for offline/air-gapped environments).
- When provided, ``dataset_path`` is ignored.
- dataset_path: HuggingFace dataset path. Defaults to ``DailyOmniDataset.DEFAULT_HF_DATASET_ID``
- when ``qa_json_path`` is omitted. Pass ``None`` only together with ``qa_json_path``.
- dataset_split: Dataset split to use (default: "train")
-
- Returns:
- Statistics dict with task type distribution and other info
-
- Example:
- >>> from vllm_omni.benchmarks.data_modules.daily_omni_dataset import get_daily_omni_statistics
-
- # Local JSON mode
- >>> stats = get_daily_omni_statistics(qa_json_path="/path/to/qa.json")
-
- # HuggingFace mode
- >>> stats = get_daily_omni_statistics(dataset_path="liarliar/Daily-Omni")
- >>> print(f"Total QA pairs: {stats['total_qa_pairs']}")
- >>> print(f"Task distribution: {stats['task_distribution']}")
- """
- dataset = DailyOmniDataset(
- qa_json_path=qa_json_path,
- dataset_path=dataset_path,
- dataset_split=dataset_split,
- )
- task_stats = dataset.get_task_statistics()
-
- source = str(qa_json_path) if qa_json_path else f"{dataset_path}/{dataset_split}"
- return {
- "source": source,
- "total_qa_pairs": len(list(dataset.data)),
- "task_distribution": task_stats,
- }
diff --git a/vllm_omni/benchmarks/data_modules/daily_omni_eval.py b/vllm_omni/benchmarks/data_modules/daily_omni_eval.py
deleted file mode 100644
index f191cf2febc..00000000000
--- a/vllm_omni/benchmarks/data_modules/daily_omni_eval.py
+++ /dev/null
@@ -1,417 +0,0 @@
-"""Daily-Omni multiple-choice accuracy scoring for vLLM-Omni bench serve.
-
-Compares model ``generated_text`` to dataset ``Answer`` (A/B/C/D).
-
-**Alignment with open-source** (`Lliar-liar/Daily-Omni` ``test_model/.../testmodel.py``):
-
-- Answer extraction defaults to the same rules as ``extract_choice_letter`` (strip after an
- ``assistant`` marker, then leading ``A``–``D``, else ``\\b[A-D]\\b``, else a CJK-safe
- non-letter-boundary pass). Set env
- ``DAILY_OMNI_EXTRACT_MODE=relaxed`` to use the older vLLM-Omni heuristics (last ``answer:``,
- tail scan, etc.).
-- Overall accuracy comparable to the official script uses **successful HTTP responses only** as
- the denominator (their ``valid_questions = total - failed`` excludes inference / I/O skips).
- We also report ``daily_omni_accuracy_incl_http_fail`` where each failed request counts as a
- wrong answer in the denominator (stricter throughput-bench view).
-- **By video length:** mirrors upstream ``--- Accuracy by Video Duration ---`` for ``30s`` /
- ``60s`` (``qa.json`` ``video_duration``): ``daily_omni_per_duration*`` metrics and a printed block.
-- **By video category:** mirrors ``--- Accuracy by Video Category ---`` using ``video_category``
- from ``qa.json`` (``daily_omni_per_category*``; empty category is bucketed as ``unknown``).
-- **Correctness:** uses the same ``evaluate_answer`` rule as upstream (truthy extracted letter vs
- raw ``Answer`` string, both ``strip().upper()``). Rows with empty ``Answer`` are skipped
- (``no_gold``), matching missing-field skips in the official loop.
-"""
-
-from __future__ import annotations
-
-import os
-import re
-from typing import Any
-
-from vllm.benchmarks.lib.endpoint_request_func import RequestFuncOutput
-
-from vllm_omni.benchmarks.data_modules.daily_omni_dataset import DailyOmniSampleRequest
-
-_VALID = frozenset("ABCD")
-
-# Official ``testmodel.py`` buckets (``qa.json`` ``video_duration``).
-DAILY_OMNI_DURATION_KEYS: tuple[str, ...] = ("30s", "60s")
-
-
-def extract_choice_letter_official(text: str | None) -> str | None:
- """Port of Daily-Omni ``extract_choice_letter`` (first A–D, assistant-tail semantics)."""
- if not text:
- return None
- raw = str(text).strip()
- if not raw:
- return None
- match = re.search(r"assistant\s*([\s\S]*)$", raw, flags=re.IGNORECASE)
- candidate = match.group(1).strip() if match else raw
- direct = re.match(r"(?i)^\s*([A-D])(?:[\s\.\)::]|$)", candidate)
- if direct:
- return direct.group(1).upper()
- fallback = re.search(r"\b([A-D])\b", candidate.upper())
- if fallback:
- return fallback.group(1)
- # ``\b`` is ASCII/Latin-word-centric; CJK (e.g. "选B", "答案:B") has no boundary before B.
- loose = list(
- re.finditer(
- r"(?:[^A-Za-z]|^)([A-D])(?:[^A-Za-z]|$)",
- candidate,
- flags=re.IGNORECASE,
- )
- )
- if loose:
- return loose[-1].group(1).upper()
- return None
-
-
-def evaluate_answer_official(model_answer: str | None, correct_answer: str) -> bool:
- """Port of Daily-Omni ``evaluate_answer`` (strict string match after strip/upper)."""
- if not model_answer:
- return False
- return model_answer.strip().upper() == (correct_answer or "").strip().upper()
-
-
-def normalize_gold_answer(gold: str) -> str | None:
- """Best-effort single letter from ``Answer`` (for ``gold_normalized`` in saved items only)."""
- g = (gold or "").strip().upper()
- if len(g) == 1 and g in _VALID:
- return g
- m = re.search(r"([ABCD])\b", g)
- if m:
- return m.group(1).upper()
- return None
-
-
-def _extract_predicted_choice_relaxed(text: str) -> str | None:
- """Legacy vLLM-Omni heuristics (last ``answer:`` patterns, tail scan)."""
- if not text or not str(text).strip():
- return None
- t = str(text).strip()
-
- strong_patterns = [
- r"(?i)\*\*answer\*\*\s*[::]?\s*\(?([ABCD])\)?",
- r"(?i)\banswer\s*[::]?\s*\(?([ABCD])\)?",
- r"(?i)\bfinal\s+answer\s*[::]?\s*\(?([ABCD])\)?",
- r"(?i)\bcorrect\s+(?:answer|option)\s*[::]?\s*\(?([ABCD])\)?",
- r"(?i)\bthe\s+(?:correct\s+)?option\s+(?:is|would\s+be)\s*\(?([ABCD])\)?",
- r"(?i)\bI\s+(?:would\s+)?(?:choose|select|pick)\s*\(?([ABCD])\)?",
- ]
- last_letter: str | None = None
- for pat in strong_patterns:
- for m in re.finditer(pat, t):
- last_letter = m.group(1).upper()
- if last_letter:
- return last_letter
-
- # Weaker phrases: first match can be spurious; still prefer last occurrence.
- weak_patterns = [
- r"(?i)\boption\s*[::]?\s*\(?([ABCD])\)?",
- r"(?i)\bchoice\s*[::]?\s*\(?([ABCD])\)?",
- ]
- for pat in weak_patterns:
- for m in re.finditer(pat, t):
- last_letter = m.group(1).upper()
- if last_letter:
- return last_letter
-
- paren = list(re.finditer(r"\(([ABCD])\)", t))
- if paren:
- return paren[-1].group(1).upper()
-
- # First line sometimes is just "B" or "B." — allow if whole output is short
- one_line = t.split("\n", 1)[0].strip()
- if len(t) < 120 and len(one_line) <= 6:
- m0 = re.match(r"^([ABCD])\s*[.:\)]?\s*$", one_line, re.I)
- if m0:
- return m0.group(1).upper()
-
- # Tail-only: avoids matching echoed "A. ..." option blocks at the start
- tail_len = min(500, len(t))
- tail = t[-tail_len:]
- # ``\b`` after the letter avoids "Because"/"Definitely" false positives
- m = re.search(r"(?:^|[^\w])([ABCD])\b", tail, re.I)
- if m:
- return m.group(1).upper()
-
- return None
-
-
-def extract_predicted_choice(text: str | None) -> str | None:
- """Parse model output to A–D (official Daily-Omni rules by default)."""
- if not text or not str(text).strip():
- return None
- mode = os.environ.get("DAILY_OMNI_EXTRACT_MODE", "official").strip().lower()
- if mode in ("relaxed", "heuristic", "legacy"):
- return _extract_predicted_choice_relaxed(str(text))
- return extract_choice_letter_official(text)
-
-
-def compute_daily_omni_accuracy_metrics(
- input_requests: list[Any],
- outputs: list[RequestFuncOutput],
- *,
- include_per_item: bool = False,
-) -> dict[str, Any] | None:
- """If all requests are :class:`DailyOmniSampleRequest`, compute accuracy stats.
-
- Rows with empty ``Answer`` (after strip) are skipped as ``no_gold``, like upstream missing
- ``correct_answer``.
-
- **Denominators:** The open-source script excludes items that hit inference / I/O failures
- from ``valid_questions``; we mirror that with ``daily_omni_accuracy`` (= correct /
- successful responses). Failed HTTP requests are also tracked and used in
- ``daily_omni_accuracy_incl_http_fail`` (each failure counts as incorrect in the
- denominator).
- """
- if not input_requests or len(input_requests) != len(outputs):
- return None
- if not all(isinstance(r, DailyOmniSampleRequest) for r in input_requests):
- return None
-
- # total / correct: all rows with gold (incl. HTTP fail in total)
- # total_ok / correct_ok: successful HTTP only (GitHub-style per-type denominator)
- per_task: dict[str, dict[str, int]] = {}
- per_category: dict[str, dict[str, int]] = {}
- per_duration: dict[str, dict[str, int]] = {
- k: {"correct": 0, "total": 0, "correct_ok": 0, "total_ok": 0} for k in DAILY_OMNI_DURATION_KEYS
- }
- items: list[dict[str, Any]] = []
- correct = 0
- evaluated = 0
- no_gold = 0
- request_failed = 0
- parse_failed = 0 # success but could not extract A–D
-
- for req, out in zip(input_requests, outputs, strict=True):
- assert isinstance(req, DailyOmniSampleRequest)
- gold_raw = (req.daily_omni_gold_answer or "").strip()
- gold_norm = normalize_gold_answer(req.daily_omni_gold_answer)
- tt = (req.daily_omni_task_type or "unknown").strip() or "unknown"
- dur_key = (req.daily_omni_video_duration or "").strip()
- dur_active = dur_key in per_duration
- cat_key = (req.daily_omni_video_category or "").strip() or "unknown"
- if tt not in per_task:
- per_task[tt] = {"correct": 0, "total": 0, "correct_ok": 0, "total_ok": 0}
- if cat_key not in per_category:
- per_category[cat_key] = {"correct": 0, "total": 0, "correct_ok": 0, "total_ok": 0}
-
- if not gold_raw:
- no_gold += 1
- items.append(
- {
- "request_id": req.request_id,
- "skipped": True,
- "reason": "no_gold",
- "task_type": tt,
- "video_id": req.daily_omni_video_id,
- "video_duration": dur_key or None,
- "video_category": cat_key if cat_key != "unknown" else None,
- }
- )
- continue
-
- if not out.success:
- request_failed += 1
- evaluated += 1
- per_task[tt]["total"] += 1
- per_category[cat_key]["total"] += 1
- if dur_active:
- per_duration[dur_key]["total"] += 1
- # GitHub: failed inference not in valid_questions — do not increment total_ok
- items.append(
- {
- "request_id": req.request_id,
- "gold": gold_raw,
- "gold_normalized": gold_norm,
- "predicted": None,
- "correct": False,
- "task_type": tt,
- "video_id": req.daily_omni_video_id,
- "video_duration": dur_key or None,
- "video_category": cat_key if cat_key != "unknown" else None,
- "error": (out.error or "")[:500],
- }
- )
- continue
-
- pred = extract_predicted_choice(out.generated_text)
- evaluated += 1
- per_task[tt]["total"] += 1
- per_task[tt]["total_ok"] += 1
- per_category[cat_key]["total"] += 1
- per_category[cat_key]["total_ok"] += 1
- if dur_active:
- per_duration[dur_key]["total"] += 1
- per_duration[dur_key]["total_ok"] += 1
- if pred is None:
- parse_failed += 1
- is_correct = evaluate_answer_official(pred, req.daily_omni_gold_answer)
- if is_correct:
- correct += 1
- per_task[tt]["correct"] += 1
- per_task[tt]["correct_ok"] += 1
- per_category[cat_key]["correct"] += 1
- per_category[cat_key]["correct_ok"] += 1
- if dur_active:
- per_duration[dur_key]["correct"] += 1
- per_duration[dur_key]["correct_ok"] += 1
-
- items.append(
- {
- "request_id": req.request_id,
- "gold": gold_raw,
- "gold_normalized": gold_norm,
- "predicted": pred,
- "correct": is_correct,
- "parse_failed": pred is None,
- "task_type": tt,
- "video_id": req.daily_omni_video_id,
- "video_duration": dur_key or None,
- "video_category": cat_key if cat_key != "unknown" else None,
- }
- )
-
- evaluated_ok = evaluated - request_failed
- accuracy_github = (correct / evaluated_ok) if evaluated_ok else None
- accuracy_incl_fail = (correct / evaluated) if evaluated else None
-
- per_task_accuracy: dict[str, float | None] = {}
- per_task_accuracy_github: dict[str, float | None] = {}
- for name, st in per_task.items():
- tot = st["total"]
- per_task_accuracy[name] = (st["correct"] / tot) if tot else None
- tok = st["total_ok"]
- per_task_accuracy_github[name] = (st["correct_ok"] / tok) if tok else None
-
- per_category_accuracy: dict[str, float | None] = {}
- per_category_accuracy_github: dict[str, float | None] = {}
- for name, st in per_category.items():
- tot = st["total"]
- per_category_accuracy[name] = (st["correct"] / tot) if tot else None
- tok = st["total_ok"]
- per_category_accuracy_github[name] = (st["correct_ok"] / tok) if tok else None
-
- per_duration_accuracy: dict[str, float | None] = {}
- per_duration_accuracy_github: dict[str, float | None] = {}
- for name, st in per_duration.items():
- tot = st["total"]
- per_duration_accuracy[name] = (st["correct"] / tot) if tot else None
- tok = st["total_ok"]
- per_duration_accuracy_github[name] = (st["correct_ok"] / tok) if tok else None
-
- out: dict[str, Any] = {
- # Comparable to GitHub testmodel.py: correct / successful inferences
- "daily_omni_accuracy": accuracy_github,
- "daily_omni_accuracy_incl_http_fail": accuracy_incl_fail,
- "daily_omni_correct": correct,
- "daily_omni_evaluated": evaluated,
- "daily_omni_evaluated_ok": evaluated_ok,
- "daily_omni_no_gold": no_gold,
- "daily_omni_request_failed": request_failed,
- "daily_omni_parse_failed": parse_failed,
- "daily_omni_per_task": {k: dict(v) for k, v in per_task.items()},
- "daily_omni_per_task_accuracy": per_task_accuracy,
- "daily_omni_per_task_accuracy_github_style": per_task_accuracy_github,
- "daily_omni_per_category": {k: dict(v) for k, v in per_category.items()},
- "daily_omni_per_category_accuracy": per_category_accuracy,
- "daily_omni_per_category_accuracy_github_style": per_category_accuracy_github,
- "daily_omni_per_duration": {k: dict(v) for k, v in per_duration.items()},
- "daily_omni_per_duration_accuracy": per_duration_accuracy,
- "daily_omni_per_duration_accuracy_github_style": per_duration_accuracy_github,
- }
- if include_per_item:
- out["daily_omni_eval_items"] = items
- return out
-
-
-def print_daily_omni_accuracy_summary(metrics: dict[str, Any]) -> None:
- """Pretty-print accuracy block (stdout)."""
- acc = metrics.get("daily_omni_accuracy")
- acc_fail = metrics.get("daily_omni_accuracy_incl_http_fail")
- if acc is None and acc_fail is None and metrics.get("daily_omni_evaluated", 0) == 0:
- return
- print("{s:{c}^{n}}".format(s=" Daily-Omni accuracy (MCQ) ", n=50, c="="))
- ok = int(metrics.get("daily_omni_evaluated_ok", 0) or 0)
- cor = int(metrics.get("daily_omni_correct", 0) or 0)
- if ok > 0 and acc is not None:
- print(f"Overall Accuracy: {cor}/{ok} = {acc:.2%}")
- elif int(metrics.get("daily_omni_evaluated", 0) or 0) > 0:
- print("Overall Accuracy: 0/0 = N/A (no successful HTTP responses)")
- print(
- "{:<40} {:<10}".format(
- "Submitted (gold present):",
- metrics.get("daily_omni_evaluated", 0),
- )
- )
- print(
- "{:<40} {:<10}".format(
- "Successful HTTP (GitHub denom.):",
- metrics.get("daily_omni_evaluated_ok", 0),
- )
- )
- print("{:<40} {:<10}".format("Correct:", metrics.get("daily_omni_correct", 0)))
- if acc is not None:
- print("{:<40} {:<10.4f}".format("Accuracy (ratio, same as above):", acc))
- if acc_fail is not None and metrics.get("daily_omni_request_failed", 0):
- print(
- "{:<40} {:<10.4f}".format(
- "Accuracy (incl. HTTP as wrong):",
- acc_fail,
- )
- )
- print("{:<40} {:<10}".format("Skipped (no gold):", metrics.get("daily_omni_no_gold", 0)))
- print(
- "{:<40} {:<10}".format(
- "HTTP failed (excl. from GitHub acc.):",
- metrics.get("daily_omni_request_failed", 0),
- )
- )
- print(
- "{:<40} {:<10}".format(
- "Parsed OK but no A–D found:",
- metrics.get("daily_omni_parse_failed", 0),
- )
- )
- pt = metrics.get("daily_omni_per_task") or {}
- pta = metrics.get("daily_omni_per_task_accuracy_github_style") or {}
- if pta:
- print("\n--- Accuracy by QA Type ---")
- for name in sorted(pta.keys()):
- a = pta[name]
- st = pt.get(name) or {}
- tok = int(st.get("total_ok", 0) or 0)
- cok = int(st.get("correct_ok", 0) or 0)
- if tok and a is not None:
- print(f"{name}: {cok}/{tok} = {a:.2%}")
- else:
- print(f"{name}: 0/0 = N/A")
-
- pc = metrics.get("daily_omni_per_category") or {}
- ptc = metrics.get("daily_omni_per_category_accuracy_github_style") or {}
- if ptc:
- print("\n--- Accuracy by Video Category ---")
- for name in sorted(ptc.keys()):
- a = ptc[name]
- st = pc.get(name) or {}
- tok = int(st.get("total_ok", 0) or 0)
- cok = int(st.get("correct_ok", 0) or 0)
- if tok and a is not None:
- print(f"{name}: {cok}/{tok} = {a:.2%}")
- else:
- print(f"{name}: 0/0 = N/A")
-
- pdf = metrics.get("daily_omni_per_duration_accuracy_github_style") or {}
- if pdf:
- print("\n--- Accuracy by Video Duration ---")
- for name in DAILY_OMNI_DURATION_KEYS:
- a = pdf.get(name)
- st = (metrics.get("daily_omni_per_duration") or {}).get(name) or {}
- tok = int(st.get("total_ok", 0) or 0)
- cor = int(st.get("correct_ok", 0) or 0)
- if tok and a is not None:
- print(f"{name} Duration: {cor}/{tok} = {a:.2%}")
- else:
- print(f"{name} Duration: 0/0 = N/A")
- print("=" * 50)
diff --git a/vllm_omni/benchmarks/data_modules/daily_omni_text_audio.py b/vllm_omni/benchmarks/data_modules/daily_omni_text_audio.py
deleted file mode 100644
index 69fbe026bd8..00000000000
--- a/vllm_omni/benchmarks/data_modules/daily_omni_text_audio.py
+++ /dev/null
@@ -1,255 +0,0 @@
-"""Daily-Omni: optional consistency check between text stream and generated speech.
-
-The benchmark MCQ accuracy uses ``generated_text`` only. When the omni server also
-streams ``modality=audio`` (TTS), this module can transcribe the concatenated WAV
-with Whisper and compare the inferred option letter to the one parsed from text.
-
-Requires ``openai-whisper`` (``pip install openai-whisper``). Enable via env
-``DAILY_OMNI_TEXT_AUDIO_CONSISTENCY=1`` or CLI ``--daily-omni-text-audio-consistency``.
-
-Whisper model name defaults to ``tiny`` (override with ``DAILY_OMNI_WHISPER_MODEL``).
-"""
-
-from __future__ import annotations
-
-import logging
-import os
-import re
-import threading
-from typing import Any
-
-from vllm_omni.benchmarks.data_modules.daily_omni_dataset import DailyOmniSampleRequest
-from vllm_omni.benchmarks.data_modules.daily_omni_eval import extract_predicted_choice
-
-logger = logging.getLogger(__name__)
-
-_whisper_model = None
-_whisper_model_name: str | None = None
-_whisper_lock = threading.Lock()
-
-
-def env_text_audio_check_enabled() -> bool:
- return os.environ.get("DAILY_OMNI_TEXT_AUDIO_CONSISTENCY", "").lower() in (
- "1",
- "true",
- "yes",
- )
-
-
-def extract_choice_from_asr_transcript(transcript: str) -> str | None:
- """Parse A–D from ASR text; extends :func:`extract_predicted_choice` with spoken Chinese phrases."""
- c = extract_predicted_choice(transcript)
- if c:
- return c
- t = transcript or ""
- for pat in (
- r"(?i)选项\s*([ABCD])\b",
- r"(?i)选\s*([ABCD])\b",
- r"(?i)答案\s*是\s*([ABCD])\b",
- r"(?i)答案\s*([ABCD])\b",
- ):
- m = re.search(pat, t)
- if m:
- return m.group(1).upper()
- return None
-
-
-def _get_whisper_model(model_name: str):
- global _whisper_model, _whisper_model_name
- with _whisper_lock:
- if _whisper_model is None or _whisper_model_name != model_name:
- import whisper
-
- logger.warning(
- "Loading Whisper model %r for Daily-Omni text/audio consistency (one-time)...",
- model_name,
- )
- _whisper_model = whisper.load_model(model_name)
- _whisper_model_name = model_name
- return _whisper_model
-
-
-def transcribe_wav_bytes(
- wav_bytes: bytes,
- *,
- language: str | None = None,
- model_name: str | None = None,
-) -> tuple[str | None, str | None]:
- """Transcribe WAV bytes. Returns ``(transcript, error)`` — one of them is set.
-
- Args:
- wav_bytes: RIFF WAV file bytes.
- language: Optional Whisper language code (e.g. ``en``, ``zh``); improves accuracy/latency.
- model_name: Override model id; else ``DAILY_OMNI_WHISPER_MODEL`` or ``tiny``.
- """
- if not wav_bytes:
- return None, "empty_wav"
- if model_name is None or not str(model_name).strip():
- model_name = os.environ.get("DAILY_OMNI_WHISPER_MODEL") or "tiny"
- model_name = str(model_name).strip() or "tiny"
- path: str | None = None
- try:
- import tempfile
-
- model = _get_whisper_model(model_name)
- fd, path = tempfile.mkstemp(suffix=".wav")
- with os.fdopen(fd, "wb") as fp:
- fp.write(wav_bytes)
- kwargs: dict = {}
- if language:
- kwargs["language"] = language
- result = model.transcribe(path, **kwargs)
- text = (result.get("text") or "").strip()
- return (text if text else None), None
- except ImportError:
- return None, "openai-whisper is not installed (pip install openai-whisper)"
- except Exception as e:
- return None, str(e)[:500]
- finally:
- if path:
- try:
- os.unlink(path)
- except OSError:
- pass
-
-
-def compute_daily_omni_text_audio_consistency_metrics(
- input_requests: list[Any],
- outputs: list[Any],
- *,
- include_per_item: bool = False,
-) -> dict[str, Any] | None:
- """Compare option letter from ``generated_text`` vs Whisper transcript of output audio.
-
- Only considers requests where ``outputs[i]`` has ``generated_audio_wav_bytes`` set
- (populated by the omni benchmark when TA check is enabled).
- """
- if not input_requests or len(input_requests) != len(outputs):
- return None
- if not all(isinstance(r, DailyOmniSampleRequest) for r in input_requests):
- return None
-
- ta_no_wav = 0
- ta_asr_failed = 0
- ta_text_unparsed = 0
- ta_audio_unparsed = 0
- ta_consistent = 0
- ta_mismatch = 0
- ta_both_parsed = 0
- items: list[dict[str, Any]] = []
-
- for req, out in zip(input_requests, outputs, strict=True):
- assert isinstance(req, DailyOmniSampleRequest)
- rid = req.request_id
- if not getattr(out, "success", False):
- if include_per_item:
- items.append(
- {
- "request_id": rid,
- "skipped": True,
- "reason": "request_not_success",
- }
- )
- continue
-
- wav = getattr(out, "generated_audio_wav_bytes", None)
- if not wav:
- ta_no_wav += 1
- if include_per_item:
- items.append(
- {
- "request_id": rid,
- "skipped": False,
- "reason": "no_output_audio",
- "text_choice": extract_predicted_choice(getattr(out, "generated_text", "") or ""),
- }
- )
- continue
-
- transcript, asr_err = transcribe_wav_bytes(wav)
- if asr_err:
- ta_asr_failed += 1
- if include_per_item:
- items.append(
- {
- "request_id": rid,
- "asr_error": asr_err,
- "text_choice": extract_predicted_choice(getattr(out, "generated_text", "") or ""),
- }
- )
- continue
-
- text_choice = extract_predicted_choice(getattr(out, "generated_text", "") or "")
- audio_choice = extract_choice_from_asr_transcript(transcript or "")
-
- if text_choice is None:
- ta_text_unparsed += 1
- if audio_choice is None:
- ta_audio_unparsed += 1
-
- if text_choice is not None and audio_choice is not None:
- ta_both_parsed += 1
- if text_choice == audio_choice:
- ta_consistent += 1
- else:
- ta_mismatch += 1
-
- if include_per_item:
- consistent: bool | None
- if text_choice is None or audio_choice is None:
- consistent = None
- else:
- consistent = text_choice == audio_choice
- items.append(
- {
- "request_id": rid,
- "text_choice": text_choice,
- "audio_choice": audio_choice,
- "asr_transcript": (transcript or "")[:500],
- "text_audio_consistent": consistent,
- }
- )
-
- comparable = ta_consistent + ta_mismatch
- rate = (ta_consistent / comparable) if comparable else None
-
- out: dict[str, Any] = {
- "daily_omni_ta_enabled": True,
- "daily_omni_ta_no_output_audio": ta_no_wav,
- "daily_omni_ta_asr_failed": ta_asr_failed,
- "daily_omni_ta_text_unparsed": ta_text_unparsed,
- "daily_omni_ta_audio_unparsed": ta_audio_unparsed,
- "daily_omni_ta_both_parsed": ta_both_parsed,
- "daily_omni_ta_consistent": ta_consistent,
- "daily_omni_ta_mismatch": ta_mismatch,
- "daily_omni_ta_consistency_rate": rate,
- }
- if include_per_item:
- out["daily_omni_ta_items"] = items
- return out
-
-
-def print_daily_omni_text_audio_summary(metrics: dict[str, Any]) -> None:
- if not metrics.get("daily_omni_ta_enabled"):
- return
- print("{s:{c}^{n}}".format(s=" Daily-Omni text vs audio (ASR) ", n=50, c="="))
- print("{:<40} {:<10}".format("No output audio captured:", metrics.get("daily_omni_ta_no_output_audio", 0)))
- print("{:<40} {:<10}".format("ASR failed:", metrics.get("daily_omni_ta_asr_failed", 0)))
- print("{:<40} {:<10}".format("Both text+audio letter parsed:", metrics.get("daily_omni_ta_both_parsed", 0)))
- print("{:<40} {:<10}".format("Consistent (same letter):", metrics.get("daily_omni_ta_consistent", 0)))
- print("{:<40} {:<10}".format("Mismatch:", metrics.get("daily_omni_ta_mismatch", 0)))
- r = metrics.get("daily_omni_ta_consistency_rate")
- if r is not None:
- print("{:<40} {:<10.4f}".format("Consistency rate (of both parsed):", r))
- print(
- "{:<40} {:<10}".format(
- "Text unparsed (among w/ audio):",
- metrics.get("daily_omni_ta_text_unparsed", 0),
- )
- )
- print(
- "{:<40} {:<10}".format(
- "Audio unparsed (among w/ audio):",
- metrics.get("daily_omni_ta_audio_unparsed", 0),
- )
- )
diff --git a/vllm_omni/benchmarks/data_modules/seed_tts_dataset.py b/vllm_omni/benchmarks/data_modules/seed_tts_dataset.py
deleted file mode 100644
index 447495bffbc..00000000000
--- a/vllm_omni/benchmarks/data_modules/seed_tts_dataset.py
+++ /dev/null
@@ -1,481 +0,0 @@
-"""Seed-TTS zero-shot evaluation-style prompts for ``vllm bench serve``.
-
-Loads rows from the `meta.lst` format used in `BytedanceSpeech/seed-tts-eval`_ (or any
-HuggingFace dataset repo with the same layout)::
-
- utt_id|prompt_transcript|prompt_wav_relative_path|text_to_synthesize
-
-Each benchmark request supplies target text plus ``ref_text`` / ``ref_audio`` (Qwen3-TTS ``Base`` /
-voice clone), merged into the JSON body. By default ``ref_audio`` is an inline ``data:`` URL so
-the server does not need ``--allowed-local-media-path``. Use ``--seed-tts-file-ref-audio`` for
-``file://`` (smaller bodies; requires that flag). Use ``--backend openai-audio-speech``
-(``/v1/audio/speech``) or ``--backend openai-chat-omni`` (``/v1/chat/completions`` with the same
-fields on the body plus a Qwen3-Omni-style ``system`` message and the target text as ``user`` content).
-
-.. _BytedanceSpeech/seed-tts-eval: https://github.com/BytedanceSpeech/seed-tts-eval
-"""
-
-from __future__ import annotations
-
-import base64
-import logging
-import random
-from dataclasses import dataclass
-from pathlib import Path
-from typing import Any
-
-from vllm.benchmarks.datasets import BenchmarkDataset, SampleRequest
-from vllm.tokenizers import TokenizerLike
-from vllm.tokenizers.hf import get_cached_tokenizer
-
-logger = logging.getLogger(__name__)
-
-# Matches Qwen3-Omni serving examples (``openai_chat_completion_client_for_multimodal_generation`` /
-# ``qwen3_omni/gradio_demo``) plus explicit TTS / voice-clone instructions for chat completions.
-SEED_TTS_DEFAULT_OMNI_SYSTEM_PROMPT = (
- "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
- "capable of perceiving auditory and visual inputs, as well as generating text and speech.\n"
- "For this request you act as a text-to-speech engine with zero-shot voice cloning: "
- "the API provides reference audio and its transcript (ref_audio, ref_text) and task_type Base. "
- "The user message is the exact text you must speak. "
- "Synthesize natural speech in the same language as that user text, "
- "matching the timbre, prosody, and speaking style of the reference audio while reading the new content clearly."
-)
-
-
-@dataclass
-class SeedTTSSampleRequest(SampleRequest):
- """``SampleRequest`` with per-row fields merged into ``/v1/audio/speech`` JSON."""
-
- #: Shallow-merged into ``RequestFuncInput.extra_body`` (ref_audio, ref_text, task_type, …).
- seed_tts_speech_extra: dict[str, Any] | None = None
- seed_tts_utterance_id: str = ""
- seed_tts_locale: str = ""
- #: For ``openai-chat-omni``: becomes the chat ``system`` message (Qwen3-Omni + TTS behavior).
- seed_tts_system_prompt: str = ""
- #: Local path to reference prompt WAV (for SIM vs. synthesized PCM in ``seed_tts_eval``).
- seed_tts_ref_wav_path: str = ""
-
-
-@dataclass
-class _SeedTTSRow:
- utterance_id: str
- ref_text: str
- prompt_wav_rel: str
- target_text: str
-
-
-def _parse_meta_line(line: str) -> _SeedTTSRow | None:
- line = line.strip()
- if not line or line.startswith("#"):
- return None
- parts = line.split("|")
- if len(parts) < 4:
- logger.warning("Skipping malformed meta.lst line (need 4 '|'-fields): %r", line[:120])
- return None
- utt_id, ref_text, wav_rel, target = parts[0], parts[1], parts[2], parts[3]
- if not target.strip():
- return None
- return _SeedTTSRow(
- utterance_id=utt_id.strip(),
- ref_text=ref_text.strip(),
- prompt_wav_rel=wav_rel.strip(),
- target_text=target.strip(),
- )
-
-
-def _load_meta_rows(meta_file: Path) -> list[_SeedTTSRow]:
- text = meta_file.read_text(encoding="utf-8")
- rows: list[_SeedTTSRow] = []
- for line in text.splitlines():
- r = _parse_meta_line(line)
- if r is not None:
- rows.append(r)
- return rows
-
-
-def resolve_seed_tts_root(dataset_path: str | None, *, explicit_root: str | None, locale: str = "en") -> Path:
- """Return directory containing ``{locale}/meta.lst`` and ``{locale}/prompt-wavs/``."""
- if explicit_root:
- root = Path(explicit_root).expanduser().resolve()
- if not root.is_dir():
- raise FileNotFoundError(f"--seed-tts-root is not a directory: {root}")
- return root
-
- if not dataset_path:
- raise ValueError("Seed-TTS requires --dataset-path (HF repo id or local root) or --seed-tts-root.")
-
- p = Path(dataset_path).expanduser()
- if p.exists() and p.is_dir():
- return p.resolve()
-
- repo_id = dataset_path.strip()
- try:
- from huggingface_hub import snapshot_download
- except ImportError as e:
- raise ImportError(
- "Install huggingface_hub to download Seed-TTS from the Hub, or clone the dataset "
- "locally and pass --dataset-path / --seed-tts-root to that directory."
- ) from e
- # Download only the requested locale subtree instead of the whole dataset
- # repo. This avoids large, flaky nightly downloads when we only need e.g.
- # ``en/meta.lst`` + ``en/prompt-wavs/**``.
- cache = snapshot_download(
- repo_id=repo_id,
- repo_type="dataset",
- allow_patterns=[f"{locale}/**"],
- )
- return Path(cache).resolve()
-
-
-def _ref_audio_payload(wav_path: Path, *, inline: bool) -> str:
- if inline:
- raw = wav_path.read_bytes()
- b64 = base64.b64encode(raw).decode("ascii")
- return f"data:audio/wav;base64,{b64}"
- return wav_path.expanduser().resolve().as_uri()
-
-
-class SeedTTSDataset(BenchmarkDataset):
- """Seed-TTS-style zero-shot TTS rows for throughput/latency benchmarking.
-
- Args:
- dataset_path: HuggingFace dataset repo id (``org/dataset``) or local directory with
- ``en/meta.lst`` (and ``zh/meta.lst`` if using zh).
- locale: ``en`` or ``zh`` — which subfolder under the root to read.
- inline_ref_audio: If True (default), embed prompt WAV as ``data:audio/wav;base64,...``
- so Qwen3-TTS / ``/v1/audio/speech`` works without server
- ``--allowed-local-media-path``. If False, use ``file://`` (smaller
- requests; server must set ``--allowed-local-media-path`` to the dataset root).
- seed_tts_root: Optional override for the root directory (same layout as HF dataset).
- system_prompt: Optional override for the chat system message when using
- ``--backend openai-chat-omni``; defaults to :data:`SEED_TTS_DEFAULT_OMNI_SYSTEM_PROMPT`.
- """
-
- IS_MULTIMODAL = False
- DEFAULT_OUTPUT_LEN = 2048
-
- def __init__(
- self,
- dataset_path: str,
- random_seed: int = 0,
- locale: str = "en",
- inline_ref_audio: bool = True,
- seed_tts_root: str | None = None,
- system_prompt: str | None = None,
- disable_shuffle: bool = False,
- **kwargs: Any,
- ) -> None:
- if locale not in ("en", "zh"):
- raise ValueError("locale must be 'en' or 'zh'")
- self.locale = locale
- self.inline_ref_audio = inline_ref_audio
- self._explicit_root = seed_tts_root
- sp = (system_prompt or "").strip()
- self._system_prompt = sp if sp else SEED_TTS_DEFAULT_OMNI_SYSTEM_PROMPT
- super().__init__(
- dataset_path=dataset_path,
- random_seed=random_seed,
- disable_shuffle=disable_shuffle,
- **kwargs,
- )
- self._root = resolve_seed_tts_root(self.dataset_path, explicit_root=self._explicit_root, locale=self.locale)
- self._rows: list[_SeedTTSRow] = []
- self.load_data()
-
- def load_data(self) -> None:
- meta = self._root / self.locale / "meta.lst"
- if not meta.is_file():
- raise FileNotFoundError(
- f"Seed-TTS meta not found: {meta}. "
- f"Expected layout from seed-tts-eval (e.g. {self._root}/{self.locale}/meta.lst)."
- )
- self._rows = _load_meta_rows(meta)
- if not self._rows:
- raise ValueError(f"No valid rows in {meta}")
- if not self.disable_shuffle:
- rng = random.Random(self.random_seed)
- rng.shuffle(self._rows)
- self.data = self._rows
- logger.info(
- "Loaded Seed-TTS: root=%s locale=%s rows=%d inline_ref_audio=%s",
- self._root,
- self.locale,
- len(self._rows),
- self.inline_ref_audio,
- )
-
- def sample(
- self,
- tokenizer: TokenizerLike,
- num_requests: int,
- output_len: int | None = None,
- request_id_prefix: str = "",
- no_oversample: bool = False,
- **kwargs: Any,
- ) -> list[SampleRequest]:
- if output_len is None:
- output_len = self.DEFAULT_OUTPUT_LEN
-
- tok = get_cached_tokenizer(tokenizer)
- out: list[SampleRequest] = []
- for i, row in enumerate(self._rows):
- if len(out) >= num_requests:
- break
- wav_path = (self._root / self.locale / row.prompt_wav_rel).resolve()
- if not wav_path.is_file():
- logger.warning("Missing prompt wav for %s: %s", row.utterance_id, wav_path)
- continue
-
- target = row.target_text
- prompt_len = len(tok.encode(f"{self._system_prompt}\n{target}"))
- lang = "English" if self.locale == "en" else "Chinese"
- ref_uri = _ref_audio_payload(wav_path, inline=self.inline_ref_audio)
- speech_extra: dict[str, Any] = {
- "ref_audio": ref_uri,
- "ref_text": row.ref_text,
- "task_type": "Base",
- "language": lang,
- "max_new_tokens": output_len,
- }
-
- out.append(
- SeedTTSSampleRequest(
- prompt=target,
- prompt_len=prompt_len,
- expected_output_len=output_len,
- multi_modal_data=None,
- request_id=f"{request_id_prefix}{i}",
- seed_tts_speech_extra=speech_extra,
- seed_tts_utterance_id=row.utterance_id,
- seed_tts_locale=self.locale,
- seed_tts_system_prompt=self._system_prompt,
- seed_tts_ref_wav_path=str(wav_path),
- )
- )
-
- logger.info("Seed-TTS: built %d requests (asked %d)", len(out), num_requests)
- self.maybe_oversample_requests(out, num_requests, request_id_prefix, no_oversample)
- return out
-
-
-@dataclass
-class _SeedTTSDesignRow:
- utterance_id: str
- target_text: str
- voice_description: str
-
-
-def _parse_design_meta_line(line: str) -> _SeedTTSDesignRow | None:
- """Parse a 5-field design meta.lst line.
-
- Format: ``utt_id|ref_text|wav_rel|target_text|voice_description``
-
- Returns None (with a warning) if the line has fewer than 5 fields or if
- voice_description is empty.
- """
- line = line.strip()
- if not line or line.startswith("#"):
- return None
- parts = line.split("|")
- if len(parts) < 5:
- logger.warning("Skipping malformed design meta.lst line (need 5 '|'-fields): %r", line[:120])
- return None
- utt_id = parts[0].strip()
- target_text = parts[3].strip()
- voice_description = parts[4].strip()
- if not voice_description:
- logger.warning("Skipping design meta.lst line with empty voice_description: %r", line[:120])
- return None
- return _SeedTTSDesignRow(
- utterance_id=utt_id,
- target_text=target_text,
- voice_description=voice_description,
- )
-
-
-@dataclass
-class SeedTTSDesignSampleRequest(SeedTTSSampleRequest):
- """SampleRequest for voice-design TTS (no ref_audio; voice described via natural language).
-
- The ``seed_tts_speech_extra`` dict carries ``instructions`` (natural-language
- voice description, forwarded as-is to the Qwen3-TTS VoiceDesign endpoint) and
- ``task_type="VoiceDesign"`` instead of ``ref_audio`` / ``ref_text``.
- SIM is skipped (``seed_tts_ref_wav_path`` is empty).
- """
-
-
-class SeedTTSDesignDataset(SeedTTSDataset):
- """Seed-TTS prompts for voice-design benchmarking (dataset name: ``seed-tts-design``).
-
- Loads a 5-field ``meta.lst``::
-
- utt_id|ref_text|wav_rel|target_text|voice_description
-
- and builds requests with ``task_type="VoiceDesign"`` and the natural-language
- ``voice_description`` column forwarded via the ``instructions`` field
- (the Qwen3-TTS VoiceDesign endpoint's expected key) instead of
- ``ref_audio`` / ``ref_text``. Speaker-similarity (SIM) is not computed.
- """
-
- def load_data(self) -> None:
- # Does NOT call super().load_data() — the format is different (5 fields,
- # no wav file). self._rows is intentionally left empty; the parent
- # sample() is fully overridden so an empty self._rows is safe.
- meta = self._root / self.locale / "meta.lst"
- if not meta.is_file():
- raise FileNotFoundError(
- f"Seed-TTS-Design meta not found: {meta}. Expected layout: {self._root}/{self.locale}/meta.lst"
- )
- text = meta.read_text(encoding="utf-8")
- design_rows: list[_SeedTTSDesignRow] = []
- for line in text.splitlines():
- r = _parse_design_meta_line(line)
- if r is not None:
- design_rows.append(r)
- if not design_rows:
- raise ValueError(f"No valid rows in {meta}")
- if not self.disable_shuffle:
- rng = random.Random(self.random_seed)
- rng.shuffle(design_rows)
- self._design_rows = design_rows
- # Keep self._rows empty — parent sample() is overridden.
- self._rows = []
- self.data = self._design_rows
- logger.info(
- "Loaded Seed-TTS-Design: root=%s locale=%s rows=%d",
- self._root,
- self.locale,
- len(self._design_rows),
- )
-
- def sample(
- self,
- tokenizer: TokenizerLike,
- num_requests: int,
- output_len: int | None = None,
- request_id_prefix: str = "",
- no_oversample: bool = False,
- **kwargs: Any,
- ) -> list[SampleRequest]:
- if output_len is None:
- output_len = self.DEFAULT_OUTPUT_LEN
-
- tok = get_cached_tokenizer(tokenizer)
- lang = "English" if self.locale == "en" else "Chinese"
- out: list[SampleRequest] = []
- for i, row in enumerate(self._design_rows):
- if len(out) >= num_requests:
- break
- target = row.target_text
- prompt_len = len(tok.encode(target))
- speech_extra: dict[str, Any] = {
- "instructions": row.voice_description,
- "task_type": "VoiceDesign",
- "language": lang,
- "max_new_tokens": output_len,
- }
- out.append(
- SeedTTSDesignSampleRequest(
- prompt=target,
- prompt_len=prompt_len,
- expected_output_len=output_len,
- multi_modal_data=None,
- request_id=f"{request_id_prefix}{i}",
- seed_tts_speech_extra=speech_extra,
- seed_tts_utterance_id=row.utterance_id,
- seed_tts_locale=self.locale,
- seed_tts_system_prompt=self._system_prompt,
- seed_tts_ref_wav_path="", # SIM skipped for voice-design
- )
- )
-
- logger.info(
- "Seed-TTS-Design: built %d requests (asked %d) — no ref_audio (voice design)",
- len(out),
- num_requests,
- )
- self.maybe_oversample_requests(out, num_requests, request_id_prefix, no_oversample)
- return out
-
-
-def load_seed_tts_dataset(
- dataset_path: str,
- random_seed: int = 0,
- locale: str = "en",
- inline_ref_audio: bool = True,
- seed_tts_root: str | None = None,
- system_prompt: str | None = None,
- **kwargs: Any,
-) -> SeedTTSDataset:
- return SeedTTSDataset(
- dataset_path=dataset_path,
- random_seed=random_seed,
- locale=locale,
- inline_ref_audio=inline_ref_audio,
- seed_tts_root=seed_tts_root,
- system_prompt=system_prompt,
- **kwargs,
- )
-
-
-@dataclass
-class SeedTTSTextSampleRequest(SeedTTSSampleRequest):
- """SampleRequest for default-voice TTS (no ref_audio, no ref_text).
-
- The voice param (e.g. ``voice: "Vivian"``) is supplied at request time via
- ``--extra-body`` in the benchmark config. SIM is skipped (empty ref_wav_path).
- WER and UTMOS are computed normally.
- """
-
-
-class SeedTTSTextDataset(SeedTTSDataset):
- """Seed-TTS prompts for default-voice benchmarking (dataset name: ``seed-tts-text``).
-
- Loads the same ``meta.lst`` as :class:`SeedTTSDataset` but builds requests
- WITHOUT ``ref_audio`` / ``ref_text`` body fields. The named voice must be
- supplied via ``--extra-body`` in the benchmark config.
- Speaker-similarity (SIM) is not computed.
- """
-
- def sample(
- self,
- tokenizer: TokenizerLike,
- num_requests: int,
- output_len: int | None = None,
- request_id_prefix: str = "",
- no_oversample: bool = False,
- **kwargs: Any,
- ) -> list[SampleRequest]:
- if output_len is None:
- output_len = self.DEFAULT_OUTPUT_LEN
-
- tok = get_cached_tokenizer(tokenizer)
- out: list[SampleRequest] = []
- for i, row in enumerate(self._rows):
- if len(out) >= num_requests:
- break
- target = row.target_text
- prompt_len = len(tok.encode(target))
- out.append(
- SeedTTSTextSampleRequest(
- prompt=target,
- prompt_len=prompt_len,
- expected_output_len=output_len,
- multi_modal_data=None,
- request_id=f"{request_id_prefix}{i}",
- seed_tts_speech_extra=None, # voice supplied via --extra-body in config
- seed_tts_utterance_id=row.utterance_id,
- seed_tts_locale=self.locale,
- seed_tts_system_prompt=self._system_prompt,
- seed_tts_ref_wav_path="", # empty → SIM skipped in seed_tts_eval
- )
- )
-
- logger.info(
- "Seed-TTS-Text: built %d requests (asked %d) — no ref_audio (default voice)",
- len(out),
- num_requests,
- )
- self.maybe_oversample_requests(out, num_requests, request_id_prefix, no_oversample)
- return out
diff --git a/vllm_omni/benchmarks/data_modules/seed_tts_eval.py b/vllm_omni/benchmarks/data_modules/seed_tts_eval.py
deleted file mode 100644
index d8c37af1300..00000000000
--- a/vllm_omni/benchmarks/data_modules/seed_tts_eval.py
+++ /dev/null
@@ -1,729 +0,0 @@
-"""Seed-TTS WER aligned with Bytedance ``seed-tts-eval`` / ``run_wer.py``.
-
-Matches the published protocol (see Hugging Face dataset card and
-https://github.com/zhaochenyang20/seed-tts-eval):
-
-- **EN**: ``openai/whisper-large-v3`` via ``transformers``, audio resampled to **16 kHz**
- (same as ``run_wer.py``).
-- **ZH**: ``funasr`` **paraformer-zh**, hypothesis converted with **zhconv** to zh-cn.
-- **WER**: ``jiwer`` after punctuation stripping (``zhon.hanzi.punctuation`` + ``string.punctuation``,
- preserving ``'``) and EN lowercasing / ZH per-character spacing. Supports jiwer 3.x
- (``compute_measures``) and 4.x (``process_words``).
-
-- **SIM** (speaker similarity proxy): cosine similarity of L2-normalized mean-pooled **WavLM**
- embeddings (reference prompt WAV vs. synthesized PCM), 16 kHz. Official ``cal_sim.sh`` uses
- UniSpeech ``verification_pair_list_v2.py`` with a **fine-tuned** WavLM SV checkpoint — set
- ``SEED_TTS_WAVLM_MODEL`` to another HF id if you need closer parity. Disable with
- ``SEED_TTS_SIM_EVAL=0``. Optional: ``SEED_TTS_SIM_DEVICE`` (e.g. ``cpu``) to avoid GPU
- issues when Whisper already uses CUDA; ``SEED_TTS_WAVLM_MIN_SAMPLES`` pads very short
- waveforms so the WavLM CNN front-end does not fail.
-
-- **UTMOS** (predicted MOS from TorchScript): default ``balacoon/utmos`` → ``utmos.jit``
- (Sarulab-style demo export). Uses ``torch`` + ``huggingface_hub`` only. Aggregate metrics
- are over **all requests with captured PCM** (independent of ASR/WER). Non-finite scores are
- dropped and counted as failures. Override repo/file via ``SEED_TTS_UTMOS_HF_REPO`` /
- ``SEED_TTS_UTMOS_JIT_FILE``. **Device**: defaults to **CPU** when ``SEED_TTS_UTMOS_DEVICE``
- is unset; set ``SEED_TTS_UTMOS_DEVICE=cuda:0`` (or ``cuda:1`` etc.) to run on GPU. The JIT
- model is loaded directly onto the target device via ``map_location`` to avoid cross-device
- issues (some PyTorch builds/Windows have problems moving TorchScript modules after load).
- Forward uses **float32** waveform in ``[-1, 1]`` (same as the WER resampled array) so
- tensor dtypes match JIT weights; using int16 triggers
- ``RuntimeError: input type and weight type should be same`` on common exports. Disable
- with ``SEED_TTS_UTMOS_EVAL=0``.
-
-Enable with ``SEED_TTS_WER_EVAL=1`` or ``--seed-tts-wer-eval``. Install optional deps::
-
- pip install 'vllm-omni[seed-tts-eval]'
-
-Env: ``SEED_TTS_EVAL_DEVICE`` (e.g. ``cuda:0``, ``cpu``); ``SEED_TTS_HF_WHISPER_MODEL``
-defaults to ``openai/whisper-large-v3`` (override for debugging only).
-"""
-
-from __future__ import annotations
-
-import io
-import logging
-import math
-import os
-import statistics
-import string
-import tempfile
-import threading
-import wave
-from typing import Any
-
-import numpy as np
-from vllm.benchmarks.datasets import SampleRequest
-
-from vllm_omni.benchmarks.data_modules.seed_tts_dataset import SeedTTSSampleRequest
-
-logger = logging.getLogger(__name__)
-
-# Mirrors seed-tts-eval/run_wer.py
-OFFICIAL_WHISPER_HF_ID = "openai/whisper-large-v3"
-PARAFORMER_MODEL_ID = "paraformer-zh"
-
-_lock = threading.Lock()
-_device: str | None = None
-_en_processor = None
-_en_model = None
-_zh_paraformer = None
-_wavlm_model = None
-_wavlm_processor = None
-_wavlm_device: str | None = None
-_utmos_jit_model = None
-_utmos_jit_device: str | None = None
-_utmos_jit_load_failed = False
-_utmos_forward_warned = False
-
-
-def pcm_s16le_mono_to_wav_bytes(pcm: bytes, *, sample_rate: int = 24000) -> bytes:
- buf = io.BytesIO()
- with wave.open(buf, "wb") as wf:
- wf.setnchannels(1)
- wf.setsampwidth(2)
- wf.setframerate(sample_rate)
- wf.writeframes(pcm)
- return buf.getvalue()
-
-
-def _get_eval_device() -> str:
- explicit = os.environ.get("SEED_TTS_EVAL_DEVICE", "").strip()
- if explicit:
- return explicit
- try:
- import torch
-
- return "cuda:0" if torch.cuda.is_available() else "cpu"
- except ImportError:
- return "cpu"
-
-
-def _punctuation_all() -> str:
- from zhon.hanzi import punctuation
-
- return punctuation + string.punctuation
-
-
-def _jiwer_wer(reference: str, hypothesis: str) -> float:
- """Word-level WER; strings are normalized like ``run_wer.process_one``.
-
- jiwer 4.x removed ``compute_measures`` (``ImportError``); fall back to ``process_words``.
- """
- try:
- from jiwer import compute_measures
-
- return float(compute_measures(reference, hypothesis)["wer"])
- except ImportError:
- import jiwer
-
- out = jiwer.process_words(reference, hypothesis)
- return float(out.wer)
-
-
-def process_one_official(hypo: str, truth: str, lang: str) -> tuple[float, str, str]:
- """Same normalization + ``jiwer`` call as ``run_wer.process_one`` (hypo=ASR, truth=reference)."""
- raw_truth = truth
- raw_hypo = hypo
- truth_n = truth
- hypo_n = hypo
- for x in _punctuation_all():
- if x == "'":
- continue
- truth_n = truth_n.replace(x, "")
- hypo_n = hypo_n.replace(x, "")
- truth_n = truth_n.replace(" ", " ")
- hypo_n = hypo_n.replace(" ", " ")
- if lang == "zh":
- truth_n = " ".join([x for x in truth_n])
- hypo_n = " ".join([x for x in hypo_n])
- elif lang == "en":
- truth_n = truth_n.lower()
- hypo_n = hypo_n.lower()
- else:
- raise ValueError(f"unsupported lang {lang!r}")
- wer = _jiwer_wer(truth_n, hypo_n)
- return wer, raw_truth, raw_hypo
-
-
-def _pcm_s16le_to_f32_16k(pcm: bytes, pcm_sample_rate: int = 24000) -> np.ndarray:
- import scipy.signal
-
- if not pcm:
- return np.zeros(0, dtype=np.float32)
- raw = np.frombuffer(pcm, dtype=np.int16).astype(np.float32) / 32768.0
- target_len = int(len(raw) * 16000 / pcm_sample_rate)
- if target_len <= 0:
- return np.zeros(0, dtype=np.float32)
- return scipy.signal.resample(raw, target_len).astype(np.float32)
-
-
-def _eval_submetric_enabled(env_name: str, *, default: bool = True) -> bool:
- raw = os.environ.get(env_name, "").strip().lower()
- if raw in ("0", "false", "no", "off"):
- return False
- if raw in ("1", "true", "yes", "on"):
- return True
- return default
-
-
-def _audio_path_to_f32_16k(path: str) -> np.ndarray:
- import scipy.signal
- import soundfile as sf
-
- data, sr = sf.read(path, dtype="float32", always_2d=True)
- mono = np.mean(data, axis=1).astype(np.float32)
- if int(sr) == 16000:
- return mono
- target_len = max(1, int(len(mono) * 16000 / int(sr)))
- return scipy.signal.resample(mono, target_len).astype(np.float32)
-
-
-def _ensure_wavlm_sim() -> None:
- global _wavlm_model, _wavlm_processor, _wavlm_device
- with _lock:
- if _wavlm_model is not None:
- return
- from transformers import AutoFeatureExtractor, AutoModel
-
- mid = os.environ.get("SEED_TTS_WAVLM_MODEL", "microsoft/wavlm-base-plus").strip() or "microsoft/wavlm-base-plus"
- _wavlm_device = os.environ.get("SEED_TTS_SIM_DEVICE", "").strip() or _get_eval_device()
- logger.warning(
- "Loading WavLM %r on %s for Seed-TTS SIM (embedding cosine; not identical to "
- "seed-tts-eval UniSpeech SV checkpoint).",
- mid,
- _wavlm_device,
- )
- _wavlm_processor = AutoFeatureExtractor.from_pretrained(mid)
- _wavlm_model = AutoModel.from_pretrained(mid).to(_wavlm_device)
- _wavlm_model.eval()
-
-
-def _wavlm_prepare_waveform(wav: np.ndarray) -> np.ndarray:
- """Trim, pad to a minimum length WavLM/Wav2Vec2 CNN stack accepts, float32 mono."""
- max_sec = float(os.environ.get("SEED_TTS_WAVLM_MAX_SECONDS", "30"))
- cap = int(max_sec * 16000)
- w = np.asarray(wav, dtype=np.float32).reshape(-1)
- if len(w) == 0:
- return w
- if len(w) > cap:
- w = w[:cap].copy()
- # Very short clips make the strided conv front-end fail (shape / empty time dim).
- min_samples = int(os.environ.get("SEED_TTS_WAVLM_MIN_SAMPLES", "4000"))
- if len(w) < min_samples:
- w = np.pad(w, (0, min_samples - len(w)), mode="constant")
- return w
-
-
-def _wavlm_mean_embedding_f32_16k(wav: np.ndarray) -> np.ndarray | None:
- import torch
-
- _ensure_wavlm_sim()
- w = _wavlm_prepare_waveform(wav)
- if len(w) == 0:
- return None
- assert _wavlm_processor is not None and _wavlm_model is not None and _wavlm_device is not None
- # Single utterance: avoid padding=True (adds zeros that distort mean pooling). Still pass
- # attention_mask when the extractor provides it (sample-level; do not mix with hidden length).
- try:
- inputs = _wavlm_processor(
- w,
- sampling_rate=16000,
- return_tensors="pt",
- padding=False,
- return_attention_mask=True,
- )
- except TypeError:
- inputs = _wavlm_processor(
- w,
- sampling_rate=16000,
- return_tensors="pt",
- padding=False,
- )
- iv = inputs["input_values"].to(_wavlm_device)
- am = inputs.get("attention_mask")
- if am is not None:
- am = am.to(_wavlm_device)
- with torch.inference_mode():
- out = _wavlm_model(iv, attention_mask=am)
- h = out.last_hidden_state
- v = h.mean(dim=1).squeeze(0).float().cpu().numpy()
- n = float(np.linalg.norm(v))
- if not np.isfinite(n) or n < 1e-8:
- return None
- return (v / n).astype(np.float32)
-
-
-def _cosine_similarity_unit_vectors(a: np.ndarray, b: np.ndarray) -> float:
- return float(np.dot(a, b))
-
-
-def _ensure_utmos_jit_model() -> Any | None:
- """Load UTMOS as TorchScript (``balacoon/utmos`` style): no ``import utmos`` / fairseq."""
- global _utmos_jit_model, _utmos_jit_device, _utmos_jit_load_failed
- with _lock:
- if _utmos_jit_load_failed:
- return None
- if _utmos_jit_model is not None:
- return _utmos_jit_model
- try:
- import torch
- from huggingface_hub import hf_hub_download
-
- repo = os.environ.get("SEED_TTS_UTMOS_HF_REPO", "balacoon/utmos").strip() or "balacoon/utmos"
- fname = os.environ.get("SEED_TTS_UTMOS_JIT_FILE", "utmos.jit").strip() or "utmos.jit"
- logger.warning(
- "Loading UTMOS TorchScript from Hugging Face %r file %r (one-time download/cache)...",
- repo,
- fname,
- )
- path = hf_hub_download(repo_id=repo, filename=fname, repo_type="model")
-
- # TODO The model weights in UTMOS must be loaded in cuda:0; otherwise, the model execution will fail.
- want = "cuda:0"
- if want.startswith("cuda") and torch.cuda.is_available():
- idx = want.split(":")[-1] if ":" in want else "0"
- target_dev = f"cuda:{idx}"
- else:
- target_dev = "cpu"
-
- try:
- m = torch.jit.load(path, map_location=target_dev)
- m.eval()
- _utmos_jit_device = target_dev
- except Exception as load_e:
- if target_dev.startswith("cuda"):
- logger.warning(
- "UTMOS JIT load on %s failed (%s), retrying on CPU...",
- target_dev,
- load_e,
- )
- m = torch.jit.load(path, map_location="cpu")
- m.eval()
- _utmos_jit_device = "cpu"
- else:
- raise
- _utmos_jit_model = m
- except Exception as e:
- logger.warning(
- "UTMOS JIT unavailable (install torch + huggingface_hub; check HF access): %s",
- e,
- )
- _utmos_jit_load_failed = True
- return None
- return _utmos_jit_model
-
-
-def _utmos_predict_f32_16k(wav_f32: np.ndarray) -> float | None:
- """MOS from JIT model; input is float32 mono @ 16 kHz in ``[-1, 1]`` (WER pipeline).
-
- ``balacoon/utmos`` demos sometimes use int16 numpy, but the exported ``.jit`` weights are
- float32; passing int16 tensors causes: "RuntimeError: ... input type and weight type
- should be same".
- """
- import torch
-
- if len(wav_f32) == 0:
- return None
- model = _ensure_utmos_jit_model()
- if model is None:
- return None
- # Infer model's device from its first parameter/buffer to guarantee input sits with weights.
- try:
- model_dev = next(model.parameters()).device
- except StopIteration:
- try:
- model_dev = next(model.buffers()).device
- except StopIteration:
- model_dev = torch.device("cpu")
- w = np.ascontiguousarray(wav_f32, dtype=np.float32)
- x = torch.from_numpy(w).unsqueeze(0).to(device=model_dev, dtype=torch.float32)
- with torch.no_grad():
- out = model(x)
- val = float(out.reshape(-1)[0].item())
- if not math.isfinite(val):
- return None
- return val
-
-
-def _ensure_en_asr() -> None:
- global _en_processor, _en_model, _device
- with _lock:
- if _en_processor is not None:
- return
- from transformers import WhisperForConditionalGeneration, WhisperProcessor
-
- _device = _get_eval_device()
- mid = os.environ.get("SEED_TTS_HF_WHISPER_MODEL", OFFICIAL_WHISPER_HF_ID).strip() or OFFICIAL_WHISPER_HF_ID
- logger.warning(
- "Loading Seed-TTS eval Whisper HF model %r on %s (one-time, seed-tts-eval protocol)...",
- mid,
- _device,
- )
- _en_processor = WhisperProcessor.from_pretrained(mid)
- _en_model = WhisperForConditionalGeneration.from_pretrained(mid).to(_device)
- _en_model.eval()
-
-
-def _ensure_zh_asr() -> None:
- global _zh_paraformer, _device
- with _lock:
- if _zh_paraformer is not None:
- return
- from funasr import AutoModel
-
- _device = _get_eval_device()
- logger.warning(
- "Loading Seed-TTS eval Paraformer %r on %s (one-time, seed-tts-eval protocol)...",
- PARAFORMER_MODEL_ID,
- _device,
- )
- try:
- _zh_paraformer = AutoModel(model=PARAFORMER_MODEL_ID, device=_device)
- except TypeError:
- _zh_paraformer = AutoModel(model=PARAFORMER_MODEL_ID)
-
-
-def _transcribe_en_f32_16k(wav_f32: np.ndarray) -> str:
- import torch
-
- _ensure_en_asr()
- if len(wav_f32) == 0:
- return ""
- with _lock:
- assert _en_processor is not None and _en_model is not None and _device is not None
- inputs = _en_processor(wav_f32, sampling_rate=16000, return_tensors="pt")
- input_features = inputs.input_features.to(_device)
- with torch.no_grad():
- try:
- forced = _en_processor.get_decoder_prompt_ids(language="english", task="transcribe")
- predicted_ids = _en_model.generate(input_features, forced_decoder_ids=forced)
- except Exception:
- predicted_ids = _en_model.generate(
- input_features,
- language="english",
- task="transcribe",
- )
- text = _en_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
- return (text or "").strip()
-
-
-def _transcribe_zh_wav_path(wav_path: str) -> str:
- import zhconv
-
- _ensure_zh_asr()
- with _lock:
- assert _zh_paraformer is not None
- res = _zh_paraformer.generate(input=wav_path, batch_size_s=300)
- transcription = res[0]["text"] if res else ""
- return zhconv.convert(transcription, "zh-cn").strip()
-
-
-def _missing_deps_message(lang: str) -> str | None:
- try:
- import jiwer # noqa: F401
- from zhon.hanzi import punctuation # noqa: F401
- except ImportError as e:
- return f"Seed-TTS WER eval needs jiwer and zhon ({e!s}). Install: pip install 'vllm-omni[seed-tts-eval]'"
- try:
- import scipy.signal # noqa: F401
- import soundfile # noqa: F401
- except ImportError as e:
- return f"Seed-TTS WER eval needs scipy and soundfile ({e!s})."
- if lang == "en":
- try:
- import torch # noqa: F401
- from transformers import WhisperForConditionalGeneration # noqa: F401
- except ImportError as e:
- return f"English WER needs torch and transformers ({e!s}). Install: pip install 'vllm-omni[seed-tts-eval]'"
- else:
- try:
- import zhconv # noqa: F401
- from funasr import AutoModel # noqa: F401
- except ImportError as e:
- return f"Chinese WER needs funasr and zhconv ({e!s}). Install: pip install 'vllm-omni[seed-tts-eval]'"
- return None
-
-
-def compute_seed_tts_wer_metrics(
- input_requests: list[SampleRequest],
- outputs: list[Any],
- *,
- include_per_item: bool = False,
-) -> dict[str, Any] | None:
- """If all requests are :class:`SeedTTSSampleRequest`, run seed-tts-eval-style WER."""
- global _utmos_forward_warned
- if not input_requests or len(input_requests) != len(outputs):
- return None
- if not all(isinstance(r, SeedTTSSampleRequest) for r in input_requests):
- return None
-
- first = input_requests[0]
- assert isinstance(first, SeedTTSSampleRequest)
- lang = "zh" if (first.seed_tts_locale or "en").lower().startswith("zh") else "en"
-
- setup_err = _missing_deps_message(lang)
- if setup_err:
- logger.error("%s", setup_err)
- return {
- "seed_tts_eval_setup_error": setup_err,
- "seed_tts_eval_protocol": "seed-tts-eval",
- "seed_tts_content_evaluated": 0,
- "seed_tts_content_error_mean": None,
- "seed_tts_content_error_median": None,
- "seed_tts_request_failed": 0,
- "seed_tts_no_pcm": 0,
- "seed_tts_asr_failed": 0,
- "seed_tts_content_metric": "wer",
- }
-
- import soundfile as sf
-
- errs: list[float] = []
- items: list[dict[str, Any]] = []
- asr_failed = 0
- no_pcm = 0
- request_failed = 0
- sim_values: list[float] = []
- utmos_values: list[float] = []
- sim_failed = 0
- sim_skipped_no_ref = 0
- utmos_failed = 0
- utmos_on = _eval_submetric_enabled("SEED_TTS_UTMOS_EVAL", default=False)
-
- for req, out in zip(input_requests, outputs, strict=True):
- assert isinstance(req, SeedTTSSampleRequest)
- ref = req.prompt
- locale = req.seed_tts_locale or "en"
- row_lang = "zh" if locale.lower().startswith("zh") else "en"
- utmos_v: float | None = None
-
- if not out.success:
- request_failed += 1
- if include_per_item:
- items.append(
- {
- "utterance_id": req.seed_tts_utterance_id,
- "locale": locale,
- "error": "request_failed",
- "detail": (out.error or "")[:500],
- }
- )
- continue
-
- pcm = getattr(out, "tts_output_pcm_bytes", None)
- if not pcm:
- no_pcm += 1
- if include_per_item:
- items.append(
- {
- "utterance_id": req.seed_tts_utterance_id,
- "locale": locale,
- "error": "no_pcm",
- }
- )
- continue
-
- wav_16k = _pcm_s16le_to_f32_16k(pcm)
- if len(wav_16k) == 0:
- asr_failed += 1
- if include_per_item:
- items.append(
- {
- "utterance_id": req.seed_tts_utterance_id,
- "locale": locale,
- "error": "empty_audio",
- }
- )
- continue
-
- # UTMOS scores synthesized audio only; do not gate on ASR/WER (those can fail independently).
- if utmos_on:
- try:
- utmos_v = _utmos_predict_f32_16k(wav_16k)
- if utmos_v is not None:
- utmos_values.append(utmos_v)
- elif not _utmos_jit_load_failed:
- utmos_failed += 1
- except Exception:
- if not _utmos_forward_warned:
- _utmos_forward_warned = True
- logger.warning(
- "UTMOS JIT forward failed (first utterance=%s; set logging DEBUG for "
- "full trace). Check sample rate (16 kHz), input shape, or "
- "SEED_TTS_UTMOS_DEVICE.",
- req.seed_tts_utterance_id,
- exc_info=True,
- )
- else:
- logger.debug(
- "UTMOS forward failed for %s",
- req.seed_tts_utterance_id,
- exc_info=True,
- )
- utmos_failed += 1
-
- try:
- if row_lang == "en":
- hyp = _transcribe_en_f32_16k(wav_16k)
- else:
- fd, tmp_wav = tempfile.mkstemp(suffix=".wav")
- os.close(fd)
- try:
- sf.write(tmp_wav, wav_16k, 16000, subtype="PCM_16")
- hyp = _transcribe_zh_wav_path(tmp_wav)
- finally:
- try:
- os.unlink(tmp_wav)
- except OSError:
- pass
- except Exception as e:
- logger.exception("Seed-TTS ASR failed for %s", req.seed_tts_utterance_id)
- asr_failed += 1
- if include_per_item:
- items.append(
- {
- "utterance_id": req.seed_tts_utterance_id,
- "locale": locale,
- "error": "asr_exception",
- "detail": str(e)[:500],
- }
- )
- continue
-
- if not hyp:
- asr_failed += 1
- if include_per_item:
- items.append(
- {
- "utterance_id": req.seed_tts_utterance_id,
- "locale": locale,
- "error": "empty_asr",
- }
- )
- continue
-
- try:
- wer, raw_truth, raw_hypo = process_one_official(hyp, ref, row_lang)
- except Exception as e:
- logger.warning("jiwer/normalize failed for %s: %s", req.seed_tts_utterance_id, e)
- asr_failed += 1
- if include_per_item:
- items.append(
- {
- "utterance_id": req.seed_tts_utterance_id,
- "locale": locale,
- "error": "wer_compute_failed",
- "detail": str(e)[:500],
- }
- )
- continue
-
- errs.append(wer)
- sim_v: float | None = None
-
- if _eval_submetric_enabled("SEED_TTS_SIM_EVAL", default=False):
- ref_path = getattr(req, "seed_tts_ref_wav_path", "") or ""
- if ref_path and os.path.isfile(ref_path):
- try:
- ref_wav = _audio_path_to_f32_16k(ref_path)
- e_ref = _wavlm_mean_embedding_f32_16k(ref_wav)
- e_hyp = _wavlm_mean_embedding_f32_16k(wav_16k)
- if e_ref is not None and e_hyp is not None:
- sim_v = _cosine_similarity_unit_vectors(e_ref, e_hyp)
- sim_values.append(sim_v)
- except Exception as e:
- logger.warning(
- "SIM embedding failed for utterance=%s: %s: %s",
- req.seed_tts_utterance_id,
- type(e).__name__,
- e,
- )
- sim_failed += 1
- else:
- sim_skipped_no_ref += 1
-
- if include_per_item:
- row: dict[str, Any] = {
- "utterance_id": req.seed_tts_utterance_id,
- "locale": locale,
- "wer": wer,
- "reference_raw": raw_truth,
- "asr_raw": raw_hypo,
- }
- if sim_v is not None:
- row["sim"] = sim_v
- if utmos_v is not None:
- row["utmos"] = utmos_v
- items.append(row)
-
- result: dict[str, Any] = {
- "seed_tts_eval_protocol": "seed-tts-eval",
- "seed_tts_content_evaluated": len(errs),
- "seed_tts_content_error_mean": statistics.fmean(errs) if errs else None,
- "seed_tts_content_error_median": statistics.median(errs) if errs else None,
- "seed_tts_request_failed": request_failed,
- "seed_tts_no_pcm": no_pcm,
- "seed_tts_asr_failed": asr_failed,
- "seed_tts_content_metric": "wer",
- "seed_tts_sim_evaluated": len(sim_values),
- "seed_tts_sim_mean": statistics.fmean(sim_values) if sim_values else None,
- "seed_tts_sim_median": statistics.median(sim_values) if sim_values else None,
- "seed_tts_sim_failed": sim_failed,
- "seed_tts_sim_skipped_no_ref": sim_skipped_no_ref,
- "seed_tts_utmos_evaluated": len(utmos_values),
- "seed_tts_utmos_mean": statistics.fmean(utmos_values) if utmos_values else None,
- "seed_tts_utmos_median": statistics.median(utmos_values) if utmos_values else None,
- "seed_tts_utmos_failed": utmos_failed,
- }
- if include_per_item:
- result["seed_tts_wer_eval_items"] = items
- return result
-
-
-def print_seed_tts_wer_summary(metrics: dict[str, Any]) -> None:
- setup = metrics.get("seed_tts_eval_setup_error")
- if setup:
- print("{s:{c}^{n}}".format(s=" Seed-TTS eval (seed-tts-eval protocol) ", n=50, c="="))
- print(setup)
- return
-
- ev = int(metrics.get("seed_tts_content_evaluated", 0) or 0)
- rf = int(metrics.get("seed_tts_request_failed", 0) or 0)
- npc = int(metrics.get("seed_tts_no_pcm", 0) or 0)
- af = int(metrics.get("seed_tts_asr_failed", 0) or 0)
- sim_ev = int(metrics.get("seed_tts_sim_evaluated", 0) or 0)
- ut_ev = int(metrics.get("seed_tts_utmos_evaluated", 0) or 0)
- if ev == 0 and rf == 0 and npc == 0 and af == 0 and sim_ev == 0 and ut_ev == 0:
- return
- print("{s:{c}^{n}}".format(s=" Seed-TTS eval (seed-tts-eval protocol) ", n=50, c="="))
- print("{:<40} {:<10}".format("Evaluated (WER, lower is better):", ev))
- mean = metrics.get("seed_tts_content_error_mean")
- if mean is not None:
- print("{:<40} {:<10.4f}".format("Mean WER:", float(mean)))
- med = metrics.get("seed_tts_content_error_median")
- if med is not None:
- print("{:<40} {:<10.4f}".format("Median WER:", float(med)))
- print("{:<40} {:<10}".format("Request failed:", metrics.get("seed_tts_request_failed", 0)))
- print("{:<40} {:<10}".format("No PCM captured:", metrics.get("seed_tts_no_pcm", 0)))
- print("{:<40} {:<10}".format("ASR / WER failed:", metrics.get("seed_tts_asr_failed", 0)))
- if sim_ev or metrics.get("seed_tts_sim_skipped_no_ref") or metrics.get("seed_tts_sim_failed"):
- print("{:<40} {:<10}".format("SIM evaluated (higher ~ closer):", sim_ev))
- sm = metrics.get("seed_tts_sim_mean")
- if sm is not None:
- print("{:<40} {:<10.4f}".format("Mean SIM:", float(sm)))
- s_med = metrics.get("seed_tts_sim_median")
- if s_med is not None:
- print("{:<40} {:<10.4f}".format("Median SIM:", float(s_med)))
- print("{:<40} {:<10}".format("SIM skipped (no ref path):", metrics.get("seed_tts_sim_skipped_no_ref", 0)))
- print("{:<40} {:<10}".format("SIM embedding errors:", metrics.get("seed_tts_sim_failed", 0)))
- if ut_ev or metrics.get("seed_tts_utmos_failed"):
- print("{:<40} {:<10}".format("UTMOS evaluated (JIT MOS, higher better):", ut_ev))
- um = metrics.get("seed_tts_utmos_mean")
- if um is not None:
- print("{:<40} {:<10.4f}".format("Mean UTMOS:", float(um)))
- u_med = metrics.get("seed_tts_utmos_median")
- if u_med is not None:
- print("{:<40} {:<10.4f}".format("Median UTMOS:", float(u_med)))
- print("{:<40} {:<10}".format("UTMOS errors:", metrics.get("seed_tts_utmos_failed", 0)))
- print("=" * 50)
diff --git a/vllm_omni/benchmarks/metrics/metrics.py b/vllm_omni/benchmarks/metrics/metrics.py
index dbf764698a0..a2acc7d7567 100644
--- a/vllm_omni/benchmarks/metrics/metrics.py
+++ b/vllm_omni/benchmarks/metrics/metrics.py
@@ -185,7 +185,7 @@ def calculate_metrics(
# Note : this may inflate the output token count slightly
output_len = len(tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids)
actual_output_lens.append(output_len)
- total_input += outputs[i].prompt_len
+ total_input += input_requests[i].prompt_len
tpot = 0
if output_len > 1:
latency_minus_ttft = outputs[i].text_latency - outputs[i].ttft
diff --git a/vllm_omni/benchmarks/patch/__init__.py b/vllm_omni/benchmarks/patch/__init__.py
index ca6b41ba8f7..e69de29bb2d 100644
--- a/vllm_omni/benchmarks/patch/__init__.py
+++ b/vllm_omni/benchmarks/patch/__init__.py
@@ -1,3 +0,0 @@
-"""Omni benchmark monkey-patches (side effects in ``patch.patch``)."""
-
-from . import patch as _patch_module # noqa: F401
diff --git a/vllm_omni/benchmarks/patch/patch.py b/vllm_omni/benchmarks/patch/patch.py
index bda75ef624d..d8145c40bcd 100644
--- a/vllm_omni/benchmarks/patch/patch.py
+++ b/vllm_omni/benchmarks/patch/patch.py
@@ -6,7 +6,6 @@
import os
import random
import ssl
-import sys
import time
import traceback
from collections.abc import Iterable
@@ -34,252 +33,15 @@
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
-
-from vllm_omni.benchmarks.data_modules.daily_omni_dataset import DailyOmniDataset, DailyOmniSampleRequest
from vllm_omni.benchmarks.data_modules.random_multi_modal_dataset import OmniRandomMultiModalDataset
-from vllm_omni.benchmarks.data_modules.seed_tts_dataset import (
- SEED_TTS_DEFAULT_OMNI_SYSTEM_PROMPT,
- SeedTTSDataset,
- SeedTTSDesignDataset,
- SeedTTSSampleRequest,
- SeedTTSTextDataset,
-)
get_samples_old = datasets.get_samples
-_DEFAULT_DAILY_OMNI_REPO = "liarliar/Daily-Omni"
-
-
-def _seed_tts_capture_pcm_for_wer() -> bool:
- return os.environ.get("SEED_TTS_WER_EVAL", "").lower() in (
- "1",
- "true",
- "yes",
- )
-
-
-def _merge_extra_body_mm_kwargs(base: dict | None, overlay: dict | None) -> dict | None:
- """Shallow-merge ``extra_body`` dicts; deep-merge ``mm_processor_kwargs`` if both set."""
- if not base and not overlay:
- return None
- out = dict(base or {})
- if not overlay:
- return out
- for k, v in overlay.items():
- if k == "mm_processor_kwargs" and isinstance(v, dict):
- prev = out.get("mm_processor_kwargs")
- merged_kw = {**(prev if isinstance(prev, dict) else {}), **v}
- out["mm_processor_kwargs"] = merged_kw
- else:
- out[k] = v
- return out
-
-
-def _attach_daily_omni_to_request_func_input(sample: SampleRequest, rfi: RequestFuncInput) -> None:
- """Apply per-request OpenAI fields (``mm_processor_kwargs``, messages) for Daily-Omni."""
- if not isinstance(sample, DailyOmniSampleRequest):
- return
- rfi.extra_body = _merge_extra_body_mm_kwargs(rfi.extra_body, sample.omni_extra_body)
- if sample.omni_chat_messages is not None:
- setattr(rfi, "omni_chat_messages", sample.omni_chat_messages)
- else:
- setattr(rfi, "mm_position", sample.omni_chat_mm_position)
-
-
-def _attach_seed_tts_to_request_func_input(sample: SampleRequest, rfi: RequestFuncInput) -> None:
- """Merge Seed-TTS per-row TTS fields into ``extra_body`` and mark for PCM capture.
-
- Always sets ``seed_tts_row=True`` on the RequestFuncInput for any
- :class:`SeedTTSSampleRequest` subclass (including text-only and design
- variants that carry no ``ref_audio``). This enables PCM capture for WER /
- UTMOS evaluation even when there is no reference audio.
- """
- if not isinstance(sample, SeedTTSSampleRequest):
- return
- # Mark for PCM capture (WER / UTMOS eval) regardless of extra body presence.
- setattr(rfi, "seed_tts_row", True)
- sys_prompt = (sample.seed_tts_system_prompt or "").strip() or SEED_TTS_DEFAULT_OMNI_SYSTEM_PROMPT
- setattr(
- rfi,
- "omni_chat_messages",
- [
- {"role": "system", "content": [{"type": "text", "text": sys_prompt}]},
- {"role": "user", "content": [{"type": "text", "text": sample.prompt}]},
- ],
- )
- ex = sample.seed_tts_speech_extra
- if not ex:
- return # voice comes from --extra-body in config; no ref_audio to merge
- base = dict(rfi.extra_body) if rfi.extra_body else {}
- base.update(ex)
- rfi.extra_body = base
-
-
-def _daily_omni_repo_from_args(args) -> str | None:
- """Resolve HuggingFace repo id for Daily-Omni from CLI args.
-
- vLLM allows ``--dataset-path`` to be a local path while the real HF id is
- passed via ``--hf-name``. Upstream ``get_samples`` for ``hf`` only matches
- a fixed elif-chain and never discovers Omni's loader, so we must detect
- Daily-Omni here using either field.
- """
- dp = getattr(args, "dataset_path", None)
- hn = getattr(args, "hf_name", None)
- if dp in DailyOmniDataset.SUPPORTED_DATASET_PATHS:
- return dp
- if hn in DailyOmniDataset.SUPPORTED_DATASET_PATHS:
- return hn
- return None
-
def get_samples(args, tokenizer):
- # Daily-Omni: explicit dataset name, or hf + matching path/hf-name
- is_daily_omni = args.dataset_name == "daily-omni" or (
- args.dataset_name == "hf" and _daily_omni_repo_from_args(args) is not None
- )
- is_seed_tts = args.dataset_name in ("seed-tts", "seed-tts-text", "seed-tts-design")
-
- # Check if we need to handle omni-related backends/datasets
- is_omni_backend = args.backend in ["openai-chat-omni", "openai-audio-speech", "daily-omni"]
- is_omni_dataset = is_daily_omni or is_seed_tts or args.dataset_name == "random-mm"
-
- if not is_omni_backend and not is_omni_dataset:
- # Not an omni-related request, delegate to original implementation
+ if args.backend not in ["openai-chat-omni", "openai-audio-speech"]:
return get_samples_old(args, tokenizer)
-
- # Handle Daily-Omni dataset
- if is_daily_omni:
- # Support:
- # --dataset-name daily-omni [--dataset-path liarliar/Daily-Omni]
- # --dataset-name daily-omni --daily-omni-qa-json /path/to/qa.json (offline QA)
- # --dataset-name hf --dataset-path liarliar/Daily-Omni
- # --dataset-name hf --hf-name liarliar/Daily-Omni (dataset-path may be local)
-
- # Validate backend supports multimodal (video)
- if args.backend not in ["openai-chat-omni", "daily-omni"]:
- raise ValueError(
- f"Daily-Omni dataset requires a multimodal backend that supports video. "
- f"Got backend='{args.backend}'. Please use '--backend openai-chat-omni'"
- )
-
- # Determine video directory if specified (for local video files)
- video_dir = getattr(args, "daily_omni_video_dir", None)
-
- # Get HF split (default to "train"; unused when loading from local qa.json)
- dataset_split = getattr(args, "hf_split", None) or "train"
-
- qa_json = getattr(args, "daily_omni_qa_json", None)
- if isinstance(qa_json, str):
- qa_json = qa_json.strip() or None
-
- if qa_json is not None:
- logger.info(
- "Loading Daily-Omni dataset: qa_json=%s, video_dir=%s (Hub not used for QA)",
- qa_json,
- video_dir,
- )
- dataset = DailyOmniDataset(
- qa_json_path=qa_json,
- dataset_path=None,
- dataset_split=dataset_split,
- random_seed=args.seed,
- video_dir=video_dir,
- input_mode=getattr(args, "daily_omni_input_mode", "all"),
- inline_local_video=getattr(args, "daily_omni_inline_local_video", False),
- trust_remote_code=getattr(args, "trust_remote_code", False),
- disable_shuffle=getattr(args, "disable_shuffle", False),
- )
- else:
- repo_id = _daily_omni_repo_from_args(args)
- if args.dataset_name == "daily-omni":
- if repo_id is None:
- repo_id = _DEFAULT_DAILY_OMNI_REPO
- elif repo_id is None:
- raise ValueError(
- "Daily-Omni with --dataset-name hf requires "
- f"--dataset-path {_DEFAULT_DAILY_OMNI_REPO} or "
- f"--hf-name {_DEFAULT_DAILY_OMNI_REPO}."
- )
-
- logger.info(
- "Loading Daily-Omni dataset: hf_repo=%s, split=%s, video_dir=%s",
- repo_id,
- dataset_split,
- video_dir,
- )
-
- dataset = DailyOmniDataset(
- dataset_path=repo_id,
- dataset_split=dataset_split,
- dataset_subset=getattr(args, "hf_subset", None),
- random_seed=args.seed,
- video_dir=video_dir,
- input_mode=getattr(args, "daily_omni_input_mode", "all"),
- inline_local_video=getattr(args, "daily_omni_inline_local_video", False),
- trust_remote_code=getattr(args, "trust_remote_code", False),
- no_stream=getattr(args, "no_stream", False),
- disable_shuffle=getattr(args, "disable_shuffle", False),
- )
-
- out_len = getattr(args, "output_len", None)
- if out_len is None:
- out_len = getattr(args, "hf_output_len", None)
- if out_len is None:
- out_len = DailyOmniDataset.DEFAULT_OUTPUT_LEN
-
- input_requests = dataset.sample(
- tokenizer=tokenizer,
- num_requests=args.num_prompts,
- output_len=out_len,
- request_id_prefix=args.request_id_prefix,
- no_oversample=args.no_oversample,
- )
- return input_requests
-
- if is_seed_tts:
- if args.backend not in ("openai-audio-speech", "openai-chat-omni"):
- raise ValueError(
- "Seed-TTS requires --backend openai-audio-speech (POST /v1/audio/speech) or "
- "--backend openai-chat-omni (POST /v1/chat/completions with ref_audio/ref_text). "
- f"Got backend={args.backend!r}."
- )
- repo_id = getattr(args, "dataset_path", None) or getattr(args, "hf_name", None)
- if not repo_id:
- raise ValueError(
- "Seed-TTS requires --dataset-path (HF dataset repo id or local directory) or "
- "--hf-name for the Hub dataset id."
- )
-
- _cls_map = {
- "seed-tts": SeedTTSDataset,
- "seed-tts-text": SeedTTSTextDataset,
- "seed-tts-design": SeedTTSDesignDataset,
- }
- DatasetCls = _cls_map[args.dataset_name]
- dataset = DatasetCls(
- dataset_path=repo_id,
- random_seed=args.seed,
- locale=getattr(args, "seed_tts_locale", "en"),
- inline_ref_audio=not getattr(args, "seed_tts_file_ref_audio", False),
- seed_tts_root=getattr(args, "seed_tts_root", None),
- system_prompt=getattr(args, "seed_tts_system_prompt", None),
- disable_shuffle=getattr(args, "disable_shuffle", False),
- )
- out_len = getattr(args, "output_len", None)
- if out_len is None:
- out_len = getattr(args, "hf_output_len", None)
- if out_len is None:
- out_len = SeedTTSDataset.DEFAULT_OUTPUT_LEN
- return dataset.sample(
- tokenizer=tokenizer,
- num_requests=args.num_prompts,
- output_len=out_len,
- request_id_prefix=args.request_id_prefix,
- no_oversample=args.no_oversample,
- )
-
- # Handle random-mm dataset (Omni's synthetic multimodal dataset)
- if args.dataset_name == "random-mm":
+ elif args.dataset_name == "random-mm":
dataset = OmniRandomMultiModalDataset(random_seed=args.seed, dataset_path=args.dataset_path)
input_requests = dataset.sample(
tokenizer=tokenizer,
@@ -302,10 +64,6 @@ def get_samples(args, tokenizer):
datasets.get_samples = get_samples
-_serve_mod = sys.modules.get("vllm.benchmarks.serve")
-if _serve_mod is not None:
- _serve_mod.get_samples = get_samples
-
@dataclass
class MixRequestFuncOutput(RequestFuncOutput):
@@ -314,9 +72,6 @@ class MixRequestFuncOutput(RequestFuncOutput):
audio_frames: int = 0
audio_rtf: float = 0.0
text_latency: float = 0.0
- #: Raw PCM s16le mono at 24 kHz for Seed-TTS WER: from ``/v1/audio/speech`` stream or
- #: resampled export after ``openai-chat-omni`` audio deltas.
- tts_output_pcm_bytes: bytes | None = None
async def async_request_openai_chat_omni_completions(
@@ -328,17 +83,13 @@ async def async_request_openai_chat_omni_completions(
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions")
- omni_messages = getattr(request_func_input, "omni_chat_messages", None)
- if omni_messages is not None:
- messages_payload = omni_messages
- else:
- effective_mm_position = getattr(request_func_input, "mm_position", mm_position)
- content = _get_chat_content(request_func_input, mm_position=effective_mm_position)
- messages_payload = [{"role": "user", "content": content}]
+ content = _get_chat_content(request_func_input, mm_position=mm_position)
payload = {
"model": request_func_input.model_name if request_func_input.model_name else request_func_input.model,
- "messages": messages_payload,
+ "messages": [
+ {"role": "user", "content": content},
+ ],
"temperature": 0.0,
"max_tokens": request_func_input.output_len,
"stream": True,
@@ -347,10 +98,6 @@ async def async_request_openai_chat_omni_completions(
},
}
_update_payload_common(payload, request_func_input)
- # Seed-TTS via chat: voice-clone fields live on the body; ensure audio is streamed.
- if getattr(request_func_input, "seed_tts_row", False):
- if payload.get("modalities") is None:
- payload["modalities"] = ["text", "audio"]
response_format = payload.get("response_format", "wav")
if response_format == "pcm":
@@ -396,11 +143,7 @@ async def async_request_openai_chat_omni_completions(
if response.status == 200:
handler = StreamedResponseHandler()
async for chunk_bytes in response.content.iter_any():
- # NOTE: Do NOT strip() here; TCP may fragment the SSE messages,
- # so stripping here can cause problems depending on how it is split.
- #
- # Simple example: [b'data: ', b'{json}\n\n'] <- stripping the first
- # chunk will break SSE parsing because the space after 'data:' is required.
+ chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
@@ -420,10 +163,7 @@ async def async_request_openai_chat_omni_completions(
data = json.loads(chunk)
if choices := data.get("choices"):
modality = data.get("modality")
- delta = choices[0].get("delta") or {}
- content = delta.get("content")
- if not content and isinstance(delta.get("audio"), dict):
- content = delta["audio"].get("data")
+ content = choices[0]["delta"].get("content")
if modality == "text":
# First token
if ttft == 0.0:
@@ -438,7 +178,7 @@ async def async_request_openai_chat_omni_completions(
if output.audio_ttfp == 0.0:
output.audio_ttfp = timestamp - st
audio_generate_time = timestamp - st
- if content:
+ if content != "":
audio_bytes = base64.b64decode(content)
seg = AudioSegment.from_file(io.BytesIO(audio_bytes))
if seg is not None:
@@ -450,10 +190,6 @@ async def async_request_openai_chat_omni_completions(
if metrics := data.get("metrics"):
output.output_tokens = metrics.get("num_tokens_out", 0)
- if usage := data.get("usage"):
- if (pt := usage.get("prompt_tokens")) is not None:
- output.prompt_len = pt
-
output.latency = timestamp - st
output.generated_text = generated_text
if generated_audio is not None:
@@ -470,12 +206,6 @@ async def async_request_openai_chat_omni_completions(
else:
output.audio_rtf = 0
logger.warning("Audio duration is zero")
- if _seed_tts_capture_pcm_for_wer() and getattr(request_func_input, "seed_tts_row", False):
- try:
- seg = generated_audio.set_frame_rate(24000).set_channels(1).set_sample_width(2)
- output.tts_output_pcm_bytes = bytes(seg.raw_data)
- except Exception as ex:
- logger.warning("seed_tts WER PCM export failed: %s", ex)
output.success = True
else:
output.error = response.reason or ""
@@ -530,10 +260,6 @@ async def async_request_openai_audio_speech(
"response_format": "pcm",
}
_update_payload_common(payload, request_func_input)
- # Seed-TTS + WER: ``--extra-body`` may set stream=false / other formats; speech must stream PCM.
- if getattr(request_func_input, "seed_tts_row", False) and _seed_tts_capture_pcm_for_wer():
- payload["stream"] = True
- payload["response_format"] = "pcm"
headers = {
"Content-Type": "application/json",
@@ -552,8 +278,6 @@ async def async_request_openai_audio_speech(
st = time.perf_counter()
output.start_time = st
total_pcm_bytes = 0
- capture_wer_pcm = _seed_tts_capture_pcm_for_wer() and getattr(request_func_input, "seed_tts_row", False)
- pcm_capture = bytearray() if capture_wer_pcm else None
try:
async with session.post(url=api_url, json=payload, headers=headers) as response:
if response.status == 200:
@@ -565,8 +289,6 @@ async def async_request_openai_audio_speech(
output.audio_ttfp = timestamp - st
output.ttft = output.audio_ttfp
total_pcm_bytes += len(chunk)
- if pcm_capture is not None:
- pcm_capture.extend(chunk)
end_time = time.perf_counter()
output.latency = end_time - st
@@ -579,16 +301,6 @@ async def async_request_openai_audio_speech(
else:
output.audio_rtf = 0
logger.warning("Audio duration is zero")
- if pcm_capture is not None and pcm_capture:
- output.tts_output_pcm_bytes = bytes(pcm_capture)
- elif capture_wer_pcm:
- ct = response.headers.get("Content-Type", "")
- logger.warning(
- "Seed-TTS WER: HTTP 200 but no PCM bytes (Content-Type=%r, url=%s). "
- "Check stream=true and response_format=pcm on the server.",
- ct,
- api_url,
- )
output.success = True
else:
output.error = response.reason or ""
@@ -611,12 +323,6 @@ async def async_request_openai_audio_speech(
if "openai-audio-speech" not in OPENAI_COMPATIBLE_BACKENDS:
OPENAI_COMPATIBLE_BACKENDS.append("openai-audio-speech")
-# Daily-Omni backend for audio-visual reasoning benchmark
-# Reuses openai-chat-omni completions for video+text understanding
-ASYNC_REQUEST_FUNCS["daily-omni"] = async_request_openai_chat_omni_completions
-if "daily-omni" not in OPENAI_COMPATIBLE_BACKENDS:
- OPENAI_COMPATIBLE_BACKENDS.append("daily-omni")
-
# ruff: noqa: E402
# Prevent import order from causing patch failures
from vllm.benchmarks import serve
@@ -708,8 +414,6 @@ async def benchmark(
extra_headers=extra_headers,
extra_body=extra_body,
)
- _attach_daily_omni_to_request_func_input(input_requests[0], test_input)
- _attach_seed_tts_to_request_func_input(input_requests[0], test_input)
if ready_check_timeout_sec > 0:
test_output = await wait_for_endpoint(
@@ -772,8 +476,6 @@ async def warmup_limited_request_func():
extra_headers=extra_headers,
extra_body=extra_body,
)
- _attach_daily_omni_to_request_func_input(input_requests[0], profile_input)
- _attach_seed_tts_to_request_func_input(input_requests[0], profile_input)
profile_output = await request_func(request_func_input=profile_input, session=session)
if profile_output.success:
print("Profiler started")
@@ -854,8 +556,6 @@ async def limited_request_func(request_func_input, session, pbar):
extra_body=extra_body,
request_id=request_id,
)
- _attach_daily_omni_to_request_func_input(request, request_func_input)
- _attach_seed_tts_to_request_func_input(request, request_func_input)
tasks.append(
asyncio.create_task(limited_request_func(request_func_input=request_func_input, session=session, pbar=pbar))
)
@@ -923,37 +623,6 @@ async def limited_request_func(request_func_input, session, pbar):
"errors": [output.error for output in outputs],
}
- from vllm_omni.benchmarks.data_modules.daily_omni_eval import (
- compute_daily_omni_accuracy_metrics,
- print_daily_omni_accuracy_summary,
- )
-
- _save_items = os.environ.get("DAILY_OMNI_SAVE_EVAL_ITEMS", "").lower() in (
- "1",
- "true",
- "yes",
- )
- _daily_acc = compute_daily_omni_accuracy_metrics(input_requests, outputs, include_per_item=_save_items)
- if _daily_acc is not None:
- result.update(_daily_acc)
- print_daily_omni_accuracy_summary(_daily_acc)
-
- if _seed_tts_capture_pcm_for_wer():
- from vllm_omni.benchmarks.data_modules.seed_tts_eval import (
- compute_seed_tts_wer_metrics,
- print_seed_tts_wer_summary,
- )
-
- _save_wer = os.environ.get("SEED_TTS_WER_SAVE_ITEMS", "").lower() in (
- "1",
- "true",
- "yes",
- )
- _wer_m = compute_seed_tts_wer_metrics(input_requests, outputs, include_per_item=_save_wer)
- if _wer_m is not None:
- result.update(_wer_m)
- print_seed_tts_wer_summary(_wer_m)
-
if rps_change_events:
result["rps_change_events"] = rps_change_events
diff --git a/vllm_omni/benchmarks/serve.py b/vllm_omni/benchmarks/serve.py
index d3f3510c567..fe946036931 100644
--- a/vllm_omni/benchmarks/serve.py
+++ b/vllm_omni/benchmarks/serve.py
@@ -1,21 +1,9 @@
import argparse
import asyncio
-import os
from typing import Any
from vllm.benchmarks.serve import main_async
-# Import patch to register daily-omni dataset and omni backends
-# This monkey-patches vllm.benchmarks.datasets.get_samples before it's used
-# Must be imported before any vllm.benchmarks module usage
-import vllm_omni.benchmarks.patch.patch # noqa: F401
-
def main(args: argparse.Namespace) -> dict[str, Any]:
- if getattr(args, "seed_tts_wer_eval", False):
- os.environ["SEED_TTS_WER_EVAL"] = "1"
- if getattr(args, "seed_tts_wer_save_items", False):
- os.environ["SEED_TTS_WER_SAVE_ITEMS"] = "1"
- if getattr(args, "daily_omni_save_eval_items", False):
- os.environ["DAILY_OMNI_SAVE_EVAL_ITEMS"] = "1"
return asyncio.run(main_async(args))
diff --git a/vllm_omni/config/__init__.py b/vllm_omni/config/__init__.py
index f02c0758805..2aa236e69f5 100644
--- a/vllm_omni/config/__init__.py
+++ b/vllm_omni/config/__init__.py
@@ -5,18 +5,10 @@
from vllm_omni.config.lora import LoRAConfig
from vllm_omni.config.model import OmniModelConfig
from vllm_omni.config.stage_config import (
- DeployConfig,
ModelPipeline,
- PipelineConfig,
StageConfig,
StageConfigFactory,
- StageDeployConfig,
- StageExecutionType,
- StagePipelineConfig,
StageType,
- load_deploy_config,
- merge_pipeline_deploy,
- register_pipeline,
)
from vllm_omni.config.yaml_util import (
create_config,
@@ -32,14 +24,6 @@
"StageConfigFactory",
"ModelPipeline",
"StageType",
- "StageExecutionType",
- "StagePipelineConfig",
- "PipelineConfig",
- "StageDeployConfig",
- "DeployConfig",
- "load_deploy_config",
- "merge_pipeline_deploy",
- "register_pipeline",
"create_config",
"load_yaml_config",
"merge_configs",
diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py
index 72731eda85f..588efabfc4f 100644
--- a/vllm_omni/config/model.py
+++ b/vllm_omni/config/model.py
@@ -5,10 +5,7 @@
from vllm.config import ModelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
-from vllm.transformers_utils.config import (
- get_hf_text_config,
- thinker_uses_mrope,
-)
+from vllm.transformers_utils.config import get_hf_text_config
from vllm.transformers_utils.model_arch_config_convertor import (
ModelArchConfigConvertorBase,
)
@@ -52,25 +49,6 @@ def get_quantization_config(self):
if quant_cfg is not None:
return quant_cfg
- # Fall back to top-level quantization_config
- top_quant = super().get_quantization_config()
- if top_quant is not None:
- block_names = top_quant.get("block_name_to_quantize")
- if block_names is not None:
- # NOTE: This assumes stage_config_name follows the HF
- # ``_config`` convention (e.g. thinker_config →
- # prefix "thinker."). removesuffix is a no-op when
- # the suffix doesn't match, so a non-standard name
- # would just use itself as prefix — safe but worth
- # verifying if new stage names are introduced.
- hf_prefix = self.stage_config_name.removesuffix("_config") + "."
- if isinstance(block_names, str):
- block_names = [b.strip() for b in block_names.split(",")]
- if isinstance(block_names, list) and not any(b.startswith(hf_prefix) for b in block_names):
- # This stage is not listed → no quantization.
- return None
- return top_quant
-
# For non-thinker stages (talker, code2wav) whose text_config
# has no quantization_config, return None so quantization is
# not applied to stages that were not quantized.
@@ -131,12 +109,9 @@ class OmniModelConfig(ModelConfig):
"extra": {},
}
)
- subtalker_sampling_params: dict[str, Any] | None = None
omni_kv_config: dict | None = None
codec_frame_rate_hz: float | None = None
task_type: str | None = None
- enable_sleep_mode: bool = False
- has_sampling_extra_args: bool = False
@property
def registry(self):
@@ -148,18 +123,6 @@ def architectures(self) -> list[str]:
return [self.model_arch]
return super().architectures
- @property
- def uses_mrope(self) -> bool:
- if self.hf_config_name is not None:
- # talker_config/thinker_config/etc
- stage_config = getattr(self.hf_config, self.hf_config_name, None)
- if stage_config is None:
- # Check the named sub-config's text_config directly.
- # Handles mrope resolution of stage-specific cls
- # (e.g., talker runs as a standalone cls)
- return thinker_uses_mrope(self.hf_config)
- return super().uses_mrope
-
@property
def embedding_size(self):
if self.hf_config_name is not None:
diff --git a/vllm_omni/config/pipeline_registry.py b/vllm_omni/config/pipeline_registry.py
deleted file mode 100644
index 9372209144f..00000000000
--- a/vllm_omni/config/pipeline_registry.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Central declarative registry of all vllm-omni pipelines.
-
-Mirrors the pattern in ``vllm/model_executor/models/registry.py``: each entry
-is ``model_type -> (module_path, variable_name)``, and the module is imported
-lazily on first lookup (see ``_LazyPipelineRegistry`` in
-``vllm_omni/config/stage_config.py``). Keeping every pipeline declared in one
-file makes it easy to spot a missing registration, which was the original
-motivation in https://github.com/vllm-project/vllm-omni/issues/2887 (item 4).
-
-Per-model ``pipeline.py`` modules still define the ``PipelineConfig`` instance;
-they just no longer need to self-register via ``register_pipeline(...)``.
-
-Adding a new pipeline:
- 1. Define the ``PipelineConfig`` instance as a module-level variable in
- ``vllm_omni/.../pipeline.py``.
- 2. Add one line to ``_OMNI_PIPELINES`` below.
-
-Single-stage diffusion models continue to use the
-``_create_default_diffusion_stage_cfg`` fallback in
-``async_omni_engine.py`` — they don't need a registry entry. The empty
-``_DIFFUSION_PIPELINES`` placeholder previously here (#2915) was removed
-once #2987 (which would have populated it) was deferred.
-
-``register_pipeline(config)`` in ``stage_config`` is still supported for
-out-of-tree plugins and tests that create pipelines at runtime; those override
-the entries declared here.
-"""
-
-from __future__ import annotations
-
-# --- Multi-stage omni pipelines (LLM-centric; audio / video I/O) ---
-_OMNI_PIPELINES: dict[str, tuple[str, str]] = {
- # model_type -> (module_path, variable_name)
- "qwen2_5_omni": (
- "vllm_omni.model_executor.models.qwen2_5_omni.pipeline",
- "QWEN2_5_OMNI_PIPELINE",
- ),
- "qwen2_5_omni_thinker_only": (
- "vllm_omni.model_executor.models.qwen2_5_omni.pipeline",
- "QWEN2_5_OMNI_THINKER_ONLY_PIPELINE",
- ),
- "qwen3_omni_moe": (
- "vllm_omni.model_executor.models.qwen3_omni.pipeline",
- "QWEN3_OMNI_PIPELINE",
- ),
- "qwen3_tts": (
- "vllm_omni.model_executor.models.qwen3_tts.pipeline",
- "QWEN3_TTS_PIPELINE",
- ),
- "bagel": (
- "vllm_omni.model_executor.models.bagel.pipeline",
- "BAGEL_PIPELINE",
- ),
- "bagel_think": (
- "vllm_omni.model_executor.models.bagel.pipeline",
- "BAGEL_THINK_PIPELINE",
- ),
- "bagel_single_stage": (
- "vllm_omni.model_executor.models.bagel.pipeline",
- "BAGEL_SINGLE_STAGE_PIPELINE",
- ),
- "glm_image": (
- "vllm_omni.model_executor.models.glm_image.pipeline",
- "GLM_IMAGE_PIPELINE",
- ),
- "voxcpm2": (
- "vllm_omni.model_executor.models.voxcpm2.pipeline",
- "VOXCPM2_PIPELINE",
- ),
- "cosyvoice3": (
- "vllm_omni.model_executor.models.cosyvoice3.pipeline",
- "COSYVOICE3_PIPELINE",
- ),
- "mimo_audio": (
- "vllm_omni.model_executor.models.mimo_audio.pipeline",
- "MIMO_AUDIO_PIPELINE",
- ),
- "voxtral_tts": (
- "vllm_omni.model_executor.models.voxtral_tts.pipeline",
- "VOXTRAL_TTS_PIPELINE",
- ),
- "fish_qwen3_omni": (
- "vllm_omni.model_executor.models.fish_speech.pipeline",
- "FISH_SPEECH_PIPELINE",
- ),
-}
diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py
index 950685d6219..a4e186c3bd2 100644
--- a/vllm_omni/config/stage_config.py
+++ b/vllm_omni/config/stage_config.py
@@ -1,13 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Stage configuration system for vLLM-Omni."""
+"""
+Stage Configuration System for vLLM-Omni.
+
+Pipeline structure (stages, types, data-flow) is defined in per-model YAML
+files and is set by model developers at integration time.
+Runtime parameters (gpu_memory_utilization, tp_size, etc.) come from CLI.
+"""
from __future__ import annotations
-import dataclasses
import re
import warnings
-from dataclasses import asdict, dataclass, field, fields
+from dataclasses import asdict, dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any
@@ -15,838 +20,76 @@
from vllm.logger import init_logger
from vllm_omni.config.yaml_util import create_config, load_yaml_config, to_dict
-from vllm_omni.core.sched.omni_ar_scheduler import OmniARScheduler
-from vllm_omni.core.sched.omni_generation_scheduler import OmniGenerationScheduler
+# Pipeline YAMLs live alongside model code in model_executor/models//
_MODELS_DIR = Path(__file__).resolve().parent.parent / "model_executor" / "models"
def get_pipeline_path(model_dir: str, filename: str) -> Path:
- return _MODELS_DIR / model_dir / filename
-
-
-logger = init_logger(__name__)
-
-
-def _warn_deprecated_kwargs(kwargs: dict[str, Any]) -> None:
- if "cli_explicit_keys" in kwargs:
- warnings.warn(
- "cli_explicit_keys= is deprecated and ignored. Remove the kwarg.",
- DeprecationWarning,
- stacklevel=3,
- )
-
-
-_STAGE_OVERRIDE_PATTERN = re.compile(r"^stage_(\d+)_(.+)$")
-
+ """Return the full path to a pipeline YAML file.
-def build_stage_runtime_overrides(
- stage_id: int,
- cli_overrides: dict[str, Any],
- *,
- internal_keys: set[str] | frozenset[str] | None = None,
-) -> dict[str, Any]:
- """Build per-stage runtime overrides from global and ``stage__*`` kwargs.
+ Args:
+ model_dir: Model subdirectory name (e.g., "qwen3_omni").
+ filename: Name of the YAML file (e.g., "pipeline.yaml").
- ``internal_keys`` defaults to the union of
- ``arg_utils.internal_blacklist_keys()`` and ``arg_utils.SHARED_FIELDS``
- so that neither orchestrator-only fields nor shared-pipeline fields
- (``model`` / ``stage_configs_path`` / ``log_stats`` / ``stage_id``) leak
- into a stage's per-stage runtime overrides — the orchestrator sets those
- uniformly for every stage, they are not per-stage knobs. Callers can
- pass an explicit set for tests or specialized flows.
+ Returns:
+ Absolute path to the file.
"""
- if internal_keys is None:
- from vllm_omni.engine.arg_utils import SHARED_FIELDS, internal_blacklist_keys
-
- internal_keys = internal_blacklist_keys() | SHARED_FIELDS
-
- result: dict[str, Any] = {}
-
- for key, value in cli_overrides.items():
- if value is None or key in internal_keys:
- continue
-
- match = _STAGE_OVERRIDE_PATTERN.match(key)
- if match is not None:
- override_stage_id = int(match.group(1))
- param_name = match.group(2)
- if override_stage_id == stage_id and param_name not in internal_keys:
- result[param_name] = value
- continue
-
- result[key] = value
-
- return result
-
-
-def strip_parent_engine_args(
- kwargs: dict[str, Any],
- *,
- parent_fields: dict[str, dataclasses.Field],
- keep_keys: set[str] | frozenset[str] = frozenset(),
- strip_keys: set[str] | frozenset[str] = frozenset(),
- no_warn_keys: set[str] | frozenset[str] = frozenset(),
-) -> tuple[dict[str, Any], list[str]]:
- """Strip parent ``EngineArgs`` fields before merging into stage YAML."""
- overridden: list[str] = []
- result: dict[str, Any] = {}
-
- for key, value in kwargs.items():
- if key in strip_keys:
- continue
-
- if key not in parent_fields or key in keep_keys:
- result[key] = value
- continue
-
- field_def = parent_fields[key]
- if field_def.default is not dataclasses.MISSING:
- default = field_def.default
- elif field_def.default_factory is not dataclasses.MISSING:
- default = field_def.default_factory()
- else:
- default = dataclasses.MISSING
-
- if default is dataclasses.MISSING or value is None:
- continue
-
- if dataclasses.is_dataclass(default) and not isinstance(default, type):
- default = asdict(default)
+ return _MODELS_DIR / model_dir / filename
- if value != default and key not in no_warn_keys:
- overridden.append(key)
- return result, sorted(overridden)
+logger = init_logger(__name__)
class StageType(str, Enum):
"""Type of processing stage in the Omni pipeline."""
- # TODO(@lishunyang12): remove once all models migrate to StageExecutionType
LLM = "llm"
DIFFUSION = "diffusion"
-class StageExecutionType(str, Enum):
- """Merged StageType + WorkerType — 3 combinations today."""
-
- LLM_AR = "llm_ar"
- LLM_GENERATION = "llm_generation"
- DIFFUSION = "diffusion"
-
-
-# Mapping class refs (not dotted-path strings) so module/class renames fail
-# at import time instead of lazily at scheduler resolution. YAML overrides
-# and downstream serialization still use the dotted-path string form; the
-# conversion happens at the map lookup site via _scheduler_path().
-_EXECUTION_TYPE_TO_SCHEDULER: dict[StageExecutionType, type | None] = {
- StageExecutionType.LLM_AR: OmniARScheduler,
- StageExecutionType.LLM_GENERATION: OmniGenerationScheduler,
- StageExecutionType.DIFFUSION: None,
-}
-
-
-def _scheduler_path(cls: type | None) -> str | None:
- """Return the dotted import path for a scheduler class (``None`` passes through)."""
- if cls is None:
- return None
- return f"{cls.__module__}.{cls.__qualname__}"
-
-
-@dataclass(frozen=True)
-class StagePipelineConfig:
- """Fixed topology for one stage (frozen, not user-configurable)."""
-
- stage_id: int
- model_stage: str
- execution_type: StageExecutionType = StageExecutionType.LLM_AR
- input_sources: tuple[int, ...] = ()
- final_output: bool = False
- final_output_type: str | None = None
- owns_tokenizer: bool = False
- requires_multimodal_data: bool = False
- hf_config_name: str | None = None
- engine_output_type: str | None = None
- model_arch: str | None = None
- sampling_constraints: dict[str, Any] = field(default_factory=dict)
- custom_process_input_func: str | None = None
- custom_process_next_stage_input_func: str | None = None
- # Alternates picked by ``merge_pipeline_deploy`` based on ``deploy.async_chunk``.
- async_chunk_process_next_stage_input_func: str | None = None
- sync_process_input_func: str | None = None
- prompt_expand_func: str | None = None
- cfg_kv_collect_func: str | None = None
- omni_kv_config: dict[str, Any] | None = None
- # Model subdirectory indirections: for multi-component HF repos where the
- # stage's config/tokenizer lives in a subdirectory (e.g. GLM-Image's AR
- # config is in ``vision_language_encoder/``). Consumed at stage-init time
- # by ``stage_init_utils._resolve_model_tokenizer_paths``.
- model_subdir: str | None = None
- tokenizer_subdir: str | None = None
- extras: dict[str, Any] = field(default_factory=dict)
-
-
-@dataclass(frozen=True)
-class PipelineConfig:
- """Complete pipeline topology for a model (frozen)."""
-
- model_type: str
- model_arch: str = ""
- stages: tuple[StagePipelineConfig, ...] = ()
- # HF architecture aliases: used by StageConfigFactory when the model's
- # HF config reports a generic model_type that collides with a different
- # model (e.g. MiMo Audio reports model_type="qwen2"). The factory
- # matches ``hf_config.architectures[*]`` against this tuple to route
- # to the correct pipeline. Leave empty for models with unique model_type.
- hf_architectures: tuple[str, ...] = ()
- # Diffusers pipeline class name: for models that ship a ``model_index.json``
- # (no root ``config.json``), the ``_class_name`` field is matched against
- # this value to auto-detect the pipeline. Only needed for diffusers-style
- # multi-component repos (e.g. GLM-Image). ``None`` = not a diffusers model.
- diffusers_class_name: str | None = None
-
- def get_stage(self, stage_id: int) -> StagePipelineConfig | None:
- """Look up a stage by its ID."""
- for stage in self.stages:
- if stage.stage_id == stage_id:
- return stage
- return None
-
- def get_scheduler_cls(self, stage_id: int) -> str | None:
- """Return the inferred scheduler class path for a stage.
-
- Returns ``None`` for DIFFUSION stages (no vLLM scheduler). Raises
- ``ValueError`` if ``stage_id`` doesn't exist in this pipeline, and
- ``KeyError`` if ``execution_type`` isn't in the scheduler map.
- """
- stage = self.get_stage(stage_id)
- if stage is None:
- raise ValueError(f"Pipeline {self.model_type!r} has no stage with id {stage_id}")
- return _scheduler_path(_EXECUTION_TYPE_TO_SCHEDULER[stage.execution_type])
-
- def validate(self) -> list[str]:
- """Return list of topology errors (empty if valid)."""
- errors: list[str] = []
- if not self.stages:
- errors.append("Pipeline has no stages defined")
- return errors
- stage_ids = [s.stage_id for s in self.stages]
- if len(stage_ids) != len(set(stage_ids)):
- errors.append("Duplicate stage IDs found")
- stage_id_set = set(stage_ids)
- for stage in self.stages:
- for src in stage.input_sources:
- if src not in stage_id_set:
- errors.append(f"Stage {stage.stage_id} references non-existent input source {src}")
- if src == stage.stage_id:
- errors.append(f"Stage {stage.stage_id} references itself")
- if not any(not s.input_sources for s in self.stages):
- errors.append("No entry point (stage with empty input_sources)")
- return errors
-
-
-class _LazyPipelineRegistry:
- """Dict-like registry that lazy-loads pipelines from the central declaration.
-
- In-tree pipelines are declared once in
- ``vllm_omni/config/pipeline_registry.py`` as
- ``model_type -> (module_path, variable_name)`` entries; the module is
- imported only when the pipeline is first looked up. This mirrors the
- pattern in ``vllm/model_executor/models/registry.py`` and addresses
- https://github.com/vllm-project/vllm-omni/issues/2887 (item 4): having
- every registration in one file makes a missing entry easy to spot.
-
- Out-of-tree / dynamic registrations via ``register_pipeline()`` are stored
- directly in ``_loaded`` and take precedence over the lazy-map entry with
- the same ``model_type``.
-
- The class exposes the subset of ``dict`` operations the rest of this
- module relies on (``__contains__``, ``__getitem__``, ``__setitem__``,
- ``get``, ``keys``, ``values``, ``items``, ``__iter__``), so existing call
- sites don't need to change.
- """
-
- def __init__(self) -> None:
- self._loaded: dict[str, PipelineConfig] = {}
- # Populated lazily to avoid a circular import at module init time.
- self._lazy_map: dict[str, tuple[str, str]] | None = None
-
- def _get_lazy_map(self) -> dict[str, tuple[str, str]]:
- if self._lazy_map is None:
- from vllm_omni.config.pipeline_registry import _OMNI_PIPELINES
-
- self._lazy_map = _OMNI_PIPELINES
- return self._lazy_map
-
- def _load_lazy(self, model_type: str) -> PipelineConfig | None:
- entry = self._get_lazy_map().get(model_type)
- if entry is None:
- return None
- module_path, var_name = entry
- import importlib
-
- try:
- module = importlib.import_module(module_path)
- except ImportError as exc:
- logger.error(
- "Failed to import pipeline module %r for %r: %s",
- module_path,
- model_type,
- exc,
- )
- return None
- pipeline = getattr(module, var_name, None)
- if pipeline is None:
- logger.error(
- "Pipeline variable %r not found in module %r (registered for %r)",
- var_name,
- module_path,
- model_type,
- )
- return None
- errors = pipeline.validate()
- if errors:
- logger.warning("Pipeline %s has issues: %s", pipeline.model_type, errors)
- self._loaded[model_type] = pipeline
- return pipeline
-
- def __contains__(self, model_type: str) -> bool:
- if model_type in self._loaded:
- return True
- return model_type in self._get_lazy_map()
-
- def __getitem__(self, model_type: str) -> PipelineConfig:
- if model_type in self._loaded:
- return self._loaded[model_type]
- pipeline = self._load_lazy(model_type)
- if pipeline is None:
- raise KeyError(model_type)
- return pipeline
-
- def get(self, model_type: str, default: PipelineConfig | None = None) -> PipelineConfig | None:
- if model_type in self._loaded:
- return self._loaded[model_type]
- pipeline = self._load_lazy(model_type)
- return pipeline if pipeline is not None else default
-
- def __setitem__(self, model_type: str, pipeline: PipelineConfig) -> None:
- self._loaded[model_type] = pipeline
-
- def __delitem__(self, model_type: str) -> None:
- """Remove a dynamically-registered pipeline.
-
- Only the dynamic-cache side of the registry can be mutated; the
- central declarative registry is immutable at runtime. Calling ``del``
- on a model_type that only exists in the central registry raises
- ``KeyError``.
- """
- if model_type in self._loaded:
- del self._loaded[model_type]
- return
- if model_type in self._get_lazy_map():
- raise KeyError(
- f"{model_type!r} is declared in the central pipeline_registry and "
- "cannot be removed at runtime. Edit "
- "vllm_omni/config/pipeline_registry.py to delete a built-in entry."
- )
- raise KeyError(model_type)
-
- def keys(self) -> set[str]:
- return set(self._get_lazy_map().keys()) | set(self._loaded.keys())
-
- def values(self):
- # Iterating values forces load of every lazy pipeline.
- for key in self.keys():
- yield self[key]
-
- def items(self):
- for key in self.keys():
- yield key, self[key]
-
- def __iter__(self):
- return iter(self.keys())
-
-
-_PIPELINE_REGISTRY = _LazyPipelineRegistry()
-
-
-def register_pipeline(pipeline: PipelineConfig) -> None:
- """Register a pipeline config dynamically.
-
- In-tree pipelines are declared in ``pipeline_registry._OMNI_PIPELINES``
- and loaded lazily; calling ``register_pipeline`` is only needed for
- out-of-tree plugins or tests that build a ``PipelineConfig`` at runtime.
- A dynamic registration overrides the central-registry entry with the same
- ``model_type``.
- """
- errors = pipeline.validate()
- if errors:
- logger.warning("Pipeline %s has issues: %s", pipeline.model_type, errors)
- _PIPELINE_REGISTRY[pipeline.model_type] = pipeline
-
-
-_DEPLOY_DIR = Path(__file__).resolve().parent.parent / "deploy"
-
-
-@dataclass
-class StageDeployConfig:
- """Per-stage deployment knobs.
-
- Only fields whose value legitimately varies across stages of the same
- pipeline live here (e.g. ``max_num_seqs`` on thinker vs talker,
- ``devices`` for GPU placement). Pipeline-wide settings
- (``trust_remote_code``, ``distributed_executor_backend``, ``dtype``,
- ``quantization``, prefix/chunked prefill, DP/PP sizes) are declared at
- the top level of ``DeployConfig`` and propagated to every stage.
- """
-
- stage_id: int
- max_num_seqs: int = 64
- gpu_memory_utilization: float = 0.9
- tensor_parallel_size: int = 1
- enforce_eager: bool = False
- max_num_batched_tokens: int = 32768
- max_model_len: int | None = None
- async_scheduling: bool | None = None
- devices: str = "0"
- output_connectors: dict[str, str] | None = None
- input_connectors: dict[str, str] | None = None
- default_sampling_params: dict[str, Any] | None = None
- subtalker_sampling_params: dict[str, Any] | None = None
- engine_extras: dict[str, Any] = field(default_factory=dict)
-
-
-@dataclass
-class DeployConfig:
- """Loaded from deploy/.yaml — the only config file users edit.
-
- Top-level fields (``trust_remote_code``, ``distributed_executor_backend``,
- ``dtype``, ``quantization``, ``enable_prefix_caching``,
- ``enable_chunked_prefill``, ``data_parallel_size``,
- ``pipeline_parallel_size``) are pipeline-wide: they apply uniformly to
- every stage. Fields that legitimately vary per stage live in the
- individual ``StageDeployConfig`` entries under ``stages:``.
- """
-
- async_chunk: bool = True
- connectors: dict[str, Any] | None = None
- edges: list[dict[str, Any]] | None = None
- stages: list[StageDeployConfig] = field(default_factory=list)
- platforms: dict[str, Any] | None = None
- # Overrides the auto-detected pipeline registry key for structural variants.
- pipeline: str | None = None
-
- # === Pipeline-wide engine settings (applied uniformly to every stage) ===
- trust_remote_code: bool = True
- distributed_executor_backend: str | None = None
- dtype: str | None = None
- quantization: str | None = None
- enable_prefix_caching: bool = False
- enable_chunked_prefill: bool | None = None
- data_parallel_size: int = 1
- pipeline_parallel_size: int = 1
-
-
-_STAGE_NON_ENGINE_KEYS = frozenset(
- {
- "stage_id",
- "devices",
- "output_connectors",
- "input_connectors",
- "default_sampling_params",
- "engine_extras",
- }
-)
-
-# Fields on StageDeployConfig that are populated from engine_args dict
-_STAGE_DEPLOY_FIELDS = {f.name: f for f in fields(StageDeployConfig) if f.name not in _STAGE_NON_ENGINE_KEYS}
-
-
-def _parse_stage_deploy(stage_data: dict[str, Any]) -> StageDeployConfig:
- """Parse a single stage entry from deploy YAML into StageDeployConfig."""
- if "engine_args" in stage_data:
- engine_args = dict(stage_data["engine_args"])
- devices = stage_data.get("runtime", {}).get("devices", stage_data.get("devices", "0"))
- else:
- engine_args = {k: v for k, v in stage_data.items() if k not in _STAGE_NON_ENGINE_KEYS and k != "stage_id"}
- devices = stage_data.get("devices", "0")
-
- kwargs: dict[str, Any] = {"stage_id": stage_data["stage_id"], "devices": devices}
- for name, f in _STAGE_DEPLOY_FIELDS.items():
- if name in engine_args:
- kwargs[name] = engine_args.pop(name)
-
- kwargs["output_connectors"] = stage_data.get("output_connectors")
- kwargs["input_connectors"] = stage_data.get("input_connectors")
- kwargs["default_sampling_params"] = stage_data.get("default_sampling_params")
- kwargs["engine_extras"] = engine_args
- return StageDeployConfig(**kwargs)
-
-
-_DEEP_MERGE_KEYS = frozenset({"default_sampling_params", "subtalker_sampling_params", "engine_extras", "engine_args"})
-
-
-def _deep_merge_stage(base: dict, overlay: dict) -> dict:
- """Deep-merge ``_DEEP_MERGE_KEYS`` so thin overlays don't drop base keys."""
- merged = dict(base)
- for k, v in overlay.items():
- if k in _DEEP_MERGE_KEYS:
- base_val = merged.get(k)
- if isinstance(v, dict) and isinstance(base_val, dict):
- merged[k] = {**base_val, **v}
- continue
- # Deep-merge key but at least one side isn't a dict: surface the
- # silent clobber so mismatched YAML types don't get past review.
- if base_val is not None:
- logger.warning(
- "Deep-merge key %r has non-dict value (base=%s, overlay=%s); "
- "overlay will fully replace base instead of merging.",
- k,
- type(base_val).__name__,
- type(v).__name__,
- )
- merged[k] = v
- return merged
-
-
-def _merge_stage_lists(
- base_stages: list[dict[str, Any]] | None,
- overlay_stages: list[dict[str, Any]] | None,
-) -> list[dict[str, Any]]:
- """Merge two ``stages:`` lists by ``stage_id`` (overlay wins per field)."""
- by_id: dict[int, dict[str, Any]] = {s["stage_id"]: s for s in (base_stages or [])}
- for overlay_stage in overlay_stages or []:
- sid = overlay_stage["stage_id"]
- if sid in by_id:
- by_id[sid] = _deep_merge_stage(by_id[sid], overlay_stage)
- else:
- by_id[sid] = overlay_stage
- return list(by_id.values())
-
-
-def _merge_platforms(
- base: dict[str, Any] | None,
- overlay: dict[str, Any] | None,
-) -> dict[str, Any] | None:
- """Deep-merge two ``platforms:`` blocks per-platform, per-stage_id."""
- if not base and not overlay:
- return None
- base = base or {}
- overlay = overlay or {}
- merged: dict[str, Any] = {}
- for plat in set(base) | set(overlay):
- bp = base.get(plat) or {}
- op = overlay.get(plat) or {}
- merged_plat = {**bp, **{k: v for k, v in op.items() if k != "stages"}}
- merged_plat["stages"] = _merge_stage_lists(bp.get("stages"), op.get("stages"))
- merged[plat] = merged_plat
- return merged
-
-
-def resolve_deploy_yaml(path: str | Path) -> dict[str, Any]:
- """Load a deploy YAML with optional ``base_config`` inheritance."""
- raw_dict = to_dict(load_yaml_config(path))
-
- base_path = raw_dict.pop("base_config", None)
- if base_path is None:
- return raw_dict
-
- # Resolve relative to the overlay file's directory
- base_path = Path(path).parent / base_path
- base_dict = resolve_deploy_yaml(base_path)
-
- # Merge top-level scalars: overlay wins. ``stages:`` and ``platforms:``
- # are deep-merged below so an overlay can layer on top of the base.
- merged = {
- **base_dict,
- **{k: v for k, v in raw_dict.items() if k not in ("stages", "platforms")},
- }
- merged["stages"] = _merge_stage_lists(base_dict.get("stages"), raw_dict.get("stages"))
- merged_platforms = _merge_platforms(base_dict.get("platforms"), raw_dict.get("platforms"))
- if merged_platforms is not None:
- merged["platforms"] = merged_platforms
-
- return merged
-
-
-def load_deploy_config(path: str | Path) -> DeployConfig:
- """Load a deploy YAML (with optional base_config inheritance)."""
- raw_dict = resolve_deploy_yaml(path)
-
- stages = [_parse_stage_deploy(s) for s in raw_dict.get("stages", [])]
-
- kwargs: dict[str, Any] = {
- "async_chunk": raw_dict.get("async_chunk", True),
- "connectors": raw_dict.get("connectors", None),
- "edges": raw_dict.get("edges", None),
- "stages": stages,
- "platforms": raw_dict.get("platforms", None),
- "pipeline": raw_dict.get("pipeline", None),
- }
- # Pipeline-wide engine settings: only set if explicitly present in YAML
- # so the DeployConfig dataclass defaults take effect otherwise.
- for name in (
- "trust_remote_code",
- "distributed_executor_backend",
- "dtype",
- "quantization",
- "enable_prefix_caching",
- "enable_chunked_prefill",
- "data_parallel_size",
- "pipeline_parallel_size",
- ):
- if name in raw_dict:
- kwargs[name] = raw_dict[name]
- return DeployConfig(**kwargs)
-
-
-def _extract_platform_overrides(ps: dict[str, Any]) -> tuple[dict[str, Any], str | None]:
- """Return ``(overrides, devices)`` from a platform stage entry.
-
- Handles both the nested layout (``engine_args:`` / ``runtime.devices``) and
- the flat layout. ``devices`` is ``None`` when no override is set.
- """
- if "engine_args" in ps:
- return dict(ps["engine_args"]), ps.get("runtime", {}).get("devices")
- overrides = {k: v for k, v in ps.items() if k not in ("stage_id", "devices")}
- return overrides, ps.get("devices")
-
-
-def _apply_platform_overrides(
- deploy: DeployConfig,
- platform: str | None = None,
-) -> DeployConfig:
- """Merge platform-specific stage overrides into deploy config."""
- if platform is None:
- from vllm_omni.platforms import current_omni_platform
-
- platform = current_omni_platform.device_name.lower()
- if platform is None or deploy.platforms is None:
- return deploy
- platform_section = deploy.platforms.get(platform)
- if platform_section is None:
- return deploy
-
- platform_stages = platform_section.get("stages", [])
- base_by_id = {s.stage_id: s for s in deploy.stages}
-
- for ps in platform_stages:
- base = base_by_id.get(ps["stage_id"])
- if base is None:
- continue
- overrides, devices = _extract_platform_overrides(ps)
- if devices is not None:
- base.devices = devices
- for key, val in overrides.items():
- if hasattr(base, key):
- setattr(base, key, val)
- else:
- base.engine_extras[key] = val
-
- return deploy
-
-
-_EXECUTION_TYPE_TO_STAGE_WORKER: dict[StageExecutionType, tuple[StageType, str | None]] = {
- StageExecutionType.LLM_AR: (StageType.LLM, "ar"),
- StageExecutionType.LLM_GENERATION: (StageType.LLM, "generation"),
- StageExecutionType.DIFFUSION: (StageType.DIFFUSION, None),
-}
-
-
-def _resolve_execution_mode(
- execution_type: StageExecutionType,
-) -> tuple[StageType, str | None]:
- """Map ``execution_type`` → ``(stage_type, worker_type)`` legacy tuple."""
- return _EXECUTION_TYPE_TO_STAGE_WORKER.get(execution_type, (StageType.LLM, None))
-
-
-def _select_processor_funcs(
- ps: StagePipelineConfig,
- async_chunk: bool,
-) -> tuple[str | None, str | None]:
- """Pick ``(input_proc, next_stage_proc)`` based on the async_chunk mode."""
- next_stage_proc = ps.custom_process_next_stage_input_func
- input_proc = ps.custom_process_input_func
- if async_chunk and ps.async_chunk_process_next_stage_input_func:
- next_stage_proc = ps.async_chunk_process_next_stage_input_func
- elif not async_chunk and ps.sync_process_input_func:
- input_proc = ps.sync_process_input_func
- return input_proc, next_stage_proc
-
-
-# Pipeline-wide DeployConfig fields that are propagated to every stage's
-# engine args during merge. These live at top level of the deploy YAML.
-_PIPELINE_WIDE_ENGINE_FIELDS: tuple[str, ...] = (
- "trust_remote_code",
- "distributed_executor_backend",
- "dtype",
- "quantization",
- "enable_prefix_caching",
- "enable_chunked_prefill",
- "data_parallel_size",
- "pipeline_parallel_size",
-)
-
-
-def _build_engine_args(
- ps: StagePipelineConfig,
- ds: StageDeployConfig | None,
- pipeline: PipelineConfig,
- deploy: DeployConfig,
- next_stage_proc: str | None,
-) -> dict[str, Any]:
- """Assemble the flat ``yaml_engine_args`` dict for one stage.
-
- Pipeline-wide DeployConfig fields are applied uniformly to every stage;
- per-stage StageDeployConfig overrides take precedence when present (e.g.
- ``engine_extras`` can still carry a stage-specific ``dtype``).
- """
- engine_args: dict[str, Any] = {"model_arch": ps.model_arch or pipeline.model_arch}
- if ps.engine_output_type:
- engine_args["engine_output_type"] = ps.engine_output_type
- if next_stage_proc:
- engine_args["custom_process_next_stage_input_func"] = next_stage_proc
- # Subdirectory indirections from StagePipelineConfig (structural, not
- # deployment knobs). Deploy YAML ``engine_extras`` can still override
- # these per-stage if needed.
- if ps.model_subdir:
- engine_args["model_subdir"] = ps.model_subdir
- if ps.tokenizer_subdir:
- engine_args["tokenizer_subdir"] = ps.tokenizer_subdir
-
- # Pipeline-wide top-level DeployConfig settings, applied to every stage.
- for name in _PIPELINE_WIDE_ENGINE_FIELDS:
- value = getattr(deploy, name)
- if value is not None:
- engine_args[name] = value
-
- # Per-stage StageDeployConfig values override pipeline-wide settings.
- if ds is not None:
- for k, v in asdict(ds).items():
- if k in _STAGE_NON_ENGINE_KEYS or v is None:
- continue
- engine_args[k] = v
- engine_args.update(ds.engine_extras)
- # Materialize the resolved pipeline-wide async_chunk value into every
- # stage so explicit False overrides do not get lost downstream.
- engine_args["async_chunk"] = bool(deploy.async_chunk)
- if ps.omni_kv_config:
- engine_args["omni_kv_config"] = dict(ps.omni_kv_config)
- return engine_args
-
-
-def _build_extras(
- ps: StagePipelineConfig,
- ds: StageDeployConfig | None,
-) -> dict[str, Any]:
- """Assemble ``yaml_extras`` (sampling + connectors + pipeline extras)."""
- extras: dict[str, Any] = {}
- sampling: dict[str, Any] = {}
- if ds is not None and ds.default_sampling_params:
- sampling.update(ds.default_sampling_params)
- sampling.update(ps.sampling_constraints)
- if sampling:
- extras["default_sampling_params"] = sampling
- if ds is not None and ds.output_connectors:
- extras["output_connectors"] = dict(ds.output_connectors)
- if ds is not None and ds.input_connectors:
- extras["input_connectors"] = dict(ds.input_connectors)
- if ps.prompt_expand_func:
- extras["prompt_expand_func"] = ps.prompt_expand_func
- if ps.cfg_kv_collect_func:
- extras["cfg_kv_collect_func"] = ps.cfg_kv_collect_func
- if ps.extras:
- extras.update(ps.extras)
- return extras
-
-
-def merge_pipeline_deploy(
- pipeline: PipelineConfig,
- deploy: DeployConfig,
- cli_overrides: dict[str, Any] | None = None,
-) -> list[StageConfig]:
- """Merge pipeline + deploy + platform overrides → list[StageConfig]."""
- if cli_overrides is None:
- cli_overrides = {}
-
- deploy = _apply_platform_overrides(deploy)
- deploy_by_id = {s.stage_id: s for s in deploy.stages}
-
- # A pipeline supports async_chunk if any stage has either an explicit
- # async-chunk-only processor slot OR a custom next-stage processor (some
- # pipelines like qwen3_omni wire async-chunk processing directly through
- # ``custom_process_next_stage_input_func``). Only raise when neither is
- # present — that's the "user enabled async_chunk but pipeline has no
- # inter-stage processing at all" case.
- if deploy.async_chunk and not any(
- ps.async_chunk_process_next_stage_input_func or ps.custom_process_next_stage_input_func
- for ps in pipeline.stages
- ):
- raise ValueError(
- f"Pipeline {pipeline.model_type!r} has async_chunk=True in deploy but no stage "
- "declares a next-stage input processor "
- "(``async_chunk_process_next_stage_input_func`` or ``custom_process_next_stage_input_func``). "
- "Either set async_chunk=False or implement an async-chunk processor on the pipeline."
- )
-
- result: list[StageConfig] = []
- for ps in pipeline.stages:
- ds = deploy_by_id.get(ps.stage_id)
- stage_type, worker_type = _resolve_execution_mode(ps.execution_type)
- input_proc, next_stage_proc = _select_processor_funcs(ps, deploy.async_chunk)
- engine_args = _build_engine_args(ps, ds, pipeline, deploy, next_stage_proc)
- extras = _build_extras(ps, ds)
- runtime: dict[str, Any] = {"process": True}
- if ds is not None:
- runtime["devices"] = ds.devices
-
- result.append(
- StageConfig(
- stage_id=ps.stage_id,
- model_stage=ps.model_stage,
- stage_type=stage_type,
- input_sources=list(ps.input_sources),
- custom_process_input_func=input_proc,
- final_output=ps.final_output,
- final_output_type=ps.final_output_type,
- worker_type=worker_type,
- scheduler_cls=_scheduler_path(_EXECUTION_TYPE_TO_SCHEDULER.get(ps.execution_type)),
- hf_config_name=ps.hf_config_name,
- is_comprehension=ps.owns_tokenizer,
- yaml_engine_args=engine_args,
- yaml_runtime=runtime,
- yaml_extras=extras,
- )
- )
- return result
-
-
@dataclass
class StageConfig:
- """Per-stage config (legacy path). Used by both new and legacy loaders.
+ """Per-stage configuration from pipeline YAML.
- TODO(@lishunyang12): replace with ResolvedStageConfig once all models are migrated.
+ Topology fields (stage_id, input_sources, etc.) define the DAG.
+ Engine and runtime defaults come from the YAML; CLI overrides take
+ precedence via ``runtime_overrides``.
"""
+ # Identity
stage_id: int
model_stage: str
+
+ # Stage type
stage_type: StageType = StageType.LLM
+
input_sources: list[int] = field(default_factory=list)
custom_process_input_func: str | None = None
final_output: bool = False
- final_output_type: str | None = None
- worker_type: str | None = None
+ final_output_type: str | None = None # "text", "audio", "image"
+ worker_type: str | None = None # "ar" or "generation"
scheduler_cls: str | None = None
hf_config_name: str | None = None
is_comprehension: bool = False
+
+ # Per-stage engine args from pipeline YAML (defaults)
yaml_engine_args: dict[str, Any] = field(default_factory=dict)
+ # Per-stage runtime config from pipeline YAML (devices, etc.)
yaml_runtime: dict[str, Any] = field(default_factory=dict)
+ # Pass-through fields from pipeline YAML (default_sampling_params,
+ # output_connectors, input_connectors, tts_args, etc.)
yaml_extras: dict[str, Any] = field(default_factory=dict)
+
+ # Runtime overrides (populated from CLI, not from pipeline YAML)
runtime_overrides: dict[str, Any] = field(default_factory=dict)
def to_omegaconf(self) -> Any:
- """TODO(@lishunyang12): remove once engine consumes ResolvedStageConfig directly."""
+ """Convert to OmegaConf for backward compatibility with OmniStage.
+
+ Returns:
+ OmegaConf DictConfig with stage configuration in legacy format.
+ """
# Start with YAML engine_args defaults
engine_args: dict[str, Any] = dict(self.yaml_engine_args)
@@ -909,9 +152,9 @@ def to_omegaconf(self) -> Any:
@dataclass
class ModelPipeline:
- """Complete pipeline definition for a multi-stage model (legacy).
+ """Complete pipeline definition for a multi-stage model.
- TODO(@lishunyang12): remove once all models migrate to PipelineConfig.
+ Defined by model developers, bundled with the model, not user-editable.
"""
model_type: str
@@ -982,50 +225,49 @@ class StageConfigFactory:
"""Factory that loads pipeline YAML and merges CLI overrides.
Handles both single-stage and multi-stage models.
-
- Pipelines are declared in ``vllm_omni/config/pipeline_registry.py`` and
- loaded lazily via ``_PIPELINE_REGISTRY``; no hardcoded model-type →
- directory mapping is maintained here. Models with generic HF
- ``model_type`` collisions (e.g. MiMo Audio reports ``qwen2``) should
- declare ``hf_architectures=(...)`` on their ``PipelineConfig`` so the
- factory can disambiguate via ``hf_config.architectures``.
"""
+ # Mapping of model types to directories under model_executor/models/.
+ PIPELINE_MODELS: dict[str, str] = {
+ "qwen3_omni_moe": "qwen3_omni",
+ "qwen2_5_omni": "qwen2_5_omni",
+ "bagel": "bagel",
+ "qwen3_tts": "qwen3_tts",
+ "voxtral_tts": "voxtral_tts",
+ "mimo_audio": "mimo_audio",
+ "glm-image": "glm_image",
+ "cosyvoice3": "cosyvoice3",
+ "mammothmoda2": "mammoth_moda2",
+ }
+
+ # Fallback: map HF architecture class names to pipeline dirs.
+ # Used when model_type collides with another model (e.g. MiMo Audio
+ # reports model_type="qwen2" which matches plain Qwen2, not our pipeline).
+ _ARCHITECTURE_MODELS: dict[str, str] = {
+ "MiMoAudioForConditionalGeneration": "mimo_audio",
+ "HunyuanImage3ForCausalMM": "hunyuan_image3",
+ }
+
@classmethod
def create_from_model(
cls,
model: str,
cli_overrides: dict[str, Any] | None = None,
- deploy_config_path: str | None = None,
- **deprecated_kwargs: Any,
) -> list[StageConfig] | None:
- """Load pipeline + deploy config, merge with CLI overrides.
+ """Load pipeline YAML, merge with CLI overrides.
- Checks _PIPELINE_REGISTRY first (new path), falls back to legacy YAML.
- """
- _warn_deprecated_kwargs(deprecated_kwargs)
+ Args:
+ model: Model name or path.
+ cli_overrides: CLI overrides from VllmConfig/OmniDiffusionConfig.
+ Returns:
+ List of StageConfig objects with CLI overrides applied,
+ or None if no pipeline definition was found for this model.
+ """
if cli_overrides is None:
cli_overrides = {}
trust_remote_code = cli_overrides.get("trust_remote_code", True)
-
- # --- New path: check pipeline registry by model_type first ---
- model_type, hf_config = cls._auto_detect_model_type(model, trust_remote_code=trust_remote_code)
- if model_type and model_type in _PIPELINE_REGISTRY:
- return cls._create_from_registry(model_type, cli_overrides, deploy_config_path)
-
- # --- HF architecture fallback: some models report a generic
- # model_type that collides with another model. Match by the
- # hf_architectures declared on each registered PipelineConfig.
- if hf_config is not None:
- hf_archs = set(getattr(hf_config, "architectures", []) or [])
- if hf_archs:
- for registered in _PIPELINE_REGISTRY.values():
- if hf_archs.intersection(registered.hf_architectures):
- return cls._create_from_registry(registered.model_type, cli_overrides, deploy_config_path)
-
- # --- Legacy path: load from pipeline YAML ---
pipeline = cls._load_pipeline(model, trust_remote_code=trust_remote_code)
if pipeline is None:
@@ -1035,14 +277,14 @@ def create_from_model(
if errors:
logger.warning(f"Pipeline validation warnings for {model}: {errors}")
- # Materialize the resolved pipeline-wide async_chunk value into every
- # stage so build_engine_args_dict() can inject the stage connector
- # spec and explicit False overrides are preserved.
- resolved_async_chunk = cli_overrides.get("async_chunk")
- if resolved_async_chunk is None:
- resolved_async_chunk = bool(pipeline.async_chunk)
- for stage in pipeline.stages:
- stage.yaml_engine_args["async_chunk"] = bool(resolved_async_chunk)
+ # Inject pipeline-wide async_chunk into ALL stages' engine_args.
+ # The legacy loader (load_stage_configs_from_yaml) sets async_chunk
+ # on every stage so that build_engine_args_dict() can inject the
+ # stage_connector_spec. AsyncOmniEngine.__init__ also reads it
+ # from stage_configs[0].engine_args.async_chunk.
+ if pipeline.async_chunk:
+ for stage in pipeline.stages:
+ stage.yaml_engine_args.setdefault("async_chunk", True)
# Apply CLI overrides
result: list[StageConfig] = []
@@ -1053,58 +295,6 @@ def create_from_model(
return result
- @classmethod
- def _create_from_registry(
- cls,
- model_type: str,
- cli_overrides: dict[str, Any],
- deploy_config_path: str | None = None,
- **deprecated_kwargs: Any,
- ) -> list[StageConfig]:
- """Create StageConfigs from pipeline registry + deploy YAML.
-
- Precedence: caller-typed (non-None) value > deploy YAML >
- StageDeployConfig dataclass default.
- """
- _warn_deprecated_kwargs(deprecated_kwargs)
-
- # Resolve deploy config path
- if deploy_config_path is None:
- deploy_path = _DEPLOY_DIR / f"{model_type}.yaml"
- else:
- deploy_path = Path(deploy_config_path)
-
- if not deploy_path.exists():
- logger.warning(
- "Deploy config not found: %s — using pipeline defaults only",
- deploy_path,
- )
- deploy_cfg = DeployConfig()
- else:
- deploy_cfg = load_deploy_config(deploy_path)
-
- cli_async_chunk = cli_overrides.get("async_chunk")
- if cli_async_chunk is not None:
- deploy_cfg.async_chunk = bool(cli_async_chunk)
-
- pipeline_key = deploy_cfg.pipeline or model_type
- if pipeline_key not in _PIPELINE_REGISTRY:
- raise KeyError(
- f"Pipeline {pipeline_key!r} not in registry "
- f"(resolved from {deploy_path.name!r}). Available: "
- f"{sorted(_PIPELINE_REGISTRY.keys())}"
- )
- pipeline_cfg = _PIPELINE_REGISTRY[pipeline_key]
-
- stages = merge_pipeline_deploy(pipeline_cfg, deploy_cfg, cli_overrides)
-
- explicit_overrides = {k: v for k, v in cli_overrides.items() if v is not None}
-
- for stage in stages:
- stage.runtime_overrides = cls._merge_cli_overrides(stage, explicit_overrides)
-
- return stages
-
@classmethod
def create_default_diffusion(cls, kwargs: dict[str, Any]) -> list[dict[str, Any]]:
"""Single-stage diffusion - no YAML needed.
@@ -1132,16 +322,9 @@ def create_default_diffusion(cls, kwargs: dict[str, Any]) -> list[dict[str, Any]
continue
engine_args[key] = value
- # Serialize parallel_config as dict for OmegaConf. Test helpers
- # sometimes pass SimpleNamespace rather than a dataclass instance.
+ # Serialize parallel_config as dict for OmegaConf compatibility
if "parallel_config" in kwargs:
- parallel_config = kwargs["parallel_config"]
- if dataclasses.is_dataclass(parallel_config) and not isinstance(parallel_config, type):
- engine_args["parallel_config"] = asdict(parallel_config)
- elif hasattr(parallel_config, "__dict__"):
- engine_args["parallel_config"] = dict(vars(parallel_config))
- else:
- engine_args["parallel_config"] = parallel_config
+ engine_args["parallel_config"] = asdict(kwargs["parallel_config"])
engine_args.setdefault("cache_backend", "none")
engine_args["model_stage"] = "diffusion"
@@ -1168,49 +351,40 @@ def create_default_diffusion(cls, kwargs: dict[str, Any]) -> list[dict[str, Any]
@classmethod
def _load_pipeline(cls, model: str, trust_remote_code: bool = True) -> ModelPipeline | None:
- """Load a legacy ``pipeline.yaml`` for the model.
+ """Load pipeline YAML for the model.
- Searches ``model_executor/models//pipeline.yaml`` by trying
- (a) the raw ``model_type`` as the directory name, then
- (b) ``model_type`` with hyphens replaced by underscores,
- and finally (c) scanning every ``pipeline.yaml`` for one that
- declares a matching ``model_type`` or ``hf_architectures``.
+ Args:
+ model: Model name or path.
+ trust_remote_code: Whether to trust remote code for HF config loading.
- Returns None if no pipeline.yaml is found — caller handles the
- ``resolve_model_config_path`` fallback via stage_configs/ YAMLs.
+ Returns:
+ ModelPipeline if found, None otherwise.
"""
model_type, hf_config = cls._auto_detect_model_type(model, trust_remote_code=trust_remote_code)
if model_type is None:
return None
- # Direct lookups by convention
- candidates = [model_type, model_type.replace("-", "_")]
- for dir_name in candidates:
- pipeline_path = get_pipeline_path(dir_name, "pipeline.yaml")
- if pipeline_path.exists():
- return cls._parse_pipeline_yaml(pipeline_path, model_type)
-
- # Scan fallback: read every pipeline.yaml and match on declared fields
- hf_archs = set(getattr(hf_config, "architectures", []) or []) if hf_config else set()
- if _MODELS_DIR.exists():
- for subdir in sorted(_MODELS_DIR.iterdir()):
- if not subdir.is_dir():
- continue
- pipeline_path = subdir / "pipeline.yaml"
- if not pipeline_path.exists():
- continue
- try:
- cfg = load_yaml_config(pipeline_path)
- except Exception as exc:
- logger.debug("Skip %s: %s", pipeline_path, exc)
- continue
- declared_type = getattr(cfg, "model_type", None)
- declared_archs = set(getattr(cfg, "hf_architectures", None) or [])
- if declared_type == model_type or (hf_archs and hf_archs.intersection(declared_archs)):
- return cls._parse_pipeline_yaml(pipeline_path, declared_type or model_type)
+ pipeline_dir = cls.PIPELINE_MODELS.get(model_type)
- logger.debug("No pipeline.yaml found for model_type %s (archs=%s)", model_type, sorted(hf_archs))
- return None
+ # Fallback: check HF architectures when model_type doesn't match
+ if pipeline_dir is None and hf_config is not None:
+ for arch in getattr(hf_config, "architectures", []) or []:
+ pipeline_dir = cls._ARCHITECTURE_MODELS.get(arch)
+ if pipeline_dir is not None:
+ model_type = pipeline_dir
+ break
+
+ if pipeline_dir is None:
+ logger.debug(f"No pipeline mapping for model_type: {model_type}")
+ return None
+
+ pipeline_path = get_pipeline_path(pipeline_dir, "pipeline.yaml")
+
+ if not pipeline_path.exists():
+ logger.debug(f"Pipeline file not found: {pipeline_path}")
+ return None
+
+ return cls._parse_pipeline_yaml(pipeline_path, model_type)
# Keys consumed as explicit StageConfig fields — everything else is
# passed through via yaml_extras.
@@ -1361,66 +535,73 @@ def _auto_detect_model_type(cls, model: str, trust_remote_code: bool = True) ->
from vllm.transformers_utils.config import get_hf_file_to_dict
config_dict = get_hf_file_to_dict("config.json", model, revision=None)
- if config_dict:
- if "model_type" in config_dict:
- return config_dict["model_type"], None
- # VoxCPM2-style configs use singular ``architecture`` rather
- # than HF's standard ``model_type`` / ``architectures``. Accept
- # it as a fallback so the pipeline registry can still match.
- if "architecture" in config_dict and isinstance(config_dict["architecture"], str):
- return config_dict["architecture"], None
+ if config_dict and "model_type" in config_dict:
+ return config_dict["model_type"], None
except Exception as e:
logger.debug(f"Failed to auto-detect model type for {model}: {e}")
- # Fallback for diffusers-style models: check model_index.json.
- # Some models (e.g. GLM-Image) have no root config.json but ship a
- # model_index.json with _class_name that maps to a pipeline key via
- # PipelineConfig.diffusers_class_name.
- try:
- from vllm.transformers_utils.config import get_hf_file_to_dict
-
- model_index = get_hf_file_to_dict("model_index.json", model, revision=None)
- if model_index and "_class_name" in model_index:
- class_name = model_index["_class_name"]
- for pipeline_cfg in _PIPELINE_REGISTRY.values():
- if pipeline_cfg.diffusers_class_name == class_name:
- logger.info(
- "Detected pipeline %r from model_index.json (_class_name=%r)",
- pipeline_cfg.model_type,
- class_name,
- )
- return pipeline_cfg.model_type, None
- except Exception:
- pass
-
- # Final fallback: some models (e.g. CosyVoice3) ship an empty
- # config.json and rely on naming conventions. Match the model path
- # basename against registered pipeline keys — longest match wins
- # so "cosyvoice3" (length 10) beats "cosyvoice" (length 9).
- model_lower = model.lower().replace("-", "").replace("_", "")
- best: str | None = None
- best_len = 0
- for registered_key in _PIPELINE_REGISTRY.keys():
- candidate = registered_key.lower().replace("-", "").replace("_", "")
- if candidate and candidate in model_lower and len(candidate) > best_len:
- best = registered_key
- best_len = len(candidate)
- if best is not None:
- return best, None
-
return None, None
+ # Keys that should never be forwarded as engine overrides (internal /
+ # orchestrator-only knobs, complex objects, etc.).
+ _INTERNAL_KEYS: set[str] = {
+ "model",
+ "stage_configs_path",
+ "stage_id",
+ "stage_init_timeout",
+ "init_timeout",
+ "shm_threshold_bytes",
+ "worker_backend",
+ "ray_address",
+ "batch_timeout",
+ "log_stats",
+ "tokenizer",
+ "parallel_config",
+ }
+
@classmethod
def _merge_cli_overrides(
cls,
stage: StageConfig,
cli_overrides: dict[str, Any],
) -> dict[str, Any]:
- """Merge global and per-stage (``stage_N_*``) CLI overrides.
+ """Merge CLI overrides into stage runtime config.
+
+ All CLI arguments registered by engine config classes (e.g.
+ EngineArgs / OmniDiffusionConfig) are accepted as overrides
+ unless they appear in ``_INTERNAL_KEYS``.
- Orchestrator-owned keys are filtered by ``build_stage_runtime_overrides``
- using ``OrchestratorArgs`` as the single source of truth; unknown
- server/uvicorn keys are dropped downstream by
- ``filter_dataclass_kwargs(OmniEngineArgs, ...)``.
+ Handles:
+ - Global overrides (apply to all stages)
+ - Per-stage overrides (--stage-N-* format, take precedence)
+
+ Args:
+ stage: The stage to merge overrides into.
+ cli_overrides: CLI arguments from VllmConfig/OmniDiffusionConfig.
+
+ Returns:
+ Dict of runtime overrides for this stage.
"""
- return build_stage_runtime_overrides(stage.stage_id, cli_overrides)
+ result: dict[str, Any] = {}
+
+ # Apply global overrides – any key not in the internal blocklist
+ # is forwarded so that engine-registered params work out of the box.
+ for key, value in cli_overrides.items():
+ if key in cls._INTERNAL_KEYS:
+ continue
+ if re.match(r"stage_\d+_", key):
+ # Per-stage keys handled below
+ continue
+ if value is not None:
+ result[key] = value
+
+ # Apply per-stage overrides (--stage-N-* format, take precedence)
+ stage_prefix = f"stage_{stage.stage_id}_"
+ for key, value in cli_overrides.items():
+ if key.startswith(stage_prefix) and value is not None:
+ param_name = key[len(stage_prefix) :]
+ if param_name in cls._INTERNAL_KEYS:
+ continue
+ result[param_name] = value
+
+ return result
diff --git a/vllm_omni/core/prefix_cache.py b/vllm_omni/core/prefix_cache.py
deleted file mode 100644
index 69e7346c4c1..00000000000
--- a/vllm_omni/core/prefix_cache.py
+++ /dev/null
@@ -1,264 +0,0 @@
-"""
-Utilities for Prefix Caching in Omni models.
-"""
-
-import torch
-from vllm.logger import init_logger
-from vllm.v1.worker.gpu_input_batch import InputBatch
-
-from vllm_omni.utils.mm_outputs import build_mm_cpu, to_payload_element
-
-logger = init_logger(__name__)
-
-
-class OmniTensorPrefixCache:
- """Prefix cache for hidden states (model outputs) and model specific
- multimodal outputs.
-
- This class implements prefix caching in a non-invasive way on top of
- vLLM by leveraging the same slot mappings that the vLLM scheduler uses
- for the KV Cache.
-
- Conceptually, this means we are mapping vLLM's cache mapping:
- (num_blocks, block_size)
-
- to 3D tensors of shape:
- (num_blocks, block_size, feature_size)
-
- Note that feature_size may vary across multimodal_outputs.
- """
-
- def __init__(
- self,
- num_blocks: int,
- block_size: int,
- hidden_size: int,
- hs_dtype: torch.dtype,
- ):
- self.num_blocks = num_blocks
- self.block_size = block_size
- self.default_hidden_size = hidden_size
-
- # Initialize the hidden states cache immediately
- self.hidden_states_cache = self._get_cache_tensor(dtype=hs_dtype)
-
- # Defer initialization of the mm_outputs_cache until we
- # actually see mm output tensors dependent on num tokens.
- self.mm_outputs_cache = {}
- self.mm_cache_keys = set()
- self._new_req_cache_hit_ids: set[str] = set()
-
- def maybe_init_missing_mm_cache_keys(self, multimodal_outputs: dict, seq_len: int):
- """Given multimodal outputs from executing the model, dynamically
- determine which multimodal outputs are tensors depending on sequence
- length and should be cached, and initialize the cache tensors
- accordingly.
-
- NOTE: This is done to avoid the need for explicit specification of
- cache keys for every model/stage and aligns with the current way
- that we slice the multimodal outputs based on the first dimension.
-
- This will usually be called by the first forward pass, i.e.,
- determined by the warmup.
- """
- for key, val in multimodal_outputs.items():
- if isinstance(val, torch.Tensor) and val.shape[0] == seq_len and key not in self.mm_cache_keys:
- feat_dim = val.shape[-1]
- self.mm_outputs_cache[key] = self._get_cache_tensor(
- dtype=val.dtype,
- hidden_size=feat_dim,
- )
- self.mm_cache_keys.add(key)
- new_tensor_shape = self.mm_outputs_cache[key].shape
- logger.info("Initializing multimodal output cache of size %s for key: %s", list(new_tensor_shape), key)
-
- def _get_cache_tensor(self, dtype: torch.dtype, hidden_size: int | None = None) -> torch.Tensor:
- """Allocate a CPU cache tensor for a specific key."""
- actual_hidden_size = hidden_size if hidden_size is not None else self.default_hidden_size
- return torch.zeros(
- (self.num_blocks, self.block_size, actual_hidden_size),
- dtype=dtype,
- device="cpu",
- )
-
- def add_prefix_cached_new_req_id(self, req_id: str):
- """Adds a new request ID to the set of prefix cache hits on the batch."""
- self._new_req_cache_hit_ids.add(req_id)
-
- def reset_prefix_cached_new_req_ids(self):
- """Clears the cache hit IDs to prepare for a new engine step."""
- self._new_req_cache_hit_ids.clear()
-
- @staticmethod
- def _coerce_to_cpu_tensor(maybe_gpu_tensor: torch.Tensor) -> torch.Tensor:
- """Convert GPU tensors -> contiguous CPU tensors if needed."""
- return maybe_gpu_tensor.detach().cpu().contiguous()
-
- def update_omni_tensor_prefix_cache(
- self,
- hidden_states: torch.Tensor | None,
- multimodal_outputs: dict[str, torch.Tensor] | None,
- num_tokens_unpadded: int,
- slot_mapping: torch.Tensor,
- num_tokens_padded: int | None = None,
- ):
- """Updates the hidden cache state for the provided hidden states and multimodal outputs.
-
- Args:
- hidden_states: Hidden states tensor to cache (if any)
- multimodal_outputs: Multimodal dict whose tensors may be cached
- num_tokens_unpadded: Number of tokens without padding
- slot_mapping: Slot mapping for the input sequence
- num_tokens_padded: Total number of tokens including padding
- """
- unpadded_slot_mapping = slot_mapping[:num_tokens_unpadded]
- if num_tokens_padded is None:
- num_tokens_padded = num_tokens_unpadded
-
- if hidden_states is not None:
- # Slice to unpadded portion before caching
- hidden_states = hidden_states[:num_tokens_unpadded]
- # Ensure that hidden states are on the CPU
- hidden_states = OmniTensorPrefixCache._coerce_to_cpu_tensor(hidden_states)
- # View the cache as 2D so that we can treat our slots as row indices
- flat_cache = self.hidden_states_cache.view(-1, self.hidden_states_cache.shape[-1])
- flat_cache[unpadded_slot_mapping] = hidden_states
- logger.debug("Writing to hidden states for %s tokens", num_tokens_unpadded)
-
- # Do the same for the stage's cached multimodal outputs
- if multimodal_outputs is not None:
- # If we haven't initialized the keys already, do it now
- # We check against the padded token count since we haven't sliced yet
- self.maybe_init_missing_mm_cache_keys(
- multimodal_outputs,
- seq_len=num_tokens_padded,
- )
-
- for mm_out_key, mm_cache in self.mm_outputs_cache.items():
- if mm_out_key in multimodal_outputs:
- # Slice to unpadded portion before caching
- mm_state = multimodal_outputs[mm_out_key][:num_tokens_unpadded]
- mm_state = OmniTensorPrefixCache._coerce_to_cpu_tensor(mm_state)
- flat_cache = mm_cache.view(-1, mm_cache.shape[-1])
- flat_cache[unpadded_slot_mapping] = mm_state
- logger.debug("Writing to mm output cache for %s tokens", num_tokens_unpadded)
-
- def _coerce_to_payload_dict(
- self,
- element: object,
- query_start_loc: torch.Tensor,
- input_batch: InputBatch,
- num_scheduled_tokens: dict[str, int],
- ) -> dict[str, object]:
- """Build the multimodal passthrough data per request for
- the object under consideration. This is identical to the case
- for no prefix cache when we tensor does have a first dimension
- matching the seq len.
- """
- elem_dict = {}
- for req_id in input_batch.req_ids:
- req_idx = input_batch.req_id_to_index[req_id]
- start = query_start_loc[req_idx]
- end = start + num_scheduled_tokens[req_id]
- elem_dict[req_id] = to_payload_element(
- element, req_idx, start=start, end=end, pass_lists_through=True, seq_len=None
- )
- return elem_dict
-
- def get_merged_multimodal_states(
- self,
- query_start_loc: torch.Tensor,
- input_batch: InputBatch,
- multimodal_outputs: dict,
- num_scheduled_tokens: dict[str, int],
- ):
- """Get the merged multimodal states if hidden state prefix caching is enabled."""
- combined_multimodal_outputs = {}
- # First get the prefix cached tensors that are present in the mm data
- for mm_key in self.mm_cache_keys:
- if mm_key in multimodal_outputs:
- combined_multimodal_outputs[mm_key] = self._get_merged_tensors(
- query_start_loc=query_start_loc,
- input_batch=input_batch,
- cache=self.mm_outputs_cache[mm_key],
- hidden_states=multimodal_outputs[mm_key],
- num_scheduled_tokens=num_scheduled_tokens,
- )
-
- # Then, get everything else (passthrough data); first, convert to CPU
- # tensors similarly to the non prefix cached path, and then populate
- # the subdicts mapping request IDs -> payload objects
- passthrough_keys = set(multimodal_outputs.keys()) - self.mm_cache_keys
- passthrough_mm_data = {k: v for k, v in multimodal_outputs.items() if k in passthrough_keys}
- mm_cpu = build_mm_cpu(multimodal_outputs=passthrough_mm_data)
-
- for mm_key, mm_val in mm_cpu.items():
- combined_multimodal_outputs[mm_key] = self._coerce_to_payload_dict(
- element=mm_val,
- query_start_loc=query_start_loc,
- input_batch=input_batch,
- num_scheduled_tokens=num_scheduled_tokens,
- )
- return combined_multimodal_outputs
-
- def get_merged_hidden_states(self, *args, **kwargs) -> dict[str, torch.Tensor]:
- """Get the merged hidden states."""
- return self._get_merged_tensors(
- *args,
- **kwargs,
- cache=self.hidden_states_cache,
- )
-
- def _get_merged_tensors(
- self,
- query_start_loc: torch.Tensor,
- input_batch: InputBatch,
- cache: torch.Tensor,
- hidden_states: torch.Tensor,
- num_scheduled_tokens: dict[str, int],
- ) -> dict[str, torch.Tensor]:
- """When hidden state caching is enabled, takes the input hidden_states,
- which only correspond to the scheduled tokens, and returns a mapping
- from request IDs to their full hidden states. This is accomplished by
- looking up the block IDs & scheduled token counts to split the
- hidden_states.
- """
- # We do not support hybrid caches at the moment.
- if len(input_batch.block_table.block_tables) > 1:
- logger.warning_once(
- "Omni prefix caching is enabled, but the batch block table appears to"
- " have multiple kv groups; only the first group will be used!"
- )
-
- combined_hidden_states = {}
- hidden_states = OmniTensorPrefixCache._coerce_to_cpu_tensor(hidden_states)
- for req_id in input_batch.req_ids:
- req_idx = input_batch.req_id_to_index[req_id]
-
- if req_id in self._new_req_cache_hit_ids:
- block_ids = self._get_cached_block_ids(req_idx, input_batch)
- cached_hs = cache[block_ids].reshape(-1, cache.shape[-1])
-
- # Slice the hidden states corresponding to this request;
- # we do this by using the query start
- start = query_start_loc[req_idx]
- new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]]
- combined_hidden_states[req_id] = torch.cat([cached_hs, new_hs], dim=0)
- else:
- # cache miss for this request, pass through normally
- start = query_start_loc[req_idx]
- new_hs = hidden_states[start : start + num_scheduled_tokens[req_id]]
- combined_hidden_states[req_id] = new_hs
-
- return combined_hidden_states
-
- def _get_cached_block_ids(self, req_idx: int, input_batch: InputBatch) -> torch.Tensor:
- """Given an input batch and request index in the batch (not ID), get the
- block IDs corresponding to the cache hit.
- """
- num_computed = input_batch.num_computed_tokens_cpu[req_idx]
- # NOTE: vLLM only caches full blocks
- num_cached_blocks = num_computed // self.block_size
- # Get the block IDs attached to this cache hit and reindex into
- # the flattened cached hidden states (i.e., 1 row per token).
- return input_batch.block_table[0].block_table.cpu[req_idx, :num_cached_blocks]
diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py
index a5579dd4640..af178d14d27 100644
--- a/vllm_omni/core/sched/omni_ar_scheduler.py
+++ b/vllm_omni/core/sched/omni_ar_scheduler.py
@@ -15,10 +15,9 @@
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.perf import PerfStats
from vllm.v1.outputs import ModelRunnerOutput
-from vllm.v1.request import Request, RequestStatus, StreamingUpdate
+from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
-from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin
from vllm_omni.core.sched.output import OmniSchedulerOutput
from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import (
OmniChunkTransferAdapter,
@@ -39,7 +38,7 @@ def to_dict(self) -> dict[str, Any]:
return asdict(self)
-class OmniARScheduler(OmniSchedulerMixin, VLLMScheduler):
+class OmniARScheduler(VLLMScheduler):
"""
OmniARScheduler: Scheduler for vLLM-Omni multimodal processing.
@@ -60,25 +59,15 @@ def __init__(self, *args, **kwargs):
# Track ACTIVE transfers (submitted to runner but not yet acked via kv_extracted_req_ids)
self.active_kv_transfers: set[str] = set()
- # Requests marked for deferred stop: keep running until KV extraction
- # completes so that kv_ready can be emitted while the request is still
- # alive. Stopped on the first scheduler step after extraction ack.
- self.pending_stop_after_extraction: set[str] = set()
-
# [Omni] Pre-parse KV transfer criteria
self.kv_transfer_criteria = self._get_kv_transfer_criteria()
# Track requests that have already triggered prefill transfer to avoid duplicates
self.transfer_triggered_requests: set[str] = set()
-
- # Cache per-request flag to avoid repeated deserialization of additional_information
- self._omits_kv_transfer_cache: dict[str, bool] = {}
model_config = self.vllm_config.model_config
self.chunk_transfer_adapter = None
if getattr(model_config, "async_chunk", False):
self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config)
- # Snapshot prompt length for each streaming input update
- self._new_prompt_len_snapshot: dict[str, int] = {}
def _get_kv_transfer_criteria(self) -> dict | None:
# Note: vllm_config is available in Scheduler after super().__init__
@@ -93,27 +82,6 @@ def _get_kv_transfer_criteria(self) -> dict | None:
return getattr(omni_kv_config, "kv_transfer_criteria", None)
return None
- def _request_omits_kv_transfer_to_next_stage(self, request: Request) -> bool:
- """True when orchestrator will not run stage 1+ for this request (e.g. text-only).
-
- The result is cached per request to avoid repeated deserialization of
- additional_information on every scheduler tick.
- """
- rid = request.request_id
- cached = self._omits_kv_transfer_cache.get(rid)
- if cached is not None:
- return cached
-
- payload = getattr(request, "additional_information", None)
- if payload is None:
- result = False
- else:
- info = deserialize_additional_information(payload)
- result = info.get("omni_final_stage_id") == 0
-
- self._omits_kv_transfer_cache[rid] = result
- return result
-
def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int]) -> bool:
"""
Check triggers and process side effects (marking transfer).
@@ -123,10 +91,6 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
if not self.kv_transfer_criteria:
return False
- # Text-only requests finalize at stage 0; do not prefill-stop for DiT KV.
- if self._request_omits_kv_transfer_to_next_stage(request):
- return False
-
if request.request_id in self.waiting_for_transfer_free:
return False
@@ -134,16 +98,11 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
stop_decode_on_trigger = self.kv_transfer_criteria.get("stop_after_transfer", True)
if request.request_id in self.transfer_triggered_requests:
- # Deferred stop: once KV extraction is complete (no longer in
- # active_kv_transfers), stop the request. This guarantees the
- # kv_ready signal was emitted while the request was still alive.
- if (
- request.request_id in self.pending_stop_after_extraction
- and request.request_id not in self.active_kv_transfers
- ):
- self.pending_stop_after_extraction.discard(request.request_id)
- request.status = RequestStatus.FINISHED_STOPPED
- return True
+ # Already triggered. When stop_decode_on_trigger is True AND
+ # transfer was actually queued, the request was already stopped
+ # at trigger time (see below). Any request that reaches this
+ # point either has stop_decode_on_trigger=False (continue
+ # decoding) or was not actually queued (should not be stopped).
return False
if criteria_type == "prefill_finished":
@@ -153,11 +112,14 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
actually_queued = request.request_id in self.requests_needing_kv_transfer
if stop_decode_on_trigger and actually_queued:
- # Defer the stop until KV extraction completes so that
- # the kv_ready signal can be emitted while the request
- # is still alive. The request will be stopped on the
- # next scheduler step after extraction ack arrives.
- self.pending_stop_after_extraction.add(request.request_id)
+ # Stop immediately so the request is NOT scheduled in
+ # the next step, freeing scheduling budget for companion
+ # requests whose chunked-prefill boundaries must be
+ # deterministic. waiting_for_transfer_free keeps blocks
+ # alive until the model runner finishes KV extraction.
+ self.waiting_for_transfer_free.add(request.request_id)
+ request.status = RequestStatus.FINISHED_STOPPED
+ return True
return False
@@ -177,7 +139,9 @@ def _process_kv_transfer_trigger(self, request: Request, new_token_ids: list[int
actually_queued = request.request_id in self.requests_needing_kv_transfer
if stop_decode_on_trigger and actually_queued:
- self.pending_stop_after_extraction.add(request.request_id)
+ self.waiting_for_transfer_free.add(request.request_id)
+ request.status = RequestStatus.FINISHED_STOPPED
+ return True
return False
@@ -276,26 +240,6 @@ def update_from_output(
num_scheduled_tokens,
)
- # Pre-process KV extraction acks so that the per-request loop below
- # can see up-to-date active_kv_transfers state and emit kv_ready
- # signals while requests are still alive (before any deferred stop).
- kv_extracted_ids = getattr(model_runner_output, "kv_extracted_req_ids", None)
- if kv_extracted_ids:
- for req_id in kv_extracted_ids:
- try:
- self.active_kv_transfers.discard(req_id)
- req = self.requests.get(req_id)
- if req is not None and not req.is_finished():
- outputs[req.client_index].append(
- EngineCoreOutput(
- request_id=req_id,
- new_token_ids=[],
- kv_transfer_params={"kv_ready": True},
- )
- )
- except Exception:
- init_logger(__name__).exception("Failed to pre-process KV extraction for %s", req_id)
-
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
# the below loop can be a performance bottleneck. We should do our best
# to avoid expensive operations inside the loop.
@@ -341,7 +285,6 @@ def update_from_output(
)
stopped = False
- is_segment_finished = False
new_logprobs = None
new_token_ids = generated_token_ids
pooler_output = pooler_outputs[req_index] if pooler_outputs else None
@@ -370,10 +313,13 @@ def update_from_output(
# Capture finish_reason BEFORE _handle_stopped_request, which may
# reset the status to WAITING for streaming requests that continue.
finish_reason = request.get_finished_reason()
- is_segment_finished = request.is_finished() and request.resumable
finished = self._handle_stopped_request(request)
if finished:
kv_transfer_params = self._free_request(request)
+ if self.chunk_transfer_adapter is not None:
+ self.chunk_transfer_adapter.cleanup_receiver(
+ request.request_id,
+ )
if status_before_stop == RequestStatus.RUNNING:
stopped_running_reqs.add(request)
elif status_before_stop == RequestStatus.WAITING_FOR_CHUNK:
@@ -423,8 +369,6 @@ def update_from_output(
num_external_computed_tokens=request.num_external_computed_tokens,
routed_experts=routed_experts,
num_nans_in_logits=request.num_nans_in_logits,
- is_segment_finished=is_segment_finished,
- new_prompt_len_snapshot=self._new_prompt_len_snapshot.get(req_id, None),
)
)
if self.chunk_transfer_adapter is not None:
@@ -468,7 +412,6 @@ def update_from_output(
self.transfer_triggered_requests.remove(req.request_id)
if req.request_id in self.active_kv_transfers:
self.active_kv_transfers.remove(req.request_id)
- self.pending_stop_after_extraction.discard(req.request_id)
# Same for preempted
for req in stopped_preempted_reqs:
@@ -477,8 +420,6 @@ def update_from_output(
self.transfer_triggered_requests.remove(req.request_id)
if req.request_id in self.active_kv_transfers:
self.active_kv_transfers.remove(req.request_id)
- self.pending_stop_after_extraction.discard(req.request_id)
-
# KV Connector: update state for finished KV Transfers.
if kv_connector_output:
self._update_from_kv_xfer_finished(kv_connector_output)
@@ -524,12 +465,35 @@ def update_from_output(
engine_core_outputs[0] = eco = EngineCoreOutputs()
eco.scheduler_stats = stats
- # Free blocks that were held for transfer (kv_ready and
- # active_kv_transfers updates already done before the per-request loop).
- if kv_extracted_ids:
- for req_id in kv_extracted_ids:
- try:
+ # This is where we free blocks that were held for transfer
+ try:
+ kv_extracted_ids = getattr(model_runner_output, "kv_extracted_req_ids", None)
+ if kv_extracted_ids:
+ for req_id in kv_extracted_ids:
+ # Emit a kv_ready signal so the orchestrator can forward
+ # the request to the DiT stage immediately after KV
+ # extraction, without waiting for AR decode to finish.
+ req = self.requests.get(req_id)
+ if req is not None and not req.is_finished():
+ eco = engine_core_outputs.get(req.client_index)
+ if eco is None:
+ eco = EngineCoreOutputs()
+ engine_core_outputs[req.client_index] = eco
+ eco.outputs.append(
+ EngineCoreOutput(
+ request_id=req_id,
+ new_token_ids=[],
+ kv_transfer_params={"kv_ready": True},
+ )
+ )
+
+ # Mark transfer as finished
+ if req_id in self.active_kv_transfers:
+ self.active_kv_transfers.remove(req_id)
+ logger.debug(f"[Omni] KV Transfer finished for {req_id}")
+
if req_id in self.waiting_for_transfer_free:
+ # Now it's safe to free blocks
req = self.requests.get(req_id)
if req:
self.kv_cache_manager.free(req)
@@ -537,62 +501,27 @@ def update_from_output(
del self.requests[req_id]
if req_id in self.transfer_triggered_requests:
self.transfer_triggered_requests.remove(req_id)
- self.active_kv_transfers.discard(req_id)
- self.pending_stop_after_extraction.discard(req_id)
+ if req_id in self.active_kv_transfers:
+ self.active_kv_transfers.remove(req_id)
+
logger.debug(f"Freed blocks for {req_id} after transfer extraction")
self.waiting_for_transfer_free.remove(req_id)
- except Exception:
- init_logger(__name__).exception("Failed to free blocks for %s after transfer", req_id)
+ except Exception:
+ init_logger(__name__).exception("Failed to process finished transfer requests")
return engine_core_outputs
- def finish_requests(self, request_ids: Any, finished_status: RequestStatus) -> list[tuple[str, int]]:
- """Handles the finish signal from outside the scheduler.
-
- For example, the API server can abort a request when the client
- disconnects.
-
- If request_ids is None, all requests will be finished.
-
- Returns:
- Tuple of (req_id, client_index) for requests that were aborted. Will not
- include any that were already finished.
- """
-
- if self.chunk_transfer_adapter:
- self.chunk_transfer_adapter.finish_requests(request_ids, finished_status, self.requests)
-
- return super().finish_requests(request_ids, finished_status)
-
- def _update_request_as_session(self, session: Request, update: StreamingUpdate) -> None:
- """
- Override: Only extend prompt at stage 0, and replace
- the existing session with the next streaming update at other stages.
-
- Discards the last sampled output token from the prior input chunk at stage 0.
- """
- req_id = session.request_id
- self._new_prompt_len_snapshot[req_id] = len(update.prompt_token_ids)
- if self.vllm_config.model_config.stage_id != 0:
- self._replace_session_with_streaming_update(session, update)
-
- else:
- super()._update_request_as_session(session, update)
-
def _free_request(self, request: Request, delay_free_blocks: bool = False) -> dict[str, Any] | None:
# TODO(wzliu)! for offline mode, we should not end process until all data is transferred
"""Mark a request as finished and free its resources."""
assert request.is_finished()
- self._omits_kv_transfer_cache.pop(request.request_id, None)
-
# 1. Standard cleanup parts from base _free_request
connector_delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request)
request_id = request.request_id
self.finished_req_ids.add(request_id)
- self._new_prompt_len_snapshot.pop(request_id, None)
if self.finished_req_ids_dict is not None:
self.finished_req_ids_dict[request.client_index].add(request_id)
@@ -609,7 +538,8 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di
kv_xfer_params = None
return kv_xfer_params
elif request_id in self.waiting_for_transfer_free:
- # Blocks held until KV extraction completes in a future step.
+ # Stopped immediately by stop_decode_on_trigger; blocks are
+ # held until KV extraction completes in a future step.
return None
else:
logger.debug(
@@ -712,12 +642,7 @@ def _should_transfer_kv_for_request(self, req_id: str) -> bool:
need_send = omni_kv_config.get("need_send_cache", False)
else:
need_send = getattr(omni_kv_config, "need_send_cache", False)
- if not need_send:
- return False
- request = self.requests.get(req_id)
- if request is not None and self._request_omits_kv_transfer_to_next_stage(request):
- return False
- return True
+ return need_send
def has_requests(self) -> bool:
"""Check if there are any requests to process, including KV transfers."""
diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py
index 81f0b7fc2b4..1c4356d4f5e 100644
--- a/vllm_omni/core/sched/omni_generation_scheduler.py
+++ b/vllm_omni/core/sched/omni_generation_scheduler.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
import time
from collections import defaultdict
@@ -13,16 +11,11 @@
from vllm.v1.core.sched.request_queue import create_request_queue
from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler
from vllm.v1.core.sched.utils import remove_all
-from vllm.v1.engine import (
- EngineCoreEventType,
- EngineCoreOutput,
- EngineCoreOutputs,
-)
+from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.perf import PerfStats
-from vllm.v1.request import Request, RequestStatus, StreamingUpdate
+from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
-from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin
from vllm_omni.core.sched.output import OmniCachedRequestData, OmniNewRequestData
from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import (
OmniChunkTransferAdapter,
@@ -32,7 +25,7 @@
logger = init_logger(__name__)
-class OmniGenerationScheduler(OmniSchedulerMixin, VLLMScheduler):
+class OmniGenerationScheduler(VLLMScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
model_config = self.vllm_config.model_config
@@ -331,24 +324,6 @@ def schedule(self) -> SchedulerOutput:
return scheduler_output
- def finish_requests(self, request_ids, finished_status: RequestStatus) -> list[tuple[str, int]]:
- """Handles the finish signal from outside the scheduler.
-
- For example, the API server can abort a request when the client
- disconnects.
-
- If request_ids is None, all requests will be finished.
-
- Returns:
- Tuple of (req_id, client_index) for requests that were aborted. Will not
- include any that were already finished.
- """
-
- if self.chunk_transfer_adapter:
- self.chunk_transfer_adapter.finish_requests(request_ids, finished_status, self.requests)
-
- return super().finish_requests(request_ids, finished_status)
-
"""
Scheduler for the diffusion model.
This scheduler is modified to stop the request immediately for the diffusion model.
@@ -606,11 +581,3 @@ def update_from_output(
eco.scheduler_stats = stats
return engine_core_outputs
-
- def _update_request_as_session(self, session: Request, update: StreamingUpdate) -> None:
- """
- Override: Just replace the existing session with the next streaming update.
-
- Do not expend prompt id using update.
- """
- self._replace_session_with_streaming_update(session, update)
diff --git a/vllm_omni/core/sched/omni_scheduler_mixin.py b/vllm_omni/core/sched/omni_scheduler_mixin.py
deleted file mode 100644
index 36080e63acc..00000000000
--- a/vllm_omni/core/sched/omni_scheduler_mixin.py
+++ /dev/null
@@ -1,33 +0,0 @@
-from __future__ import annotations
-
-from vllm.v1.engine import EngineCoreEventType
-from vllm.v1.request import Request, RequestStatus, StreamingUpdate
-
-
-class OmniSchedulerMixin:
- """Shared scheduler helpers for omni-specific request handling."""
-
- def _replace_session_with_streaming_update(
- self,
- session: Request,
- update: StreamingUpdate,
- ) -> None:
- """For streaming input: Replace an existing streaming session payload with the latest update."""
- session._output_token_ids.clear()
- session._all_token_ids.clear()
- new_prompt = update.prompt_token_ids or ()
- session._all_token_ids.extend(new_prompt)
- session.num_computed_tokens = 0
- session.prompt_token_ids = update.prompt_token_ids or ()
- session.additional_information = update.additional_information or None
- # Update block hashes for the new tokens.
- session.update_block_hashes()
- session.num_prompt_tokens = len(session.prompt_token_ids)
- session.arrival_time = update.arrival_time
- session.sampling_params = update.sampling_params
- if session.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
- self.num_waiting_for_streaming_input -= 1
- session.status = RequestStatus.WAITING
-
- if self.log_stats:
- session.record_event(EngineCoreEventType.QUEUED)
diff --git a/vllm_omni/core/sched/omni_scheduling_coordinator.py b/vllm_omni/core/sched/omni_scheduling_coordinator.py
deleted file mode 100644
index 3b6b4892ecb..00000000000
--- a/vllm_omni/core/sched/omni_scheduling_coordinator.py
+++ /dev/null
@@ -1,380 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Scheduling-side coordination for chunk and full_payload input waiting.
-
-Manages WAITING_FOR_CHUNK and WAITING_FOR_INPUT state transitions
-based on readiness signals from OmniConnectorOutput, without ever
-calling connector.put()/get().
-
-This replaces the scheduling half of OmniChunkTransferAdapter; the
-transport half lives in OmniConnectorModelRunnerMixin.
-"""
-
-from __future__ import annotations
-
-import time
-from collections import deque
-from typing import Any
-
-from vllm.logger import init_logger
-from vllm.v1.request import Request, RequestStatus
-
-logger = init_logger(__name__)
-
-
-class OmniSchedulingCoordinator:
- """Pure-scheduling coordinator for chunk and full_payload input waiting.
-
- The Scheduler owns an instance of this class. It consumes readiness
- signals produced by the Model Runner's ``OmniConnectorModelRunnerMixin``
- (via ``OmniConnectorOutput``) and manages ``WAITING_FOR_CHUNK`` and
- ``WAITING_FOR_INPUT`` state transitions accordingly.
- """
-
- def __init__(self, scheduler_max_num_seqs: int, stage_id: int = 0, async_chunk: bool = False):
- self._stage_id = stage_id
- self._scheduler_max_num_seqs = scheduler_max_num_seqs
- self._async_chunk = async_chunk
-
- self.finished_requests: set[str] = set()
- self.requests_with_ready_chunks: set[str] = set()
- self._full_payload_input_received: set[str] = set()
-
- self._waiting_for_chunk_waiting: deque[Any] = deque()
- self._waiting_for_chunk_running: deque[Any] = deque()
-
- # Request IDs that were newly registered for chunk recv this cycle.
- # The engine/Model Runner should call register_chunk_recv() for these
- # so the bg thread starts polling.
- self.pending_chunk_registrations: list[Any] = []
-
- # Requests waiting for full_payload stage input (WAITING_FOR_INPUT).
- self._waiting_for_input: deque[Any] = deque()
- self.pending_input_registrations: list[Any] = []
-
- # Monotonic timestamp recording when each request first entered
- # WAITING_FOR_CHUNK or WAITING_FOR_INPUT. Used by
- # collect_timed_out_request_ids() to detect orphaned waits.
- self._waiting_since: dict[str, float] = {}
-
- # ------------------------------------------------------------------ #
- # Core scheduling methods
- # ------------------------------------------------------------------ #
-
- def process_pending_chunks(
- self,
- waiting_queue: Any,
- running_queue: list[Request],
- chunk_ready_req_ids: set[str],
- chunk_finished_req_ids: set[str],
- ) -> None:
- """Transition requests whose chunks have arrived.
-
- Args:
- waiting_queue: Scheduler's waiting request queue.
- running_queue: Scheduler's running request list.
- chunk_ready_req_ids: IDs with a newly arrived chunk this cycle.
- chunk_finished_req_ids: IDs whose final chunk has arrived.
- """
- if self._stage_id == 0 or not self._async_chunk:
- return
-
- terminal_ready_req_ids = chunk_ready_req_ids.intersection(chunk_finished_req_ids)
- self.finished_requests.update(chunk_finished_req_ids - terminal_ready_req_ids)
- self.pending_chunk_registrations = []
-
- self._process_chunk_queue(
- waiting_queue,
- self._waiting_for_chunk_waiting,
- RequestStatus.WAITING,
- chunk_ready_req_ids,
- )
- self._process_chunk_queue(
- running_queue,
- self._waiting_for_chunk_running,
- RequestStatus.RUNNING,
- chunk_ready_req_ids,
- )
- self.finished_requests.update(terminal_ready_req_ids)
-
- while len(running_queue) > self._scheduler_max_num_seqs:
- request = running_queue.pop()
- # Must reset status to WAITING so the scheduler treats it as
- # schedulable work. KV blocks are NOT freed here (unlike a
- # real preemption), so PREEMPTED would be incorrect.
- request.status = RequestStatus.WAITING
- waiting_queue.prepend_requests([request])
-
- def process_pending_full_payload_inputs(
- self,
- waiting_queue: Any,
- running_queue: list[Request],
- stage_recv_req_ids: set[str],
- ) -> None:
- """Manage WAITING_FOR_INPUT lifecycle for full_payload_mode.
-
- For non-Stage-0 stages in full_payload_mode (``async_chunk=False``):
- 1. Fresh WAITING requests are transitioned to WAITING_FOR_INPUT
- and registered for bg-thread polling.
- 2. WAITING_FOR_INPUT requests whose data has arrived (in
- ``stage_recv_req_ids``) are transitioned back to WAITING.
- """
- if self._stage_id == 0:
- return
-
- self._full_payload_input_received.update(stage_recv_req_ids)
- if not self._async_chunk and stage_recv_req_ids:
- self.finished_requests.update(stage_recv_req_ids)
- logger.debug(
- "[Coordinator stage-%s] full_payload recv -> finished_requests: %s",
- self._stage_id,
- stage_recv_req_ids,
- )
- self.pending_input_registrations = []
-
- remaining: deque[Any] = deque()
- for request in self._waiting_for_input:
- if request.request_id in stage_recv_req_ids:
- request.status = RequestStatus.WAITING
- self._waiting_since.pop(request.request_id, None)
- waiting_queue.add_request(request)
- else:
- remaining.append(request)
- self._waiting_for_input = remaining
-
- if not self._async_chunk:
- to_remove: list[Any] = []
- queue_snapshot = list(waiting_queue)
- for request in queue_snapshot:
- if request.status == RequestStatus.WAITING:
- if request.request_id in self._full_payload_input_received:
- continue
- if request.request_id in self.requests_with_ready_chunks:
- continue
- if request.request_id in self.finished_requests:
- continue
- request.status = RequestStatus.WAITING_FOR_INPUT
- self._waiting_since.setdefault(request.request_id, time.monotonic())
- to_remove.append(request)
- self._waiting_for_input.append(request)
- self.pending_input_registrations.append(request)
- elif request.status == RequestStatus.WAITING_FOR_INPUT:
- if request.request_id in stage_recv_req_ids:
- request.status = RequestStatus.WAITING
- self._waiting_since.pop(request.request_id, None)
- else:
- to_remove.append(request)
- self._waiting_for_input.append(request)
- self.pending_input_registrations.append(request)
- for request in to_remove:
- waiting_queue.remove(request)
-
- def process_pending_full_payload_inputs_legacy(
- self,
- waiting_queue: Any,
- running_queue: list[Request],
- stage_recv_req_ids: set[str],
- ) -> None:
- """Compatibility wrapper for ``process_pending_full_payload_inputs``."""
- self.process_pending_full_payload_inputs(waiting_queue, running_queue, stage_recv_req_ids)
-
- def free_finished_request(self, request_id: str) -> None:
- """Prune internal tracking sets for a freed request to prevent unbounded growth."""
- self._full_payload_input_received.discard(request_id)
- self.finished_requests.discard(request_id)
- self.requests_with_ready_chunks.discard(request_id)
- self._waiting_since.pop(request_id, None)
-
- def collect_timed_out_request_ids(
- self,
- timeout_s: float,
- ) -> set[str]:
- """Return IDs of requests that have been waiting longer than *timeout_s*.
-
- Uses ``_waiting_since`` timestamps (always up-to-date) to detect
- timed-out requests. This method is safe to call at any point in
- the scheduling cycle — it does **not** rely on coordinator internal
- queues (which are empty after ``restore_queues()``).
-
- Clears ``_waiting_since`` for timed-out IDs and defensively removes
- them from coordinator internal queues if present. The caller
- (scheduler) should then remove the requests from its queues,
- set ``FINISHED_ERROR``, and call ``_free_request()`` so that
- ``cleanup_finished_request()`` fires in the model runner mixin.
- """
- if timeout_s <= 0:
- return set()
- now = time.monotonic()
- timed_out_ids: set[str] = set()
- for req_id, start_time in self._waiting_since.items():
- if now - start_time > timeout_s:
- timed_out_ids.add(req_id)
- if not timed_out_ids:
- return set()
-
- # Defensively remove from coordinator internal queues (may already
- # be empty if restore_queues() has run).
- for queue_attr in (
- "_waiting_for_chunk_waiting",
- "_waiting_for_chunk_running",
- "_waiting_for_input",
- ):
- queue = getattr(self, queue_attr)
- remaining: deque[Any] = deque()
- for request in queue:
- if request.request_id not in timed_out_ids:
- remaining.append(request)
- setattr(self, queue_attr, remaining)
-
- for req_id in timed_out_ids:
- self._waiting_since.pop(req_id, None)
- logger.warning(
- "[Coordinator stage-%s] Request %s timed out waiting for chunk/input (waited > %.0fs)",
- self._stage_id,
- req_id,
- timeout_s,
- )
-
- return timed_out_ids
-
- def restore_queues(
- self,
- waiting_queue: Any,
- running_queue: list[Request],
- ) -> None:
- """Return waiting-for-chunk/input requests to scheduling queues."""
- for request in self._waiting_for_chunk_waiting:
- waiting_queue.add_request(request)
- self._waiting_for_chunk_waiting = deque()
-
- if self._waiting_for_chunk_running:
- running_queue.extend(self._waiting_for_chunk_running)
- self._waiting_for_chunk_running = deque()
-
- for request in self._waiting_for_input:
- waiting_queue.add_request(request)
- self._waiting_for_input = deque()
-
- def update_request_metadata(
- self,
- requests: dict[str, Request],
- request_metadata: dict[str, dict[str, Any]],
- model_mode: str = "ar",
- ) -> None:
- """Apply received scheduling metadata to request objects.
-
- For AR mode: only scheduler-visible metadata is applied locally.
- For Generation mode: updates ``request.prompt_token_ids``.
-
- Additionally, if the payload contains ``next_stage_prompt_len``,
- updates the request's ``prompt_token_ids`` to the correct length.
- """
- for req_id, metadata in request_metadata.items():
- request = requests.get(req_id)
- if request is None:
- continue
-
- # Handle next_stage_prompt_len if present (for models like Qwen3-Omni).
- # Only apply when the request has not started decoding yet
- # (no output tokens). Resetting a mid-decode request would
- # destroy generated tokens and desync KV cache state.
- if "next_stage_prompt_len" in metadata:
- next_len = metadata["next_stage_prompt_len"]
- if isinstance(next_len, int) and next_len > 0:
- output_token_ids = getattr(request, "_output_token_ids", None)
- has_decode_output = output_token_ids is not None and len(output_token_ids) > 0
- if has_decode_output:
- logger.debug(
- "[Coordinator stage-%s] Skipping prompt resize for req %s: "
- "request already has %s output tokens",
- self._stage_id,
- req_id,
- len(output_token_ids),
- )
- else:
- current_prompt_ids = getattr(request, "prompt_token_ids", []) or []
- current_prompt_len = len(current_prompt_ids)
- if current_prompt_len != next_len or getattr(request, "num_prompt_tokens", None) != next_len:
- new_prompt = [0] * next_len
- request.prompt_token_ids = new_prompt
- request.num_prompt_tokens = next_len
- request._all_token_ids.clear()
- request._all_token_ids.extend(new_prompt)
- request._output_token_ids.clear()
- request.num_computed_tokens = 0
- logger.debug(
- "[Coordinator stage-%s] Updated prompt_token_ids length to %s for req %s",
- self._stage_id,
- next_len,
- req_id,
- )
-
- if model_mode != "ar":
- new_ids = metadata.get("code_predictor_codes", [])
- runtime_seed = None
- if "left_context_size" in metadata:
- runtime_seed = {
- "meta": {"left_context_size": metadata["left_context_size"]},
- }
- request._omni_initial_model_buffer = runtime_seed
- if new_ids:
- request.prompt_token_ids = new_ids
- request.num_computed_tokens = 0
-
- def postprocess_scheduler_output(
- self,
- scheduler_output: Any,
- requests: dict[str, Request] | None = None,
- ) -> None:
- """Clear per-cycle ready state after scheduler output is materialized."""
- self._clear_chunk_ready(scheduler_output)
-
- # ------------------------------------------------------------------ #
- # Internal helpers
- # ------------------------------------------------------------------ #
-
- def _process_chunk_queue(
- self,
- queue: Any,
- waiting_for_chunk_list: deque[Any],
- target_status: RequestStatus,
- chunk_ready_req_ids: set[str],
- ) -> None:
- queue_snapshot = list(queue)
- for request in queue_snapshot:
- if request.status != RequestStatus.WAITING_FOR_CHUNK:
- if request.request_id in self.requests_with_ready_chunks:
- continue
- if request.request_id in self.finished_requests:
- continue
- if request.status == RequestStatus.WAITING_FOR_INPUT:
- continue
- if request.request_id in chunk_ready_req_ids:
- self.requests_with_ready_chunks.add(request.request_id)
- continue
- self.pending_chunk_registrations.append(request)
- request.status = RequestStatus.WAITING_FOR_CHUNK
- self._waiting_since.setdefault(request.request_id, time.monotonic())
- else:
- if request.request_id in chunk_ready_req_ids:
- request.status = target_status
- self.requests_with_ready_chunks.add(request.request_id)
- self._waiting_since.pop(request.request_id, None)
- continue
- queue.remove(request)
- waiting_for_chunk_list.append(request)
-
- def _clear_chunk_ready(self, scheduler_output: Any) -> None:
- if scheduler_output.scheduled_new_reqs:
- for req_data in scheduler_output.scheduled_new_reqs:
- self.requests_with_ready_chunks.discard(
- getattr(req_data, "req_id", None),
- )
-
- if scheduler_output.scheduled_cached_reqs:
- for req_id in scheduler_output.scheduled_cached_reqs.req_ids:
- self.requests_with_ready_chunks.discard(req_id)
-
-
-# Backward-compatible alias
-ChunkSchedulingCoordinator = OmniSchedulingCoordinator
diff --git a/vllm_omni/data_entry_keys.py b/vllm_omni/data_entry_keys.py
deleted file mode 100644
index f22aad9d46d..00000000000
--- a/vllm_omni/data_entry_keys.py
+++ /dev/null
@@ -1,277 +0,0 @@
-"""Structured payload types for inter-stage communication.
-
-Adding a new model?
-~~~~~~~~~~~~~~~~~~~
-Every key you put into the inter-stage payload (``additional_information``,
-``multimodal_output``, ``pooling_output``) **must** use the nested
-``OmniPayload`` TypedDict structure. For each category, every known
-qualifier is an explicit field so misspellings are caught statically.
-
-Categories
- hidden_states – intermediate / output hidden-state tensors
- embed – embedding tensors (prefill, decode, special tokens)
- ids – token-ID sequences
- codes – codec / audio code tensors
- meta – scalar metadata, control flags, shapes
-
-This module provides:
-- Structured ``TypedDict`` types for static type checking (``OmniPayload``)
-- ``serialize_payload`` / ``deserialize_payload`` for transport across
- process boundaries via ``AdditionalInformationPayload``
-"""
-
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Any, TypedDict
-
-import numpy as np
-import torch
-
-if TYPE_CHECKING:
- from vllm_omni.engine import AdditionalInformationEntry, AdditionalInformationPayload
-
-# ── Structured payload types ──
-# These are TypedDicts (plain dicts at runtime, zero overhead) that give
-# static type checking and IDE autocomplete for inter-stage payloads.
-# Every field is optional (total=False) because each stage only populates
-# the subset it needs.
-
-
-class HiddenStates(TypedDict, total=False):
- output: torch.Tensor
- trailing_text: torch.Tensor
- last: torch.Tensor
- layers: dict[int, torch.Tensor]
-
-
-class Embeddings(TypedDict, total=False):
- prefill: torch.Tensor
- decode: torch.Tensor
- cached_decode: torch.Tensor
- tts_bos: torch.Tensor
- tts_eos: torch.Tensor
- tts_pad: torch.Tensor
- tts_pad_projected: torch.Tensor
- voice: torch.Tensor
- speech_feat: torch.Tensor
- thinker_reply: torch.Tensor
-
-
-class Codes(TypedDict, total=False):
- audio: torch.Tensor
- ref: torch.Tensor
-
-
-class Ids(TypedDict, total=False):
- all: list[int]
- prompt: list[int]
- output: list[int]
- speech_token: list[int]
- prior_image: list[int]
-
-
-class OmniPayloadMeta(TypedDict, total=False):
- finished: torch.Tensor
- left_context_size: int
- override_keys: list[tuple[str, str]]
- num_processed_tokens: int
- next_stage_prompt_len: int
- ar_width: int
- eol_token_id: int
- visual_token_start_id: int
- visual_token_end_id: int
- gen_token_mask: torch.Tensor
- omni_task: list[str]
- height: int
- width: int
- decode_flag: bool
- codec_streaming: bool
- ref_code_len: int
- talker_prefill_offset: int
-
-
-class OmniPayload(TypedDict, total=False):
- hidden_states: HiddenStates
- embed: Embeddings
- ids: Ids
- codes: Codes
- meta: OmniPayloadMeta
- latent: torch.Tensor
- generated_len: int
- model_outputs: list[torch.Tensor]
- mtp_inputs: tuple[torch.Tensor, torch.Tensor]
- speaker: Any
- language: Any
- request_id: str
-
-
-# ── Keys whose values are nested dicts (TypedDict sub-categories) ──
-_NESTED_KEYS = frozenset({"hidden_states", "embed", "ids", "codes", "meta"})
-
-# Sub-TypedDict for each nested category, used by runtime validation.
-_NESTED_SCHEMAS: dict[str, type] = {
- "hidden_states": HiddenStates,
- "embed": Embeddings,
- "ids": Ids,
- "codes": Codes,
- "meta": OmniPayloadMeta,
-}
-
-_ROOT_KEYS: frozenset[str] = frozenset(OmniPayload.__annotations__.keys())
-
-
-def assert_payload(payload: dict[str, Any], *, context: str = "payload") -> None:
- """Validate ``payload`` matches the ``OmniPayload`` nested schema.
-
- TypedDict is a static-only contract in Python; this helper closes the
- loop at runtime by rejecting:
- * non-dict payloads
- * top-level keys not declared on ``OmniPayload``
- * nested-category values that aren't dicts
- * sub-keys not declared on the matching nested TypedDict
-
- Call at producer/consumer boundaries when a schema violation should
- crash the pipeline instead of silently degrading audio quality.
- """
- assert isinstance(payload, dict), f"{context}: expected dict, got {type(payload).__name__}"
- extra_top = set(payload) - _ROOT_KEYS
- assert not extra_top, f"{context}: unknown top-level keys {sorted(extra_top)!r}"
- for nested_key, schema in _NESTED_SCHEMAS.items():
- if nested_key not in payload:
- continue
- sub = payload[nested_key]
- assert isinstance(sub, dict), f"{context}: payload[{nested_key!r}] must be dict, got {type(sub).__name__}"
- known_sub = frozenset(schema.__annotations__.keys())
- extra_sub = set(sub) - known_sub
- assert not extra_sub, f"{context}: payload[{nested_key!r}] unknown sub-keys {sorted(extra_sub)!r}"
-
-
-def flatten_payload(payload: dict[str, Any]) -> dict[str, Any]:
- """Flatten a nested ``OmniPayload`` to dotted keys.
-
- Nested sub-dicts under ``_NESTED_KEYS`` are expanded:
- ``{"codes": {"audio": tensor}}`` → ``{"codes.audio": tensor}``.
- ``hidden_states["layers"]`` is expanded to ``hidden_states.layer_N``.
- Top-level values are kept as-is.
- """
- if not payload:
- return {}
- flat: dict[str, Any] = {}
- for key, value in payload.items():
- if key in _NESTED_KEYS and isinstance(value, dict):
- for qual, val in value.items():
- if qual == "layers" and key == "hidden_states" and isinstance(val, dict):
- for layer_idx, tensor in val.items():
- flat[f"hidden_states.layer_{layer_idx}"] = tensor
- else:
- flat[f"{key}.{qual}"] = val
- else:
- flat[key] = value
- return flat
-
-
-def unflatten_payload(flat: dict[str, Any]) -> dict[str, Any]:
- """Unflatten dotted keys back to nested dicts.
-
- Reverse of :func:`flatten_payload`.
- ``hidden_states.layer_N`` keys are collected into ``hidden_states.layers``.
- """
- result: dict[str, Any] = {}
- for key, value in flat.items():
- if "." in key:
- type_key, qualifier = key.split(".", 1)
- sub = result.setdefault(type_key, {})
- if type_key == "hidden_states" and qualifier.startswith("layer_"):
- layers = sub.setdefault("layers", {})
- layer_idx = int(qualifier[len("layer_") :])
- layers[layer_idx] = value
- else:
- sub[qualifier] = value
- else:
- result[key] = value
- return result
-
-
-# ── dtype helpers ──
-_DTYPE_TO_NAME: dict[torch.dtype, str] = {
- torch.float32: "float32",
- torch.float16: "float16",
- torch.bfloat16: "bfloat16",
- torch.float64: "float64",
- torch.int64: "int64",
- torch.int32: "int32",
- torch.int16: "int16",
- torch.int8: "int8",
- torch.uint8: "uint8",
- torch.bool: "bool",
-}
-
-
-def _dtype_to_name(dtype: torch.dtype) -> str:
- return _DTYPE_TO_NAME.get(dtype, str(dtype).replace("torch.", ""))
-
-
-def _serialize_tensor(t: torch.Tensor) -> AdditionalInformationEntry:
- from vllm_omni.engine import AdditionalInformationEntry
-
- t_cpu = t.detach().to("cpu").contiguous()
- return AdditionalInformationEntry(
- tensor_data=t_cpu.numpy().tobytes(),
- tensor_shape=list(t_cpu.shape),
- tensor_dtype=_dtype_to_name(t_cpu.dtype),
- )
-
-
-def _deserialize_tensor(entry: AdditionalInformationEntry) -> torch.Tensor:
- dt = np.dtype(entry.tensor_dtype or "float32")
- arr = np.frombuffer(entry.tensor_data, dtype=dt) # type: ignore[arg-type]
- arr = arr.reshape(entry.tensor_shape)
- return torch.from_numpy(arr.copy())
-
-
-def serialize_payload(
- payload: OmniPayload,
-) -> AdditionalInformationPayload | None:
- """Serialize an ``OmniPayload`` for EngineCore transport.
-
- Uses :func:`flatten_payload` to produce dotted keys, then converts
- each value to an ``AdditionalInformationEntry``.
- """
- from vllm_omni.engine import (
- AdditionalInformationEntry,
- AdditionalInformationPayload,
- )
-
- flat = flatten_payload(payload)
- entries: dict[str, AdditionalInformationEntry] = {}
-
- for key, value in flat.items():
- if isinstance(value, torch.Tensor):
- entries[key] = _serialize_tensor(value)
- elif isinstance(value, list):
- entries[key] = AdditionalInformationEntry(list_data=value)
- elif value is not None:
- entries[key] = AdditionalInformationEntry(scalar_data=value)
-
- return AdditionalInformationPayload(entries=entries) if entries else None
-
-
-def deserialize_payload(
- wire: AdditionalInformationPayload,
-) -> OmniPayload:
- """Deserialize an ``AdditionalInformationPayload`` back to ``OmniPayload``.
-
- Decodes entries to tensors/lists, then uses :func:`unflatten_payload`
- to reconstruct the nested structure.
- """
- flat: dict[str, Any] = {}
-
- for key, entry in wire.entries.items():
- if entry.tensor_data is not None:
- flat[key] = _deserialize_tensor(entry)
- elif entry.list_data is not None:
- flat[key] = entry.list_data
- elif entry.scalar_data is not None:
- flat[key] = entry.scalar_data
-
- return unflatten_payload(flat) # type: ignore[return-value]
diff --git a/vllm_omni/deploy/bagel.yaml b/vllm_omni/deploy/bagel.yaml
deleted file mode 100644
index 9d2f1f8fffa..00000000000
--- a/vllm_omni/deploy/bagel.yaml
+++ /dev/null
@@ -1,48 +0,0 @@
-# BAGEL-7B-MoT deploy: CUDA defaults, verified on NVIDIA A100 (80GB).
-#
-# Stage 0 (Thinker) and Stage 1 (DiT) share a single GPU by default.
-# For dual-GPU setups, set stage 1 devices to "1".
-#
-# Fields omitted from a stage fall back to StageDeployConfig dataclass
-# defaults (see vllm_omni/config/stage_config.py).
-
-async_chunk: false
-
-stages:
- - stage_id: 0
- max_num_seqs: 3
- gpu_memory_utilization: 0.45
- devices: "0"
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 52
- detokenize: true
- repetition_penalty: 1.05
-
- - stage_id: 1
- max_num_seqs: 1
- enforce_eager: true
- devices: "0"
- input_connectors:
- from_stage_0: shared_memory_connector
- default_sampling_params:
- seed: 52
-
-connectors:
- shared_memory_connector:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
-
- rdma_connector:
- name: MooncakeTransferEngineConnector
- extra:
- host: "auto"
- zmq_port: 50051
- protocol: "rdma"
- device_name: ""
- memory_pool_size: 4294967296
- memory_pool_device: "cpu"
diff --git a/vllm_omni/deploy/bagel_single_stage.yaml b/vllm_omni/deploy/bagel_single_stage.yaml
deleted file mode 100644
index 8470124ec78..00000000000
--- a/vllm_omni/deploy/bagel_single_stage.yaml
+++ /dev/null
@@ -1,22 +0,0 @@
-# BAGEL-7B-MoT single-stage deploy: all modalities handled by the DiT stage.
-#
-# The DiT stage contains a full LLM (Qwen2-MoT), ViT, VAE, and tokenizer,
-# so it supports text2img, img2img, img2text, text2text, and think mode
-# without a separate Thinker (AR) stage.
-#
-# Select this topology via:
-# vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni \
-# --deploy-config vllm_omni/deploy/bagel_single_stage.yaml
-#
-# Or programmatically:
-# Omni(model="...", deploy_config_path="vllm_omni/deploy/bagel_single_stage.yaml")
-
-pipeline: bagel_single_stage
-async_chunk: false
-
-stages:
- - stage_id: 0
- max_num_seqs: 1
- devices: "0"
- default_sampling_params:
- seed: 52
diff --git a/vllm_omni/deploy/bagel_think.yaml b/vllm_omni/deploy/bagel_think.yaml
deleted file mode 100644
index d7adf7f66b3..00000000000
--- a/vllm_omni/deploy/bagel_think.yaml
+++ /dev/null
@@ -1,16 +0,0 @@
-# BAGEL-7B-MoT think-mode deploy.
-#
-# Inherits all settings from bagel.yaml; only overrides the pipeline to
-# bagel_think which uses expand_cfg_prompts_think and omits
-# kv_transfer_criteria so the Thinker decodes tokens before
-# transferring KV to DiT.
-#
-# Select this topology via:
-# python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT --think
-#
-# Or explicitly:
-# vllm serve ByteDance-Seed/BAGEL-7B-MoT --omni \
-# --deploy-config vllm_omni/deploy/bagel_think.yaml
-
-base_config: bagel.yaml
-pipeline: bagel_think
diff --git a/vllm_omni/deploy/cosyvoice3.yaml b/vllm_omni/deploy/cosyvoice3.yaml
deleted file mode 100644
index 53e3eb3f301..00000000000
--- a/vllm_omni/deploy/cosyvoice3.yaml
+++ /dev/null
@@ -1,58 +0,0 @@
-# CosyVoice3 deploy: 2-stage talker → flow-matching code2wav.
-#
-# Default mode is async-chunk streaming through SharedMemoryConnector. Pass
-# ``--no-async-chunk`` to the serve CLI for the legacy sync path (stage 1
-# builds flow input from the full speech-token sequence via ``text2flow``).
-# The shared-memory connector definition and input/output_connectors
-# references stay in the yaml unconditionally; the runtime only activates
-# them when ``async_chunk: true``.
-#
-# enforce_eager=true everywhere:
-# * Stage 0 (talker) — cudagraph not verified on this checkpoint.
-# * Stage 1 (code2wav) — CUDA graphs don't work with dynamic conv shapes.
-async_chunk: true
-dtype: float32
-
-connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- codec_streaming: true
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- codec_chunk_frames: 25
- codec_left_context_frames: 25
- codec_vocab_size: 6561
-
-stages:
- - stage_id: 0
- max_num_seqs: 1
- gpu_memory_utilization: 0.4
- enforce_eager: true
- devices: "0"
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- max_tokens: 2048
- top_k: 25
- top_p: 0.8
- # near-identity repetition penalty forces vLLM to track
- # output_token_ids for RAS (stop-token logit logsumexp).
- repetition_penalty: 1.0001
- disable_hybrid_kv_cache_manager: true
- mm_processor_cache_gb: 0
- skip_mm_profiling: true
-
- - stage_id: 1
- max_num_seqs: 1
- gpu_memory_utilization: 0.2
- enforce_eager: true
- max_model_len: 32768
- devices: "0"
- input_connectors:
- from_stage_0: connector_of_shared_memory
- default_sampling_params:
- max_tokens: 2048
- disable_hybrid_kv_cache_manager: true
- skip_mm_profiling: true
diff --git a/vllm_omni/deploy/fish_qwen3_omni.yaml b/vllm_omni/deploy/fish_qwen3_omni.yaml
deleted file mode 100644
index a5bee925b68..00000000000
--- a/vllm_omni/deploy/fish_qwen3_omni.yaml
+++ /dev/null
@@ -1,61 +0,0 @@
-# Fish Speech S2 Pro deploy: async-chunk streaming slow_ar → dac_decoder.
-# Verified on 1× H20.
-# Registry key and filename use the HF ``model_type=fish_qwen3_omni`` for
-# auto-detection; human-readable "fish_speech" stays as the source directory.
-async_chunk: true
-enable_chunked_prefill: false
-
-connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
- codec_streaming: true
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- # ~21 Hz codec; 25 frames ≈ 1.16s of audio.
- codec_chunk_frames: 25
- codec_left_context_frames: 25
- initial_codec_chunk_frames: 4
-
-stages:
- - stage_id: 0
- max_num_seqs: 4
- gpu_memory_utilization: 0.6
- enforce_eager: false
- async_scheduling: false
- # vLLM >=0.19 requires max_num_batched_tokens >= max_model_len when
- # enable_chunked_prefill=false. Bumped from legacy 3072 to match
- # max_model_len; chunked prefill stays off because the SlowAR decode
- # loop isn't chunked-prefill-safe.
- max_num_batched_tokens: 16384
- max_model_len: 16384
- devices: "0"
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.8
- top_k: 30
- top_p: 0.9
- max_tokens: 2048
- seed: 42
- repetition_penalty: 1.0
-
- - stage_id: 1
- max_num_seqs: 1
- gpu_memory_utilization: 0.1
- enforce_eager: true
- async_scheduling: false
- max_num_batched_tokens: 16384
- max_model_len: 16384
- devices: "0"
- input_connectors:
- from_stage_0: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- repetition_penalty: 1.0
diff --git a/vllm_omni/deploy/glm_image.yaml b/vllm_omni/deploy/glm_image.yaml
deleted file mode 100644
index 28b88fb429a..00000000000
--- a/vllm_omni/deploy/glm_image.yaml
+++ /dev/null
@@ -1,43 +0,0 @@
-# GLM-Image deploy: AR (stage 0) + Diffusion (stage 1).
-# Topology declared in vllm_omni/model_executor/models/glm_image/pipeline.py.
-#
-# Fields omitted from a stage fall back to StageDeployConfig dataclass
-# defaults (see vllm_omni/config/stage_config.py). Pipeline-wide settings
-# (trust_remote_code, distributed_executor_backend, etc.) use DeployConfig
-# defaults unless overridden here.
-async_chunk: false
-
-stages:
- # Stage 0: AR Model (GlmImageForConditionalGeneration)
- # Generates prior_token_ids for conditioning the diffusion process.
- # max_tokens is set to the theoretical maximum (2048x2048 t2i ≈ 4353
- # tokens). The AR model stops naturally at EOS (16385) well before
- # this ceiling at lower resolutions. Do NOT override max_tokens per
- # request — the actual resolution is communicated via target_h/w.
- - stage_id: 0
- max_num_seqs: 1
- gpu_memory_utilization: 0.6
- enforce_eager: false
- max_num_batched_tokens: 32768
- devices: "0"
- default_sampling_params:
- temperature: 0.9
- top_p: 0.75
- top_k: 16512
- stop_token_ids: [16385]
- max_tokens: 4353
- seed: 42
- detokenize: false
-
- # Stage 1: Diffusion (DiT + VAE)
- # Receives prior_token_ids from AR, performs denoising + VAE decode.
- - stage_id: 1
- max_num_seqs: 1
- enforce_eager: true
- devices: "1"
- default_sampling_params:
- seed: 42
- num_inference_steps: 50
- guidance_scale: 1.5
- height: 1024
- width: 1024
diff --git a/vllm_omni/deploy/mimo_audio.yaml b/vllm_omni/deploy/mimo_audio.yaml
deleted file mode 100644
index f5e704f9bd4..00000000000
--- a/vllm_omni/deploy/mimo_audio.yaml
+++ /dev/null
@@ -1,56 +0,0 @@
-# MiMo Audio deploy: 2-stage thinker+talker → code2wav.
-#
-# Default mode is async-chunk streaming on a single GPU (both stages on
-# device 0) through SharedMemoryConnector. For the legacy 2-GPU sync
-# pipeline (stage 1 on a second card), pass ``--no-async-chunk
-# --stage-1-devices 1 --stage-1-max-model-len 18192
-# --stage-1-max-num-batched-tokens 18192`` to the serve CLI.
-async_chunk: true
-dtype: bfloat16
-
-connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 365536
- codec_streaming: true
- connector_get_sleep_s: 0.001
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- codec_chunk_frames: 3
- codec_left_context_frames: 3
-
-stages:
- - stage_id: 0
- max_num_seqs: 1
- gpu_memory_utilization: 0.3
- enforce_eager: true
- max_num_batched_tokens: 8192
- max_model_len: 8192
- devices: "0"
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.6
- top_p: 0.95
- top_k: 50
- max_tokens: 18192
- seed: 42
- repetition_penalty: 1.1
-
- - stage_id: 1
- max_num_seqs: 1
- gpu_memory_utilization: 0.2
- enforce_eager: true
- async_scheduling: false
- max_num_batched_tokens: 8192
- max_model_len: 8192
- devices: "0"
- input_connectors:
- from_stage_0: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 18192
- seed: 42
diff --git a/vllm_omni/deploy/qwen2_5_omni.yaml b/vllm_omni/deploy/qwen2_5_omni.yaml
deleted file mode 100644
index 41aef0df6f6..00000000000
--- a/vllm_omni/deploy/qwen2_5_omni.yaml
+++ /dev/null
@@ -1,92 +0,0 @@
-# Qwen2.5-Omni deploy: CUDA defaults + platform overrides, verified on 2x H100.
-# Stage 2 disables flashinfer autotune because its DiT block never invokes
-# flashinfer; the autotune dummy run OOMs the shared cuda:0 device otherwise.
-#
-# Fields omitted from a stage fall back to StageDeployConfig dataclass
-# defaults (see vllm_omni/config/stage_config.py). For instance, every
-# stage here uses vLLM's default max_num_batched_tokens=32768 because
-# chat-sized prefill comfortably fits; only models with codec prefill
-# (Qwen3-Omni, Qwen3-TTS) need to bump it above 32k.
-#
-# enforce_eager policy across the three deploy YAMLs:
-# * code2wav / generation stages: always true (cudagraph incompatible with
-# the custom generation loop — set explicitly everywhere).
-# * AR stages (thinker, talker): model-dependent. Qwen2.5-Omni runs eager
-# on CUDA (thinker uses custom ops that don't trace cleanly); NPU / XPU
-# platform overrides flip back to false where cudagraph is verified.
-# Qwen3-Omni / Qwen3-TTS AR stages use the default (false = cudagraph on).
-async_chunk: false
-
-stages:
- - stage_id: 0
- max_num_seqs: 1
- gpu_memory_utilization: 0.8
- enforce_eager: true
- mm_processor_cache_gb: 0
- devices: "0"
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- repetition_penalty: 1.1
-
- - stage_id: 1
- max_num_seqs: 1
- gpu_memory_utilization: 0.8
- enforce_eager: true
- devices: "1"
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 2048
- seed: 42
- repetition_penalty: 1.05
-
- - stage_id: 2
- max_num_seqs: 1
- gpu_memory_utilization: 0.15
- enforce_eager: true
- enable_flashinfer_autotune: false
- async_scheduling: false
- devices: "0"
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- repetition_penalty: 1.1
-
-platforms:
- npu:
- stages:
- # NPU has cudagraph support for the thinker, unlike GPU which still
- # only runs eager.
- - stage_id: 0
- enforce_eager: false
- - stage_id: 2
- # 3-NPU layout: stage 2 lives on its own card.
- devices: "2"
-
- rocm:
- stages:
- - stage_id: 2
- # 3-GPU MI325 layout: stage 2 on a separate card.
- devices: "2"
-
- xpu:
- stages:
- # Verified on 2x Intel Arc Pro B60. Both AR stages use cudagraphs.
- - stage_id: 0
- gpu_memory_utilization: 0.9
- enforce_eager: false
- - stage_id: 1
- gpu_memory_utilization: 0.5
- enforce_eager: false
- - stage_id: 2
- gpu_memory_utilization: 0.3
- # Stage 2 colocates with stage 1's device on XPU.
- devices: "1"
diff --git a/vllm_omni/deploy/qwen3_omni_moe.yaml b/vllm_omni/deploy/qwen3_omni_moe.yaml
deleted file mode 100644
index 39baed6bd7b..00000000000
--- a/vllm_omni/deploy/qwen3_omni_moe.yaml
+++ /dev/null
@@ -1,98 +0,0 @@
-# Qwen3-Omni-MoE production deploy, verified on 2x H100 (stage 0 on cuda:0,
-# stages 1+2 on cuda:1).
-#
-# Fields omitted from a stage fall back to StageDeployConfig defaults (see
-# vllm_omni/config/stage_config.py). Notable implicit defaults for this
-# model:
-# * Stages 0/1 (thinker, talker) do not set max_num_batched_tokens —
-# chat-sized prefill fits in the 32768 default.
-# * Stages 0/1 do not set enforce_eager — cudagraph runs by default
-# (false). Stage 2 (code2wav) sets true because its generation loop
-# is cudagraph-incompatible.
-# * Platform sections flip enforce_eager per-stage where platform
-# cudagraph support differs.
-async_chunk: true
-
-connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- codec_chunk_frames: 25
- codec_left_context_frames: 25
-
-stages:
- - stage_id: 0
- gpu_memory_utilization: 0.9
- devices: "0"
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- repetition_penalty: 1.05
-
- - stage_id: 1
- gpu_memory_utilization: 0.6
- devices: "1"
- input_connectors:
- from_stage_0: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- repetition_penalty: 1.05
-
- - stage_id: 2
- gpu_memory_utilization: 0.1
- enforce_eager: true
- async_scheduling: false
- max_num_batched_tokens: 51200
- devices: "1"
- input_connectors:
- from_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- repetition_penalty: 1.1
-
-platforms:
- npu:
- stages:
- - stage_id: 0
- gpu_memory_utilization: 0.6
- tensor_parallel_size: 2
- max_num_batched_tokens: 8192
- devices: "0,1"
- - stage_id: 1
- gpu_memory_utilization: 0.6
- max_num_batched_tokens: 8192
- devices: "2"
- - stage_id: 2
- gpu_memory_utilization: 0.3
- devices: "2"
-
- rocm:
- stages:
- - stage_id: 0
- enforce_eager: true
-
- xpu:
- stages:
- - stage_id: 0
- tensor_parallel_size: 4
- enforce_eager: true
- max_cudagraph_capture_size: 0
- devices: "0,1,2,3"
- - stage_id: 1
- enforce_eager: true
- max_cudagraph_capture_size: 0
- devices: "4"
- - stage_id: 2
- gpu_memory_utilization: 0.3
- max_cudagraph_capture_size: 0
- devices: "4"
diff --git a/vllm_omni/deploy/qwen3_tts.yaml b/vllm_omni/deploy/qwen3_tts.yaml
deleted file mode 100644
index 522ea7c58c8..00000000000
--- a/vllm_omni/deploy/qwen3_tts.yaml
+++ /dev/null
@@ -1,73 +0,0 @@
-# Qwen3-TTS deploy: talker → code2wav via shared-memory chunk streaming.
-# Verified on 1x H100.
-#
-# Fields omitted from a stage fall back to StageDeployConfig defaults (see
-# vllm_omni/config/stage_config.py). Notable choices for this model:
-# * Stage 0 (talker) sets max_num_batched_tokens=512 for async-chunk
-# latency tuning (not correctness) — small per-step batches keep
-# first-chunk latency low.
-# * Stage 1 (code2wav) sets max_num_batched_tokens=65536 for correctness:
-# codec prefill length (Q * num_frames) exceeds the 32k default.
-# * Stage 0 does not set enforce_eager — talker runs cudagraph by default.
-# Stage 1 sets true because its codec generation loop is not
-# cudagraph-compatible. NPU platform flips stage 0 to true where
-# cudagraph is not yet verified.
-async_chunk: true
-
-connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
- codec_streaming: true
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- # Must match the decoder sliding attention window.
- codec_chunk_frames: 25
- codec_left_context_frames: 72
-
-stages:
- - stage_id: 0
- max_num_seqs: 10
- gpu_memory_utilization: 0.3
- async_scheduling: true
- max_num_batched_tokens: 512
- max_model_len: 4096
- devices: "0"
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- repetition_penalty: 1.05
- subtalker_sampling_params:
- do_sample: true
- temperature: 0.9
- top_k: 50
- top_p: 1.0
-
- - stage_id: 1
- max_num_seqs: 1
- gpu_memory_utilization: 0.3
- enforce_eager: true
- async_scheduling: true
- # Must be divisible by num_code_groups and cover (left_context + chunk).
- # Prefill length is Q * num_frames (e.g. 16 * 2148 = 34368); keep
- # headroom past 32k.
- max_num_batched_tokens: 65536
- # async_chunk appends windows per step; max_model_len must cover the
- # accumulated flat codec stream.
- max_model_len: 65536
- devices: "0"
- input_connectors:
- from_stage_0: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- repetition_penalty: 1.0
diff --git a/vllm_omni/deploy/voxcpm2.yaml b/vllm_omni/deploy/voxcpm2.yaml
deleted file mode 100644
index b49906710df..00000000000
--- a/vllm_omni/deploy/voxcpm2.yaml
+++ /dev/null
@@ -1,29 +0,0 @@
-# VoxCPM2 deploy: single-stage AR pipeline with per-request state batching.
-# Verified on 1x H20 141GB.
-#
-# Fields omitted from a stage fall back to StageDeployConfig defaults (see
-# vllm_omni/config/stage_config.py). Notable choices:
-# * enforce_eager=true because the talker's KV-cache save/restore loop is
-# not cudagraph-compatible (captured separately via voxcpm2_talker's own
-# _CapturedGraph path).
-# * max_num_seqs=4 matches the legacy max_batch_size=4 contract: concurrent
-# requests are supported via per-request StaticKVCache save/restore.
-async_chunk: false
-dtype: bfloat16
-
-stages:
- - stage_id: 0
- max_num_seqs: 4
- gpu_memory_utilization: 0.9
- enforce_eager: true
- async_scheduling: true
- max_num_batched_tokens: 4096
- max_model_len: 4096
- devices: "0"
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 4096
- seed: 42
- repetition_penalty: 1.0
diff --git a/vllm_omni/deploy/voxtral_tts.yaml b/vllm_omni/deploy/voxtral_tts.yaml
deleted file mode 100644
index 87d999c67e0..00000000000
--- a/vllm_omni/deploy/voxtral_tts.yaml
+++ /dev/null
@@ -1,67 +0,0 @@
-# Voxtral TTS deploy: async-chunk streaming generator → audio tokenizer.
-# Verified on 1× H20.
-#
-# Mistral tokenizer/config/load flags are declared per-stage here because
-# DeployConfig does not yet expose them as pipeline-wide fields. They must
-# match on every stage so the engine args stay consistent across workers.
-async_chunk: true
-
-connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
- codec_streaming: true
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- codec_chunk_frames: 25
- codec_chunk_frames_at_begin: 5
- codec_left_context_frames: 25
-
-stages:
- - stage_id: 0
- max_num_seqs: 32
- gpu_memory_utilization: 0.8
- enforce_eager: false
- async_scheduling: true
- max_model_len: 4096
- devices: "0"
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- repetition_penalty: 1.1
- extra_args:
- cfg_alpha: 1.2
- tokenizer_mode: mistral
- config_format: mistral
- load_format: mistral
- skip_mm_profiling: true
- enable_chunked_prefill: false
-
- - stage_id: 1
- max_num_seqs: 32
- gpu_memory_utilization: 0.1
- enforce_eager: true
- async_scheduling: false
- max_num_batched_tokens: 65536
- max_model_len: 65536
- devices: "0"
- input_connectors:
- from_stage_0: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 2048
- seed: 42
- repetition_penalty: 1.05
- tokenizer_mode: mistral
- config_format: mistral
- load_format: mistral
- skip_mm_profiling: true
diff --git a/vllm_omni/diffusion/attention/backends/flash_attn.py b/vllm_omni/diffusion/attention/backends/flash_attn.py
index d38ea4f6eaa..5c586c0631e 100644
--- a/vllm_omni/diffusion/attention/backends/flash_attn.py
+++ b/vllm_omni/diffusion/attention/backends/flash_attn.py
@@ -96,7 +96,7 @@ def forward_cuda(
value: torch.Tensor,
attn_metadata: AttentionMetadata = None,
) -> torch.Tensor:
- """CUDA/ROCm/MUSA flash attention implementation."""
+ """CUDA/ROCm flash attention implementation."""
from vllm_omni.diffusion.attention.backends.utils.fa import (
HAS_FLASH_ATTN,
flash_attn_func,
diff --git a/vllm_omni/diffusion/attention/backends/utils/fa.py b/vllm_omni/diffusion/attention/backends/utils/fa.py
index 18886871a22..77596a10333 100644
--- a/vllm_omni/diffusion/attention/backends/utils/fa.py
+++ b/vllm_omni/diffusion/attention/backends/utils/fa.py
@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flash_attention_utils.py
-from functools import lru_cache
-
import torch
import torch.nn.functional as F
@@ -40,10 +38,8 @@
except (ImportError, ModuleNotFoundError):
pass
elif current_omni_platform.is_musa():
- try:
- from flash_attn_interface import flash_attn_func, flash_attn_varlen_func # noqa: F401
- except (ImportError, ModuleNotFoundError):
- pass
+ # XXX (MUSA): Add MUSA-specific Flash Attention when available
+ pass
else:
# CUDA: try FA3 -> FA2 fallback chain
# Try FA3 from fa3-fwd PyPI package
@@ -80,12 +76,6 @@
HAS_FLASH_ATTN = flash_attn_func is not None or flash_attn_varlen_func is not None
-@lru_cache(maxsize=1)
-def is_mate_available() -> bool:
- """Check if MATE (MUSA AI Tensor Engine) is available."""
- return current_omni_platform.is_musa() and flash_attn_func is not None or flash_attn_varlen_func is not None
-
-
def _index_first_axis(tensor, indices):
"""
A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
diff --git a/vllm_omni/diffusion/attention/parallel/ulysses.py b/vllm_omni/diffusion/attention/parallel/ulysses.py
index 326b5d45671..5d860b3350e 100644
--- a/vllm_omni/diffusion/attention/parallel/ulysses.py
+++ b/vllm_omni/diffusion/attention/parallel/ulysses.py
@@ -414,6 +414,10 @@ def pre_attention(
def post_attention(self, attn_output: torch.Tensor, ctx: ParallelAttentionContext | None) -> torch.Tensor:
assert isinstance(ctx, _UlyssesCtx), f"Unexpected ctx type: {type(ctx)!r}"
+ # If we have joint tensors (Text), they were Head-Sliced.
+ # The main sequence (Image) was Sequence-Sliced.
+ # attn_output contains [Joint_Sliced | Image_Sliced] (if strategy='front').
+
if ctx.joint_len > 0:
joint_len = ctx.joint_len
diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py
index 1b1800dac5f..a5055a0688e 100644
--- a/vllm_omni/diffusion/cache/cache_dit_backend.py
+++ b/vllm_omni/diffusion/cache/cache_dit_backend.py
@@ -281,7 +281,6 @@ def enable_cache_for_longcat_image(pipeline: Any, cache_config: Any) -> Callable
],
forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1],
params_modifiers=[modifier],
- has_separate_cfg=True,
)
),
cache_config=db_cache_config,
@@ -465,77 +464,6 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
return refresh_cache_context
-def enable_cache_for_stable_audio_open(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
- """Enable cache-dit for Stable Audio Open pipeline.
-
- Args:
- pipeline: The StableAudioPipeline instance.
- cache_config: DiffusionCacheConfig instance with cache configuration.
-
- Returns:
- A refresh function that can be called to update cache context with new num_inference_steps.
- """
- db_cache_config = _build_db_cache_config(cache_config)
-
- calibrator_config = None
- if cache_config.enable_taylorseer:
- taylorseer_order = cache_config.taylorseer_order
- calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
- logger.info(f"TaylorSeer enabled with order={taylorseer_order}")
-
- # StableAudio is officially registered in CacheDiT as Pattern_3:
- # https://github.com/vipshop/cache-dit/blob/69e82bd1/src/cache_dit/caching/block_adapters/__init__.py#L562
- #
- # Pattern_3 is required because StableAudioDiT uses cross-attention
- # with static encoder_hidden_states that do not change inside the
- # transformer block loop.
- cache_dit.enable_cache(
- BlockAdapter(
- transformer=pipeline.transformer,
- blocks=pipeline.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_3,
- params_modifiers=[
- ParamsModifier(
- cache_config=db_cache_config,
- calibrator_config=calibrator_config,
- )
- ],
- ),
- cache_config=db_cache_config,
- )
-
- def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
- """Refresh cache context for the transformer with new num_inference_steps.
-
- Args:
- pipeline: The StableAudioPipeline instance.
- num_inference_steps: New number of inference steps.
- verbose: Whether to log refresh operations.
- """
- # Bypass SCM for step counts that don't support predefined masks (e.g., vLLM's 1-step dummy run)
- scm_supported_steps = num_inference_steps >= 8 or num_inference_steps in (4, 6)
-
- if cache_config.scm_steps_mask_policy is None or not scm_supported_steps:
- cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
- else:
- updated_scm_config = DBCacheConfig().reset(
- num_inference_steps=num_inference_steps,
- steps_computation_mask=cache_dit.steps_mask(
- mask_policy=cache_config.scm_steps_mask_policy,
- total_steps=num_inference_steps,
- ),
- steps_computation_policy=cache_config.scm_steps_policy,
- )
-
- cache_dit.refresh_context(
- pipeline.transformer,
- cache_config=updated_scm_config,
- verbose=verbose,
- )
-
- return refresh_cache_context
-
-
def enable_cache_for_sd3(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
"""Enable cache-dit for StableDiffusion3Pipeline.
@@ -633,7 +561,6 @@ def enable_cache_for_ltx2(pipeline: Any, cache_config: Any) -> Callable[[int], N
forward_pattern=ForwardPattern.Pattern_0,
# Treat audio_hidden_states as encoder_hidden_states in Pattern_0
check_forward_pattern=False,
- has_separate_cfg=True,
),
cache_config=db_cache_config,
calibrator_config=calibrator_config,
@@ -1170,85 +1097,41 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
return refresh_cache_context
-def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
- """Enable cache-dit for Flux.2-dev pipeline.
+def enable_cache_for_glm_image(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
+ """Enable cache-dit for GLM-Image pipeline.
- Args:
- pipeline: The Flux2 pipeline instance.
- cache_config: DiffusionCacheConfig instance with cache configuration.
- Returns:
- A refresh function that can be called with a new ``num_inference_steps``
- to update the cache context for the pipeline.
+ GLM-Image processes prompt and image by calling the transformer before the
+ denoising loop. When an input image is provided (editing mode), the cache must
+ be force-refreshed after the preprocessing step so stale hidden states are
+ discarded. Set force_refresh_step_hint = 1 for editing, None for text-to-image.
"""
- # Build DBCacheConfig for transformer
db_cache_config = _build_db_cache_config(cache_config)
- calibrator = None
+ calibrator_config = None
if cache_config.enable_taylorseer:
- taylorseer_order = cache_config.taylorseer_order
- calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
- logger.info(f"TaylorSeer enabled with order={taylorseer_order}")
-
- # Build ParamsModifier for transformer
- modifier = ParamsModifier(
- cache_config=db_cache_config,
- calibrator_config=calibrator,
- )
+ calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=cache_config.taylorseer_order)
+ logger.info(f"TaylorSeer enabled with order={cache_config.taylorseer_order}")
logger.info(
- f"Enabling cache-dit on Flux transformer with BlockAdapter: "
+ f"Enabling cache-dit on GLM-Image transformer: "
f"Fn={db_cache_config.Fn_compute_blocks}, "
f"Bn={db_cache_config.Bn_compute_blocks}, "
f"W={db_cache_config.max_warmup_steps}, "
+ f"force_refresh_step_hint={db_cache_config.force_refresh_step_hint}, "
)
- # Enable cache-dit using BlockAdapter for transformer
cache_dit.enable_cache(
- (
- BlockAdapter(
- transformer=pipeline.transformer,
- blocks=[
- pipeline.transformer.transformer_blocks,
- pipeline.transformer.single_transformer_blocks,
- ],
- forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2],
- params_modifiers=[modifier],
- )
- ),
+ pipeline.transformer,
cache_config=db_cache_config,
+ calibrator_config=calibrator_config,
)
- def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
- """Refresh cache context for the transformer with new num_inference_steps.
-
- Args:
- pipeline: The Flux2 pipeline instance.
- num_inference_steps: New number of inference steps.
- """
- if cache_config.scm_steps_mask_policy is None:
- cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
- else:
- cache_dit.refresh_context(
- pipeline.transformer,
- cache_config=DBCacheConfig().reset(
- num_inference_steps=num_inference_steps,
- steps_computation_mask=cache_dit.steps_mask(
- mask_policy=cache_config.scm_steps_mask_policy,
- total_steps=num_inference_steps,
- ),
- steps_computation_policy=cache_config.scm_steps_policy,
- ),
- verbose=verbose,
- )
-
- return refresh_cache_context
-
-def enable_cache_for_glm_image(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
- """Enable cache-dit for GlmImage pipeline.
+def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
+ """Enable cache-dit for Flux.2-dev pipeline.
Args:
- pipeline: The GlmImage pipeline instance.
+ pipeline: The Flux2 pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
Returns:
A refresh function that can be called with a new ``num_inference_steps``
@@ -1270,25 +1153,23 @@ def enable_cache_for_glm_image(pipeline: Any, cache_config: Any) -> Callable[[in
)
logger.info(
- f"Enabling cache-dit on GlmImage transformer with BlockAdapter: "
+ f"Enabling cache-dit on Flux transformer with BlockAdapter: "
f"Fn={db_cache_config.Fn_compute_blocks}, "
f"Bn={db_cache_config.Bn_compute_blocks}, "
f"W={db_cache_config.max_warmup_steps}, "
)
# Enable cache-dit using BlockAdapter for transformer
- # Note: We don't use patch_functor here because it's designed for diffusers' GlmImage,
- # and our vllm-omni implementation has a different forward signature.
- # We use ForwardPattern.Pattern_0 because our block returns (hidden_states, encoder_hidden_states)
cache_dit.enable_cache(
(
BlockAdapter(
transformer=pipeline.transformer,
- blocks=pipeline.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_0,
+ blocks=[
+ pipeline.transformer.transformer_blocks,
+ pipeline.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2],
params_modifiers=[modifier],
- patch_functor=None,
- has_separate_cfg=True,
)
),
cache_config=db_cache_config,
@@ -1298,7 +1179,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
"""Refresh cache context for the transformer with new num_inference_steps.
Args:
- pipeline: The GlmImage pipeline instance.
+ pipeline: The Flux2 pipeline instance.
num_inference_steps: New number of inference steps.
"""
if cache_config.scm_steps_mask_policy is None:
@@ -1331,14 +1212,11 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
"Flux2KleinPipeline": enable_cache_for_flux2_klein,
"LongCatImagePipeline": enable_cache_for_longcat_image,
"LongCatImageEditPipeline": enable_cache_for_longcat_image,
- "StableAudioPipeline": enable_cache_for_stable_audio_open,
"StableDiffusion3Pipeline": enable_cache_for_sd3,
"LTX2Pipeline": enable_cache_for_ltx2,
"LTX2ImageToVideoPipeline": enable_cache_for_ltx2,
"LTX2TwoStagesPipeline": enable_cache_for_ltx2,
"LTX2ImageToVideoTwoStagesPipeline": enable_cache_for_ltx2,
- "LTX23Pipeline": enable_cache_for_ltx2,
- "LTX23ImageToVideoPipeline": enable_cache_for_ltx2,
"BagelPipeline": enable_cache_for_bagel,
"GlmImagePipeline": enable_cache_for_glm_image,
"Flux2Pipeline": enable_cache_for_flux2,
diff --git a/vllm_omni/diffusion/cache/teacache/backend.py b/vllm_omni/diffusion/cache/teacache/backend.py
index 772dec78913..a5087fe0c24 100644
--- a/vllm_omni/diffusion/cache/teacache/backend.py
+++ b/vllm_omni/diffusion/cache/teacache/backend.py
@@ -48,7 +48,16 @@ def enable_bagel_teacache(pipeline: Any, config: DiffusionCacheConfig) -> None:
coefficients=config.coefficients,
)
transformer = pipeline.bagel
+ original_forward_flow = transformer._forward_flow
+
+ import types
+
+ def forward_alias(self, *args, **kwargs):
+ return original_forward_flow(*args, **kwargs)
+
+ transformer.forward = types.MethodType(forward_alias, transformer)
apply_teacache_hook(transformer, teacache_config)
+ transformer._forward_flow = transformer.forward
pipeline.transformer = transformer
logger.info(
diff --git a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
index 38c805c28db..f3a278b2174 100644
--- a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
+++ b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
@@ -1,18 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import os
+import types
from typing import Any
import numpy as np
import torch
from vllm.config import LoadConfig
-from vllm.transformers_utils.config import get_hf_file_to_dict
+from vllm.utils.torch_utils import set_default_torch_dtype
from vllm_omni.diffusion.cache.teacache.extractors import get_extractor
-from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig
+from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.hooks import HookRegistry, ModelHook
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
+from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline
+from vllm_omni.diffusion.models.stable_audio.pipeline_stable_audio import StableAudioPipeline
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -34,7 +36,6 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any:
ctx = self.extractor_fn(module, *args, **kwargs)
- # NOTE: We upcast to float32 to also handle bfloat16.
modulated_input_cpu = ctx.modulated_input.detach().float().cpu().numpy()
outputs = ctx.run_transformer_blocks()
@@ -53,57 +54,56 @@ def stop_collection(self) -> list[tuple[np.ndarray, np.ndarray]]:
return list(self.current_trajectory)
-class DefaultAdapter:
- """Default adapter for standard diffusers pipelines."""
-
- model_class_name = None
- uses_tf_config = True
-
- @classmethod
- def load_pipeline(cls, model_path: str, device: str, dtype: torch.dtype) -> Any:
- if cls.model_class_name is None:
- raise ValueError("Adapter doesn't have a set class name.")
-
- od_config = OmniDiffusionConfig.from_kwargs(
- model_class_name=cls.model_class_name,
- model=model_path,
- dtype=dtype,
- )
+class BagelAdapter:
+ """Adapter for Bagel model."""
- if cls.uses_tf_config:
- # TODO (Alex): Refactor to handle tf_model_config in OmniDiffusionConfig
- # instead of OmniDiffusion and remove the manual population here
- tf_config_dict = get_hf_file_to_dict(
- os.path.join("transformer", "config.json"),
- od_config.model,
- )
- od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict)
+ @staticmethod
+ def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> BagelPipeline:
+ od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype)
+ od_config.model_class_name = "BagelPipeline"
- loader = DiffusersPipelineLoader(LoadConfig(), od_config=od_config)
- # load_model will handle dtypes / device placement, put in .eval() mode
- return loader.load_model(od_config=od_config, load_device=device)
+ pipeline = BagelPipeline(od_config=od_config)
+ loader = DiffusersPipelineLoader(LoadConfig())
+ loader.load_weights(pipeline)
+ pipeline.to(device)
+ return pipeline
@staticmethod
def get_transformer(pipeline: Any) -> tuple[Any, str]:
- return pipeline.transformer, pipeline.transformer.__class__.__name__
+ return pipeline.bagel, "Bagel"
@staticmethod
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
+ original_forward_flow = transformer._forward_flow
+
+ def forward_alias(self, *args, **kwargs):
+ return original_forward_flow(*args, **kwargs)
+
+ transformer.forward = types.MethodType(forward_alias, transformer)
registry = HookRegistry.get_or_create(transformer)
registry.register_hook(hook._HOOK_NAME, hook)
+ transformer._forward_flow = transformer.forward
-class BagelAdapter(DefaultAdapter):
- """Adapter for Bagel model."""
+class StableAudioAdapter:
+ """Adapter for Stable Audio Open 1.0 coefficient estimation."""
+
+ @staticmethod
+ def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.float16) -> Any:
+ od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype)
- model_class_name = "BagelPipeline"
- # Skip the hack for loading the tf model config,
- # because bagel doesn't use it.
- uses_tf_config = False
+ # Strictly necessary because we bypass loader.load_model()
+ with set_default_torch_dtype(dtype):
+ pipeline = StableAudioPipeline(od_config=od_config)
+
+ loader = DiffusersPipelineLoader(LoadConfig())
+ loader.load_weights(pipeline)
+ pipeline.to(device)
+ return pipeline
@staticmethod
def get_transformer(pipeline: Any) -> tuple[Any, str]:
- return pipeline.bagel, "Bagel"
+ return pipeline.transformer, "StableAudioDiTModel"
@staticmethod
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
@@ -111,32 +111,26 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
registry.register_hook(hook._HOOK_NAME, hook)
-class Flux2Adapter(DefaultAdapter):
- """Adapter for Flux2 model coefficient estimation."""
-
- model_class_name = "Flux2Pipeline"
-
-
-class LongCatAdapter(DefaultAdapter):
- """Adapter for LongCat Image - NOTE: currently this model needs the vLLM
- context to be correctly configured to actually run the estimation, since it
- uses vLLM norm layers etc.
- """
-
- model_class_name = "LongCatImagePipeline"
+class DefaultAdapter:
+ """Default adapter for standard diffusers pipelines."""
+ @staticmethod
+ def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any:
+ raise NotImplementedError("DefaultAdapter.load_pipeline not implemented")
-class StableAudioAdapter(DefaultAdapter):
- """Adapter for Stable Audio Open 1.0 coefficient estimation."""
+ @staticmethod
+ def get_transformer(pipeline: Any) -> tuple[Any, str]:
+ return pipeline.transformer, pipeline.transformer.__class__.__name__
- model_class_name = "StableAudioPipeline"
+ @staticmethod
+ def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
+ registry = HookRegistry.get_or_create(transformer)
+ registry.register_hook(hook._HOOK_NAME, hook)
_MODEL_ADAPTERS: dict[str, type] = {
"Bagel": BagelAdapter,
"StableAudio": StableAudioAdapter,
- "Flux2": Flux2Adapter,
- "LongCat": LongCatAdapter,
}
_EPSILON = 1e-6
@@ -183,6 +177,7 @@ def __init__(
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
+ # Add validation here ⬇️
if model_type not in _MODEL_ADAPTERS:
available_types = list(_MODEL_ADAPTERS.keys())
raise ValueError(
@@ -191,7 +186,7 @@ def __init__(
f"To add support for a new model, add an entry to _MODEL_ADAPTERS."
)
- adapter = _MODEL_ADAPTERS[model_type]
+ adapter = _MODEL_ADAPTERS.get(model_type, DefaultAdapter)
self.pipeline = adapter.load_pipeline(model_path, device, dtype)
self.transformer, self.transformer_type = adapter.get_transformer(self.pipeline)
self.hook = DataCollectionHook(self.transformer_type)
diff --git a/vllm_omni/diffusion/cache/teacache/config.py b/vllm_omni/diffusion/cache/teacache/config.py
index 7efdd418e12..96cf3f03eec 100644
--- a/vllm_omni/diffusion/cache/teacache/config.py
+++ b/vllm_omni/diffusion/cache/teacache/config.py
@@ -64,17 +64,6 @@
-1.04182570e01,
6.78098549e-01,
],
- # Flux2 transformer coefficients
- # Copied from Qwen-Image, need to be tuned specifically for Flux2 in future
- "Flux2Transformer2DModel": [
- -4.50000000e02,
- 2.80000000e02,
- -4.50000000e01,
- 3.20000000e00,
- -2.00000000e-02,
- ],
- # LongCat Image transformer coefficients
- "LongCatImageTransformer2DModel": [652.5980, -424.1615, 84.5526, -4.5923, 0.1694],
}
diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py
index d0da0d9df3f..bdb3f6a7865 100644
--- a/vllm_omni/diffusion/cache/teacache/extractors.py
+++ b/vllm_omni/diffusion/cache/teacache/extractors.py
@@ -19,12 +19,8 @@
import torch
import torch.nn as nn
-from vllm.logger import init_logger
from vllm_omni.diffusion.forward_context import get_forward_context
-from vllm_omni.platforms import current_omni_platform
-
-logger = init_logger(__name__)
@dataclass
@@ -225,8 +221,7 @@ def extract_qwen_context(
block = module.transformer_blocks[0]
img_mod_params = block.img_mod(temb)
img_mod1, _ = img_mod_params.chunk(2, dim=-1)
- img_scale1, img_shift1, _ = block._modulate(img_mod1)
- img_modulated = block.img_norm1(hidden_states, img_scale1, img_shift1)
+ img_modulated, _ = block.img_norm1(hidden_states, img_mod1)
# ============================================================================
# DEFINE TRANSFORMER EXECUTION (Qwen-specific)
@@ -726,105 +721,6 @@ def postprocess(h):
)
-def extract_longcat_context(
- module: nn.Module, # LongCatImageTransformer2DModel
- hidden_states,
- timestep,
- guidance,
- encoder_hidden_states,
- txt_ids,
- img_ids,
- **kwargs,
-) -> CacheContext:
- """Extract the cache context for LongCat Image.
-
- Similar to other extractors, this is currently the only code needed
- for TeaCache support for LongCat image, and encapsulates preprocessing,
- modulated input extraction, transformer execution, and postprocessing
- logic.
-
- Args & kawrgs are identical to the inputs to LongCat Image's forward.
-
- Returns:
- CacheContext with all information needed for generic caching
- """
- # TODO (Alex) - Refactor TeaCache extractors to more tightly integrate with .forward
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
-
- # 1. Model specific preprocessing
- fwd_context = get_forward_context()
- sp_size = module.parallel_config.sequence_parallel_size
- if sp_size is not None and sp_size > 1:
- # NOTE: For now, we set this to False on the forward context
- # to be consistent with LongCat Image's current behavior when
- # TeaCache is enabled. We do not need to reset it in post process
- # since we should never split text embed in sp for this model.
- fwd_context.split_text_embed_in_sp = False
-
- hidden_states = module.x_embedder(hidden_states)
-
- timestep = timestep.to(hidden_states.dtype) * 1000
-
- temb = module.time_embed(timestep, hidden_states.dtype)
- encoder_hidden_states = module.context_embedder(encoder_hidden_states)
-
- # Compute RoPE embeddings via rope_preparer module
- # _sp_plan will automatically shard img_cos/img_sin (outputs 2, 3)
- # txt_cos/txt_sin (outputs 0, 1) remain replicated for dual-stream attention
- txt_cos, txt_sin, img_cos, img_sin = module.rope_preparer(txt_ids, img_ids)
-
- # Reconstruct image_rotary_emb with chunked values
- # Final shape: (txt_seq_len + img_seq_len // SP, head_dim)
- image_rotary_emb = (
- torch.cat([txt_cos, img_cos], dim=0),
- torch.cat([txt_sin, img_sin], dim=0),
- )
-
- # 2. Extract the modulated output from the first mm-DiT block
- first_block = module.transformer_blocks[0]
- img_modulated = first_block.norm1(hidden_states, emb=temb)[0]
-
- # 3. Define the transformer execution
- def run_transformer_blocks():
- """Execute all Longcat transformer blocks."""
- h = hidden_states
- e = encoder_hidden_states
- for block in module.transformer_blocks:
- e, h = block(
- hidden_states=h,
- encoder_hidden_states=e,
- temb=temb,
- image_rotary_emb=image_rotary_emb,
- )
-
- for block in module.single_transformer_blocks:
- e, h = block(
- hidden_states=h,
- encoder_hidden_states=e,
- temb=temb,
- image_rotary_emb=image_rotary_emb,
- )
- # Hook expects hidden states to be first
- return (h, e)
-
- # 4. Postprocessing
- def postprocess(h):
- """Apply Longcat-specific output postprocessing."""
- h = module.norm_out(h, temb)
- output = module.proj_out(h)
- return Transformer2DModelOutput(sample=output)
-
- # 5. Return the CacheContext
- return CacheContext(
- modulated_input=img_modulated,
- hidden_states=hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- temb=temb,
- run_transformer_blocks=run_transformer_blocks,
- postprocess=postprocess,
- )
-
-
def extract_stable_audio_context(
module: nn.Module,
hidden_states: torch.Tensor,
@@ -931,144 +827,6 @@ def postprocess(h: torch.Tensor) -> Any:
)
-def extract_flux2_context(
- module: nn.Module,
- hidden_states: torch.Tensor,
- encoder_hidden_states: torch.Tensor = None,
- timestep: torch.LongTensor = None,
- img_ids: torch.Tensor = None,
- txt_ids: torch.Tensor = None,
- guidance: torch.Tensor | None = None,
- joint_attention_kwargs: dict[str, Any] | None = None,
- return_dict: bool = True,
- **kwargs: Any,
-) -> CacheContext:
- """
- Extract cache context for Flux2Transformer2DModel.
-
- This is the ONLY Flux2-specific code needed for TeaCache support.
- It encapsulates preprocessing, modulated input extraction, transformer execution,
- and postprocessing logic.
-
- Args:
- module: Flux2Transformer2DModel instance
- hidden_states: Input hidden states tensor
- encoder_hidden_states: Text encoder outputs
- timestep: Current diffusion timestep
- img_ids: Image inputs for position embedding
- txt_ids: Text inputs for position embedding
- guidance: Optional guidance scale for CFG
- joint_attention_kwargs: Additional attention arguments
- return_dict: Whether to return a Transformer2DModelOutput instead of a plain tensor
- **kwargs: Additional keyword arguments ignored by this extractor
-
- Returns:
- CacheContext with all information needed for generic caching
- """
-
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
-
- if not hasattr(module, "transformer_blocks") or len(module.transformer_blocks) == 0:
- raise ValueError("Module must have transformer_blocks")
-
- # ============================================================================
- # PREPROCESSING (Flux2-specific)
- # ============================================================================
- num_txt_tokens = encoder_hidden_states.shape[1]
-
- timestep = timestep.to(hidden_states.dtype) * 1000
- if guidance is not None:
- guidance = guidance.to(hidden_states.dtype) * 1000
-
- temb = module.time_guidance_embed(timestep, guidance)
-
- double_stream_mod_img = module.double_stream_modulation_img(temb)
- double_stream_mod_txt = module.double_stream_modulation_txt(temb)
- single_stream_mod = module.single_stream_modulation(temb)[0]
-
- hidden_states = module.x_embedder(hidden_states)
- encoder_hidden_states = module.context_embedder(encoder_hidden_states)
-
- if img_ids.ndim == 3:
- img_ids = img_ids[0]
- if txt_ids.ndim == 3:
- txt_ids = txt_ids[0]
-
- if current_omni_platform.is_npu():
- freqs_cos_image, freqs_sin_image = module.pos_embed(img_ids.cpu())
- image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
- freqs_cos_text, freqs_sin_text = module.pos_embed(txt_ids.cpu())
- text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
- else:
- image_rotary_emb = module.pos_embed(img_ids)
- text_rotary_emb = module.pos_embed(txt_ids)
- concat_rotary_emb = (
- torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
- torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
- )
-
- # ============================================================================
- # EXTRACT MODULATED INPUT (for cache decision)
- # ============================================================================
- block = module.transformer_blocks[0]
- (shift_msa, scale_msa, gate_msa), _ = double_stream_mod_img
- modulated_input = block.norm1(hidden_states)
- modulated_input = (1 + scale_msa) * modulated_input + shift_msa
-
- # ============================================================================
- # DEFINE TRANSFORMER EXECUTION (Flux2-specific)
- # ============================================================================
- def run_transformer_blocks():
- """Execute all Flux2 transformer blocks."""
- h = hidden_states
- e = encoder_hidden_states
-
- for transformer_block in module.transformer_blocks:
- e, h = transformer_block(
- hidden_states=h,
- encoder_hidden_states=e,
- temb_mod_params_img=double_stream_mod_img,
- temb_mod_params_txt=double_stream_mod_txt,
- image_rotary_emb=concat_rotary_emb,
- joint_attention_kwargs=joint_attention_kwargs,
- )
- h = torch.cat([e, h], dim=1)
-
- for single_transformer_block in module.single_transformer_blocks:
- h = single_transformer_block(
- hidden_states=h,
- encoder_hidden_states=None,
- temb_mod_params=single_stream_mod,
- image_rotary_emb=concat_rotary_emb,
- joint_attention_kwargs=joint_attention_kwargs,
- )
-
- h = h[:, num_txt_tokens:, ...]
- return (h,)
-
- # ============================================================================
- # DEFINE POSTPROCESSING
- # ============================================================================
- def postprocess(h):
- h = module.norm_out(h, temb)
- output = module.proj_out(h)
- if not return_dict:
- return (output,)
- return Transformer2DModelOutput(sample=output)
-
- # ============================================================================
- # RETURN CONTEXT
- # ============================================================================
- return CacheContext(
- modulated_input=modulated_input,
- hidden_states=hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- temb=temb,
- run_transformer_blocks=run_transformer_blocks,
- postprocess=postprocess,
- )
-
-
# Registry for model-specific extractors
# Key: Transformer class name
# Value: extractor function with signature (module, *args, **kwargs) -> CacheContext
@@ -1081,8 +839,6 @@ def postprocess(h):
"ZImageTransformer2DModel": extract_zimage_context,
"Flux2Klein": extract_flux2_klein_context,
"StableAudioDiTModel": extract_stable_audio_context,
- "Flux2Transformer2DModel": extract_flux2_context,
- "LongCatImageTransformer2DModel": extract_longcat_context,
# Future models:
# "FluxTransformer2DModel": extract_flux_context,
# "CogVideoXTransformer3DModel": extract_cogvideox_context,
diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py
index cf6841fd21d..3071fd9d56a 100644
--- a/vllm_omni/diffusion/data.py
+++ b/vllm_omni/diffusion/data.py
@@ -9,7 +9,6 @@
from typing import TYPE_CHECKING, Any
import torch
-from PIL import Image
from pydantic import model_validator
from typing_extensions import Self
from vllm.config.utils import config
@@ -18,7 +17,6 @@
QuantizationConfig,
)
-from vllm_omni.diffusion.model_metadata import get_diffusion_model_metadata
from vllm_omni.diffusion.utils.network_utils import is_port_available
from vllm_omni.quantization import build_quant_config
@@ -353,8 +351,6 @@ def __getattr__(self, item: str) -> Any:
@dataclass
class OmniDiffusionConfig:
# Model and path configuration (for convenience)
- stage_id: int = 0
-
model: str | None = None
model_class_name: str | None = None
@@ -452,14 +448,7 @@ class OmniDiffusionConfig:
custom_pipeline_args: dict[str, Any] | None = None
# Diffusion model loading format
- # "default", "custom_pipeline", "dummy", "diffusers" (HF diffusers adapter)
- diffusion_load_format: str = "default"
-
- # Diffusers adapter kwargs
- # kwargs forwarded to DiffusionPipeline.from_pretrained()
- diffusers_load_kwargs: dict[str, Any] = field(default_factory=dict)
- # kwargs forwarded to pipeline.__call__()
- diffusers_call_kwargs: dict[str, Any] = field(default_factory=dict)
+ diffusion_load_format: str = "default" # "default", "custom_pipeline", "dummy"
# http server endpoint config, would be ignored in local mode
host: str | None = None
@@ -491,10 +480,8 @@ class OmniDiffusionConfig:
# Scheduler flow_shift for Wan2.2 (12.0 for 480p, 5.0 for 720p)
flow_shift: float | None = None
- # Support multi-image inputs and expose any model-specific request limit
- # through a generic config field so serving code stays model-agnostic.
+ # support multi images input
supports_multimodal_inputs: bool = False
- max_multimodal_image_inputs: int | None = None
log_level: str = "info"
@@ -517,8 +504,6 @@ class OmniDiffusionConfig:
# Step mode settings
step_execution: bool = False
- # sleep mode
- enable_sleep_mode: bool = False
# Maximum number of sequences to generate in a batch
max_num_seqs: int = 1
@@ -655,12 +640,6 @@ def __post_init__(self):
elif self.max_cpu_loras < 1:
raise ValueError("max_cpu_loras must be >= 1 for diffusion LoRA")
- if self.diffusion_load_format != "diffusers" and (self.diffusers_load_kwargs or self.diffusers_call_kwargs):
- raise ValueError(
- "diffusers_load_kwargs and diffusers_call_kwargs are only "
- "valid together with diffusion_load_format=diffusers"
- )
-
def set_tf_model_config(self, tf_config: "TransformerConfig") -> None:
"""Assign `tf_model_config` and propagate quantization if detected.
@@ -684,68 +663,7 @@ def set_tf_model_config(self, tf_config: "TransformerConfig") -> None:
)
def update_multimodal_support(self) -> None:
- # Resolve serving-visible multimodal behavior from shared metadata
- # instead of importing concrete pipeline modules into the config layer.
- metadata = get_diffusion_model_metadata(self.model_class_name)
- self.supports_multimodal_inputs = metadata.supports_multimodal_inputs
- self.max_multimodal_image_inputs = metadata.max_multimodal_image_inputs
-
- def enrich_config(self) -> None:
- """Load model metadata from HuggingFace and populate config fields.
-
- Diffusers-style models expose ``model_index.json`` with ``_class_name``.
- Non-diffusers models (e.g. Bagel, NextStep) only have ``config.json``,
- so we fall back to reading that and mapping model_type manually.
- """
- from vllm.transformers_utils.config import get_hf_file_to_dict
-
- # Default model_class_name for diffusers adapter
- if self.model_class_name is None and self.diffusion_load_format == "diffusers":
- self.model_class_name = "DiffusersAdapterPipeline"
-
- try:
- config_dict = get_hf_file_to_dict("model_index.json", self.model)
- if config_dict is not None:
- if self.model_class_name is None:
- self.model_class_name = config_dict.get("_class_name", None)
- self.update_multimodal_support()
-
- # Skip transformer config loading for diffusers adapter
- # (non-DiT models don't have a separate transformer folder/config)
- if self.diffusion_load_format == "diffusers":
- self.tf_model_config = TransformerConfig()
- else:
- tf_config_dict = get_hf_file_to_dict("transformer/config.json", self.model)
- self.tf_model_config = TransformerConfig.from_dict(tf_config_dict)
- else:
- raise FileNotFoundError("model_index.json not found")
- except (AttributeError, OSError, ValueError, FileNotFoundError):
- # Skip transformer config loading for diffusers adapter
- # (non-DiT models don't have a separate transformer folder/config)
- if self.diffusion_load_format == "diffusers":
- self.tf_model_config = TransformerConfig()
- else:
- cfg = get_hf_file_to_dict("config.json", self.model)
- if cfg is None:
- raise ValueError(f"Could not find config.json or model_index.json for model {self.model}")
-
- self.tf_model_config = TransformerConfig.from_dict(cfg)
- model_type = cfg.get("model_type")
- architectures = cfg.get("architectures") or []
-
- if model_type == "bagel" or "BagelForConditionalGeneration" in architectures:
- self.model_class_name = "BagelPipeline"
- self.tf_model_config = TransformerConfig()
- self.update_multimodal_support()
- elif model_type == "nextstep":
- if self.model_class_name is None:
- self.model_class_name = "NextStep11Pipeline"
- self.tf_model_config = TransformerConfig()
- self.update_multimodal_support()
- elif architectures and len(architectures) == 1:
- self.model_class_name = architectures[0]
- else:
- raise
+ self.supports_multimodal_inputs = self.model_class_name in {"QwenImageEditPlusPipeline"}
@classmethod
def from_kwargs(cls, **kwargs: Any) -> "OmniDiffusionConfig":
@@ -759,7 +677,7 @@ def from_kwargs(cls, **kwargs: Any) -> "OmniDiffusionConfig":
# Backwards-compatibility: map "quantization" to "quantization_config"
# so callers using the old field name still work.
- if "quantization" in kwargs and kwargs.get("quantization_config", None) is None:
+ if "quantization" in kwargs and "quantization_config" not in kwargs:
kwargs["quantization_config"] = kwargs.pop("quantization")
else:
kwargs.pop("quantization", None)
@@ -770,12 +688,6 @@ def from_kwargs(cls, **kwargs: Any) -> "OmniDiffusionConfig":
cache_backend = os.environ.get("DIFFUSION_CACHE_BACKEND") or os.environ.get("DIFFUSION_CACHE_ADAPTER")
kwargs["cache_backend"] = cache_backend.lower() if cache_backend else "none"
- # Falsy-value check for not-None fields (convert potential None values in YAML config to empty containers)
- if "diffusers_load_kwargs" in kwargs and kwargs["diffusers_load_kwargs"] is None:
- kwargs["diffusers_load_kwargs"] = {}
- if "diffusers_call_kwargs" in kwargs and kwargs["diffusers_call_kwargs"] is None:
- kwargs["diffusers_call_kwargs"] = {}
-
# Filter kwargs to only include valid fields
valid_fields = {f.name for f in fields(cls)}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_fields}
@@ -789,12 +701,10 @@ class DiffusionOutput:
Final output (after pipeline completion)
"""
- # Fields may be replaced with SHM handle dicts by ipc.pack_diffusion_output_shm
- output: torch.Tensor | dict | None = None
- trajectory_timesteps: torch.Tensor | dict | None = None
- trajectory_latents: torch.Tensor | dict | None = None
- trajectory_log_probs: torch.Tensor | dict | None = None
- trajectory_decoded: list[Image.Image] | None = None
+ output: torch.Tensor | None = None
+ trajectory_timesteps: list[torch.Tensor] | None = None
+ trajectory_latents: torch.Tensor | None = None
+ trajectory_decoded: list[torch.Tensor] | None = None
error: str | None = None
aborted: bool = False
abort_message: str | None = None
@@ -834,43 +744,5 @@ def __str__(self):
return self.name.lower()
-@dataclass
-class OmniACK:
- """
- Handshake payload from Workers to Orchestrator.
- """
-
- task_id: str
- status: str
- stage_id: int | None = None
- rank: int | None = None
- freed_bytes: int = 0
- metadata: dict[str, Any] = field(default_factory=dict)
- """
- Additional telemetry such as:
- - max_contiguous_block: for fragmentation analysis.
- - cuda_graph_recalled: boolean if graphs were successfully destroyed/rebuilt.
- - latency_ms: time taken for the D2H/H2D transfer.
- """
- error_msg: str | None = None
-
-
-@dataclass
-class OmniSleepTask:
- """Structured sleep instruction."""
-
- task_id: str
- level: int = 2
- metadata: dict[str, Any] = field(default_factory=dict)
-
-
-@dataclass
-class OmniWakeTask:
- """Structured wake-up instruction."""
-
- task_id: str
- tags: list[str] | None = None
-
-
# Special message broadcast via scheduler queues to signal worker shutdown.
SHUTDOWN_MESSAGE = {"type": "shutdown"}
diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py
index d0215385a9d..9cf069daaa0 100644
--- a/vllm_omni/diffusion/diffusion_engine.py
+++ b/vllm_omni/diffusion/diffusion_engine.py
@@ -3,7 +3,6 @@
from __future__ import annotations
-import inspect
import queue
import threading
import time
@@ -14,7 +13,6 @@
import PIL.Image
import torch
from vllm.logger import init_logger
-from vllm.v1.engine.exceptions import EngineDeadError
from vllm_omni.diffusion.data import (
DiffusionOutput,
@@ -80,12 +78,6 @@ def __init__(
self.post_process_func = get_diffusion_post_process_func(od_config)
self.pre_process_func = get_diffusion_pre_process_func(od_config)
- # Cache whether the model-specific postprocess accepts request-level
- # sampling params so step() can support both legacy and extended hooks.
- self._post_process_accepts_sampling_params = bool(
- self.post_process_func is not None
- and "sampling_params" in inspect.signature(self.post_process_func).parameters
- )
executor_class = DiffusionExecutor.get_class(od_config)
self.executor = executor_class(od_config)
@@ -123,7 +115,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
if output.aborted:
raise DiffusionRequestAbortedError(output.abort_message or "Diffusion request aborted.")
if output.error:
- raise RuntimeError(output.error)
+ raise RuntimeError(f"{output.error}")
logger.info("Generation completed successfully.")
if output.output is None:
@@ -151,24 +143,10 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
output_data = output_data.cpu()
postprocess_start_time = time.perf_counter()
- if self.post_process_func is not None:
- # Some video pipelines need request-level controls during
- # postprocess (for example worker-side frame interpolation).
- if self._post_process_accepts_sampling_params:
- outputs = self.post_process_func(output_data, sampling_params=request.sampling_params)
- else:
- outputs = self.post_process_func(output_data)
- else:
- outputs = output_data
+ outputs = self.post_process_func(output_data) if self.post_process_func is not None else output_data
audio_payload = None
- custom_output = output.custom_output or {}
- model_audio_sample_rate = None
- model_fps = None
if isinstance(outputs, dict):
audio_payload = outputs.get("audio")
- custom_output.update(outputs.get("custom_output") or {})
- model_audio_sample_rate = outputs.get("audio_sample_rate")
- model_fps = outputs.get("fps")
outputs = outputs.get("video", outputs)
postprocess_time = time.perf_counter() - postprocess_start_time
logger.info(f"Post-processing completed in {postprocess_time:.4f} seconds")
@@ -214,10 +192,6 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
prompt=prompt,
metrics=metrics,
latents=output.trajectory_latents,
- trajectory_latents=output.trajectory_latents,
- trajectory_timesteps=output.trajectory_timesteps,
- trajectory_log_probs=output.trajectory_log_probs,
- trajectory_decoded=output.trajectory_decoded,
multimodal_output={"audio": request_audio_payload},
final_output_type="audio",
stage_durations=output.stage_durations,
@@ -228,10 +202,6 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
mm_output = {}
if audio_payload is not None:
mm_output["audio"] = audio_payload
- if model_audio_sample_rate is not None:
- mm_output["audio_sample_rate"] = model_audio_sample_rate
- if model_fps is not None:
- mm_output["fps"] = model_fps
return [
OmniRequestOutput.from_diffusion(
request_id=request_id,
@@ -239,11 +209,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
prompt=prompt,
metrics=metrics,
latents=output.trajectory_latents,
- trajectory_latents=output.trajectory_latents,
- trajectory_timesteps=output.trajectory_timesteps,
- trajectory_log_probs=output.trajectory_log_probs,
- trajectory_decoded=output.trajectory_decoded,
- custom_output=custom_output,
+ custom_output=output.custom_output or {},
multimodal_output=mm_output,
stage_durations=output.stage_durations,
peak_memory_mb=output.peak_memory_mb,
@@ -274,10 +240,6 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
prompt=prompt,
metrics=metrics,
latents=output.trajectory_latents,
- trajectory_latents=output.trajectory_latents,
- trajectory_timesteps=output.trajectory_timesteps,
- trajectory_log_probs=output.trajectory_log_probs,
- trajectory_decoded=output.trajectory_decoded,
multimodal_output={"audio": request_audio_payload},
final_output_type="audio",
stage_durations=output.stage_durations,
@@ -298,10 +260,6 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
if num_outputs == 1:
sliced_audio = sliced_audio[0]
mm_output["audio"] = sliced_audio
- if model_audio_sample_rate is not None:
- mm_output["audio_sample_rate"] = model_audio_sample_rate
- if model_fps is not None:
- mm_output["fps"] = model_fps
results.append(
OmniRequestOutput.from_diffusion(
request_id=request_id,
@@ -309,11 +267,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
prompt=prompt,
metrics=metrics,
latents=output.trajectory_latents,
- trajectory_latents=output.trajectory_latents,
- trajectory_timesteps=output.trajectory_timesteps,
- trajectory_log_probs=output.trajectory_log_probs,
- trajectory_decoded=output.trajectory_decoded,
- custom_output=custom_output,
+ custom_output=output.custom_output or {},
multimodal_output=mm_output,
stage_durations=output.stage_durations,
peak_memory_mb=output.peak_memory_mb,
@@ -359,8 +313,6 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus
sched_req_id = sched_output.scheduled_req_ids[0]
try:
runner_output = self.execute_fn(sched_output)
- except EngineDeadError:
- raise
except Exception as exc:
logger.error("Execution failed for diffusion request %s", sched_req_id, exc_info=True)
runner_output = RunnerOutput(
@@ -381,11 +333,15 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus
)
def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None:
- """Start or stop profiling on all diffusion workers.
+ """Start or stop torch profiling on all diffusion workers.
Args:
is_start: True to start profiling, False to stop.
- profile_prefix: Optional prefix for trace filename.
+ profile_prefix: Optional prefix for trace filename (vLLM compat).
+
+ Note:
+ Matches vLLM's worker.profile() signature for consistency.
+ Traces are saved automatically via on_trace_ready callback.
"""
if is_start:
if profile_prefix is None:
diff --git a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py
index 0084719a8ab..7df2d6a8add 100644
--- a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py
+++ b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py
@@ -93,7 +93,7 @@ def patch_split(self, z: torch.Tensor) -> tuple[list[TileTask], GridSpec]:
_, _, latent_h, latent_w = z.shape
scale = int(2 ** (len(self.config.block_out_channels) - 1))
- max_parallel_size = self.distributed_executor.parallel_size
+ max_parallel_size = self.distributed_decoder.parallel_size
root = int(math.sqrt(max_parallel_size))
for rows in range(root, 0, -1):
@@ -187,7 +187,7 @@ def decode(self, z: torch.Tensor, return_dict: bool = True, *args: Any, **kwargs
if split is not None:
strategy = "tile" if split == self.tile_split else "patch"
logger.info(f"Decode run with distributed executor, split strategy is {strategy}")
- result = self.distributed_executor.execute(
+ result = self.distributed_decoder.execute(
z, DistributedOperator(split=split, exec=exec, merge=merge), broadcast_result=False
)
if not return_dict:
diff --git a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_qwenimage.py b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_qwenimage.py
index f9dea8a36d9..7549bbd3d5a 100644
--- a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_qwenimage.py
+++ b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_qwenimage.py
@@ -108,8 +108,8 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True):
if not self.is_distributed_enabled():
return super().tiled_decode(z, return_dict=return_dict)
- logger.debug("Decode running with distributed executor")
- result = self.distributed_executor.execute(
+ logger.info("Decode run with distributed executor")
+ result = self.distributed_decoder.execute(
z,
DistributedOperator(split=self.tile_split, exec=self.tile_exec, merge=self.tile_merge),
broadcast_result=True,
diff --git a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py
index 35c9434d063..7defbae79b7 100644
--- a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py
+++ b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from contextlib import nullcontext
from typing import Any
import torch
@@ -16,38 +15,11 @@
GridSpec,
TileTask,
)
-from vllm_omni.platforms import current_omni_platform
logger = init_logger(__name__)
-class OmniAutoencoderKLWan(AutoencoderKLWan):
- def _execution_context(self):
- try:
- first_param = next(self.parameters())
- except StopIteration:
- return nullcontext()
-
- dtype = first_param.dtype
- if dtype not in (torch.float16, torch.bfloat16):
- return nullcontext()
-
- return current_omni_platform.create_autocast_context(
- device_type=first_param.device.type,
- dtype=dtype,
- enabled=True,
- )
-
- def encode(self, x: torch.Tensor, return_dict: bool = True):
- with self._execution_context():
- return super().encode(x, return_dict=return_dict)
-
- def decode(self, z: torch.Tensor, return_dict: bool = True):
- with self._execution_context():
- return super().decode(z, return_dict=return_dict)
-
-
-class DistributedAutoencoderKLWan(OmniAutoencoderKLWan, DistributedVaeMixin):
+class DistributedAutoencoderKLWan(AutoencoderKLWan, DistributedVaeMixin):
@classmethod
def from_pretrained(cls, *args: Any, **kwargs: Any):
model = super().from_pretrained(*args, **kwargs)
@@ -112,128 +84,14 @@ def tile_exec(self, task: TileTask) -> torch.Tensor:
"""Decode a single latent tile into RGB space."""
self.clear_cache()
time = []
- with self._execution_context():
- for k in range(len(task.tensor)):
- self._conv_idx = [0]
- tile = self.post_quant_conv(task.tensor[k])
- decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0))
- time.append(decoded)
+ for k in range(len(task.tensor)):
+ self._conv_idx = [0]
+ tile = self.post_quant_conv(task.tensor[k])
+ decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0))
+ time.append(decoded)
result = torch.cat(time, dim=2)
return result
- def encode_tile_split(self, x: torch.Tensor) -> tuple[list[TileTask], GridSpec]:
- _, _, num_frames, height, width = x.shape
- encode_spatial_compression_ratio = self.spatial_compression_ratio
- # Scale tile parameters for patchified coordinate system
- tile_sample_min_height = self.tile_sample_min_height
- tile_sample_min_width = self.tile_sample_min_width
- tile_sample_stride_height = self.tile_sample_stride_height
- tile_sample_stride_width = self.tile_sample_stride_width
- if self.config.patch_size is not None:
- assert encode_spatial_compression_ratio % self.config.patch_size == 0
- encode_spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size
- # When input is patchified, scale tile parameters accordingly
- tile_sample_min_height = tile_sample_min_height // self.config.patch_size
- tile_sample_min_width = tile_sample_min_width // self.config.patch_size
- tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
- tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
-
- latent_height = height // encode_spatial_compression_ratio
- latent_width = width // encode_spatial_compression_ratio
-
- tile_latent_min_height = tile_sample_min_height // encode_spatial_compression_ratio
- tile_latent_min_width = tile_sample_min_width // encode_spatial_compression_ratio
- tile_latent_stride_height = tile_sample_stride_height // encode_spatial_compression_ratio
- tile_latent_stride_width = tile_sample_stride_width // encode_spatial_compression_ratio
-
- blend_height = tile_latent_min_height - tile_latent_stride_height
- blend_width = tile_latent_min_width - tile_latent_stride_width
-
- tiletask_list = []
- temporal_compression = self.config.scale_factor_temporal
- for i in range(0, height, tile_sample_stride_height):
- for j in range(0, width, tile_sample_stride_width):
- time_list = []
- frame_range = 1 + (num_frames - 1) // temporal_compression
- for k in range(frame_range):
- if k == 0:
- tile = x[:, :, :1, i : i + tile_sample_min_height, j : j + tile_sample_min_width]
- else:
- tile = x[
- :,
- :,
- 1 + temporal_compression * (k - 1) : 1 + temporal_compression * k,
- i : i + tile_sample_min_height,
- j : j + tile_sample_min_width,
- ]
- time_list.append(tile)
- tiletask_list.append(
- TileTask(
- len(tiletask_list),
- (i // tile_sample_stride_height, j // tile_sample_stride_width),
- time_list,
- workload=time_list[0].shape[3] * time_list[0].shape[4],
- )
- )
-
- grid_spec = GridSpec(
- split_dims=(3, 4),
- grid_shape=(tiletask_list[-1].grid_coord[0] + 1, tiletask_list[-1].grid_coord[1] + 1),
- tile_spec={
- "latent_height": latent_height,
- "latent_width": latent_width,
- "blend_height": blend_height,
- "blend_width": blend_width,
- "tile_latent_stride_height": tile_latent_stride_height,
- "tile_latent_stride_width": tile_latent_stride_width,
- },
- output_dtype=self.dtype,
- )
- return tiletask_list, grid_spec
-
- def encode_tile_exec(self, task: TileTask) -> torch.Tensor:
- """Encode a single sample tile into latent space."""
- self.clear_cache()
- time = []
- for k, tile in enumerate(task.tensor):
- self._enc_conv_idx = [0]
- encoded = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
- encoded = self.quant_conv(encoded)
- time.append(encoded)
- result = torch.cat(time, dim=2)
- self.clear_cache()
- return result
-
- def encode_tile_merge(
- self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec
- ) -> torch.Tensor:
- """Merge encoded tiles into a full latent tensor."""
- grid_h, grid_w = grid_spec.grid_shape
- result_rows = []
- for i in range(grid_h):
- result_row = []
- for j in range(grid_w):
- tile = coord_tensor_map[(i, j)]
- if i > 0:
- tile = self.blend_v(coord_tensor_map[(i - 1, j)], tile, grid_spec.tile_spec["blend_height"])
- if j > 0:
- tile = self.blend_h(coord_tensor_map[(i, j - 1)], tile, grid_spec.tile_spec["blend_width"])
- result_row.append(
- tile[
- :,
- :,
- :,
- : grid_spec.tile_spec["tile_latent_stride_height"],
- : grid_spec.tile_spec["tile_latent_stride_width"],
- ]
- )
- result_rows.append(torch.cat(result_row, dim=-1))
-
- enc = torch.cat(result_rows, dim=3)[
- :, :, :, : grid_spec.tile_spec["latent_height"], : grid_spec.tile_spec["latent_width"]
- ]
- return enc
-
def tile_merge(self, coord_tensor_map: dict[tuple[int, ...], torch.Tensor], grid_spec: GridSpec) -> torch.Tensor:
"""Merge decoded tiles into a full image."""
grid_h, grid_w = grid_spec.grid_shape
@@ -272,8 +130,8 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True):
if not self.is_distributed_enabled():
return super().tiled_decode(z, return_dict=return_dict)
- logger.debug("Decode running with distributed executor")
- result = self.distributed_executor.execute(
+ logger.info("Decode run with distributed executor")
+ result = self.distributed_decoder.execute(
z,
DistributedOperator(split=self.tile_split, exec=self.tile_exec, merge=self.tile_merge),
broadcast_result=False,
@@ -282,26 +140,3 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True):
return (result,)
return DecoderOutput(sample=result)
-
- def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
- """
- Encode using distributed VAE executor.
-
- Note: x is already patchified by parent's _encode() before calling this method.
- """
- if not self.is_distributed_enabled():
- return super().tiled_encode(x)
-
- logger.debug("Encode running with distributed executor")
- self.clear_cache()
- result = self.distributed_executor.execute(
- x,
- DistributedOperator(
- split=self.encode_tile_split,
- exec=self.encode_tile_exec,
- merge=self.encode_tile_merge,
- ),
- broadcast_result=True,
- )
- self.clear_cache()
- return result
diff --git a/vllm_omni/diffusion/distributed/autoencoders/distributed_vae_executor.py b/vllm_omni/diffusion/distributed/autoencoders/distributed_vae_executor.py
index 209f6562552..bdf664741db 100644
--- a/vllm_omni/diffusion/distributed/autoencoders/distributed_vae_executor.py
+++ b/vllm_omni/diffusion/distributed/autoencoders/distributed_vae_executor.py
@@ -54,9 +54,9 @@ def set_parallel_size(self, parallel_size: int):
self.parallel_size = parallel_size
def gather_tensors(self, tensor: torch.Tensor):
- gather_list = [torch.empty_like(tensor) for _ in range(self.world_size)]
- dist.all_gather(gather_list, tensor, group=self.group)
- return gather_list if self.rank == 0 else None
+ gather_list = [torch.empty_like(tensor) for _ in range(self.world_size)] if self.rank == 0 else None
+ dist.gather(tensor, gather_list=gather_list, dst=0, group=self.group)
+ return gather_list
def broadcast_tensor(self, tensor: torch.Tensor):
dist.broadcast(tensor, src=0, group=self.group)
@@ -168,25 +168,25 @@ def _sync_final_result(self, rank0_result, output_ndim, output_device, output_dt
class DistributedVaeMixin:
def init_distributed(self):
- self.distributed_executor = DistributedVaeExecutor()
+ self.distributed_decoder = DistributedVaeExecutor()
- def set_parallel_size(self, parallel_size: int) -> None:
- self.distributed_executor.set_parallel_size(parallel_size)
+ def set_parallel_size(self, parallel_size: int) -> bool:
+ return self.distributed_decoder.set_parallel_size(parallel_size)
def is_distributed_enabled(self) -> bool:
if (
- self.distributed_executor.parallel_size <= 1
+ self.distributed_decoder.parallel_size <= 1
or not dist.is_initialized()
or not getattr(self, "use_tiling", False)
):
return False
- world_size = dist.get_world_size(group=self.distributed_executor.group)
- pp_size = min(int(self.distributed_executor.parallel_size), int(world_size))
+ world_size = dist.get_world_size(group=self.distributed_decoder.group)
+ pp_size = min(int(self.distributed_decoder.parallel_size), int(world_size))
if pp_size <= 1:
return False
- if self.distributed_executor.parallel_size > pp_size:
+ if self.distributed_decoder.parallel_size > pp_size:
logger.warning(
- f"vae_patch_parallel_size={self.distributed_executor.parallel_size} "
+ f"vae_patch_parallel_size={self.distributed_decoder.parallel_size} "
f"is greater than dit_group={world_size};"
f" using dit_group size={world_size}"
)
diff --git a/vllm_omni/diffusion/distributed/cfg_parallel.py b/vllm_omni/diffusion/distributed/cfg_parallel.py
index 98757006bfe..a8b0012f665 100644
--- a/vllm_omni/diffusion/distributed/cfg_parallel.py
+++ b/vllm_omni/diffusion/distributed/cfg_parallel.py
@@ -9,7 +9,6 @@
from typing import Any
import torch
-from vllm.logger import init_logger
from vllm_omni.diffusion.distributed.parallel_state import (
get_cfg_group,
@@ -17,8 +16,6 @@
get_classifier_free_guidance_world_size,
)
-logger = init_logger(__name__)
-
def _wrap(pred: torch.Tensor | tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]:
"""Normalize prediction to tuple form."""
@@ -35,24 +32,6 @@ def _slice_pred(pred: tuple[torch.Tensor, ...], output_slice: int) -> tuple[torc
return tuple(p[:, :output_slice] for p in pred)
-def _dispatch_branches(n_branches: int, n_ranks: int) -> list[list[int]]:
- """
- Round-robin dispatch N branches to M ranks.
-
- Rule: branch i → rank (i % n_ranks).
-
- Examples:
- _dispatch_branches(3, 2) -> [[0, 2], [1]]
- _dispatch_branches(3, 3) -> [[0], [1], [2]]
- _dispatch_branches(4, 2) -> [[0, 2], [1, 3]]
- _dispatch_branches(4, 4) -> [[0], [1], [2], [3]]
- """
- assignments: list[list[int]] = [[] for _ in range(n_ranks)]
- for i in range(n_branches):
- assignments[i % n_ranks].append(i)
- return assignments
-
-
class CFGParallelMixin(metaclass=ABCMeta):
"""
Base Mixin class for Diffusion pipelines providing shared CFG methods.
@@ -210,165 +189,6 @@ def combine_cfg_noise(self, positive_noise_pred, negative_noise_pred, scale, nor
results.append(comb)
return _unwrap(tuple(results))
- # ── N-branch CFG interface (for 3+ branch models) ──
-
- def predict_noise_with_multi_branch_cfg(
- self,
- do_true_cfg: bool,
- true_cfg_scale: float | dict[str, float],
- branches_kwargs: list[dict[str, Any]],
- cfg_normalize: bool = False,
- output_slice: int | None = None,
- ) -> torch.Tensor | tuple[torch.Tensor, ...]:
- """
- Predict noise with N-branch CFG dispatch across M GPUs.
-
- This is the multi-branch counterpart of predict_noise_maybe_with_cfg().
- Use this for models with 3 or more CFG branches (e.g., OmniGen2, Bagel,
- DreamID). Existing 2-branch models should continue using
- predict_noise_maybe_with_cfg().
-
- Args:
- do_true_cfg: Whether to apply CFG.
- true_cfg_scale: CFG scale factor (passed to combine_multi_branch_cfg_noise).
- branches_kwargs: List of N dicts, each containing kwargs for one
- predict_noise() call. branches_kwargs[0] is always the
- positive/conditional branch.
- cfg_normalize: Whether to normalize (passed to combine_multi_branch_cfg_noise).
- output_slice: If set, slice each output to [:, :output_slice].
-
- Returns:
- Combined noise prediction, identical on all ranks in CFG parallel.
- """
- if do_true_cfg:
- n_branches = len(branches_kwargs)
- cfg_world_size = get_classifier_free_guidance_world_size()
- cfg_parallel_ready = cfg_world_size > 1
-
- if cfg_parallel_ready:
- return self._predict_multi_branch_parallel(
- branches_kwargs,
- n_branches,
- cfg_world_size,
- true_cfg_scale,
- cfg_normalize,
- output_slice,
- )
- else:
- # Sequential: run all N branches on single device
- preds: list[torch.Tensor | tuple[torch.Tensor, ...]] = []
- for kw in branches_kwargs:
- pred = _wrap(self.predict_noise(**kw))
- if output_slice is not None:
- pred = _slice_pred(pred, output_slice)
- preds.append(_unwrap(pred))
- return self.combine_multi_branch_cfg_noise(preds, true_cfg_scale, cfg_normalize)
- else:
- # No CFG: only compute positive/conditional prediction
- pred = self.predict_noise(**branches_kwargs[0])
- if output_slice is not None:
- pred = _unwrap(_slice_pred(_wrap(pred), output_slice))
- return pred
-
- def _predict_multi_branch_parallel(
- self,
- branches_kwargs: list[dict[str, Any]],
- n_branches: int,
- cfg_world_size: int,
- true_cfg_scale: float,
- cfg_normalize: bool,
- output_slice: int | None,
- ) -> torch.Tensor | tuple[torch.Tensor, ...]:
- """Dispatch N branches across M ranks, all_gather, then combine."""
- cfg_group = get_cfg_group()
- cfg_rank = get_classifier_free_guidance_rank()
-
- if cfg_world_size > n_branches:
- logger.warning_once(
- "cfg_parallel_size=%d > n_branches=%d, %d GPU(s) will be idle for CFG",
- cfg_world_size,
- n_branches,
- cfg_world_size - n_branches,
- )
-
- # Assign branches to ranks via round-robin
- assignments = _dispatch_branches(n_branches, cfg_world_size)
- my_branch_ids = assignments[cfg_rank]
- max_per_rank = max(len(a) for a in assignments)
-
- # Run assigned branches
- my_preds: list[tuple[torch.Tensor, ...]] = []
- for bid in my_branch_ids:
- pred = _wrap(self.predict_noise(**branches_kwargs[bid]))
- if output_slice is not None:
- pred = _slice_pred(pred, output_slice)
- my_preds.append(pred)
-
- # Idle ranks (cfg_world_size > n_branches) run a forward pass to get the output shape for all_gather.
- # Output shape cannot be inferred from kwargs — may be tuple, sliced, etc.
- if not my_preds:
- pred = _wrap(self.predict_noise(**branches_kwargs[0]))
- if output_slice is not None:
- pred = _slice_pred(pred, output_slice)
- my_preds.append(pred)
-
- # Pad to max_per_rank with zeros so all ranks have same size
- ref_pred = my_preds[0]
- while len(my_preds) < max_per_rank:
- my_preds.append(tuple(torch.zeros_like(t) for t in ref_pred))
-
- # All-gather each output element separately (like predict_noise_maybe_with_cfg)
- # For each slot, gather across ranks; then pick valid results by owner_rank
- # all_slots[slot][elem_idx] = [rank0_tensor, rank1_tensor, ...]
- all_slots: list[list[list[torch.Tensor]]] = []
- for slot in range(max_per_rank):
- slot_results: list[list[torch.Tensor]] = []
- for p in my_preds[slot]:
- gathered = cfg_group.all_gather(p, separate_tensors=True)
- slot_results.append(gathered)
- all_slots.append(slot_results)
-
- # Reconstruct final_preds in branch order
- final_preds: list[torch.Tensor | tuple[torch.Tensor, ...]] = []
- for bid in range(n_branches):
- owner_rank = bid % cfg_world_size
- slot_idx = bid // cfg_world_size
- elements = tuple(all_slots[slot_idx][elem_idx][owner_rank] for elem_idx in range(len(ref_pred)))
- final_preds.append(_unwrap(elements))
-
- return self.combine_multi_branch_cfg_noise(final_preds, true_cfg_scale, cfg_normalize)
-
- def combine_multi_branch_cfg_noise(
- self,
- predictions: list[torch.Tensor | tuple[torch.Tensor, ...]],
- true_cfg_scale: float | dict[str, float],
- cfg_normalize: bool = False,
- ) -> torch.Tensor | tuple[torch.Tensor, ...]:
- """
- Combine N branch predictions. Default: standard 2-branch CFG formula.
-
- Override this method for custom multi-branch combine logic.
-
- Args:
- predictions: List of N predictions, where predictions[0] is always
- the positive/conditional branch.
- true_cfg_scale: CFG scale factor (float for 2-branch, dict for multi-branch).
- cfg_normalize: Whether to normalize the combined prediction.
-
- Returns:
- Combined noise prediction.
- """
- positive = _wrap(predictions[0])
- negative = _wrap(predictions[1])
-
- results = []
- for p, n in zip(positive, negative):
- comb = n + true_cfg_scale * (p - n)
- if cfg_normalize:
- comb = self.cfg_normalize_function(p, comb)
- results.append(comb)
- return _unwrap(tuple(results))
-
def predict_noise(self, *args: Any, **kwargs: Any) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""
Forward pass through transformer to predict noise.
diff --git a/vllm_omni/diffusion/distributed/group_coordinator.py b/vllm_omni/diffusion/distributed/group_coordinator.py
index 5294e6c9ed6..8ab38f2a651 100644
--- a/vllm_omni/diffusion/distributed/group_coordinator.py
+++ b/vllm_omni/diffusion/distributed/group_coordinator.py
@@ -104,7 +104,6 @@ def __init__(
self.local_rank = local_rank
self.device_group = None
self.cpu_group = None
- self.shm_broadcaster = None
for ranks in group_ranks:
device_group = torch.distributed.new_group(ranks, backend=torch_distributed_backend)
@@ -317,7 +316,7 @@ def send_object(self, obj: Any, dst: int) -> None:
assert dst < self.world_size, f"Invalid dst rank ({dst})"
- assert dst != self.rank_in_group, "Invalid destination rank. Destination rank is the same as the current rank."
+ assert dst != self.rank, "Invalid destination rank. Destination rank is the same as the current rank."
# Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
@@ -339,7 +338,7 @@ def recv_object(self, src: int) -> Any:
assert src < self.world_size, f"Invalid src rank ({src})"
- assert src != self.rank_in_group, "Invalid source rank. Source rank is the same as the current rank."
+ assert src != self.rank, "Invalid source rank. Source rank is the same as the current rank."
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
diff --git a/vllm_omni/diffusion/envs.py b/vllm_omni/diffusion/envs.py
index ea7b2c24c8c..a71dc2e8e13 100644
--- a/vllm_omni/diffusion/envs.py
+++ b/vllm_omni/diffusion/envs.py
@@ -7,7 +7,6 @@
from vllm.logger import init_logger
-from vllm_omni.diffusion.attention.backends.utils.fa import is_mate_available
from vllm_omni.platforms import current_omni_platform
if TYPE_CHECKING:
@@ -53,10 +52,6 @@ def _check_flash_attn(self, packages_info) -> bool:
"""Check if flash attention is available and compatible."""
platform = current_omni_platform
- # MUSA uses MATE for flash attention
- if platform.is_musa():
- return is_mate_available()
-
# Flash attention requires CUDA-like platforms (CUDA or ROCm)
if not platform.is_cuda_alike():
return False
diff --git a/vllm_omni/diffusion/executor/abstract.py b/vllm_omni/diffusion/executor/abstract.py
index 81eba172c36..564980f6601 100644
--- a/vllm_omni/diffusion/executor/abstract.py
+++ b/vllm_omni/diffusion/executor/abstract.py
@@ -22,10 +22,6 @@ class DiffusionExecutor(ABC):
def get_class(od_config: OmniDiffusionConfig) -> type[DiffusionExecutor]:
executor_class: type[DiffusionExecutor]
distributed_executor_backend = od_config.distributed_executor_backend
- # Keep backward-compatible behavior for callers/configs that omit this
- # field and rely on the historical diffusion default backend.
- if distributed_executor_backend is None:
- distributed_executor_backend = "mp"
if isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, DiffusionExecutor):
diff --git a/vllm_omni/diffusion/executor/multiproc_executor.py b/vllm_omni/diffusion/executor/multiproc_executor.py
index dcb35cfde1f..e55a464fb4a 100644
--- a/vllm_omni/diffusion/executor/multiproc_executor.py
+++ b/vllm_omni/diffusion/executor/multiproc_executor.py
@@ -1,18 +1,14 @@
from __future__ import annotations
import multiprocessing as mp
-import multiprocessing.connection
-import threading
import time
import weakref
-from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import zmq
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.logger import init_logger
-from vllm.v1.engine.exceptions import EngineDeadError
from vllm_omni.diffusion.data import SHUTDOWN_MESSAGE, DiffusionOutput
from vllm_omni.diffusion.executor.abstract import DiffusionExecutor
@@ -26,8 +22,6 @@
logger = init_logger(__name__)
-_DEQUEUE_TIMEOUT_S = 5.0
-
@dataclass
class BackgroundResources:
@@ -42,14 +36,10 @@ class BackgroundResources:
def __call__(self):
"""Clean up background resources."""
- if hasattr(self, "wake_events") and self.wake_events:
- for ev in self.wake_events:
- ev.set()
-
if self.broadcast_mq is not None:
try:
for _ in range(self.num_workers):
- self.broadcast_mq.enqueue(SHUTDOWN_MESSAGE, timeout=1.0)
+ self.broadcast_mq.enqueue(SHUTDOWN_MESSAGE)
self.broadcast_mq = None
self.result_mq = None
@@ -73,17 +63,13 @@ class MultiprocDiffusionExecutor(DiffusionExecutor):
def _init_executor(self) -> None:
self._processes: list[mp.Process] = []
self._closed = False
- self.is_failed = False
- self._failure_callbacks: list[Callable[[], None]] = []
num_workers = self.od_config.num_gpus
- self.wake_events = [mp.Event() for _ in range(num_workers)]
-
self._broadcast_mq = self._init_broadcast_queue(num_workers)
broadcast_handle = self._broadcast_mq.export_handle()
# Launch workers
- processes, result_handle = self._launch_workers(broadcast_handle, self.wake_events)
+ processes, result_handle = self._launch_workers(broadcast_handle)
self._result_mq = self._init_result_queue(result_handle)
self._processes = processes
@@ -95,8 +81,6 @@ def _init_executor(self) -> None:
)
self._finalizer = weakref.finalize(self, self.resources)
- self.start_worker_monitor()
-
def _init_broadcast_queue(self, num_workers: int) -> MessageQueue:
return MessageQueue(
n_reader=num_workers,
@@ -116,24 +100,7 @@ def _ensure_open(self) -> None:
if self._result_mq is None:
raise RuntimeError("Result queue not initialized")
- def _dequeue_one_with_failure_polling(self, deadline: float | None, method: str) -> Any:
- """Block until one result message, polling ``is_failed`` between chunk timeouts."""
- while True:
- if deadline is None:
- chunk_timeout = _DEQUEUE_TIMEOUT_S
- else:
- remaining = deadline - time.monotonic()
- if remaining <= 0:
- raise TimeoutError(f"RPC call to {method} timed out.")
- chunk_timeout = min(_DEQUEUE_TIMEOUT_S, remaining)
- try:
- return self._result_mq.dequeue(timeout=chunk_timeout)
- except (TimeoutError, zmq.error.Again):
- if self.is_failed:
- raise EngineDeadError()
- continue
-
- def _launch_workers(self, broadcast_handle, wake_events):
+ def _launch_workers(self, broadcast_handle):
od_config = self.od_config
logger.info("Starting server...")
@@ -159,7 +126,6 @@ def _launch_workers(self, broadcast_handle, wake_events):
od_config,
writer,
broadcast_handle,
- wake_events[i],
worker_extension_cls,
custom_pipeline_args,
),
@@ -198,49 +164,6 @@ def _launch_workers(self, broadcast_handle, wake_events):
return processes, result_handle
- def start_worker_monitor(self) -> None:
- # Monitors worker process liveness. If any die unexpectedly,
- # logs an error, shuts down the executor and invokes the failure
- # callback to inform the engine.
- sentinels = [p.sentinel for p in self._processes]
- if not sentinels:
- return
-
- def _monitor() -> None:
- try:
- finished = multiprocessing.connection.wait(sentinels)
- except OSError:
- return
-
- if self._closed:
- return
-
- dead = [p.name for p in self._processes if p.sentinel in finished]
- if dead:
- logger.error(
- "Diffusion worker(s) died unexpectedly: %s",
- dead,
- )
- self.is_failed = True
-
- self.shutdown()
-
- for cb in self._failure_callbacks:
- try:
- cb()
- except Exception:
- logger.exception("failure_callback raised")
-
- t = threading.Thread(target=_monitor, daemon=True, name="diffusion-worker-monitor")
- t.start()
-
- def register_failure_callback(
- self,
- callback: Callable[[], None],
- ) -> None:
- """Register a callback invoked when a worker process dies."""
- self._failure_callbacks.append(callback)
-
def add_req(self, request: OmniDiffusionRequest) -> DiffusionOutput:
self._ensure_open()
rpc_request = {
@@ -363,21 +286,27 @@ def collective_rpc(
responses = []
for _ in range(num_responses):
- response = self._dequeue_one_with_failure_polling(deadline, method)
-
+ dequeue_timeout = None if deadline is None else max(0, deadline - time.monotonic())
try:
- unpack_diffusion_output_shm(response)
- except Exception as e:
- logger.warning("SHM unpack failed (data may already be inline): %s", e)
-
- # Check if response indicates an error
- if isinstance(response, dict) and response.get("status") == "error":
- raise RuntimeError(
- f"Worker failed with error '{response.get('error')}', "
- "please check the stack trace above for the root cause"
- )
-
- responses.append(response)
+ response = self._result_mq.dequeue(timeout=dequeue_timeout)
+
+ try:
+ unpack_diffusion_output_shm(response)
+ except Exception as e:
+ logger.warning("SHM unpack failed (data may already be inline): %s", e)
+
+ # Check if response indicates an error
+ if isinstance(response, dict) and response.get("status") == "error":
+ raise RuntimeError(
+ f"Worker failed with error '{response.get('error')}', "
+ "please check the stack trace above for the root cause"
+ )
+
+ responses.append(response)
+ except zmq.error.Again as e:
+ raise TimeoutError(f"RPC call to {method} timed out.") from e
+ except TimeoutError as e:
+ raise TimeoutError(f"RPC call to {method} timed out.") from e
return responses[0] if unique_reply_rank is not None else responses
except Exception as e:
@@ -385,13 +314,10 @@ def collective_rpc(
raise
def check_health(self) -> None:
- self._ensure_open()
- if self.is_failed:
- raise EngineDeadError()
+ # Simple check if processes are alive
for p in self._processes:
if not p.is_alive():
- self.is_failed = True
- raise EngineDeadError(f"Worker process {p.name} is dead")
+ raise RuntimeError(f"Worker process {p.name} is dead")
def shutdown(self) -> None:
self._closed = True
diff --git a/vllm_omni/diffusion/hooks/base.py b/vllm_omni/diffusion/hooks/base.py
index 517c6615877..cda4201ccf3 100644
--- a/vllm_omni/diffusion/hooks/base.py
+++ b/vllm_omni/diffusion/hooks/base.py
@@ -8,7 +8,6 @@
from __future__ import annotations
-import functools
import inspect
from collections.abc import Callable
from dataclasses import dataclass
@@ -95,9 +94,10 @@ def post_forward(self, module: nn.Module, output: Any) -> Any:
return output
def new_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> Any:
- """Override the module's forward pass. This should be overridden for more complex
- cases, e.g., TeaCache. If this method is overridden in a subclass, it will be called
- instead of self.module._omni_original_forward when executing the hooks.
+ """Override the module's forward pass completely.
+
+ The default implementation calls pre_forward, then the original forward,
+ then post_forward. Override this method for more complex behavior.
Args:
module: The module being called.
@@ -105,9 +105,11 @@ def new_forward(self, module: nn.Module, *args: Any, **kwargs: Any) -> Any:
**kwargs: Keyword arguments to forward.
Returns:
- The output of the replacement for the forward pass.
+ The output of the forward pass.
"""
- raise NotImplementedError("By default, hooks do not implement new_forward")
+ args, kwargs = self.pre_forward(module, *args, **kwargs)
+ output = module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined]
+ return self.post_forward(module, output)
def reset_state(self, module: nn.Module) -> nn.Module:
"""Reset any state associated with this hook.
@@ -134,21 +136,6 @@ def __call__(self, *args: Any, **kwargs: Any):
return registry.dispatch(*args, **kwargs)
-def sort_hooks_after_call(func):
- """Calls the method on the hook registry, then sorts the hooks.
-
- This should be added to methods that mutate add or remove hooks.
- """
-
- @functools.wraps(func)
- def wrapper(self: HookRegistry, *args, **kwargs):
- res = func(self, *args, **kwargs)
- self.update_sorted_hooks()
- return res
-
- return wrapper
-
-
class HookRegistry:
"""Registry of hooks attached to a module.
@@ -159,10 +146,6 @@ class HookRegistry:
def __init__(self, module: nn.Module):
self.module = module
self._hooks: dict[str, ModelHook] = {}
- # Hooks sorted by execution order
- self._sorted_hooks: list[ModelHook] = []
- # Hooks overriding new_forward (if any)
- self._new_fwd_impl_hook: ModelHook | None = None
@classmethod
def get_or_create(cls, module: nn.Module) -> HookRegistry:
@@ -190,14 +173,6 @@ def get_or_create(cls, module: nn.Module) -> HookRegistry:
return registry
- def update_sorted_hooks(self):
- """Sort hooks by name, which dictates pre/post process order."""
- sorted_hooks = [self._hooks[k] for k in sorted(self._hooks) if self._hooks[k] != self._new_fwd_impl_hook]
- if self._new_fwd_impl_hook is not None:
- sorted_hooks.append(self._new_fwd_impl_hook)
- self._sorted_hooks = sorted_hooks
-
- @sort_hooks_after_call
def register_hook(self, name: str, hook: ModelHook) -> None:
"""Register a hook with the given name.
@@ -207,14 +182,7 @@ def register_hook(self, name: str, hook: ModelHook) -> None:
"""
hook.initialize_hook(self.module)
self._hooks[name] = hook
- # We can only have one hook that overrides new_forward,
- # since we don't currently have a mechanism for combining them.
- if type(hook).new_forward is not ModelHook.new_forward:
- if self._new_fwd_impl_hook is not None:
- raise RuntimeError("Cannot have multiple hooks that override forward active simultaneously")
- self._new_fwd_impl_hook = hook
-
- @sort_hooks_after_call
+
def remove_hook(self, name: str) -> None:
"""Remove a hook by name.
@@ -222,9 +190,6 @@ def remove_hook(self, name: str) -> None:
name: The name of the hook to remove.
"""
if name in self._hooks:
- # clear the forward hook if it's the one to delete
- if self._new_fwd_impl_hook is self._hooks[name]:
- self._new_fwd_impl_hook = None
del self._hooks[name]
def get_hook(self, name: str) -> ModelHook | None:
@@ -241,18 +206,8 @@ def get_hook(self, name: str) -> ModelHook | None:
def dispatch(self, *args: Any, **kwargs: Any) -> Any:
"""Dispatch a forward call through registered hooks.
- Multiple hooks may be used with the caveat that only one hook
- may override new_forward. While it is assumed that pre/post process
- on hooks are composable, the execution flow is as follows for determinism:
-
- - Run preprocess on all hooks in their sorted order; hooks are sorted alphabetically,
- except for the hook overriding forward (`self._new_fwd_impl_hook`), which is last
- if it exists.
-
- - If `self._new_fwd_impl_hook` isn't None, call its forward. Otherwise call the
- original model forward.
-
- - Run post process on all hooks in the reverse sorted order.
+ Currently supports a single active hook. Multiple hooks are called
+ in sorted order by name, with each hook's output passed to the next.
Args:
*args: Positional arguments to forward.
@@ -264,19 +219,24 @@ def dispatch(self, *args: Any, **kwargs: Any) -> Any:
if not self._hooks:
return self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined]
- # Apply all pre_forward hooks; if _new_fwd_impl_hook is set, it's last
- for hook in self._sorted_hooks:
+ # For single hook case, call directly
+ if len(self._hooks) == 1:
+ hook = next(iter(self._hooks.values()))
+ return hook.new_forward(self.module, *args, **kwargs)
+
+ # For multiple hooks, chain them in sorted order
+ # Each hook can modify args/kwargs via pre_forward
+ sorted_hooks = sorted(self._hooks.items(), key=lambda x: x[0])
+
+ # Apply all pre_forward hooks
+ for _, hook in sorted_hooks:
args, kwargs = hook.pre_forward(self.module, *args, **kwargs)
- # If we have a hook that overrides new_forward, call it directly
- if self._new_fwd_impl_hook is not None:
- output = self._new_fwd_impl_hook.new_forward(self.module, *args, **kwargs)
- # Otherwise just call the original forward.
- else:
- output = self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined]
+ # Call original forward
+ output = self.module._omni_original_forward(*args, **kwargs) # type: ignore[attr-defined]
- # Apply all post_forward hooks in reverse order; if _new_fwd_impl_hook is set, it's first
- for hook in reversed(self._sorted_hooks):
+ # Apply all post_forward hooks in reverse order
+ for _, hook in reversed(sorted_hooks):
output = hook.post_forward(self.module, output)
return output
diff --git a/vllm_omni/diffusion/inline_stage_diffusion_client.py b/vllm_omni/diffusion/inline_stage_diffusion_client.py
deleted file mode 100644
index a33a3e95619..00000000000
--- a/vllm_omni/diffusion/inline_stage_diffusion_client.py
+++ /dev/null
@@ -1,348 +0,0 @@
-"""Inline Stage Diffusion Client for vLLM-Omni multi-stage runtime.
-
-Runs DiffusionEngine in a ThreadPoolExecutor inside the Orchestrator process
-instead of spawning a separate StageDiffusionProc subprocess, eliminating ZMQ
-IPC overhead. Used when there is only a single diffusion stage.
-"""
-
-from __future__ import annotations
-
-import asyncio
-import time
-from concurrent.futures import ThreadPoolExecutor
-from typing import TYPE_CHECKING, Any
-
-import torch
-from PIL import Image
-from vllm.logger import init_logger
-
-from vllm_omni.diffusion.data import DiffusionRequestAbortedError
-from vllm_omni.diffusion.diffusion_engine import DiffusionEngine
-from vllm_omni.diffusion.request import OmniDiffusionRequest
-from vllm_omni.engine.stage_init_utils import StageMetadata
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-from vllm_omni.outputs import OmniRequestOutput
-
-if TYPE_CHECKING:
- from vllm_omni.diffusion.data import OmniDiffusionConfig
- from vllm_omni.inputs.data import OmniPromptType
-
-logger = init_logger(__name__)
-
-
-class InlineStageDiffusionClient:
- """Runs DiffusionEngine in a thread executor inside the Orchestrator."""
-
- stage_type: str = "diffusion"
-
- def __init__(
- self,
- model: str,
- od_config: OmniDiffusionConfig,
- metadata: StageMetadata,
- batch_size: int = 1,
- ) -> None:
- self.model = model
- self.od_config = od_config
- self.stage_id = metadata.stage_id
- self.final_output = metadata.final_output
- self.final_output_type = metadata.final_output_type
- self.default_sampling_params = metadata.default_sampling_params
- self.custom_process_input_func = metadata.custom_process_input_func
- self.engine_input_source = metadata.engine_input_source
- self.batch_size = batch_size
-
- self._enrich_config()
- self._engine = DiffusionEngine.make_engine(self.od_config)
- self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="inline-diffusion")
-
- self._output_queue: asyncio.Queue[OmniRequestOutput] = asyncio.Queue()
- self._tasks: dict[str, asyncio.Task] = {}
- self._shutting_down = False
-
- logger.info(
- "[InlineStageDiffusionClient] Stage-%s initialized inline (batch_size=%d)",
- self.stage_id,
- self.batch_size,
- )
-
- def _enrich_config(self) -> None:
- """Load model metadata from HuggingFace and populate od_config fields."""
- self.od_config.enrich_config()
-
- # ------------------------------------------------------------------
- # Request processing
- # ------------------------------------------------------------------
-
- async def add_request_async(
- self,
- request_id: str,
- prompt: OmniPromptType,
- sampling_params: OmniDiffusionSamplingParams,
- kv_sender_info: dict[int, dict[str, Any]] | None = None,
- ) -> None:
- task = asyncio.create_task(
- self._dispatch_request(
- request_id,
- prompt,
- sampling_params,
- kv_sender_info,
- )
- )
- self._tasks[request_id] = task
-
- async def _dispatch_request(
- self,
- request_id: str,
- prompt: Any,
- sampling_params: OmniDiffusionSamplingParams,
- kv_sender_info: dict[str, Any] | None = None,
- ) -> None:
- try:
- request = OmniDiffusionRequest(
- prompts=[prompt],
- sampling_params=sampling_params,
- request_ids=[request_id],
- request_id=request_id,
- kv_sender_info=kv_sender_info,
- )
-
- loop = asyncio.get_running_loop()
- results = await loop.run_in_executor(self._executor, self._engine.step, request)
- result = results[0]
- if not result.request_id:
- result.request_id = request_id
-
- self._output_queue.put_nowait(result)
- except DiffusionRequestAbortedError as e:
- logger.info("request_id: %s aborted: %s", request_id, str(e))
- except Exception as e:
- logger.exception("Diffusion request %s failed: %s", request_id, e)
- error_output = OmniRequestOutput.from_diffusion(
- request_id=request_id,
- images=[],
- )
- error_output.error = str(e)
- self._output_queue.put_nowait(error_output)
- finally:
- self._tasks.pop(request_id, None)
-
- async def add_batch_request_async(
- self,
- request_id: str,
- prompts: list[OmniPromptType],
- sampling_params: OmniDiffusionSamplingParams,
- kv_sender_info: dict[int, dict[str, Any]] | None = None,
- ) -> None:
- task = asyncio.create_task(
- self._dispatch_batch(
- request_id,
- prompts,
- sampling_params,
- kv_sender_info,
- )
- )
- self._tasks[request_id] = task
-
- async def _dispatch_batch(
- self,
- request_id: str,
- prompts: list[Any],
- sampling_params: OmniDiffusionSamplingParams,
- kv_sender_info: dict[str, Any] | None = None,
- ) -> None:
- try:
- request = OmniDiffusionRequest(
- prompts=prompts,
- sampling_params=sampling_params,
- request_ids=[f"{request_id}-{i}" for i in range(len(prompts))],
- request_id=request_id,
- kv_sender_info=kv_sender_info,
- )
-
- loop = asyncio.get_running_loop()
- results = await loop.run_in_executor(self._executor, self._engine.step, request)
-
- all_images: list = []
- merged_mm: dict[str, Any] = {}
- merged_metrics: dict[str, Any] = {}
- merged_durations: dict[str, float] = {}
- merged_custom: dict[str, Any] = {}
- peak_mem = 0.0
- latents = None
- trajectory_latents: list[torch.Tensor] | None = None
- trajectory_timesteps: list[torch.Tensor] | None = None
- trajectory_log_probs: torch.Tensor | None = None
- trajectory_decoded: list[Image.Image] | None = None
- final_output_type = "image"
-
- for r in results:
- all_images.extend(r.images)
- merged_mm.update(r._multimodal_output)
- merged_metrics.update(r.metrics)
- merged_durations.update(r.stage_durations)
- merged_custom.update(r._custom_output)
- peak_mem = max(peak_mem, r.peak_memory_mb)
- if latents is None and r.latents is not None:
- latents = r.latents
- if trajectory_latents is None:
- trajectory_latents = r.trajectory_latents
- if trajectory_timesteps is None:
- trajectory_timesteps = r.trajectory_timesteps
- if trajectory_log_probs is None:
- trajectory_log_probs = r.trajectory_log_probs
- if trajectory_decoded is None:
- trajectory_decoded = r.trajectory_decoded
- if r.final_output_type != "image":
- final_output_type = r.final_output_type
-
- result = OmniRequestOutput.from_diffusion(
- request_id=request_id,
- images=all_images,
- prompt=prompts[0] if len(prompts) == 1 else None,
- metrics=merged_metrics,
- latents=latents,
- trajectory_latents=trajectory_latents,
- trajectory_timesteps=trajectory_timesteps,
- trajectory_log_probs=trajectory_log_probs,
- trajectory_decoded=trajectory_decoded,
- custom_output=merged_custom or None,
- multimodal_output=merged_mm or None,
- final_output_type=final_output_type,
- stage_durations=merged_durations,
- peak_memory_mb=peak_mem,
- )
-
- self._output_queue.put_nowait(result)
- except DiffusionRequestAbortedError as e:
- logger.info("request_id: %s aborted: %s", request_id, str(e))
- except Exception as e:
- logger.exception("Batch diffusion request %s failed: %s", request_id, e)
- error_output = OmniRequestOutput.from_diffusion(
- request_id=request_id,
- images=[],
- )
- error_output.error = str(e)
- self._output_queue.put_nowait(error_output)
- finally:
- self._tasks.pop(request_id, None)
-
- def get_diffusion_output_nowait(self) -> OmniRequestOutput | None:
- try:
- return self._output_queue.get_nowait()
- except asyncio.QueueEmpty:
- return None
-
- async def abort_requests_async(self, request_ids: list[str]) -> None:
- for rid in request_ids:
- task = self._tasks.pop(rid, None)
- if task:
- task.cancel()
- self._engine.abort(rid)
-
- async def collective_rpc_async(
- self,
- method: str,
- timeout: float | None = None,
- args: tuple[Any, ...] = (),
- kwargs: dict[str, Any] | None = None,
- ) -> Any:
- loop = asyncio.get_running_loop()
-
- if method == "profile":
- is_start = args[0] if args else True
- profile_prefix = args[1] if len(args) > 1 else None
- if is_start and profile_prefix is None:
- profile_prefix = f"stage_{self.stage_id}_diffusion_{int(time.time())}"
- return await loop.run_in_executor(
- self._executor,
- self._engine.profile,
- is_start,
- profile_prefix,
- )
-
- kwargs = kwargs or {}
-
- # LoRA methods
- if method == "add_lora":
- lora_request = args[0] if args else kwargs.get("lora_request")
- results = await loop.run_in_executor(
- self._executor,
- self._engine.collective_rpc,
- "add_lora",
- timeout,
- (),
- {"lora_request": lora_request},
- None,
- )
- return all(results) if isinstance(results, list) else results
-
- if method == "remove_lora":
- results = await loop.run_in_executor(
- self._executor,
- self._engine.collective_rpc,
- "remove_lora",
- timeout,
- args,
- kwargs,
- None,
- )
- return all(results) if isinstance(results, list) else results
-
- if method == "list_loras":
- results = await loop.run_in_executor(
- self._executor,
- self._engine.collective_rpc,
- "list_loras",
- timeout,
- (),
- {},
- None,
- )
- if not isinstance(results, list):
- return results or []
- merged: set[int] = set()
- for part in results:
- merged.update(part or [])
- return sorted(merged)
-
- if method == "pin_lora":
- lora_id = args[0] if args else kwargs.get("adapter_id")
- results = await loop.run_in_executor(
- self._executor,
- self._engine.collective_rpc,
- "pin_lora",
- timeout,
- (),
- {"adapter_id": lora_id},
- None,
- )
- return all(results) if isinstance(results, list) else results
-
- return await loop.run_in_executor(
- self._executor,
- self._engine.collective_rpc,
- method,
- timeout,
- args,
- kwargs,
- None,
- )
-
- def shutdown(self) -> None:
- self._shutting_down = True
-
- # Cancel all pending tasks
- for task in self._tasks.values():
- task.cancel()
-
- try:
- # Cancel queued futures and wait for the running one to complete deterministically
- self._executor.shutdown(wait=True, cancel_futures=True)
- except Exception:
- pass
-
- try:
- self._engine.close()
- except Exception:
- pass
diff --git a/vllm_omni/diffusion/ipc.py b/vllm_omni/diffusion/ipc.py
index 6a96533fd40..9aafc1cf17f 100644
--- a/vllm_omni/diffusion/ipc.py
+++ b/vllm_omni/diffusion/ipc.py
@@ -78,29 +78,13 @@ def _tensor_from_shm(handle: dict[str, Any]) -> torch.Tensor:
return tensor
-def _pack_tensor_if_large(val: torch.Tensor) -> torch.Tensor | dict:
- """Replace a tensor with an SHM handle if it exceeds the threshold."""
- if val.nelement() * val.element_size() > _SHM_TENSOR_THRESHOLD:
- return _tensor_to_shm(val)
- return val
-
-
-def _unpack_if_shm_handle(val: object) -> object:
- """Reconstruct a tensor from an SHM handle dict, or return as-is."""
- if isinstance(val, dict) and val.get("__tensor_shm__"):
- return _tensor_from_shm(val)
- return val
-
-
def _pack_diffusion_fields(output: DiffusionOutput) -> DiffusionOutput:
if output.output is not None and isinstance(output.output, torch.Tensor):
- output.output = _pack_tensor_if_large(output.output)
+ if output.output.nelement() * output.output.element_size() > _SHM_TENSOR_THRESHOLD:
+ output.output = _tensor_to_shm(output.output)
if output.trajectory_latents is not None and isinstance(output.trajectory_latents, torch.Tensor):
- output.trajectory_latents = _pack_tensor_if_large(output.trajectory_latents)
- if output.trajectory_timesteps is not None and isinstance(output.trajectory_timesteps, torch.Tensor):
- output.trajectory_timesteps = _pack_tensor_if_large(output.trajectory_timesteps)
- if output.trajectory_log_probs is not None and isinstance(output.trajectory_log_probs, torch.Tensor):
- output.trajectory_log_probs = _pack_tensor_if_large(output.trajectory_log_probs)
+ if output.trajectory_latents.nelement() * output.trajectory_latents.element_size() > _SHM_TENSOR_THRESHOLD:
+ output.trajectory_latents = _tensor_to_shm(output.trajectory_latents)
return output
@@ -120,10 +104,10 @@ def pack_diffusion_output_shm(output: object) -> object:
def _unpack_diffusion_fields(output: DiffusionOutput) -> DiffusionOutput:
- output.output = _unpack_if_shm_handle(output.output)
- output.trajectory_latents = _unpack_if_shm_handle(output.trajectory_latents)
- output.trajectory_timesteps = _unpack_if_shm_handle(output.trajectory_timesteps)
- output.trajectory_log_probs = _unpack_if_shm_handle(output.trajectory_log_probs)
+ if isinstance(output.output, dict) and output.output.get("__tensor_shm__"):
+ output.output = _tensor_from_shm(output.output)
+ if isinstance(output.trajectory_latents, dict) and output.trajectory_latents.get("__tensor_shm__"):
+ output.trajectory_latents = _tensor_from_shm(output.trajectory_latents)
return output
diff --git a/vllm_omni/diffusion/layers/adalayernorm.py b/vllm_omni/diffusion/layers/adalayernorm.py
index d147bdcfeb6..35f63e2fc91 100644
--- a/vllm_omni/diffusion/layers/adalayernorm.py
+++ b/vllm_omni/diffusion/layers/adalayernorm.py
@@ -7,7 +7,6 @@
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm_omni.diffusion.layers.custom_op import CustomOp
-from vllm_omni.diffusion.layers.norm import LayerNorm
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
@@ -28,63 +27,107 @@ def __init__(self, hidden_size: int, elementwise_affine: bool = False, eps: floa
self.eps = eps
self.elementwise_affine = elementwise_affine
self.hidden_size = hidden_size
- self.layernorm = LayerNorm(self.hidden_size, elementwise_affine=self.elementwise_affine, eps=self.eps)
+ self.layernorm = nn.LayerNorm(self.hidden_size, elementwise_affine=self.elementwise_affine, eps=self.eps)
+
+ def preprocess(
+ self,
+ mod_params: torch.Tensor,
+ index: torch.Tensor = None,
+ ) -> torch.Tensor:
+ # shift: b d, scale: b d, gate: b d
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
+
+ if index is not None:
+ # Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts)
+ # So shift, scale, gate have shape [2*actual_batch, d]
+ actual_batch = shift.size(0) // 2
+ shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d]
+ scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
+ gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
+
+ # index: [b, l] where b is actual batch size
+ # Expand to [b, l, 1] to match feature dimension
+ index_expanded = index.unsqueeze(-1) # [b, l, 1]
+
+ # Expand chunks to [b, 1, d] then broadcast to [b, l, d]
+ shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d]
+ shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d]
+ scale_0_exp = scale_0.unsqueeze(1)
+ scale_1_exp = scale_1.unsqueeze(1)
+ gate_0_exp = gate_0.unsqueeze(1)
+ gate_1_exp = gate_1.unsqueeze(1)
+
+ # Use torch.where to select based on index
+ shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
+ scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
+ gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
+ else:
+ shift_result = shift.unsqueeze(1)
+ scale_result = scale.unsqueeze(1)
+ gate_result = gate.unsqueeze(1)
+
+ return shift_result, scale_result, gate_result
def forward_cuda(
self,
x: torch.Tensor,
- scale: torch.Tensor,
- shift: torch.Tensor,
+ mod_params: torch.Tensor,
+ index: torch.Tensor = None,
) -> torch.Tensor:
- return self.forward_native(x, scale, shift)
+ return self.forward_native(x, mod_params, index)
def forward_hip(
self,
x: torch.Tensor,
- scale: torch.Tensor,
- shift: torch.Tensor,
+ mod_params: torch.Tensor,
+ index: torch.Tensor = None,
) -> torch.Tensor:
- return self.forward_native(x, scale, shift)
+ return self.forward_native(x, mod_params, index)
def forward_npu(
self,
x: torch.Tensor,
- scale: torch.Tensor,
- shift: torch.Tensor,
+ mod_params: torch.Tensor,
+ index: torch.Tensor = None,
) -> torch.Tensor:
+ shift_result, scale_result, gate_result = self.preprocess(mod_params, index)
+
if _HAS_MINDIESD:
try:
from mindiesd import layernorm_scale_shift
- output = layernorm_scale_shift(self.layernorm, x, scale, shift, fused=True)
+ output = layernorm_scale_shift(self.layernorm, x, scale_result, shift_result, fused=True)
- return output
+ return output, gate_result
except ImportError as e:
logger.warning_once(f"mindiesd import failed, falling back to torch_npu: {e}")
import torch_npu
output = (
- torch_npu.npu_layer_norm_eval(x, normalized_shape=[self.hidden_size], eps=self.eps) * (1 + scale) + shift
+ torch_npu.npu_layer_norm_eval(x, normalized_shape=[self.hidden_size], eps=self.eps) * (1 + scale_result)
+ + shift_result
)
- return output
+ return output, gate_result
def forward_xpu(
self,
x: torch.Tensor,
- scale: torch.Tensor,
- shift: torch.Tensor,
+ mod_params: torch.Tensor,
+ index: torch.Tensor = None,
) -> torch.Tensor:
- return self.forward_native(x, scale, shift)
+ return self.forward_native(x, mod_params, index)
def forward_native(
self,
x: torch.Tensor,
- scale: torch.Tensor,
- shift: torch.Tensor,
+ mod_params: torch.Tensor,
+ index: torch.Tensor = None,
) -> torch.Tensor:
- return self.layernorm(x) * (1 + scale) + shift
+ shift_result, scale_result, gate_result = self.preprocess(mod_params, index)
+
+ return self.layernorm(x) * (1 + scale_result) + shift_result, gate_result
class AdaLayerNormZero(nn.Module):
diff --git a/vllm_omni/diffusion/layers/norm.py b/vllm_omni/diffusion/layers/norm.py
deleted file mode 100644
index f397c1a855d..00000000000
--- a/vllm_omni/diffusion/layers/norm.py
+++ /dev/null
@@ -1,188 +0,0 @@
-from importlib.util import find_spec
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from vllm.logger import init_logger
-
-from vllm_omni.diffusion.layers.custom_op import CustomOp
-
-logger = init_logger(__name__)
-
-_HAS_MINDIESD = find_spec("mindiesd") is not None
-
-
-class LayerNorm(nn.LayerNorm, CustomOp):
- """
- LayerNorm implementation that inherits from both ``nn.LayerNorm`` and ``CustomOp``.
- NPU:
- Uses ``mindiesd.fast_layernorm(self, x)`` when MindIE-SD is installed.
- CUDA / HIP / XPU / native:
- Falls back to FP32 nn.LayerNorm implementation.
- """
-
- def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
- super().__init__(normalized_shape=dim, eps=eps, elementwise_affine=elementwise_affine)
- # CustomOp.__init__ cannot be called here because it would re-run
- # nn.Module initialization and clear LayerNorm parameters.
- self._forward_method = CustomOp.dispatch_forward(self)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self._forward_method(x)
-
- def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
- return self.forward_native(x)
-
- def forward_hip(self, x: torch.Tensor) -> torch.Tensor:
- return self.forward_native(x)
-
- def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
- return self.forward_native(x)
-
- def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
- if _HAS_MINDIESD:
- try:
- from mindiesd import fast_layernorm
-
- return fast_layernorm(self, x)
- except ImportError as e:
- logger.warning_once(
- "mindiesd.fast_layernorm import failed, falling back to FP32 layer_norm: %s",
- e,
- )
-
- return self.forward_native(x)
-
- def forward_native(self, x: torch.Tensor) -> torch.Tensor:
- origin_dtype = x.dtype
- return F.layer_norm(
- x.float(),
- self.normalized_shape,
- self.weight.float() if self.weight is not None else None,
- self.bias.float() if self.bias is not None else None,
- self.eps,
- ).to(origin_dtype)
-
-
-class RMSNorm(CustomOp):
- def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward_cuda(
- self,
- x: torch.Tensor,
- ) -> torch.Tensor:
- return self.forward_native(x)
-
- def forward_hip(
- self,
- x: torch.Tensor,
- ) -> torch.Tensor:
- return self.forward_native(x)
-
- def forward_npu(
- self,
- x: torch.Tensor,
- ) -> torch.Tensor:
- import torch_npu
-
- output = torch_npu.npu_rms_norm(x, gamma=self.weight, epsilon=self.variance_epsilon)[0]
-
- return output
-
- def forward_xpu(
- self,
- x: torch.Tensor,
- ) -> torch.Tensor:
- return self.forward_native(x)
-
- def forward_native(
- self,
- x: torch.Tensor,
- ) -> torch.Tensor:
- input_dtype = x.dtype
- x = x.to(torch.float32)
- variance = x.pow(2).mean(-1, keepdim=True)
- out = x * torch.rsqrt(variance + self.variance_epsilon)
- out = self.weight.to(torch.float32) * out
- return out.to(input_dtype)
-
-
-class RMSNormVAE(CustomOp):
- """Root Mean Square Layer Normalization for Channel-First or Last"""
-
- def __init__(
- self,
- dim: int,
- channel_first: bool = True,
- images: bool = True,
- bias: bool = False,
- epsilon: float = 1e-6,
- ) -> None:
- super().__init__()
- broadcastable_dims = (1, 1, 1) if not images else (1, 1)
- shape = (dim, *broadcastable_dims) if channel_first else (dim,)
-
- self.channel_first = channel_first
- self.scale = dim**0.5
- self.gamma = nn.Parameter(torch.ones(shape))
- self.bias = nn.Parameter(torch.zeros(shape)) if bias else None
- self.epsilon = epsilon
-
- self.gamma_rmsnorm = None
-
- def forward_cuda(
- self,
- x: torch.Tensor,
- ) -> torch.Tensor:
- return self.forward_native(x)
-
- def forward_hip(
- self,
- x: torch.Tensor,
- ) -> torch.Tensor:
- return self.forward_native(x)
-
- def forward_npu(
- self,
- x: torch.Tensor,
- ) -> torch.Tensor:
- import torch_npu
-
- if self.gamma_rmsnorm is None:
- self.gamma_rmsnorm = self.gamma.reshape(-1)
-
- if self.channel_first:
- x = x.transpose(1, -1)
- out = torch_npu.npu_rms_norm(x, self.gamma_rmsnorm, epsilon=self.epsilon)[0].transpose(1, -1)
- else:
- out = torch_npu.npu_rms_norm(x, self.gamma_rmsnorm, epsilon=self.epsilon)[0]
-
- if self.bias is not None:
- out = out + self.bias
- return out
-
- def forward_xpu(
- self,
- x: torch.Tensor,
- ) -> torch.Tensor:
- return self.forward_native(x)
-
- def forward_native(
- self,
- x: torch.Tensor,
- ) -> torch.Tensor:
- out = (
- F.normalize(
- x,
- dim=(1 if self.channel_first else -1),
- eps=self.epsilon,
- )
- * self.scale
- * self.gamma
- )
- if self.bias is not None:
- out = out + self.bias
- return out
diff --git a/vllm_omni/diffusion/layers/rope.py b/vllm_omni/diffusion/layers/rope.py
index 127e1b2cbdf..65d37d0b017 100644
--- a/vllm_omni/diffusion/layers/rope.py
+++ b/vllm_omni/diffusion/layers/rope.py
@@ -72,18 +72,18 @@ class RotaryEmbedding(CustomOp):
of 1st half and 2nd half (GPT-NeoX style).
"""
- def __init__(self, is_neox_style: bool = False) -> None:
+ def __init__(
+ self,
+ is_neox_style: bool = False,
+ ) -> None:
super().__init__()
self.is_neox_style = is_neox_style
self.interleaved = not is_neox_style
self.apply_rotary_emb_flash_attn = None
- self.has_mindie = False
if find_spec("flash_attn") is not None:
from flash_attn.ops.triton.rotary import apply_rotary
self.apply_rotary_emb_flash_attn = apply_rotary
- if find_spec("mindiesd") is not None:
- self.has_mindie = True
def forward_cuda(
self,
@@ -132,7 +132,7 @@ def forward_npu(
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
- if self.has_mindie:
+ if find_spec("mindiesd"):
return apply_rotary_emb_mindiesd(x, cos, sin, self.interleaved)
else:
return self.forward_native(x, cos, sin)
@@ -145,14 +145,6 @@ def forward_xpu(
) -> torch.Tensor:
return self.forward_native(x, cos, sin)
- def forward_musa(
- self,
- x: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
- ) -> torch.Tensor:
- return self.forward_native(x, cos, sin)
-
def forward_native(
self,
x: torch.Tensor,
@@ -167,56 +159,6 @@ def forward_native(
)
-class RotaryEmbeddingWan(RotaryEmbedding):
- """
- rotary positional embedding for Wan.
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
- of 1st half and 2nd half (GPT-NeoX style).
- """
-
- def __init__(self, is_neox_style: bool = False, half_head_dim: bool = False) -> None:
- super().__init__(is_neox_style=is_neox_style)
- self.half_head_dim = half_head_dim
-
- def forward_cuda(
- self,
- x: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
- ) -> torch.Tensor:
- return self.forward_native(x, cos, sin)
-
- def forward_npu(
- self,
- x: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
- ) -> torch.Tensor:
- if self.has_mindie:
- if cos.dim() > 2:
- cos = cos.reshape(-1, cos.shape[-1])
- sin = sin.reshape(-1, sin.shape[-1])
- return apply_rotary_emb_mindiesd(x, cos, sin, self.interleaved, self.half_head_dim)
- else:
- return self.forward_native(x, cos, sin)
-
- def forward_native(
- self,
- x: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
- ) -> torch.Tensor:
- x1, x2 = x.unflatten(-1, (-1, 2)).unbind(-1)
- rotated = torch.stack(
- (
- x1 * cos - x2 * sin,
- x1 * sin + x2 * cos,
- ),
- dim=-1,
- )
- return rotated.flatten(-2, -1).to(x.dtype)
-
-
def apply_rope_to_qk(
rope: RotaryEmbedding,
query: torch.Tensor,
diff --git a/vllm_omni/diffusion/lora/manager.py b/vllm_omni/diffusion/lora/manager.py
index 63e8d9a96f5..5f75e26cb16 100644
--- a/vllm_omni/diffusion/lora/manager.py
+++ b/vllm_omni/diffusion/lora/manager.py
@@ -366,17 +366,13 @@ def _matches_target(module_name: str) -> bool:
fully_sharded_loras=False,
)
- for component_name in ("transformer", "transformer_2", "dit", "bagel"):
+ for component_name in ("transformer", "transformer_2", "dit"):
if not hasattr(self.pipeline, component_name):
continue
component = getattr(self.pipeline, component_name)
if not isinstance(component, nn.Module):
continue
- # Collect replacements first to avoid mutating the module tree
- # while iterating over named_modules().
- pending_replacements: list[tuple[str, str, nn.Module, list[str]]] = []
-
for module_name, module in component.named_modules(remove_duplicate=False):
# Don't recurse into already-replaced LoRA wrappers. Their
# original LinearBase lives under "base_layer", and replacing
@@ -405,9 +401,6 @@ def _matches_target(module_name: str) -> bool:
if not should_replace:
continue
- pending_replacements.append((module_name, full_module_name, module, packed_modules_list))
-
- for module_name, full_module_name, module, packed_modules_list in pending_replacements:
lora_layer = from_layer_diffusion(
layer=module,
max_loras=1,
diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py
index 91f3574b185..146afb26fbc 100644
--- a/vllm_omni/diffusion/model_loader/diffusers_loader.py
+++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py
@@ -32,7 +32,6 @@
from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.distributed.hsdp import HSDPInferenceConfig
from vllm_omni.diffusion.model_loader.gguf_adapters import get_gguf_adapter
-from vllm_omni.diffusion.models.diffusers_adapter.pipeline_diffusers_adapter import DiffusersAdapterPipeline
from vllm_omni.diffusion.registry import initialize_model
if TYPE_CHECKING:
@@ -258,14 +257,11 @@ def load_model(
self,
od_config: OmniDiffusionConfig,
load_device: str,
- load_format: str | None = "default",
+ load_format: str = "default",
custom_pipeline_name: str | None = None,
device: torch.device | None = None,
) -> nn.Module:
"""Load a model with the given configurations."""
- if load_format is None:
- load_format = "default"
-
# CPU offload + FP8: load weights on device for FP8 quantization
if load_device == "cpu" and od_config.quantization_config is not None:
load_device = device.type
@@ -281,21 +277,11 @@ def load_model(
with target_device:
if load_format == "default":
model = initialize_model(od_config)
- elif load_format == "diffusers":
- model = DiffusersAdapterPipeline(od_config=od_config, device=target_device)
elif load_format == "custom_pipeline":
model_cls = resolve_obj_by_qualname(custom_pipeline_name)
model = model_cls(od_config=od_config)
- else:
- # 'dummy' format should not call this function at all
- raise ValueError(f"Unknown load_format: {load_format}")
logger.debug("Loading weights on %s ...", load_device)
- if load_format == "diffusers":
- # DiffusersAdapterPipeline.load_weights() calls
- # DiffusionPipeline.from_pretrained() internally — it does
- # NOT use our native (customized) pipeline classes.
- cast(DiffusersAdapterPipeline, model).load_weights()
- elif self._is_gguf_quantization(od_config):
+ if self._is_gguf_quantization(od_config):
self._load_weights_with_gguf(model, od_config)
else:
# Quantization does not happen in `load_weights` but after it
diff --git a/vllm_omni/diffusion/model_metadata.py b/vllm_omni/diffusion/model_metadata.py
deleted file mode 100644
index ec133e7380e..00000000000
--- a/vllm_omni/diffusion/model_metadata.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from dataclasses import dataclass
-
-
-@dataclass(frozen=True)
-class DiffusionModelMetadata:
- # Keep serving-facing capability metadata in a lightweight shared module so
- # config/model plumbing can read it without importing concrete pipelines.
- supports_multimodal_inputs: bool = False
- max_multimodal_image_inputs: int | None = None
-
-
-QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES = 4
-
-
-_DIFFUSION_MODEL_METADATA: dict[str, DiffusionModelMetadata] = {
- "QwenImageEditPlusPipeline": DiffusionModelMetadata(
- supports_multimodal_inputs=True,
- max_multimodal_image_inputs=QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES,
- ),
-}
-
-
-def get_diffusion_model_metadata(model_class_name: str | None) -> DiffusionModelMetadata:
- # Unknown models fall back to "no special multimodal capabilities" so new
- # pipelines do not accidentally inherit limits meant for other models.
- if model_class_name is None:
- return DiffusionModelMetadata()
- return _DIFFUSION_MODEL_METADATA.get(model_class_name, DiffusionModelMetadata())
diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py
index cff09db3c0c..685d14729e5 100644
--- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py
+++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py
@@ -25,7 +25,6 @@
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
- MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
@@ -158,12 +157,21 @@ def __init__(
prefix: str = "",
) -> None:
super().__init__()
- self.gate_up_proj = MergedColumnParallelLinear(
+ self.gate_proj = ColumnParallelLinear(
hidden_size,
- [intermediate_size, intermediate_size],
+ intermediate_size,
+ bias=False,
+ gather_output=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_proj",
+ )
+ self.up_proj = ColumnParallelLinear(
+ hidden_size,
+ intermediate_size,
bias=False,
+ gather_output=False,
quant_config=quant_config,
- prefix=f"{prefix}.gate_up_proj",
+ prefix=f"{prefix}.up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
@@ -178,8 +186,8 @@ def __init__(
self.act_fn = nn.SiLU()
def forward(self, x):
- gate_up, _ = self.gate_up_proj(x)
- gate, up = gate_up.chunk(2, dim=-1)
+ gate, _ = self.gate_proj(x)
+ up, _ = self.up_proj(x)
x = self.act_fn(gate) * up
x, _ = self.down_proj(x)
return x
@@ -738,8 +746,6 @@ def forward(
class Qwen2MoTModel(Qwen2PreTrainedModel):
- _layerwise_offload_blocks_attrs = ["layers"]
-
def __init__(
self,
config,
@@ -856,7 +862,6 @@ def __init__(
config, parallel_config=parallel_config, quant_config=quant_config, prefix=f"{prefix}.model"
)
self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
@@ -867,12 +872,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.embed_tokens = value
- def get_output_embeddings(self):
- return self.lm_head
-
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
-
def set_decoder(self, decoder):
self.model = decoder
@@ -930,38 +929,27 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
- # MLP gate/up projections — fused into MergedColumnParallelLinear.
- # HF checkpoints store separate gate_proj / up_proj weights;
- # these entries remap them to the fused gate_up_proj parameter.
- (".gate_up_proj", ".gate_proj", 0),
- (".gate_up_proj", ".up_proj", 1),
]
- self.stacked_params_mapping = stacked_params_mapping
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
- loaded = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
- stacked_name = name.replace(weight_name, param_name)
- param = params_dict.get(stacked_name)
+ name = name.replace(weight_name, param_name)
+ param = params_dict.get(name)
if param is None:
break
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight, shard_id)
- name = stacked_name
- loaded = True
break
-
- if not loaded:
+ else:
param = params_dict.get(name)
if param is None:
continue
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
-
loaded_params.add(name)
return loaded_params
@@ -1216,7 +1204,7 @@ def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
- text_ids = tokenizer.encode(prompt, add_special_tokens=False)
+ text_ids = tokenizer.encode(prompt)
text_ids = [new_token_ids["bos_token_id"]] + text_ids + [new_token_ids["eos_token_id"]]
text_token_lens.append(len(text_ids))
packed_text_ids.extend(text_ids)
@@ -1628,110 +1616,10 @@ def _merge_naive_caches(caches: list) -> NaiveCache:
num_layers = len(caches[0].key_cache)
merged = NaiveCache(num_layers)
for layer_idx in range(num_layers):
- key_parts = [c.key_cache[layer_idx] for c in caches if c.key_cache[layer_idx] is not None]
- val_parts = [c.value_cache[layer_idx] for c in caches if c.value_cache[layer_idx] is not None]
- merged.key_cache[layer_idx] = torch.cat(key_parts, dim=0) if key_parts else None
- merged.value_cache[layer_idx] = torch.cat(val_parts, dim=0) if val_parts else None
+ merged.key_cache[layer_idx] = torch.cat([c.key_cache[layer_idx] for c in caches], dim=0)
+ merged.value_cache[layer_idx] = torch.cat([c.value_cache[layer_idx] for c in caches], dim=0)
return merged
- def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids):
- """Prepare start tokens for autoregressive text generation.
-
- Ported from the original BAGEL ``Bagel.prepare_start_tokens``.
- """
- packed_start_tokens, packed_key_value_indexes = list(), list()
- packed_query_position_ids = list()
-
- curr = 0
- for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope):
- packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
- packed_start_tokens.append(new_token_ids["bos_token_id"])
- packed_query_position_ids.append(curr_position_id)
- curr += curr_kvlen
-
- generation_input = {
- "packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long),
- "packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long),
- "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
- "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
- }
- return generation_input
-
- @torch.no_grad()
- def generate_text(
- self,
- past_key_values: NaiveCache,
- packed_key_value_indexes: torch.LongTensor,
- key_values_lens: torch.IntTensor,
- packed_start_tokens: torch.LongTensor,
- packed_query_position_ids: torch.LongTensor,
- max_length: int,
- do_sample: bool = False,
- temperature: float = 1.0,
- end_token_id: int | None = None,
- ):
- """Autoregressive text generation (ported from original BAGEL).
-
- Decodes tokens one at a time, appending to ``past_key_values``
- until ``max_length`` is reached or ``end_token_id`` is generated.
- """
- step = 0
- generated_sequence = []
- curr_tokens = packed_start_tokens
- while step < max_length:
- generated_sequence.append(curr_tokens)
- packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
- query_lens = torch.ones_like(curr_tokens)
- packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
- 0,
- len(key_values_lens),
- device=key_values_lens.device,
- dtype=key_values_lens.dtype,
- )
-
- uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
- for i in range(len(uppacked)):
- uppacked[i] += i
- packed_key_value_indexes = torch.cat(uppacked, dim=0)
-
- output = self.language_model(
- packed_query_sequence=packed_text_embedding,
- query_lens=query_lens,
- packed_query_position_ids=packed_query_position_ids,
- packed_query_indexes=packed_query_indexes,
- past_key_values=past_key_values,
- key_values_lens=key_values_lens,
- packed_key_value_indexes=packed_key_value_indexes,
- update_past_key_values=True,
- is_causal=True,
- mode="und",
- )
- past_key_values = output.past_key_values
- packed_query_sequence = output.packed_query_sequence
- pred_logits = self.language_model.lm_head(packed_query_sequence)
-
- if do_sample:
- probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
- curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
- else:
- curr_tokens = torch.argmax(pred_logits, dim=-1)
-
- uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
- for i in range(len(uppacked)):
- uppacked[i] = torch.cat(
- [uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0
- )
- packed_key_value_indexes = torch.cat(uppacked, dim=0)
- key_values_lens = key_values_lens + 1
- packed_query_position_ids = packed_query_position_ids + 1
- step += 1
-
- if end_token_id is not None and curr_tokens[0] == end_token_id:
- break
-
- output_device = generated_sequence[0].device
- return torch.stack([i.to(output_device) for i in generated_sequence], dim=0)
-
def generate_image(
self,
packed_text_ids: torch.LongTensor,
@@ -1764,9 +1652,6 @@ def generate_image(
cfg_img_past_key_values: NaiveCache | None = None,
cfg_img_key_values_lens: torch.IntTensor | None = None,
cfg_img_packed_key_value_indexes: torch.LongTensor | None = None,
- return_trajectory_latents: bool = False,
- scheduler: object | None = None,
- scheduler_kwargs: dict | None = None,
):
x_t = packed_init_noises
@@ -1775,14 +1660,6 @@ def generate_image(
dts = timesteps[:-1] - timesteps[1:]
timesteps = timesteps[:-1]
- # Optional trajectory recording for RL rollout data collection
- trajectory_latents: list[torch.Tensor] | None = [] if return_trajectory_latents else None
- trajectory_timesteps: list[torch.Tensor] | None = [] if return_trajectory_latents else None
- trajectory_log_probs: list[torch.Tensor] | None = (
- [] if (return_trajectory_latents and scheduler is not None) else None
- )
- _sched_kw = scheduler_kwargs or {}
-
use_cfg_text = cfg_text_scale > 1.0
use_cfg_img = cfg_img_scale > 1.0
@@ -1819,9 +1696,6 @@ def generate_image(
cfg_img_past_key_values=cfg_img_past_key_values,
cfg_img_key_values_lens=cfg_img_key_values_lens,
cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes,
- return_trajectory_latents=return_trajectory_latents,
- scheduler=scheduler,
- scheduler_kwargs=scheduler_kwargs,
)
# ── SP + CFG: sequential single-branch forwards ──
@@ -1843,7 +1717,7 @@ def generate_image(
packed_seqlens=packed_seqlens,
)
- v_t = self.forward_single_branch(
+ v_t = self._forward_flow_single_branch(
**common,
packed_indexes=packed_indexes,
packed_position_ids=packed_position_ids,
@@ -1853,7 +1727,7 @@ def generate_image(
)
if cfg_text_scale_ > 1.0:
- cfg_text_v_t = self.forward_single_branch(
+ cfg_text_v_t = self._forward_flow_single_branch(
**common,
packed_indexes=cfg_text_packed_query_indexes,
packed_position_ids=cfg_text_packed_position_ids,
@@ -1863,7 +1737,7 @@ def generate_image(
)
cfg_img_v_t = None
if cfg_img_scale_ > 1.0:
- cfg_img_v_t = self.forward_single_branch(
+ cfg_img_v_t = self._forward_flow_single_branch(
**common,
packed_indexes=cfg_img_packed_query_indexes,
packed_position_ids=cfg_img_packed_position_ids,
@@ -1881,25 +1755,16 @@ def generate_image(
cfg_renorm_min,
)
- if scheduler is not None:
- out = scheduler.step(v_t.to(x_t.device), timesteps[i], x_t, dts[i], **_sched_kw)
- x_t = out.prev_sample
- if trajectory_log_probs is not None and out.log_prob is not None:
- trajectory_log_probs.append(out.log_prob)
- else:
- x_t = x_t - v_t.to(x_t.device) * dts[i]
- if return_trajectory_latents:
- trajectory_latents.append(x_t.clone())
- trajectory_timesteps.append(timesteps[i] - dts[i])
+ x_t = x_t - v_t.to(x_t.device) * dts[i]
unpacked_latent = x_t.split((packed_seqlens - 2).tolist())
- return unpacked_latent, trajectory_latents, trajectory_timesteps, trajectory_log_probs
+ return unpacked_latent
# ── SP without CFG: direct single-branch loop ──
if use_sp:
for i, t in enumerate(timesteps):
timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
- v_t = self.forward_single_branch(
+ v_t = self._forward_flow_single_branch(
x_t=x_t,
timestep=timestep,
packed_vae_token_indexes=packed_vae_token_indexes,
@@ -1913,20 +1778,10 @@ def generate_image(
past_key_values=past_key_values,
packed_key_value_indexes=packed_key_value_indexes,
)
- if scheduler is not None:
- out = scheduler.step(v_t.to(x_t.device), timesteps[i], x_t, dts[i], **_sched_kw)
- x_t = out.prev_sample
- out_log_prob = getattr(out, "log_prob", None)
- if trajectory_log_probs is not None and out_log_prob is not None:
- trajectory_log_probs.append(out_log_prob)
- else:
- x_t = x_t - v_t.to(x_t.device) * dts[i]
- if return_trajectory_latents:
- trajectory_latents.append(x_t.clone())
- trajectory_timesteps.append(timesteps[i] - dts[i])
+ x_t = x_t - v_t.to(x_t.device) * dts[i]
unpacked_latent = x_t.split((packed_seqlens - 2).tolist())
- return unpacked_latent, trajectory_latents, trajectory_timesteps, trajectory_log_probs
+ return unpacked_latent
# ── Batched CFG mode (cfg_parallel_size=1, no SP) ──
cfg_batched = None
@@ -1992,7 +1847,7 @@ def generate_image(
else:
cfg_text_scale_ = 1.0
cfg_img_scale_ = 1.0
- v_t = self.forward(
+ v_t = self._forward_flow(
x_t=x_t,
timestep=timestep,
packed_vae_token_indexes=packed_vae_token_indexes,
@@ -2012,19 +1867,10 @@ def generate_image(
cfg_batched=cfg_batched,
)
- if scheduler is not None:
- out = scheduler.step(v_t.to(x_t.device), timesteps[i], x_t, dts[i], **_sched_kw)
- x_t = out.prev_sample
- if trajectory_log_probs is not None and out.log_prob is not None:
- trajectory_log_probs.append(out.log_prob)
- else:
- x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise
- if return_trajectory_latents:
- trajectory_latents.append(x_t.clone())
- trajectory_timesteps.append(timesteps[i] - dts[i])
+ x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise
unpacked_latent = x_t.split((packed_seqlens - 2).tolist())
- return unpacked_latent, trajectory_latents, trajectory_timesteps, trajectory_log_probs
+ return unpacked_latent
def _generate_image_parallel(
self,
@@ -2056,9 +1902,6 @@ def _generate_image_parallel(
cfg_img_past_key_values: NaiveCache | None,
cfg_img_key_values_lens: torch.IntTensor | None,
cfg_img_packed_key_value_indexes: torch.LongTensor | None,
- return_trajectory_latents: bool = False,
- scheduler: object | None = None,
- scheduler_kwargs: dict | None = None,
):
"""CFG parallel denoising loop: each rank computes one CFG branch.
@@ -2115,20 +1958,13 @@ def _generate_image_parallel(
else:
raise RuntimeError(f"Unexpected cfg_rank={cfg_rank} for Bagel 3-branch CFG parallel")
- trajectory_latents: list[torch.Tensor] | None = [] if return_trajectory_latents else None
- trajectory_timesteps: list[torch.Tensor] | None = [] if return_trajectory_latents else None
- trajectory_log_probs: list[torch.Tensor] | None = (
- [] if (return_trajectory_latents and scheduler is not None) else None
- )
- _sched_kw = scheduler_kwargs or {}
-
for i, t in enumerate(timesteps):
timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
use_cfg_this_step = t > cfg_interval[0] and t <= cfg_interval[1] and cfg_text_scale > 1.0
if use_cfg_this_step:
# CFG interval: each rank computes its own branch
- local_v_t = self.forward_single_branch(
+ local_v_t = self._forward_flow_single_branch(
x_t=x_t,
timestep=timestep,
packed_vae_token_indexes=packed_vae_token_indexes,
@@ -2155,7 +1991,7 @@ def _generate_image_parallel(
)
else:
# Outside CFG interval: all ranks compute with gen inputs, no comm
- v_t = self.forward_single_branch(
+ v_t = self._forward_flow_single_branch(
x_t=x_t,
timestep=timestep,
packed_vae_token_indexes=packed_vae_token_indexes,
@@ -2170,19 +2006,10 @@ def _generate_image_parallel(
packed_key_value_indexes=packed_key_value_indexes,
)
- if scheduler is not None:
- out = scheduler.step(v_t.to(x_t.device), timesteps[i], x_t, dts[i], **_sched_kw)
- x_t = out.prev_sample
- if trajectory_log_probs is not None and out.log_prob is not None:
- trajectory_log_probs.append(out.log_prob)
- else:
- x_t = x_t - v_t.to(x_t.device) * dts[i]
- if return_trajectory_latents:
- trajectory_latents.append(x_t.clone())
- trajectory_timesteps.append(timesteps[i] - dts[i])
+ x_t = x_t - v_t.to(x_t.device) * dts[i]
unpacked_latent = x_t.split((packed_seqlens - 2).tolist())
- return unpacked_latent, trajectory_latents, trajectory_timesteps, trajectory_log_probs
+ return unpacked_latent
@staticmethod
def _combine_cfg(
@@ -2237,7 +2064,7 @@ def _combine_cfg(
return v_t
- def forward_single_branch(
+ def _forward_flow_single_branch(
self,
x_t: torch.Tensor,
timestep: torch.LongTensor,
@@ -2367,7 +2194,7 @@ def forward_single_branch(
v_t = v_t[packed_vae_token_indexes]
return v_t
- def forward(
+ def _forward_flow(
self,
x_t: torch.Tensor,
timestep: torch.LongTensor,
diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py
index d08a2cdc80c..3e053cbda50 100644
--- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py
+++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py
@@ -12,7 +12,6 @@
from copy import deepcopy
from dataclasses import dataclass
from math import isqrt
-from typing import ClassVar
import numpy as np
import torch
@@ -27,7 +26,6 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
-from vllm_omni.diffusion.models.interface import SupportsModuleOffload
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific
@@ -150,33 +148,17 @@ def forward(self, packed_pixel_values, packed_flattened_position_ids, cu_seqlens
return outputs.last_hidden_state.squeeze(0)
-class BagelPipeline(nn.Module, SupportsModuleOffload, DiffusionPipelineProfilerMixin):
+class BagelPipeline(nn.Module, DiffusionPipelineProfilerMixin):
"""Bagel generation pipeline (MoT) packaged for vllm-omni diffusion engine.
This pipeline is self-contained and uses the ported Bagel core files.
"""
- _dit_modules: ClassVar[list[str]] = ["language_model.model"]
- _encoder_modules: ClassVar[list[str]] = []
- _vae_modules: ClassVar[list[str]] = ["vae"]
- _resident_modules: ClassVar[list[str]] = [
- "bagel.time_embedder",
- "bagel.vae2llm",
- "bagel.llm2vae",
- "bagel.latent_pos_embed",
- "bagel.vit_model",
- "bagel.connector",
- "bagel.vit_pos_embed",
- ]
-
def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
super().__init__()
self.od_config = od_config
self.device = get_local_device()
- self.scheduler: object | None = None
- self.scheduler_kwargs: dict = {}
-
model = od_config.model
local_files_only = os.path.exists(model)
if local_files_only:
@@ -288,7 +270,7 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
# device until the weight loader materializes them. Calling
# .to(device) would fail on those meta tensors, so we skip it
# entirely and let the weight loader handle device placement.
- if quant_config is None and not od_config.enable_layerwise_offload:
+ if quant_config is None:
self.to(self.device)
self.setup_diffusion_pipeline_profiler(
enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler
@@ -380,65 +362,35 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
if req.sampling_params.kv_metadata and "image_shape" in req.sampling_params.kv_metadata:
image_shape = tuple(req.sampling_params.kv_metadata["image_shape"])
- branch_kvs = getattr(req.sampling_params, "cfg_branch_past_key_values", None) or {}
- branch_metadata = getattr(req.sampling_params, "cfg_branch_kv_metadata", None) or {}
- active_branch = getattr(req.sampling_params, "cfg_active_branch", None)
- branch_roles = getattr(req.sampling_params, "cfg_branch_roles", None) or list(branch_kvs.keys())
-
- cfg_text_kv = getattr(req.sampling_params, "cfg_text_past_key_values", None) or branch_kvs.get("cfg_text")
- cfg_text_metadata = getattr(req.sampling_params, "cfg_text_kv_metadata", None) or branch_metadata.get(
- "cfg_text"
- )
- cfg_img_kv = getattr(req.sampling_params, "cfg_img_past_key_values", None) or branch_kvs.get("cfg_img")
- cfg_img_metadata = getattr(req.sampling_params, "cfg_img_kv_metadata", None) or branch_metadata.get(
- "cfg_img"
- )
-
- cfg_parallel_contract = (
- active_branch is not None or bool(branch_roles) or cfg_text_kv is not None or cfg_img_kv is not None
- )
- if cfg_parallel_contract:
- logger.info(
- "CFG enabled with injected branch KV context roles=%s active=%s",
- branch_roles,
- active_branch,
- )
-
+ cfg_text_kv = getattr(req.sampling_params, "cfg_text_past_key_values", None)
if cfg_text_kv is not None:
+ logger.info("CFG enabled with multi-KV: using injected cfg_text KV Cache")
cfg_text_seq_len = cfg_text_kv.key_cache[0].shape[0]
cfg_text_context["past_key_values"] = cfg_text_kv
cfg_text_context["kv_lens"] = [cfg_text_seq_len]
+ cfg_text_metadata = getattr(req.sampling_params, "cfg_text_kv_metadata", None)
if cfg_text_metadata and "ropes" in cfg_text_metadata:
cfg_text_context["ropes"] = cfg_text_metadata["ropes"]
else:
cfg_text_context["ropes"] = [cfg_text_seq_len]
- else:
- # No cfg_text companion received. For text2img this is the
- # expected path: original BAGEL uses an empty KV cache (0
- # tokens) as the text-unconditional branch. Keep the default
- # empty NaiveCache in cfg_text_context and preserve the
- # original cfg_text_scale so CFG still applies.
- pass
-
- if cfg_img_kv is None:
- # text2img multi-stage: cfg_img reuses gen KV (positive prompt,
- # no image), mirroring forward_cache_update_text on cfg_img_context
- # in the single-stage path.
- cfg_img_seq_len = injected_kv.key_cache[0].shape[0]
- cfg_img_context["past_key_values"] = injected_kv
- cfg_img_context["kv_lens"] = [cfg_img_seq_len]
- if req.sampling_params.kv_metadata and "ropes" in req.sampling_params.kv_metadata:
- cfg_img_context["ropes"] = req.sampling_params.kv_metadata["ropes"]
- else:
- cfg_img_context["ropes"] = [cfg_img_seq_len]
- else:
+
+ cfg_img_kv = getattr(req.sampling_params, "cfg_img_past_key_values", None) or injected_kv
cfg_img_seq_len = cfg_img_kv.key_cache[0].shape[0]
cfg_img_context["past_key_values"] = cfg_img_kv
cfg_img_context["kv_lens"] = [cfg_img_seq_len]
+ cfg_img_metadata = getattr(req.sampling_params, "cfg_img_kv_metadata", None)
if cfg_img_metadata and "ropes" in cfg_img_metadata:
cfg_img_context["ropes"] = cfg_img_metadata["ropes"]
else:
cfg_img_context["ropes"] = [cfg_img_seq_len]
+ else:
+ logger.warning("CFG is disabled: only single KV cache available")
+ gen_params = BagelGenParams(
+ num_timesteps=gen_params.num_timesteps,
+ timestep_shift=gen_params.timestep_shift,
+ cfg_text_scale=1.0,
+ cfg_img_scale=1.0,
+ )
else:
image_input = (
@@ -540,15 +492,11 @@ def vae_transforms(img):
cfg_text_context = deepcopy(gen_context)
- # Strip <|im_start|>/<|im_end|> wrappers that end2end.py may have
- # already added, so prepare_prompts doesn't double-add bos/eos.
- clean_prompt = prompt.removeprefix("<|im_start|>").removesuffix("<|im_end|>")
-
# Update gen_context with text prompt
generation_input, newlens, new_rope = self.bagel.prepare_prompts(
curr_kvlens=gen_context["kv_lens"],
curr_rope=gen_context["ropes"],
- prompts=[clean_prompt],
+ prompts=[prompt],
tokenizer=self.tokenizer,
new_token_ids=self.new_token_ids,
)
@@ -576,37 +524,34 @@ def vae_transforms(img):
gen_context["kv_lens"] = newlens
gen_context["ropes"] = new_rope
- # cfg_text_context: update with negative prompt (no text condition).
- # When empty, keep cfg_text_context as-is (kv_lens=0) to match
- # original BAGEL; _merge_naive_caches handles None KV entries.
+ # cfg_text_context: update with negative prompt (no text condition)
neg_prompt = extra_args.get("negative_prompt", "")
- if neg_prompt:
- neg_input, neg_newlens, neg_rope = self.bagel.prepare_prompts(
- curr_kvlens=cfg_text_context["kv_lens"],
- curr_rope=cfg_text_context["ropes"],
- prompts=[neg_prompt],
- tokenizer=self.tokenizer,
- new_token_ids=self.new_token_ids,
+ neg_input, neg_newlens, neg_rope = self.bagel.prepare_prompts(
+ curr_kvlens=cfg_text_context["kv_lens"],
+ curr_rope=cfg_text_context["ropes"],
+ prompts=[neg_prompt],
+ tokenizer=self.tokenizer,
+ new_token_ids=self.new_token_ids,
+ )
+ for k, v in neg_input.items():
+ if torch.is_tensor(v):
+ neg_input[k] = v.to(self.device)
+ with torch.autocast(
+ device_type=self.device.type,
+ enabled=self.device.type != "cpu",
+ dtype=self.od_config.dtype,
+ ):
+ cfg_text_context["past_key_values"] = self.bagel.forward_cache_update_text(
+ cfg_text_context["past_key_values"], **neg_input
)
- for k, v in neg_input.items():
- if torch.is_tensor(v):
- neg_input[k] = v.to(self.device)
- with torch.autocast(
- device_type=self.device.type,
- enabled=self.device.type != "cpu",
- dtype=self.od_config.dtype,
- ):
- cfg_text_context["past_key_values"] = self.bagel.forward_cache_update_text(
- cfg_text_context["past_key_values"], **neg_input
- )
- cfg_text_context["kv_lens"] = neg_newlens
- cfg_text_context["ropes"] = neg_rope
+ cfg_text_context["kv_lens"] = neg_newlens
+ cfg_text_context["ropes"] = neg_rope
# cfg_img_context: update with text prompt (no image condition)
cfg_img_generation_input, cfg_img_newlens, cfg_img_new_rope = self.bagel.prepare_prompts(
curr_kvlens=cfg_img_context["kv_lens"],
curr_rope=cfg_img_context["ropes"],
- prompts=[clean_prompt],
+ prompts=[prompt],
tokenizer=self.tokenizer,
new_token_ids=self.new_token_ids,
)
@@ -624,96 +569,6 @@ def vae_transforms(img):
cfg_img_context["kv_lens"] = cfg_img_newlens
cfg_img_context["ropes"] = cfg_img_new_rope
- # ---- Detect output modality and think mode ----
- modalities = first_prompt.get("modalities", []) if isinstance(first_prompt, dict) else []
- is_text_output = "text" in modalities
- think_enabled = extra_args.get("think", False)
- think_text = None
-
- if think_enabled and injected_kv is None:
- max_think_tokens = int(extra_args.get("max_think_tokens", 1000))
- do_sample = bool(extra_args.get("do_sample", False))
- text_temperature = float(extra_args.get("text_temperature", 0.3))
-
- with torch.autocast(
- device_type=self.device.type,
- enabled=self.device.type != "cpu",
- dtype=self.od_config.dtype,
- ):
- start_input = self.bagel.prepare_start_tokens(
- gen_context["kv_lens"], gen_context["ropes"], self.new_token_ids
- )
- for k, v in start_input.items():
- if torch.is_tensor(v):
- start_input[k] = v.to(self.device)
-
- gen_ctx_copy = deepcopy(gen_context)
- token_ids = self.bagel.generate_text(
- past_key_values=gen_ctx_copy["past_key_values"],
- max_length=max_think_tokens,
- do_sample=do_sample,
- temperature=text_temperature,
- end_token_id=self.new_token_ids["eos_token_id"],
- **start_input,
- )
- # token_ids shape: (seq_len, batch=1)
- decoded = self.tokenizer.decode(token_ids[:, 0].tolist())
- # Strip chat markers to get clean text
- think_text = decoded.split("<|im_end|>")[0]
- if "<|im_start|>" in think_text:
- think_text = think_text.split("<|im_start|>")[-1]
- logger.info("Think mode generated %d tokens", token_ids.shape[0])
-
- if not is_text_output:
- # Use the autoregressive KV cache from think generation
- # directly, instead of decode→re-encode which adds extra
- # bos/eos and may alter tokenization.
- num_think_tokens = token_ids.shape[0]
- gen_context["past_key_values"] = gen_ctx_copy["past_key_values"]
- gen_context["kv_lens"] = [kl + num_think_tokens for kl in gen_context["kv_lens"]]
- gen_context["ropes"] = [r + num_think_tokens for r in gen_context["ropes"]]
-
- # ---- Text-only output (text2text / img2text) ----
- if is_text_output and injected_kv is None:
- if think_text is not None:
- # Think mode already generated the text (including reasoning)
- text_output = think_text
- else:
- max_text_tokens = int(extra_args.get("max_think_tokens", 500))
- do_sample = bool(extra_args.get("do_sample", False))
- text_temperature = float(extra_args.get("text_temperature", 0.3))
-
- with torch.autocast(
- device_type=self.device.type,
- enabled=self.device.type != "cpu",
- dtype=self.od_config.dtype,
- ):
- start_input = self.bagel.prepare_start_tokens(
- gen_context["kv_lens"], gen_context["ropes"], self.new_token_ids
- )
- for k, v in start_input.items():
- if torch.is_tensor(v):
- start_input[k] = v.to(self.device)
- token_ids = self.bagel.generate_text(
- past_key_values=gen_context["past_key_values"],
- max_length=max_text_tokens,
- do_sample=do_sample,
- temperature=text_temperature,
- end_token_id=self.new_token_ids["eos_token_id"],
- **start_input,
- )
- decoded = self.tokenizer.decode(token_ids[:, 0].tolist())
- text_output = decoded.split("<|im_end|>")[0]
- if "<|im_start|>" in text_output:
- text_output = text_output.split("<|im_start|>")[-1]
-
- return DiffusionOutput(
- output=text_output,
- custom_output={"text_output": text_output},
- stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None,
- )
-
- # ---- Image generation (text2img / img2img) ----
if req.sampling_params.seed is not None:
torch.manual_seed(req.sampling_params.seed)
if self.device.type == "cuda":
@@ -775,7 +630,7 @@ def vae_transforms(img):
enabled=self.device.type != "cpu",
dtype=self.od_config.dtype,
):
- latents, trajectory_latents, trajectory_timesteps, trajectory_log_probs = self.bagel.generate_image(
+ latents = self.bagel.generate_image(
past_key_values=gen_context["past_key_values"],
cfg_text_past_key_values=cfg_text_context["past_key_values"],
cfg_img_past_key_values=cfg_img_context["past_key_values"],
@@ -795,41 +650,11 @@ def vae_transforms(img):
cfg_img_packed_query_indexes=generation_input_cfg_img["cfg_packed_query_indexes"],
cfg_img_key_values_lens=generation_input_cfg_img["cfg_key_values_lens"],
cfg_img_packed_key_value_indexes=generation_input_cfg_img["cfg_packed_key_value_indexes"],
- return_trajectory_latents=req.sampling_params.return_trajectory_latents,
- scheduler=self.scheduler,
- scheduler_kwargs=self.scheduler_kwargs,
)
img = self._decode_image_from_latent(self.bagel, self.vae, latents[0], image_shape)
-
- # Build trajectory output when requested
- trajectory_latents_stacked: torch.Tensor | None = None
- trajectory_timesteps_stacked: torch.Tensor | None = None
- trajectory_decoded: list[Image.Image] | None = None
- if trajectory_latents:
- trajectory_latents_stacked = torch.stack(trajectory_latents)
- trajectory_timesteps_stacked = torch.stack(trajectory_timesteps)
- if req.sampling_params.return_trajectory_decoded:
- trajectory_decoded = [
- self._decode_image_from_latent(self.bagel, self.vae, lat, image_shape) for lat in trajectory_latents
- ]
-
- trajectory_log_probs_stacked: torch.Tensor | None = None
- if trajectory_log_probs:
- trajectory_log_probs_stacked = torch.stack(trajectory_log_probs)
-
- custom = {}
- if think_text is not None:
- custom["think_text"] = think_text
-
return DiffusionOutput(
- output=img,
- trajectory_latents=trajectory_latents_stacked,
- trajectory_timesteps=trajectory_timesteps_stacked,
- trajectory_log_probs=trajectory_log_probs_stacked,
- trajectory_decoded=trajectory_decoded,
- custom_output=custom,
- stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None,
+ output=img, stage_durations=self.stage_durations if hasattr(self, "stage_durations") else None
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
@@ -850,8 +675,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
(".qkv_proj_moe_gen", ".q_proj_moe_gen"),
(".qkv_proj_moe_gen", ".k_proj_moe_gen"),
(".qkv_proj_moe_gen", ".v_proj_moe_gen"),
- (".gate_up_proj", ".gate_proj"),
- (".gate_up_proj", ".up_proj"),
]
stacked_source_names: set[str] = set()
for name in list(allowed):
diff --git a/vllm_omni/diffusion/models/diffusers_adapter/__init__.py b/vllm_omni/diffusion/models/diffusers_adapter/__init__.py
deleted file mode 100644
index c8dd51c8e7c..00000000000
--- a/vllm_omni/diffusion/models/diffusers_adapter/__init__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""Diffusers backend adapter for vLLM-Omni."""
-
-from vllm_omni.diffusion.models.diffusers_adapter.pipeline_diffusers_adapter import (
- DiffusersAdapterPipeline,
-)
-
-__all__ = [
- "DiffusersAdapterPipeline",
-]
diff --git a/vllm_omni/diffusion/models/diffusers_adapter/pipeline_diffusers_adapter.py b/vllm_omni/diffusion/models/diffusers_adapter/pipeline_diffusers_adapter.py
deleted file mode 100644
index 8a1fdfc08f3..00000000000
--- a/vllm_omni/diffusion/models/diffusers_adapter/pipeline_diffusers_adapter.py
+++ /dev/null
@@ -1,358 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Diffusers backend adapter for vLLM-Omni.
-
-Provides a black-box wrapper around any 🤗 Diffusers pipeline, enabling
-vLLM-Omni to directly serve Diffusers models with near-zero per-model code.
-
-The adapter delegates full pipeline execution to diffusers' ``__call__()``.
-It does NOT support:
-- CFG parallel (diffusers handles CFG via guidance_scale internally)
-- Sequence parallel (requires model-specific attention surgery)
-- TeaCache / Cache-DiT (requires hooking into transformer blocks)
-- Step-wise execution (continuous batching)
-"""
-
-import logging
-import os
-from typing import Any
-
-import torch
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline
-from torch import nn
-
-from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
-from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
-from vllm_omni.diffusion.request import OmniDiffusionRequest
-from vllm_omni.inputs.data import OmniPromptType
-from vllm_omni.platforms import current_omni_platform
-
-logger = logging.getLogger(__name__)
-
-
-class DiffusersAdapterPipeline(nn.Module, DiffusionPipelineProfilerMixin):
- """Black-box adapter that delegates full pipeline execution to a diffusers pipeline.
-
- Usage::
-
- adapter = DiffusersAdapterPipeline(od_config=od_config)
- adapter.load_weights() # calls DiffusionPipeline.from_pretrained()
- output = adapter.forward(req)
-
- Step-wise execution is explicitly rejected — diffusers encapsulates the
- full denoising loop internally. Use native pipelines for continuous
- batching mode.
- """
-
- supports_step_execution: bool = False
-
- def __init__(self, *, od_config: OmniDiffusionConfig, device: torch.device | None = None):
- super().__init__()
- self._pipeline: DiffusionPipeline
- self.od_config = od_config
- self.device = device
- self._capabilities: dict[str, Any] = {}
- self._raise_unsupported_features()
-
- self.setup_diffusion_pipeline_profiler(
- enable_diffusion_pipeline_profiler=od_config.enable_diffusion_pipeline_profiler,
- profiler_targets=["forward"],
- )
- if od_config.enable_diffusion_pipeline_profiler:
- logger.info("Profiling enabled for DiffusersAdapterPipeline. Only 'forward' is supported.")
-
- # ------------------------------------------------------------------
- # Weight loading
- # ------------------------------------------------------------------
-
- def load_weights(self) -> None:
- """Load the diffusers pipeline via ``DiffusionPipeline.from_pretrained()``."""
-
- model_id = self.od_config.model
- dtype = self.od_config.dtype
-
- load_kwargs = {
- "torch_dtype": dtype,
- **self.od_config.diffusers_load_kwargs,
- }
- logger.debug(f"Loading diffusers pipeline with kwargs: {load_kwargs}")
-
- self._pipeline = DiffusionPipeline.from_pretrained(
- model_id,
- **load_kwargs,
- ).to(self.device)
-
- # CPU offloading
- if self.od_config.enable_layerwise_offload:
- self._pipeline.enable_sequential_cpu_offload()
- elif self.od_config.enable_cpu_offload:
- self._pipeline.enable_model_cpu_offload()
-
- # VAE slicing and tiling: try-catch because not all models have VAE
- if self.od_config.vae_use_slicing:
- try:
- self._pipeline.enable_vae_slicing()
- except Exception as e:
- logger.warning(
- f"Failed to enable VAE slicing for diffusers pipeline {self._pipeline.__class__.__name__}: {e}"
- )
- if self.od_config.vae_use_tiling:
- try:
- self._pipeline.enable_vae_tiling()
- except Exception as e:
- logger.warning(
- f"Failed to enable VAE tiling for diffusers pipeline {self._pipeline.__class__.__name__}: {e}"
- )
-
- # Attention backend
- self._set_attention_backend()
-
- # ------------------------------------------------------------------
- # Step-wise execution — explicitly rejected
- # ------------------------------------------------------------------
-
- def prepare_encode(self, **_: Any) -> Any:
- raise NotImplementedError(
- "Step-wise execution is not yet supported with the diffusers backend. "
- "Use a native pipeline for continuous batching mode."
- )
-
- def denoise_step(self, **_: Any) -> torch.Tensor | None:
- raise NotImplementedError(
- "Step-wise execution is not yet supported with the diffusers backend. "
- "Use a native pipeline for continuous batching mode."
- )
-
- def step_scheduler(self, **_: Any) -> None:
- raise NotImplementedError(
- "Step-wise execution is not yet supported with the diffusers backend. "
- "Use a native pipeline for continuous batching mode."
- )
-
- def post_decode(self, **_: Any) -> Any:
- raise NotImplementedError(
- "Step-wise execution is not yet supported with the diffusers backend. "
- "Use a native pipeline for continuous batching mode."
- )
-
- # ------------------------------------------------------------------
- # Forward pass
- # ------------------------------------------------------------------
-
- def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
- """Full delegation to diffusers ``pipeline.__call__()``."""
-
- kwargs = self._build_call_kwargs(req)
- logger.debug(f"Calling diffusers pipeline with kwargs: {kwargs}")
-
- with torch.inference_mode():
- output = self._pipeline(**kwargs) # pyright: ignore[reportCallIssue]
-
- return self._wrap_output(output)
-
- # ------------------------------------------------------------------
- # Validation guards
- # ------------------------------------------------------------------
-
- def _raise_unsupported_features(self) -> None:
- """Raise an error for incompatible feature switches."""
- pc = self.od_config.parallel_config
- if pc.cfg_parallel_size > 1:
- raise NotImplementedError(
- "CFG parallel is not supported with the diffusers backend. "
- "Diffusers handles CFG internally via guidance_scale."
- )
- if pc.sequence_parallel_size is not None and pc.sequence_parallel_size > 1:
- raise NotImplementedError(
- "Sequence parallel is not supported with the diffusers backend. "
- "It requires model-specific attention surgery."
- )
- if self.od_config.cache_backend not in ("none", None):
- raise NotImplementedError(
- f"Cache backend '{self.od_config.cache_backend}' is not supported "
- "with the diffusers backend. TeaCache/Cache-DiT require hooking "
- "into individual transformer blocks."
- )
- if self.od_config.enforce_eager:
- raise NotImplementedError(
- "Eager execution is not supported with the diffusers backend. "
- "Use a native pipeline for continuous batching mode."
- )
- if self.od_config.quantization_config is not None:
- raise NotImplementedError(
- "Quantization is not supported with the diffusers backend. Use a native pipeline for quantization."
- )
-
- # ------------------------------------------------------------------
- # Wrap settings, inputs, and outputs
- # ------------------------------------------------------------------
-
- def _set_attention_backend(self) -> None:
- """Set the attention backend.
-
- Roughly follow the logic in vllm_omni/diffusion/attention/backends/utils/fa.py,
- But also consider the available attention backends in diffusers.
- (See: https://huggingface.co/docs/diffusers/optimization/attention_backends)
- """
- if not hasattr(self._pipeline, "transformer"):
- logging.info("No transformer found in diffusers pipeline. Skipping attention backend setting.")
- return
-
- attention_backend_config = self.od_config.attention_backend or os.environ.get("DIFFUSION_ATTENTION_BACKEND")
- attention_backend_attempts: list[str] = []
- match attention_backend_config:
- case "FLASH_ATTN" | None:
- if current_omni_platform.is_rocm():
- attention_backend_attempts.append("aiter")
- elif current_omni_platform.is_xpu():
- attention_backend_attempts.append("_native_xla")
- elif current_omni_platform.is_musa():
- logger.warning(
- "Unknown diffusers attention backend option for MUSA platform. Falling back to SDPA."
- )
- attention_backend_attempts.append("native")
- else:
- attention_backend_attempts.extend(
- [
- "_flash_3_hub",
- "_flash_3_varlen_hub",
- "_flash_3",
- "_flash_varlen_3",
- "flash_hub",
- "flash_varlen_hub",
- "flash",
- "flash_varlen",
- "_native_flash",
- ]
- )
- case "SAGE_ATTN":
- attention_backend_attempts.extend(["sage_hub", "sage", "sage", "sage_varlen"])
- case "ASCEND":
- attention_backend_attempts.append("_native_npu")
- case "TORCH_SDPA":
- attention_backend_attempts.append("native")
- case _:
- logger.warning(f"Invalid attention backend: {attention_backend_config}. Falling back to SDPA.")
- attention_backend_attempts.append("native")
-
- attempt_errors: list[str] = []
- set_backend: str | None = None
- for backend in attention_backend_attempts:
- try:
- self._pipeline.transformer.set_attention_backend(backend)
- set_backend = backend
- break
- except Exception as e:
- attempt_errors.append(str(e))
-
- # If all attempts fail, fallback to SDPA and warn the user about the failures
- if len(attempt_errors) == len(attention_backend_attempts):
- self._pipeline.transformer.set_attention_backend("native")
- logger.warning(
- f"Failed to set attention backend '{attention_backend_config}' for "
- f"diffusers pipeline {self._pipeline.__class__.__name__}. "
- "Falling back to SDPA. "
- f"The following attempts were made: {dict(zip(attention_backend_attempts, attempt_errors))}"
- )
- return
-
- # If some attempts fail, only warn the user about the failures
- logger.info(
- f"Set diffusers attention backend to '{set_backend}', adapted from "
- f"user config value '{attention_backend_config}'."
- )
- if len(attempt_errors) > 0:
- logger.warning(
- f"The following failed attempts were made before choosing this diffusers backend: "
- f"{dict(zip(attention_backend_attempts, attempt_errors))}"
- )
-
- def _build_call_kwargs(self, req: OmniDiffusionRequest) -> dict[str, Any]:
- """Translate ``OmniDiffusionRequest`` into diffusers ``__call__`` kwargs."""
- sampling = req.sampling_params
- prompt, neg_prompt = self._extract_prompt(req.prompts)
-
- # Merge user-provided call kwargs from stage/CLI defaults.
- # Request-time parameters take precedence over stage-config defaults
- call_kwargs = self.od_config.diffusers_call_kwargs
- kwargs: dict[str, Any] = {
- **call_kwargs,
- "prompt": prompt,
- "num_inference_steps": sampling.num_inference_steps,
- "guidance_scale": sampling.guidance_scale,
- "output_type": sampling.output_type or self.od_config.output_type,
- }
-
- if sampling.height is not None:
- kwargs["height"] = sampling.height
- if sampling.width is not None:
- kwargs["width"] = sampling.width
- if sampling.num_frames is not None and sampling.num_frames > 1:
- kwargs["num_frames"] = sampling.num_frames
- if sampling.num_outputs_per_prompt is not None and sampling.num_outputs_per_prompt > 1:
- kwargs["num_images_per_prompt"] = sampling.num_outputs_per_prompt
-
- if neg_prompt is not None:
- kwargs["negative_prompt"] = neg_prompt
-
- if sampling.generator is not None:
- kwargs["generator"] = sampling.generator
- elif sampling.seed is not None:
- kwargs["generator"] = torch.Generator(device=sampling.generator_device).manual_seed(sampling.seed)
- else:
- kwargs["generator"] = torch.Generator(device=sampling.generator_device)
-
- if sampling.latents is not None:
- kwargs["latents"] = sampling.latents
-
- return kwargs
-
- @staticmethod
- def _extract_prompt(prompt_obj: list[OmniPromptType]) -> tuple[str | list[str], str | list[str] | None]:
- """Extract the text prompts and negative prompts from a list of prompt objects."""
- if len(prompt_obj) == 1:
- if isinstance(prompt_obj[0], str):
- return prompt_obj[0], None
- else:
- return prompt_obj[0].get("prompt", ""), prompt_obj[0].get("negative_prompt", None)
-
- prompts = []
- negative_prompts: list[str] | None = []
- for prompt in prompt_obj:
- if isinstance(prompt, str):
- prompts.append(prompt)
- else:
- prompts.append(prompt.get("prompt", ""))
- negative_prompts.append(prompt.get("negative_prompt", ""))
- if all(not np for np in negative_prompts):
- negative_prompts = None
- return prompts, negative_prompts
-
- @staticmethod
- def _extract_negative_prompt(prompt_obj: Any) -> str | None:
- """Extract the negative prompt from a prompt object, if present."""
- if isinstance(prompt_obj, dict):
- return prompt_obj.get("negative_prompt")
- return getattr(prompt_obj, "negative_prompt", None)
-
- def _wrap_output(self, output: Any) -> DiffusionOutput:
- """Convert diffusers pipeline output to ``DiffusionOutput``.
-
- Diffusers output types:
- - ``ImagePipelineOutput(images=...)`` — text2img, img2img
- - ``VideoPipelineOutput(frames=...)`` — text2vid, img2vid
- """
- from vllm_omni.diffusion.data import DiffusionOutput
-
- if hasattr(output, "images"):
- # Preserve diffusers image format (`output_type`)
- return DiffusionOutput(output=output.images)
-
- if hasattr(output, "frames"):
- # Preserve diffusers video format (`output_type`)
- return DiffusionOutput(output=output.frames)
-
- if hasattr(output, "audios"):
- return DiffusionOutput(output=output.audios)
-
- return DiffusionOutput(output=output)
diff --git a/vllm_omni/diffusion/models/dmd2/__init__.py b/vllm_omni/diffusion/models/dmd2/__init__.py
deleted file mode 100644
index d0c8219d4d1..00000000000
--- a/vllm_omni/diffusion/models/dmd2/__init__.py
+++ /dev/null
@@ -1,8 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from vllm_omni.diffusion.models.dmd2.mixin import DMD2PipelineMixin
-
-__all__ = [
- "DMD2PipelineMixin",
-]
diff --git a/vllm_omni/diffusion/models/dmd2/mixin.py b/vllm_omni/diffusion/models/dmd2/mixin.py
deleted file mode 100644
index 60c4b95baff..00000000000
--- a/vllm_omni/diffusion/models/dmd2/mixin.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from __future__ import annotations
-
-import logging
-import os
-
-from vllm_omni.diffusion.data import DiffusionOutput
-from vllm_omni.diffusion.models.schedulers import DMD2EulerScheduler
-from vllm_omni.diffusion.models.utils import _load_json
-from vllm_omni.diffusion.request import OmniDiffusionRequest
-
-logger = logging.getLogger(__name__)
-
-
-class DMD2PipelineMixin:
- """Mixin for FastGen DMD2-distilled models. Must appear before the base pipeline in MRO."""
-
- def __init_dmd2__(self) -> None:
- """Call after super().__init__() to apply DMD2 scheduler and read model_index."""
- local_files_only = os.path.exists(self.od_config.model)
- try:
- model_index = _load_json(self.od_config.model, "model_index.json", local_files_only)
- except Exception:
- model_index = {}
-
- dmd2_timesteps = model_index.get("dmd2_denoising_timesteps", [999, 937, 833, 624])
- self.num_inference_steps = model_index.get("dmd2_num_inference_steps", 4)
- shift = model_index.get("dmd2_scheduler_shift", 1.0)
- self.dmd2_guidance_scale = model_index.get("dmd2_guidance_scale", 1.0)
-
- self.scheduler = DMD2EulerScheduler(
- num_train_timesteps=1000,
- shift=shift,
- dmd2_timesteps=dmd2_timesteps,
- )
-
- def _sanitize_dmd2_request(self, req: OmniDiffusionRequest) -> None:
- """Sanitize CFG-related fields in-place. Mutates req.sampling_params and req.prompts."""
- sp = req.sampling_params
-
- if sp.num_inference_steps and sp.num_inference_steps != self.num_inference_steps:
- logger.warning(
- "DMD2: ignoring num_inference_steps=%d, forcing %d.",
- sp.num_inference_steps,
- self.num_inference_steps,
- )
- sp.num_inference_steps = self.num_inference_steps
-
- if sp.guidance_scale_provided and sp.guidance_scale != self.dmd2_guidance_scale:
- logger.warning(
- "DMD2: ignoring guidance_scale=%.2f, forcing %.2f.",
- sp.guidance_scale,
- self.dmd2_guidance_scale,
- )
- sp.guidance_scale = self.dmd2_guidance_scale
- sp.guidance_scale_provided = False
-
- if sp.guidance_scale_2 is not None:
- logger.warning("DMD2: ignoring guidance_scale_2.")
- sp.guidance_scale_2 = None
-
- if sp.true_cfg_scale is not None:
- logger.warning("DMD2: ignoring true_cfg_scale.")
- sp.true_cfg_scale = None
-
- sp.do_classifier_free_guidance = False
- sp.is_cfg_negative = False
-
- fixed = []
- for p in req.prompts:
- if isinstance(p, dict) and "negative_prompt" in p:
- logger.warning("DMD2: ignoring negative_prompt.")
- p = {k: v for k, v in p.items() if k != "negative_prompt"}
- fixed.append(p)
- req.prompts = fixed
-
- def forward(self, req: OmniDiffusionRequest, **kwargs) -> DiffusionOutput:
- self._sanitize_dmd2_request(req)
- kwargs.pop("guidance_scale", None)
- kwargs.pop("num_inference_steps", None)
- return super().forward(
- req,
- guidance_scale=self.dmd2_guidance_scale,
- num_inference_steps=self.num_inference_steps,
- **kwargs,
- )
diff --git a/vllm_omni/diffusion/models/dreamid_omni/fusion.py b/vllm_omni/diffusion/models/dreamid_omni/fusion.py
index abca4c9474f..a534f5a76fa 100644
--- a/vllm_omni/diffusion/models/dreamid_omni/fusion.py
+++ b/vllm_omni/diffusion/models/dreamid_omni/fusion.py
@@ -1,5 +1,3 @@
-import re
-
import torch
import torch.nn as nn
from vllm.logger import init_logger
@@ -17,26 +15,78 @@
logger = init_logger(__name__)
-class FusedBlock(nn.Module):
- """Wrapper pairing a video block and audio block for layerwise offloading.
+class FusionModel(nn.Module):
+ def __init__(self, video_config=None, audio_config=None):
+ super().__init__()
+ has_video = True
+ has_audio = True
+ if video_config is not None:
+ self.video_model = WanModel(**video_config)
+ else:
+ has_video = False
+ self.video_model = None
+ logger.warning("No video model is provided!")
- Registers both blocks as submodules so their parameters are visible to the offload hooks.
- """
+ if audio_config is not None:
+ self.audio_model = WanModel(**audio_config)
+ else:
+ has_audio = False
+ self.audio_model = None
+ logger.warning("No audio model is provided!")
- def __init__(
- self,
- vid_block: nn.Module,
- audio_block: nn.Module,
- device: torch.device,
- ):
- super().__init__()
- self.vid_block = vid_block
- self.audio_block = audio_block
- self.device = device
+ if has_video and has_audio:
+ assert len(self.video_model.blocks) == len(self.audio_model.blocks)
+ self.num_blocks = len(self.video_model.blocks)
+
+ self.inject_cross_attention_kv_projections()
+ self.device = get_local_device()
+
+ self.num_heads = self.video_model.num_heads
+ self.head_dim = self.video_model.dim // self.video_model.num_heads
+ self.attn = Attention(
+ num_heads=self.num_heads,
+ head_size=self.head_dim,
+ num_kv_heads=self.num_heads,
+ softmax_scale=1.0 / (self.head_dim**0.5),
+ causal=False,
+ )
+
+ def inject_cross_attention_kv_projections(self):
+ for vid_block in self.video_model.blocks:
+ vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim)
+ vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim)
+ vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True)
+ vid_block.cross_attn.norm_k_fusion = (
+ WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity()
+ )
+
+ for audio_block in self.audio_model.blocks:
+ audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim)
+ audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim)
+ audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True)
+ audio_block.cross_attn.norm_k_fusion = (
+ WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity()
+ )
+
+ def merge_kwargs(self, vid_kwargs, audio_kwargs):
+ """
+ keys in each kwarg:
+ e
+ seq_lens
+ grid_sizes
+ freqs
+ context
+ context_lens
+ """
+ merged_kwargs = {}
+ for key in vid_kwargs:
+ merged_kwargs[f"vid_{key}"] = vid_kwargs[key]
+ for key in audio_kwargs:
+ merged_kwargs[f"audio_{key}"] = audio_kwargs[key]
+ return merged_kwargs
- def _cross_attention_forward(
+ def single_fusion_cross_attention_forward(
self,
- attn: Attention,
cross_attn_block,
src_seq,
src_grid_sizes,
@@ -54,17 +104,21 @@ def _cross_attention_forward(
):
b, n, d = src_seq.size(0), cross_attn_block.num_heads, cross_attn_block.head_dim
if hasattr(cross_attn_block, "k_img"):
+ ## means is i2v block
q, k, v, k_img, v_img = cross_attn_block.qkv_fn(src_seq, context)
else:
+ ## means is t2v block
q, k, v = cross_attn_block.qkv_fn(src_seq, context)
k_img = v_img = None
- x = attn(q, k, v)
+ x = self.attn(q, k, v)
if k_img is not None:
- img_x = attn(q, k_img, v_img)
+ img_x = self.attn(q, k_img, v_img)
x = x + img_x
+ # is_vid = src_grid_sizes.shape[1] > 1
+ # compute target attention
target_seq = cross_attn_block.pre_attn_norm_fusion(target_seq)
k_target = cross_attn_block.norm_k_fusion(cross_attn_block.k_fusion(target_seq)).view(b, -1, n, d)
v_target = cross_attn_block.v_fusion(target_seq).view(b, -1, n, d)
@@ -78,16 +132,17 @@ def _cross_attention_forward(
freqs_scaling=target_freqs_scaling,
)
- target_x = attn(q, k_target, v_target)
+ target_x = self.attn(q, k_target, v_target)
x = x + target_x
- x = x.flatten(2)
+
+ x = x.flatten(2) # [B, L/P, C]
+
x = cross_attn_block.o(x)
return x
- def _cross_attention_ffn_forward(
+ def single_fusion_cross_attention_ffn_forward(
self,
- attn: Attention,
attn_block,
src_seq,
src_grid_sizes,
@@ -104,8 +159,7 @@ def _cross_attention_ffn_forward(
target_ref_lengths=None,
target_freqs_scaling=None,
):
- src_seq = src_seq + self._cross_attention_forward(
- attn,
+ src_seq = src_seq + self.single_fusion_cross_attention_forward(
attn_block.cross_attn,
attn_block.norm3(src_seq),
src_grid_sizes=src_grid_sizes,
@@ -126,11 +180,12 @@ def _cross_attention_ffn_forward(
src_seq = src_seq + y * src_e[5].squeeze(2)
return src_seq
- def forward(
+ def single_fusion_block_forward(
self,
+ vid_block,
+ audio_block,
vid,
audio,
- attn: Attention,
vid_e,
vid_seq_lens,
vid_grid_sizes,
@@ -148,9 +203,6 @@ def forward(
audio_ref_lengths,
audio_freqs_scaling,
):
- vid_block = self.vid_block
- audio_block = self.audio_block
-
## audio modulation
assert audio_e.dtype == torch.bfloat16
assert len(audio_e.shape) == 4 and audio_e.size(2) == 6 and audio_e.shape[1] == audio.shape[1], (
@@ -194,8 +246,7 @@ def forward(
og_audio = audio
# audio cross-attention
- audio = self._cross_attention_ffn_forward(
- attn,
+ audio = self.single_fusion_cross_attention_ffn_forward(
audio_block,
audio,
audio_grid_sizes,
@@ -216,8 +267,7 @@ def forward(
assert not torch.equal(og_audio, audio), "Audio should be changed after cross-attention!"
# video cross-attention
- vid = self._cross_attention_ffn_forward(
- attn,
+ vid = self.single_fusion_cross_attention_ffn_forward(
vid_block,
vid,
vid_grid_sizes,
@@ -237,128 +287,6 @@ def forward(
return vid, audio
-
-class FusionModel(nn.Module):
- _layerwise_offload_blocks_attrs = ["fused_blocks"]
-
- def __init__(self, video_config=None, audio_config=None):
- super().__init__()
- has_video = True
- has_audio = True
- self.device = get_local_device()
- if video_config is not None:
- self.video_model = WanModel(**video_config)
- else:
- has_video = False
- self.video_model = None
- logger.warning("No video model is provided!")
-
- if audio_config is not None:
- self.audio_model = WanModel(**audio_config)
- else:
- has_audio = False
- self.audio_model = None
- logger.warning("No audio model is provided!")
-
- if has_video and has_audio:
- assert len(self.video_model.blocks) == len(self.audio_model.blocks)
- self.num_blocks = len(self.video_model.blocks)
-
- self.inject_cross_attention_kv_projections()
-
- self.num_heads = self.video_model.num_heads
- self.head_dim = self.video_model.dim // self.video_model.num_heads
- # Make a single shared instance to pass in at forward time
- self.attn = Attention(
- num_heads=self.num_heads,
- head_size=self.head_dim,
- num_kv_heads=self.num_heads,
- softmax_scale=1.0 / (self.head_dim**0.5),
- causal=False,
- )
-
- if has_video and has_audio:
- self.fused_blocks = nn.ModuleList(
- [
- FusedBlock(
- self.video_model.blocks[i],
- self.audio_model.blocks[i],
- self.device,
- )
- for i in range(self.num_blocks)
- ]
- )
-
- def load_state_dict(self, state_dict, strict=True, assign=False):
- """Remap checkpoints where blocks are stored under
- `video_model.blocks.N.*` / `audio_model.blocks.N.*` to the current
- `fused_blocks.N.vid_block.*` / `fused_blocks.N.audio_block.*`.
- """
- needs_remap = any(re.match(r"^(video_model|audio_model)\.blocks\.\d+\.", k) for k in state_dict)
- if needs_remap:
- remapped = {}
- for k, v in state_dict.items():
- new_k = re.sub(r"^video_model\.blocks\.(\d+)\.", r"fused_blocks.\1.vid_block.", k)
- new_k = re.sub(r"^audio_model\.blocks\.(\d+)\.", r"fused_blocks.\1.audio_block.", new_k)
- remapped[new_k] = v
- state_dict = remapped
-
- self._detach_blocks_from_backbones()
-
- return super().load_state_dict(state_dict, strict=strict, assign=assign)
-
- def inject_cross_attention_kv_projections(self):
- for vid_block in self.video_model.blocks:
- vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim)
- vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim)
- vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True)
- vid_block.cross_attn.norm_k_fusion = (
- WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity()
- )
-
- for audio_block in self.audio_model.blocks:
- audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim)
- audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim)
- audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True)
- audio_block.cross_attn.norm_k_fusion = (
- WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity()
- )
-
- def _detach_blocks_from_backbones(self) -> None:
- """Keep offloadable blocks owned only by a single place.
-
- NOTE: This is a special workaround to support layerwise offloading.
- The model registers the same Wan blocks under both the video/audio
- backbones and `fused_blocks` which is a wrapper for unified blocks
- walking through. However, layerwise offloading will only consider
- `fused_blocks` as offloadable components and will materialize all
- other modules onto device, including the same blocks owned by both
- `fused_blocks` and `video_model` and `audio_model`.
- """
- video_blocks = list(self.video_model.blocks)
- audio_blocks = list(self.audio_model.blocks)
- self.video_model._modules.pop("blocks", None)
- self.audio_model._modules.pop("blocks", None)
- self.video_model.blocks = tuple(video_blocks)
- self.audio_model.blocks = tuple(audio_blocks)
-
- def merge_kwargs(self, vid_kwargs, audio_kwargs):
- """
- keys in each kwarg:
- e
- seq_lens
- grid_sizes
- freqs
- context
- context_lens
- """
- merged_kwargs = {}
- for key in vid_kwargs:
- merged_kwargs[f"vid_{key}"] = vid_kwargs[key]
- for key in audio_kwargs:
- merged_kwargs[f"audio_{key}"] = audio_kwargs[key]
- return merged_kwargs
-
def forward(
self,
vid,
@@ -388,8 +316,17 @@ def forward(
kwargs = self.merge_kwargs(vid_kwargs, audio_kwargs)
- for fused_block in self.fused_blocks:
- vid, audio = fused_block(vid, audio, self.attn, **kwargs)
+ for i in range(self.num_blocks):
+ """
+ 1 fusion block refers to 1 audio block with 1 video block.
+ """
+
+ vid_block = self.video_model.blocks[i]
+ audio_block = self.audio_model.blocks[i]
+
+ vid, audio = self.single_fusion_block_forward(
+ vid_block=vid_block, audio_block=audio_block, vid=vid, audio=audio, **kwargs
+ )
vid = self.video_model.post_transformer_block_out(vid, vid_kwargs["grid_sizes"], vid_e)
audio = self.audio_model.post_transformer_block_out(audio, audio_kwargs["grid_sizes"], audio_e)
diff --git a/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py b/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py
index cc932f8c1f8..f8074fee229 100644
--- a/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py
+++ b/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py
@@ -4,7 +4,6 @@
import logging
import math
import os
-from collections.abc import Iterable
import torch
import torch.distributed
@@ -16,8 +15,12 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
+from vllm_omni.diffusion.distributed.parallel_state import (
+ get_cfg_group,
+ get_classifier_free_guidance_rank,
+ get_classifier_free_guidance_world_size,
+)
from vllm_omni.diffusion.distributed.utils import get_local_device
-from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import SupportAudioInput, SupportImageInput
from vllm_omni.diffusion.request import OmniDiffusionRequest
@@ -29,6 +32,7 @@
init_mmaudio_vae,
init_text_model,
init_wan_vae_2_2,
+ load_fusion_checkpoint,
)
from dreamid_omni.utils.rearrange import Rearrange
from dreamid_omni.utils.resize import NaResize
@@ -39,21 +43,6 @@
logger = logging.getLogger(__name__)
-def get_dreamid_omni_post_process_func(*args, **kwargs):
- def post_process(output):
- if isinstance(output, tuple) and len(output) == 2:
- video, audio = output
- return {
- "video": video,
- "audio": audio,
- "audio_sample_rate": 16000,
- "fps": 24,
- }
- return output
-
- return post_process
-
-
AUDIO_CONFIG = {
"patch_size": [1],
"model_type": "t2a",
@@ -123,23 +112,15 @@ def __init__(
self.text_model = init_text_model(model, rank=self.device)
self.text_encoder = self.text_model.model
- # Fusion model — weights are loaded later via load_weights()
- self.model = FusionModel(VIDEO_CONFIG, AUDIO_CONFIG)
- self.transformer = self.model
+ # Fusion model
+ ## load audio/video model config
+ Fusion_model = FusionModel(VIDEO_CONFIG, AUDIO_CONFIG)
- fusion_path = self.od_config.tf_model_config.get("fusion", None)
- assert fusion_path is not None, "fusion checkpoint path is None in transformer config"
- fusion_subfolder = os.path.dirname(fusion_path) or None
- fusion_filename = os.path.basename(fusion_path)
- self.weights_sources = [
- DiffusersPipelineLoader.ComponentSource(
- model_or_path=model,
- subfolder=fusion_subfolder,
- revision=None,
- prefix="model.",
- allow_patterns_overrides=[fusion_filename],
- )
- ]
+ checkpoint_path = self.od_config.model_config.get("fusion", None)
+ assert checkpoint_path is not None, "fusion checkpoint path is None"
+ load_fusion_checkpoint(Fusion_model, checkpoint_path=os.path.join(model, checkpoint_path))
+ self.model = Fusion_model
+ self.transformer = self.model
# Fixed attributes, non-configurable
self.audio_latent_channel = AUDIO_CONFIG.get("in_dim")
@@ -235,11 +216,8 @@ def load_image_latent_ref_ip_video(
return ref_vae_latents, ref_audio_lengths
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- prefix = "model."
- state_dict = {name[len(prefix) :]: tensor for name, tensor in weights if name.startswith(prefix)}
- self.model.load_state_dict(state_dict, strict=True)
- return {prefix + k for k in state_dict}
+ def load_weights(self, weights):
+ pass
def get_scheduler_time_steps(self, sampling_steps, solver_name="unipc", device=0, shift=5.0):
torch.manual_seed(4)
@@ -271,28 +249,6 @@ def get_scheduler_time_steps(self, sampling_steps, solver_name="unipc", device=0
return sample_scheduler, timesteps
- def predict_noise(self, **kwargs):
- pred_vid, pred_audio = self.model(**kwargs)
- return (pred_vid[0], pred_audio[0])
-
- def combine_multi_branch_cfg_noise(self, predictions, true_cfg_scale, cfg_normalize=False):
- vid_pos, audio_pos = predictions[0]
- vid_neg, audio_neg = predictions[1]
- vid_ip_neg, _ = predictions[2]
- _, refaudio_neg = predictions[3]
-
- pred_video = (
- vid_neg
- + true_cfg_scale["video_cfg_scale"] * (vid_pos - vid_neg)
- + true_cfg_scale["video_ref_cfg_scale"] * (vid_pos - vid_ip_neg)
- )
- pred_audio = (
- audio_neg
- + true_cfg_scale["audio_cfg_scale"] * (audio_pos - audio_neg)
- + true_cfg_scale["audio_ref_cfg_scale"] * (audio_pos - refaudio_neg)
- )
- return (pred_video, pred_audio)
-
def diffuse(
self,
video_noise: torch.Tensor,
@@ -350,22 +306,72 @@ def diffuse(
"vid_context": [text_embeddings_video_neg],
}
- branches_kwargs = [
- {"vid": [model_input_video], "audio": [model_input_audio], "t": timestep_input, **pos_args},
- {"vid": [model_input_video], "audio": [model_input_audio], "t": timestep_input, **neg_args},
- {"vid": [model_input_video_neg], "audio": [model_input_audio], "t": timestep_input, **pos_args},
- {"vid": [model_input_video], "audio": [model_input_audio_neg], "t": timestep_input, **pos_args},
- ]
-
- pred_video_guided, pred_audio_guided = self.predict_noise_with_multi_branch_cfg(
- do_true_cfg=True,
- true_cfg_scale={
- "video_cfg_scale": self.video_cfg_scale,
- "video_ref_cfg_scale": self.video_ref_cfg_scale,
- "audio_cfg_scale": self.audio_cfg_scale,
- "audio_ref_cfg_scale": self.audio_ref_cfg_scale,
- },
- branches_kwargs=branches_kwargs,
+ if get_classifier_free_guidance_world_size() > 1:
+ # Enable CFG-parallel: rank0 computes positive, rank1 computes negative.
+ cfg_group = get_cfg_group()
+ cfg_rank = get_classifier_free_guidance_rank()
+
+ if cfg_rank == 0:
+ pred_vid, pred_audio = self.model(
+ vid=[model_input_video], audio=[model_input_audio], t=timestep_input, **pos_args
+ )
+ pre_vid_ip_neg, _ = self.model(
+ vid=[model_input_video_neg], audio=[model_input_audio], t=timestep_input, **pos_args
+ )
+ pred_vid_0 = pred_vid[0]
+ pred_audio_0 = pred_audio[0]
+ pre_vid_ip_0 = pre_vid_ip_neg[0]
+ pred_refaudio_0 = torch.zeros_like(pred_audio_0) # dummy tensor
+ else:
+ pred_vid, pred_audio = self.model(
+ vid=[model_input_video], audio=[model_input_audio], t=timestep_input, **neg_args
+ )
+ _, pred_refaudio_neg = self.model(
+ vid=[model_input_video], audio=[model_input_audio_neg], t=timestep_input, **pos_args
+ )
+ pred_vid_0 = pred_vid[0]
+ pred_audio_0 = pred_audio[0]
+ pre_vid_ip_0 = torch.zeros_like(pred_vid_0) # dummy tensor
+ pred_refaudio_0 = pred_refaudio_neg[0]
+
+ pred_vid_gathered = cfg_group.all_gather(pred_vid_0, separate_tensors=True)
+ pred_audio_gathered = cfg_group.all_gather(pred_audio_0, separate_tensors=True)
+ pre_vid_ip_gathered = cfg_group.all_gather(pre_vid_ip_0, separate_tensors=True)
+ pred_refaudio_gathered = cfg_group.all_gather(pred_refaudio_0, separate_tensors=True)
+
+ pred_vid_pos = [pred_vid_gathered[0]]
+ pred_vid_neg = [pred_vid_gathered[1]]
+ pred_audio_pos = [pred_audio_gathered[0]]
+ pred_audio_neg = [pred_audio_gathered[1]]
+ pre_vid_ip_neg = [pre_vid_ip_gathered[0]]
+ pred_refaudio_neg = [pred_refaudio_gathered[1]]
+ else:
+ pred_vid_pos, pred_audio_pos = self.model(
+ vid=[model_input_video], audio=[model_input_audio], t=timestep_input, **pos_args
+ )
+
+ pred_vid_neg, pred_audio_neg = self.model(
+ vid=[model_input_video], audio=[model_input_audio], t=timestep_input, **neg_args
+ )
+
+ pre_vid_ip_neg, _ = self.model(
+ vid=[model_input_video_neg], audio=[model_input_audio], t=timestep_input, **pos_args
+ )
+
+ _, pred_refaudio_neg = self.model(
+ vid=[model_input_video], audio=[model_input_audio_neg], t=timestep_input, **pos_args
+ )
+
+ pred_video_guided = (
+ pred_vid_neg[0]
+ + self.video_cfg_scale * (pred_vid_pos[0] - pred_vid_neg[0])
+ + self.video_ref_cfg_scale * (pred_vid_pos[0] - pre_vid_ip_neg[0])
+ )
+
+ pred_audio_guided = (
+ pred_audio_neg[0]
+ + self.audio_cfg_scale * (pred_audio_pos[0] - pred_audio_neg[0])
+ + self.audio_ref_cfg_scale * (pred_audio_pos[0] - pred_refaudio_neg[0])
)
video_noise = scheduler_video.step(
pred_video_guided.unsqueeze(0), t_v, video_noise.unsqueeze(0), return_dict=False
diff --git a/vllm_omni/diffusion/models/flux/flux_transformer.py b/vllm_omni/diffusion/models/flux/flux_transformer.py
index dd88ee76c15..362fb4446fc 100644
--- a/vllm_omni/diffusion/models/flux/flux_transformer.py
+++ b/vllm_omni/diffusion/models/flux/flux_transformer.py
@@ -381,9 +381,7 @@ def __init__(
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
- # Modulation linear kept full precision; shift/scale/gate outputs
- # are multiplied into the residual stream every block (see #2728).
- self.norm = AdaLayerNormZeroSingle(dim, quant_config=None, prefix=f"{prefix}.norm")
+ self.norm = AdaLayerNormZeroSingle(dim, quant_config=quant_config, prefix=f"{prefix}.norm")
self.proj_mlp = ReplicatedLinear(
dim,
self.mlp_hidden_dim,
@@ -512,7 +510,6 @@ class FluxTransformer2DModel(nn.Module):
# -- typically a transformer layer
# used for torch compile optimizations
_repeated_blocks = ["FluxTransformerBlock"]
- _layerwise_offload_blocks_attrs = ["transformer_blocks", "single_transformer_blocks"]
@staticmethod
def _is_transformer_block(name: str, module) -> bool:
@@ -526,10 +523,10 @@ def _is_transformer_block(name: str, module) -> bool:
def __init__(
self,
- od_config: OmniDiffusionConfig | None = None,
+ od_config: OmniDiffusionConfig = None,
patch_size: int = 1,
in_channels: int = 64,
- out_channels: int | None = None,
+ out_channels: int = None,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
@@ -565,16 +562,13 @@ def __init__(
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
- # Dual-stream blocks kept full precision — FP8 on their joint
- # attention path causes noise on FLUX (#2728). Single-stream
- # blocks (38 vs 19) still get FP8 for memory savings.
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
- quant_config=None,
+ quant_config=quant_config,
prefix=f"transformer_blocks.{i}",
)
for i in range(num_layers)
@@ -594,13 +588,12 @@ def __init__(
]
)
- # Final modulation feeds proj_out; keep full precision (see #2728).
self.norm_out = AdaLayerNormContinuous(
self.inner_dim,
self.inner_dim,
elementwise_affine=False,
eps=1e-6,
- quant_config=None,
+ quant_config=quant_config,
prefix="norm_out",
)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux.py b/vllm_omni/diffusion/models/flux/pipeline_flux.py
index 70d572d9a65..6f43e8dbb58 100644
--- a/vllm_omni/diffusion/models/flux/pipeline_flux.py
+++ b/vllm_omni/diffusion/models/flux/pipeline_flux.py
@@ -30,7 +30,6 @@
from vllm_omni.diffusion.models.t5_encoder import T5EncoderModel
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
-from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific
logger = logging.getLogger(__name__)
@@ -107,11 +106,7 @@ def __init__(
self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
self.device
)
-
- transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, FluxTransformer2DModel)
- self.transformer = FluxTransformer2DModel(
- **transformer_kwargs, od_config=od_config, quant_config=od_config.quantization_config
- )
+ self.transformer = FluxTransformer2DModel(od_config=od_config, quant_config=od_config.quantization_config)
self.tokenizer = CLIPTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
self.tokenizer_2 = T5TokenizerFast.from_pretrained(
diff --git a/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py b/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py
index c3bea7dd1c4..3232b436d60 100644
--- a/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py
+++ b/vllm_omni/diffusion/models/flux/pipeline_flux_kontext.py
@@ -31,8 +31,6 @@
)
from vllm_omni.diffusion.models.flux.flux_pipeline_mixin import FluxPipelineMixin
from vllm_omni.diffusion.models.interface import SupportImageInput
-from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
-from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.logger import init_logger
@@ -69,9 +67,7 @@ def post_process_func(images: torch.Tensor) -> list[PIL.Image.Image]:
return post_process_func
-class FluxKontextPipeline(
- nn.Module, FluxPipelineMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin
-):
+class FluxKontextPipeline(nn.Module, FluxPipelineMixin, SupportImageInput):
"""FLUX.1-Kontext pipeline for image editing with text guidance."""
support_image_input = True
@@ -152,10 +148,6 @@ def __init__(
self._callback_tensor_inputs = ["latents", "prompt_embeds"]
self.latent_channels = self.vae.config.latent_channels if hasattr(self.vae, "config") else 16
- self.setup_diffusion_pipeline_profiler(
- enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler
- )
-
def _get_t5_prompt_embeds(
self,
prompt: str | list[str] = None,
@@ -643,56 +635,58 @@ def forward(
# 5. Denoising loop
self.scheduler.set_begin_index(0)
- with self.progress_bar(total=len(timesteps)) as pbar:
- for i, t in enumerate(timesteps):
- if self.interrupt:
- continue
-
- latent_model_input = latents
- if image_latents is not None:
- latent_model_input = torch.cat([latents, image_latents], dim=1)
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
-
- noise_pred = self.transformer(
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = latents
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+
+ if do_true_cfg:
+ neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
- pooled_projections=pooled_prompt_embeds,
- encoder_hidden_states=prompt_embeds,
- txt_ids=text_ids,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
- noise_pred = noise_pred[:, : latents.size(1)]
-
- if do_true_cfg:
- neg_noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep / 1000,
- guidance=guidance,
- pooled_projections=negative_pooled_prompt_embeds,
- encoder_hidden_states=negative_prompt_embeds,
- txt_ids=negative_text_ids,
- img_ids=latent_ids,
- joint_attention_kwargs=self.joint_attention_kwargs,
- return_dict=False,
- )[0]
- neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
- noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
-
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
-
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
-
- pbar.update()
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
if output_type == "latent":
image = latents
else:
diff --git a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py
index 404f05b606c..c5bf9b77d9e 100644
--- a/vllm_omni/diffusion/models/flux2/pipeline_flux2.py
+++ b/vllm_omni/diffusion/models/flux2/pipeline_flux2.py
@@ -25,14 +25,10 @@
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
-from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
-from vllm_omni.diffusion.distributed.parallel_state import get_classifier_free_guidance_world_size
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.flux2 import Flux2Transformer2DModel
from vllm_omni.diffusion.models.interface import SupportImageInput
-from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
-from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific
@@ -335,7 +331,7 @@ def retrieve_latents(encoder_output: torch.Tensor, generator: torch.Generator =
raise AttributeError("Could not access latents of provided encoder_output")
-class Flux2Pipeline(nn.Module, CFGParallelMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin):
+class Flux2Pipeline(nn.Module, SupportImageInput):
"""Flux2 pipeline for text-to-image generation."""
_callback_tensor_inputs = ["latents", "prompt_embeds"]
@@ -393,10 +389,6 @@ def __init__(
self._guidance_scale = None
self._attention_kwargs = None
self._num_timesteps = None
-
- self.setup_diffusion_pipeline_profiler(
- enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler
- )
self._current_timestep = None
self._interrupt = False
@@ -856,21 +848,6 @@ def current_timestep(self):
def interrupt(self):
return self._interrupt
- def check_cfg_parallel_validity(self, true_cfg_scale: float, has_neg_prompt: bool):
- if get_classifier_free_guidance_world_size() == 1:
- return True
-
- if true_cfg_scale <= 1:
- logger.warning("CFG parallel is NOT working correctly when true_cfg_scale <= 1.")
- return False
-
- if not has_neg_prompt:
- logger.warning(
- "CFG parallel is NOT working correctly when there is no negative prompt or negative prompt embeddings."
- )
- return False
- return True
-
def forward(
self,
req: OmniDiffusionRequest,
@@ -938,14 +915,6 @@ def forward(
# And `torch.stack` automatically raises an exception for us
prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore # intentionally expect TypeError
- req_negative_prompt_embeds = [
- p.get("negative_prompt_embeds") if not isinstance(p, str) else None for p in req.prompts
- ]
- if all(p is not None for p in req_negative_prompt_embeds):
- negative_prompt_embeds = torch.stack(req_negative_prompt_embeds) # type: ignore # intentionally expect TypeError
-
- req_negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts]
-
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
@@ -983,22 +952,6 @@ def forward(
text_encoder_out_layers=text_encoder_out_layers,
)
- has_neg_prompt = negative_prompt_embeds is not None or any(req_negative_prompt)
- do_true_cfg = self.guidance_scale > 1 and has_neg_prompt
-
- self.check_cfg_parallel_validity(self.guidance_scale, has_neg_prompt)
- negative_text_ids = None
- if do_true_cfg:
- negative_prompt = req_negative_prompt
- negative_prompt_embeds, negative_text_ids = self.encode_prompt(
- prompt=negative_prompt,
- prompt_embeds=negative_prompt_embeds,
- device=device,
- num_images_per_prompt=num_images_per_prompt,
- max_sequence_length=max_sequence_length,
- text_encoder_out_layers=text_encoder_out_layers,
- )
-
# 4. process images
if image is not None and not isinstance(image, list):
image = [image]
@@ -1070,74 +1023,52 @@ def forward(
guidance_tensor = torch.full([1], self.guidance_scale, device=device, dtype=torch.float32)
guidance_tensor = guidance_tensor.expand(latents.shape[0])
- # For editing pipelines, we need to slice the output to remove condition latents
- output_slice = latents.size(1) if image_latents is not None else None
-
# 7. Denoising loop
# We set the index here to remove DtoH sync, helpful especially during compilation.
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
self.scheduler.set_begin_index(0)
- with self.progress_bar(total=len(timesteps)) as pbar:
- for i, t in enumerate(timesteps):
- if self.interrupt:
- continue
-
- self._current_timestep = t
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
-
- latent_model_input = latents.to(self.transformer.dtype)
- latent_image_ids = latent_ids
-
- if image_latents is not None:
- latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
- latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
-
- positive_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep / 1000,
- "guidance": guidance_tensor,
- "encoder_hidden_states": prompt_embeds,
- "txt_ids": text_ids,
- "img_ids": latent_image_ids,
- "joint_attention_kwargs": self.attention_kwargs,
- "return_dict": False,
- }
- if do_true_cfg:
- negative_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep / 1000,
- "guidance": guidance_tensor,
- "encoder_hidden_states": negative_prompt_embeds,
- "txt_ids": negative_text_ids,
- "img_ids": latent_image_ids,
- "joint_attention_kwargs": self.attention_kwargs,
- "return_dict": False,
- }
- else:
- negative_kwargs = None
-
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale=self.guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=False,
- output_slice=output_slice,
- )
-
- # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync
- latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
-
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
-
- pbar.update()
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ latent_model_input = latents.to(self.transformer.dtype)
+ latent_image_ids = latent_ids
+
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
+ latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input, # (B, image_seq_len, C)
+ timestep=timestep / 1000,
+ guidance=guidance_tensor,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids, # B, text_seq_len, 4
+ img_ids=latent_image_ids, # B, image_seq_len, 4
+ joint_attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ noise_pred = noise_pred[:, : latents.size(1) :]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype and torch.backends.mps.is_available():
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
self._current_timestep = None
diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py
index 9cf2fb7568b..1d375ca8d2e 100644
--- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py
+++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py
@@ -742,7 +742,6 @@ class Flux2Transformer2DModel(nn.Module):
"""
_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
- _layerwise_offload_blocks_attrs = ["transformer_blocks", "single_transformer_blocks"]
@staticmethod
def _is_transformer_block(name: str, module) -> bool:
diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py
index 2da0038bb48..437dd58d0c4 100644
--- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py
+++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py
@@ -234,15 +234,7 @@ def __init__(
self.transformer = Flux2Transformer2DModel(quant_config=od_config.quantization_config, **transformer_kwargs)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
- self.latent_channels = self.vae.config.latent_channels if hasattr(self.vae, "config") else 16
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
- self.mask_processor = VaeImageProcessor(
- vae_scale_factor=self.vae_scale_factor * 2,
- vae_latent_channels=self.latent_channels,
- do_normalize=False,
- do_binarize=True,
- do_convert_grayscale=True,
- )
self.tokenizer_max_length = 512
self.default_sample_size = 128
@@ -255,14 +247,6 @@ def __init__(
enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler
)
- def get_timesteps(self, num_inference_steps, strength, device):
- init_timestep = min(num_inference_steps * strength, num_inference_steps)
- t_start = int(max(num_inference_steps - init_timestep, 0))
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
- if hasattr(self.scheduler, "set_begin_index"):
- self.scheduler.set_begin_index(t_start * self.scheduler.order)
- return timesteps, num_inference_steps - t_start
-
@staticmethod
def _get_qwen3_prompt_embeds(
text_encoder: Qwen3ForCausalLM,
@@ -598,54 +582,6 @@ def prepare_image_latents(
return image_latents, image_latent_ids
- def prepare_mask_latents(
- self,
- mask,
- masked_image,
- batch_size,
- num_channels_latents,
- num_images_per_prompt,
- height,
- width,
- dtype,
- device,
- generator,
- ):
- height = 2 * (int(height) // (self.vae_scale_factor * 2))
- width = 2 * (int(width) // (self.vae_scale_factor * 2))
- mask = torch.nn.functional.interpolate(mask, size=(height, width))
- mask = mask.to(device=device, dtype=dtype)
-
- batch_size = batch_size * num_images_per_prompt
-
- if masked_image is not None:
- masked_image = masked_image.to(device=device, dtype=dtype)
- if masked_image.shape[1] != num_channels_latents:
- masked_image_latents = self._encode_vae_image(image=masked_image, generator=generator)
- else:
- masked_image_latents = masked_image
- else:
- masked_image_latents = None
-
- if mask.shape[0] < batch_size:
- if not batch_size % mask.shape[0] == 0:
- raise ValueError("The passed mask and the required batch size don't match.")
- mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
- if masked_image_latents is not None and masked_image_latents.shape[0] < batch_size:
- if not batch_size % masked_image_latents.shape[0] == 0:
- raise ValueError("The passed mask and the required batch size don't match.")
- masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
-
- if masked_image_latents is not None:
- masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
- masked_image_latents = self._pack_latents(masked_image_latents)
-
- mask = mask.repeat(1, self.latent_channels, 1, 1)
- mask = self._patchify_latents(mask)
- mask = self._pack_latents(mask)
-
- return mask, masked_image_latents
-
def check_inputs(
self,
prompt,
@@ -654,15 +590,7 @@ def check_inputs(
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
guidance_scale=None,
- strength=None,
- num_inference_steps=None,
):
- if strength is not None:
- if strength < 0 or strength > 1:
- raise ValueError(f"strength must be between 0 and 1, got {strength}")
- if num_inference_steps is not None and num_inference_steps <= 0:
- raise ValueError(f"num_inference_steps must be positive, got {num_inference_steps}")
-
if (
height is not None
and height % (self.vae_scale_factor * 2) != 0
@@ -694,7 +622,7 @@ def check_inputs(
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
- if guidance_scale is not None and guidance_scale > 1.0 and self.is_distilled:
+ if guidance_scale > 1.0 and self.is_distilled:
logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.")
@property
@@ -725,14 +653,11 @@ def forward(
self,
req: OmniDiffusionRequest,
image: PIL.Image.Image | list[PIL.Image.Image] | None = None,
- reference_image: PIL.Image.Image | list[PIL.Image.Image] | None = None,
- mask_image: PIL.Image.Image | list[PIL.Image.Image] | None = None,
prompt: str | list[str] | None = None,
height: int | None = None,
width: int | None = None,
num_inference_steps: int = 50,
sigmas: list[float] | None = None,
- strength: float = 1.0,
guidance_scale: float | None = 4.0,
num_images_per_prompt: int = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
@@ -746,7 +671,6 @@ def forward(
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
max_sequence_length: int = 512,
text_encoder_out_layers: tuple[int, ...] = (9, 18, 27),
- padding_mask_crop: int | None = None,
) -> DiffusionOutput:
r"""
Function invoked when calling the pipeline for generation.
@@ -829,21 +753,14 @@ def forward(
"""Taking only the first image for now.""",
)
first_prompt = req.prompts[0]
- if isinstance(first_prompt, str):
- multi_modal_data = {}
- prompt = first_prompt
- raw_image = None
- mask_image = None
- reference_image = None
- else:
- multi_modal_data = first_prompt.get("multi_modal_data", {})
- prompt = first_prompt.get("prompt") or ""
- raw_image = multi_modal_data.get("image")
- mask_image = multi_modal_data.get("mask_image")
- reference_image = multi_modal_data.get("reference_image")
-
- if raw_image is None:
- image = None
+ prompt = first_prompt if isinstance(first_prompt, str) else (first_prompt.get("prompt") or "")
+
+ if (
+ raw_image := None
+ if isinstance(first_prompt, str)
+ else first_prompt.get("multi_modal_data", {}).get("image")
+ ) is None:
+ pass # use image from param list
elif isinstance(raw_image, list):
image = [PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image, im) for im in raw_image]
else:
@@ -887,8 +804,6 @@ def forward(
prompt_embeds=prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
guidance_scale=guidance_scale,
- strength=strength,
- num_inference_steps=num_inference_steps,
)
self._guidance_scale = guidance_scale
@@ -933,9 +848,6 @@ def forward(
if image is not None and not isinstance(image, list):
image = [image]
- multiple_of = self.vae_scale_factor * 2
- crops_coords = None
- resize_mode = "crop"
condition_images = None
if image is not None:
for img in image:
@@ -948,14 +860,10 @@ def forward(
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
image_width, image_height = img.size
+ multiple_of = self.vae_scale_factor * 2
image_width = (image_width // multiple_of) * multiple_of
image_height = (image_height // multiple_of) * multiple_of
- if padding_mask_crop is not None:
- crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
- resize_mode = "fill"
- img = self.image_processor.preprocess(
- img, height=image_height, width=image_width, crops_coords=crops_coords, resize_mode=resize_mode
- )
+ img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
condition_images.append(img)
height = height or image_height
width = width or image_width
@@ -963,11 +871,6 @@ def forward(
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
- if mask_image is not None and (image is None or (isinstance(image, list) and len(image) == 0)):
- raise ValueError("image must be provided when using mask_image for inpainting")
-
- init_image = condition_images[0] if condition_images else None
-
# 5. prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_ids = self.prepare_latents(
@@ -981,8 +884,6 @@ def forward(
latents=latents,
)
- original_latent_ids = latent_ids
-
image_latents = None
image_latent_ids = None
if condition_images is not None:
@@ -994,71 +895,6 @@ def forward(
dtype=self.vae.dtype,
)
- # Preprocess reference_image
- if reference_image is not None and not (
- isinstance(reference_image, torch.Tensor) and reference_image.size(1) == self.latent_channels
- ):
- if (
- isinstance(reference_image, list)
- and isinstance(reference_image[0], torch.Tensor)
- and reference_image[0].ndim == 4
- ):
- reference_image = torch.cat(reference_image, dim=0)
- img_reference = reference_image[0] if isinstance(reference_image, list) else reference_image
- reference_image_height, reference_image_width = self.image_processor.get_default_height_width(img_reference)
-
- reference_image_width = reference_image_width // multiple_of * multiple_of
- reference_image_height = reference_image_height // multiple_of * multiple_of
- reference_image = self.image_processor.resize(
- reference_image, reference_image_height, reference_image_width
- )
- reference_image = self.image_processor.preprocess(
- reference_image,
- reference_image_height,
- reference_image_width,
- crops_coords=crops_coords,
- resize_mode=resize_mode,
- )
- else:
- pass
-
- reference_image_latents = None
- reference_image_latent_ids = None
- if reference_image is not None:
- reference_image_latents, reference_image_latent_ids = self.prepare_image_latents(
- images=[reference_image],
- batch_size=batch_size * num_images_per_prompt,
- generator=generator,
- device=device,
- dtype=self.vae.dtype,
- )
-
- if reference_image_latent_ids is not None:
- latent_ids = torch.cat([latent_ids, reference_image_latent_ids], dim=1)
- elif image_latent_ids is not None:
- latent_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
-
- mask = None
- if mask_image is not None:
- mask_condition = self.mask_processor.preprocess(
- mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
- )
- if init_image is not None:
- masked_image = init_image * (mask_condition < 0.5).to(init_image.dtype)
- else:
- masked_image = None
- mask, _ = self.prepare_mask_latents(
- mask_condition,
- masked_image,
- batch_size,
- self.latent_channels,
- num_images_per_prompt,
- height,
- width,
- prompt_embeds.dtype,
- device,
- generator,
- )
# 6. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
@@ -1072,13 +908,6 @@ def forward(
sigmas=sigmas,
mu=mu,
)
- if reference_image is not None or mask_image is not None:
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
- if num_inference_steps < 1:
- raise ValueError(
- f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
- f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
- )
self._num_timesteps = len(timesteps)
# 7. Denoising loop
@@ -1095,10 +924,7 @@ def forward(
latent_model_input = latents.to(self.transformer.dtype)
latent_image_ids = latent_ids
- if reference_image_latents is not None:
- latent_model_input = torch.cat([latents, reference_image_latents], dim=1)
- latent_image_ids = latent_ids
- elif image_latents is not None:
+ if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
@@ -1127,9 +953,7 @@ def forward(
negative_kwargs = None
# For editing pipelines, we need to slice the output to remove condition latents
- output_slice = (
- latents.size(1) if (image_latents is not None or reference_image_latents is not None) else None
- )
+ output_slice = latents.size(1) if image_latents is not None else None
noise_pred = self.predict_noise_maybe_with_cfg(
do_true_cfg=self.do_classifier_free_guidance,
@@ -1140,22 +964,9 @@ def forward(
output_slice=output_slice,
)
- latents_dtype = latents.dtype
# Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync
latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, self.do_classifier_free_guidance)
- if mask is not None and image_latents is not None:
- init_latents_proper = image_latents
- if i < len(timesteps) - 1:
- noise_timestep = timesteps[i + 1]
- init_latents_proper = self.scheduler.scale_noise(
- init_latents_proper, torch.tensor([noise_timestep], device=device), latents
- )
- latents = (1 - mask) * init_latents_proper + mask * latents
-
- if latents.dtype != latents_dtype and torch.backends.mps.is_available():
- latents = latents.to(latents_dtype)
-
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
@@ -1167,7 +978,7 @@ def forward(
self._current_timestep = None
- latents = self._unpack_latents_with_ids(latents, original_latent_ids)
+ latents = self._unpack_latents_with_ids(latents, latent_ids)
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
diff --git a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py
index ddb32aa2025..490e0198b93 100644
--- a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py
+++ b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py
@@ -4,7 +4,7 @@
import math
from collections.abc import Iterable
from enum import Enum
-from typing import TYPE_CHECKING, Any
+from typing import Any
import torch
import torch.nn as nn
@@ -19,19 +19,10 @@
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-if TYPE_CHECKING:
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
-
-from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.diffusion.cache.base import CachedTransformer
-from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig
+from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.distributed.hsdp_utils import is_transformer_block_module
-from vllm_omni.diffusion.distributed.sp_plan import (
- SequenceParallelInput,
- SequenceParallelOutput,
-)
-from vllm_omni.diffusion.forward_context import get_forward_context
logger = init_logger(__name__)
@@ -117,8 +108,8 @@ def __init__(
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, channel, height, width = hidden_states.shape
- post_patch_height = torch.tensor(height // self.patch_size, device=hidden_states.device, dtype=torch.int64)
- post_patch_width = torch.tensor(width // self.patch_size, device=hidden_states.device, dtype=torch.int64)
+ post_patch_height = height // self.patch_size
+ post_patch_width = width // self.patch_size
# Reshape: [B, C, H, W] -> [B, H', W', C*p*p] -> [B, H'*W', C*p*p]
hidden_states = hidden_states.reshape(
@@ -168,65 +159,6 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
return (freqs.cos(), freqs.sin())
-class GlmImagePrepare(nn.Module):
- """Prepare module for GLM-Image that handles patch embedding and RoPE computation.
-
- This module encapsulates the input processing pipeline to create a module boundary
- where _sp_plan can shard outputs via split_output=True.
-
- Similar to Qwen-Image's ImageRopePrepare, this ensures hidden_states and RoPE
- embeddings are sharded together to maintain dimension alignment.
- """
-
- def __init__(
- self,
- image_projector: nn.Module,
- rope: GlmImageRotaryPosEmbed,
- patch_size: int,
- ):
- super().__init__()
- self.image_projector = image_projector
- self.rope = rope
- self.patch_size = patch_size
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- prior_hidden_states: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- """Process hidden_states and compute RoPE embeddings.
-
- Args:
- hidden_states: Input latent tensor [B, C, H, W]
- prior_hidden_states: Optional prior embedding to add
-
- Returns:
- hidden_states: Patched hidden states [B, seq_len, D]
- rope_cos: RoPE cos embeddings [seq_len, dim]
- rope_sin: RoPE sin embeddings [seq_len, dim]
- post_patch_height: Scalar tensor for height after patching
- post_patch_width: Scalar tensor for width after patching
- """
- batch_size, num_channels, height, width = hidden_states.shape
-
- post_patch_height = torch.tensor(height // self.patch_size, device=hidden_states.device, dtype=torch.int64)
- post_patch_width = torch.tensor(width // self.patch_size, device=hidden_states.device, dtype=torch.int64)
-
- # Compute RoPE (uses original 4D hidden_states shape)
- image_rotary_emb = self.rope(hidden_states)
- rope_cos = image_rotary_emb[0].to(hidden_states.device)
- rope_sin = image_rotary_emb[1].to(hidden_states.device)
-
- # Patch embedding: [B, C, H, W] -> [B, seq_len, D]
- hidden_states = self.image_projector(hidden_states)
-
- # Add prior embedding if provided
- if prior_hidden_states is not None:
- hidden_states = hidden_states + prior_hidden_states
-
- return hidden_states, rope_cos, rope_sin, post_patch_height, post_patch_width
-
-
class GlmImageAdaLayerNormZero(nn.Module):
"""Adaptive LayerNorm with zero initialization for both image and text streams."""
@@ -465,16 +397,13 @@ def __init__(
dim: int,
num_heads: int,
head_dim: int,
- parallel_config: DiffusionParallelConfig | None = None,
out_bias: bool = True,
eps: float = 1e-5,
- quant_config: "QuantizationConfig | None" = None,
):
super().__init__()
self.dim = dim
self.total_num_heads = num_heads
self.head_dim = head_dim
- self.parallel_config = parallel_config
# QKV projection (fused for efficiency)
self.to_qkv = QKVParallelLinear(
@@ -484,7 +413,6 @@ def __init__(
total_num_kv_heads=num_heads,
bias=True,
return_bias=False,
- quant_config=quant_config,
)
# QK normalization (LayerNorm, not RMSNorm for GLM-Image)
@@ -500,7 +428,6 @@ def __init__(
bias=out_bias,
input_is_parallel=True,
return_bias=False,
- quant_config=quant_config,
),
nn.Dropout(0.0),
]
@@ -523,19 +450,16 @@ def forward(
attention_mask: torch.Tensor | None = None,
kv_cache: GlmImageLayerKVCache | None = None,
kv_cache_mode: KVCacheMode | None = None,
- hidden_states_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for joint attention.
Args:
- hidden_states: Image hidden states [B, img_seq_len, D] (sharded in SP mode)
- encoder_hidden_states: Text hidden states [B, text_seq_len, D] (full in SP mode)
- image_rotary_emb: Tuple of (cos, sin) for RoPE (sharded in SP mode)
- attention_mask: Optional attention mask
+ hidden_states: Image hidden states [B, img_seq_len, D]
+ encoder_hidden_states: Text hidden states [B, text_seq_len, D]
+ image_rotary_emb: Tuple of (cos, sin) for RoPE
kv_cache: Optional layer KV cache for image editing
kv_cache_mode: Cache mode (WRITE, READ, SKIP)
- hidden_states_mask: Mask for SP padding (True=valid, False=padding)
Returns:
Tuple of (image_hidden_states, text_hidden_states)
@@ -543,13 +467,6 @@ def forward(
dtype = encoder_hidden_states.dtype
batch_size, text_seq_length, _ = encoder_hidden_states.shape
- # Check if SP is enabled
- sp_size = self.parallel_config.sequence_parallel_size if self.parallel_config else None
- use_sp = sp_size is not None and sp_size > 1
- if use_sp:
- forward_ctx = get_forward_context()
- use_sp = not forward_ctx.split_text_embed_in_sp
-
# Concatenate text and image: [text, image]
hidden_states_combined = torch.cat([encoder_hidden_states, hidden_states], dim=1)
@@ -568,88 +485,41 @@ def forward(
query = self.norm_q(query).to(dtype=dtype)
key = self.norm_k(key).to(dtype=dtype)
- if use_sp:
- # SP mode: use joint attention mechanism
- # Split Q/K/V into text and image parts
- text_query = query[:, :text_seq_length, :, :]
- text_key = key[:, :text_seq_length, :, :]
- text_value = value[:, :text_seq_length, :, :]
- img_query = query[:, text_seq_length:, :, :]
- img_key = key[:, text_seq_length:, :, :]
- img_value = value[:, text_seq_length:, :, :]
-
- # Apply RoPE only to image part
- if image_rotary_emb is not None:
- from diffusers.models.embeddings import apply_rotary_emb
-
- img_query = apply_rotary_emb(img_query, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
- img_key = apply_rotary_emb(img_key, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
-
- # Create attention metadata for joint attention
- attn_metadata = AttentionMetadata(
- joint_query=text_query,
- joint_key=text_key,
- joint_value=text_value,
- joint_strategy="front",
- )
-
- # Add padding mask for SP if available
- if hidden_states_mask is not None:
- attn_metadata.attn_mask = hidden_states_mask
-
- # Attention computation with joint text/image
- # Note: Ulysses post_attention returns [text, image] concatenated
- joint_hidden_states_out = self.attn(img_query, img_key, img_value, attn_metadata)
+ # Apply RoPE only to image tokens (not text tokens)
+ if image_rotary_emb is not None:
+ # Only apply RoPE to image part (after text_seq_length)
+ query_img = query[:, text_seq_length:, :, :]
+ key_img = key[:, text_seq_length:, :, :]
+ from diffusers.models.embeddings import apply_rotary_emb
+
+ query_img = apply_rotary_emb(query_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
+ key_img = apply_rotary_emb(key_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
+ query = torch.cat([query[:, :text_seq_length, :, :], query_img], dim=1)
+ key = torch.cat([key[:, :text_seq_length, :, :], key_img], dim=1)
+
+ # Handle KV cache for image editing
+ if kv_cache is not None and kv_cache_mode is not None:
+ if kv_cache_mode == KVCacheMode.WRITE:
+ kv_cache.store(key, value)
+ elif kv_cache_mode == KVCacheMode.READ:
+ k_cached, v_cached = kv_cache.get()
+ if k_cached is not None:
+ key = torch.cat([k_cached, key], dim=1)
+ value = torch.cat([v_cached, value], dim=1)
+ # KVCacheMode.SKIP: do nothing
+
+ # Attention computation
+ hidden_states_out = self.attn(query, key, value)
+ hidden_states_out = hidden_states_out.flatten(2, 3)
+ hidden_states_out = hidden_states_out.to(dtype)
- # Project combined [text, image] outputs, then split.
- # This keeps SP numerically aligned with the non-SP path.
- joint_hidden_states_out = joint_hidden_states_out.flatten(2, 3).to(dtype)
- for module in self.to_out:
- joint_hidden_states_out = module(joint_hidden_states_out)
+ # Output projection
+ for module in self.to_out:
+ hidden_states_out = module(hidden_states_out)
- encoder_hidden_states_out = joint_hidden_states_out[:, :text_seq_length, :]
- hidden_states_out = joint_hidden_states_out[:, text_seq_length:, :]
- else:
- # Non-SP mode: original logic
- # Apply RoPE only to image tokens (not text tokens)
- if image_rotary_emb is not None:
- query_img = query[:, text_seq_length:, :, :]
- key_img = key[:, text_seq_length:, :, :]
- from diffusers.models.embeddings import apply_rotary_emb
-
- query_img = apply_rotary_emb(query_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
- key_img = apply_rotary_emb(key_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
- query = torch.cat([query[:, :text_seq_length, :, :], query_img], dim=1)
- key = torch.cat([key[:, :text_seq_length, :, :], key_img], dim=1)
-
- # Handle KV cache for image editing
- if kv_cache is not None and kv_cache_mode is not None:
- if kv_cache_mode == KVCacheMode.WRITE:
- kv_cache.store(key, value)
- elif kv_cache_mode == KVCacheMode.READ:
- k_cached, v_cached = kv_cache.get()
- if k_cached is not None:
- key = torch.cat([k_cached, key], dim=1)
- value = torch.cat([v_cached, value], dim=1)
-
- # Attention computation
- attn_metadata = None
- if attention_mask is not None:
- if attention_mask.dim() == 3:
- attention_mask = attention_mask.unsqueeze(1)
- attn_metadata = AttentionMetadata(attn_mask=attention_mask)
-
- hidden_states_out = self.attn(query, key, value, attn_metadata)
- hidden_states_out = hidden_states_out.flatten(2, 3)
- hidden_states_out = hidden_states_out.to(dtype)
-
- # Output projection
- for module in self.to_out:
- hidden_states_out = module(hidden_states_out)
-
- # Split back to text and image
- encoder_hidden_states_out = hidden_states_out[:, :text_seq_length, :]
- hidden_states_out = hidden_states_out[:, text_seq_length:, :]
+ # Split back to text and image
+ encoder_hidden_states_out = hidden_states_out[:, :text_seq_length, :]
+ hidden_states_out = hidden_states_out[:, text_seq_length:, :]
return hidden_states_out, encoder_hidden_states_out
@@ -662,7 +532,6 @@ def __init__(
*,
approximate: str = "none",
bias: bool = True,
- quant_config: "QuantizationConfig | None" = None,
):
super().__init__()
self.proj = ColumnParallelLinear(
@@ -671,7 +540,6 @@ def __init__(
bias=bias,
gather_output=False,
return_bias=False,
- quant_config=quant_config,
)
self.approximate = approximate
@@ -687,7 +555,6 @@ def __init__(
dim_out: int,
*,
bias: bool = True,
- quant_config: "QuantizationConfig | None" = None,
):
super().__init__()
self.proj = ColumnParallelLinear(
@@ -696,7 +563,6 @@ def __init__(
bias=bias,
gather_output=False,
return_bias=False,
- quant_config=quant_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -713,7 +579,6 @@ def __init__(
inner_dim: int | None = None,
bias: bool = True,
activation_fn: str = "gelu",
- quant_config: "QuantizationConfig | None" = None,
):
super().__init__()
inner_dim = inner_dim or int(dim * mult)
@@ -721,7 +586,7 @@ def __init__(
if activation_fn == "linear-silu":
layers: list[nn.Module] = [
- ColumnParallelSiLU(dim, inner_dim, bias=bias, quant_config=quant_config),
+ ColumnParallelSiLU(dim, inner_dim, bias=bias),
nn.Identity(),
RowParallelLinear(
inner_dim,
@@ -729,13 +594,12 @@ def __init__(
bias=bias,
input_is_parallel=True,
return_bias=False,
- quant_config=quant_config,
),
]
else:
approximate = "tanh" if activation_fn == "gelu-approximate" else "none"
layers = [
- ColumnParallelGELU(dim, inner_dim, approximate=approximate, bias=bias, quant_config=quant_config),
+ ColumnParallelGELU(dim, inner_dim, approximate=approximate, bias=bias),
nn.Identity(),
RowParallelLinear(
inner_dim,
@@ -743,7 +607,6 @@ def __init__(
bias=bias,
input_is_parallel=True,
return_bias=False,
- quant_config=quant_config,
),
]
@@ -765,8 +628,6 @@ def __init__(
attention_head_dim: int = 40,
time_embed_dim: int = 512,
ffn_hidden_dim: int | None = None,
- quant_config: "QuantizationConfig | None" = None,
- parallel_config: DiffusionParallelConfig | None = None,
) -> None:
super().__init__()
@@ -776,20 +637,12 @@ def __init__(
dim=dim,
num_heads=num_attention_heads,
head_dim=attention_head_dim,
- quant_config=quant_config,
- parallel_config=parallel_config,
)
# 2. Feedforward
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
- self.ff = GlmImageFeedForward(
- dim=dim,
- dim_out=dim,
- inner_dim=ffn_hidden_dim,
- activation_fn="gelu-approximate",
- quant_config=quant_config,
- )
+ self.ff = GlmImageFeedForward(dim=dim, dim_out=dim, inner_dim=ffn_hidden_dim, activation_fn="gelu-approximate")
def forward(
self,
@@ -801,7 +654,6 @@ def forward(
attention_kwargs: dict[str, Any] | None = None,
kv_cache: GlmImageLayerKVCache | None = None,
kv_cache_mode: KVCacheMode | None = None,
- hidden_states_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for transformer block.
@@ -815,7 +667,6 @@ def forward(
attention_kwargs: Additional attention arguments
kv_cache: Layer-specific KV cache for image editing
kv_cache_mode: Cache mode (WRITE, READ, SKIP)
- hidden_states_mask: Mask for SP padding (True=valid, False=padding)
Returns:
Tuple of (image_hidden_states, text_hidden_states)
@@ -842,7 +693,6 @@ def forward(
attention_mask=attention_mask,
kv_cache=kv_cache,
kv_cache_mode=kv_cache_mode,
- hidden_states_mask=hidden_states_mask,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
@@ -874,33 +724,12 @@ class GlmImageTransformer2DModel(CachedTransformer):
"""
_repeated_blocks = ["GlmImageTransformerBlock"]
- # SP plan using GlmImagePrepare module for sharding hidden_states and RoPE together.
- # Similar to Qwen-Image's ImageRopePrepare, this creates a module boundary where
- # _sp_plan can shard outputs via split_output=True.
- #
- # Key insight: hidden_states and RoPE embeddings MUST be sharded together
- # to maintain dimension alignment for RoPE computation in attention layers.
- _sp_plan = {
- # Shard GlmImagePrepare outputs (hidden_states and RoPE must be sharded together)
- "prepare": {
- # hidden_states: [B, seq_len, D] - shard along sequence dimension
- 0: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True, auto_pad=True),
- # RoPE cos: [seq_len, dim] - shard along sequence dimension
- 1: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True),
- # RoPE sin: [seq_len, dim] - shard along sequence dimension
- 2: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True),
- # post_patch_height and post_patch_width are scalars, not sharded
- },
- # Gather output at proj_out
- "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
- }
_hsdp_shard_conditions = [is_transformer_block_module]
def __init__(
self,
od_config: OmniDiffusionConfig,
- quant_config: "QuantizationConfig | None" = None,
):
super().__init__()
@@ -954,24 +783,13 @@ def __init__(
# 2. Patch & Text-timestep embedding
self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size)
self.glyph_projector = GlmImageFeedForward(
- dim=text_embed_dim,
- dim_out=inner_dim,
- inner_dim=inner_dim,
- activation_fn="gelu",
- quant_config=quant_config,
+ dim=text_embed_dim, dim_out=inner_dim, inner_dim=inner_dim, activation_fn="gelu"
)
self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim)
self.prior_projector = GlmImageFeedForward(
- dim=inner_dim,
- dim_out=inner_dim,
- inner_dim=inner_dim,
- activation_fn="linear-silu",
- quant_config=quant_config,
+ dim=inner_dim, dim_out=inner_dim, inner_dim=inner_dim, activation_fn="linear-silu"
)
- # Prepare module for SP (encapsulates patch embedding and RoPE for _sp_plan)
- self.prepare = GlmImagePrepare(self.image_projector, self.rope, patch_size)
-
self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings(
embedding_dim=time_embed_dim,
condition_dim=condition_dim,
@@ -988,8 +806,6 @@ def __init__(
attention_head_dim,
time_embed_dim,
ffn_hidden_dim=ffn_hidden_dim,
- quant_config=quant_config,
- parallel_config=self.parallel_config,
)
for _ in range(num_layers)
]
@@ -1043,51 +859,33 @@ def forward(
# Get KV cache mode
kv_cache_mode = kv_cache.mode if kv_cache is not None else None
- # Set SP context if enabled
- sp_size = self.parallel_config.sequence_parallel_size
- if sp_size is not None and sp_size > 1:
- get_forward_context().split_text_embed_in_sp = False
+ # 1. RoPE
+ if image_rotary_emb is None:
+ image_rotary_emb = self.rope(hidden_states)
+ # Move to correct device
+ image_rotary_emb = (
+ image_rotary_emb[0].to(hidden_states.device),
+ image_rotary_emb[1].to(hidden_states.device),
+ )
+
+ # 2. Patch & Timestep embeddings
+ p = self.patch_size
+ post_patch_height = height // p
+ post_patch_width = width // p
- # Text embedding projection
+ hidden_states = self.image_projector(hidden_states)
encoder_hidden_states = self.glyph_projector(encoder_hidden_states)
# Prior embedding with dropout
prior_embedding = self.prior_token_embedding(prior_token_id)
prior_embedding[prior_token_drop] *= 0.0
prior_hidden_states = self.prior_projector(prior_embedding)
-
- # 1. Prepare hidden_states and RoPE via GlmImagePrepare module
- # _sp_plan will shard hidden_states and RoPE together via split_output=True
- hidden_states, rope_cos, rope_sin, post_patch_height_t, post_patch_width_t = self.prepare(
- hidden_states, prior_hidden_states
- )
- image_rotary_emb = (rope_cos, rope_sin)
- post_patch_height = int(post_patch_height_t.item())
- post_patch_width = int(post_patch_width_t.item())
+ hidden_states = hidden_states + prior_hidden_states
# Timestep conditioning
temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype)
- # Create padding mask for SP if needed (after _sp_plan hooks have run)
- hidden_states_mask = None
- if sp_size is not None and sp_size > 1:
- from vllm_omni.diffusion.forward_context import is_forward_context_available
-
- if is_forward_context_available():
- ctx = get_forward_context()
- if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0:
- img_padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size
- hidden_states_mask = torch.ones(
- batch_size,
- img_padded_seq_len,
- dtype=torch.bool,
- device=hidden_states.device,
- )
- hidden_states_mask[:, ctx.sp_original_seq_len :] = False
- if hidden_states_mask.all():
- hidden_states_mask = None
-
- # 2. Transformer blocks
+ # 3. Transformer blocks
for layer_idx, block in enumerate(self.transformer_blocks):
# Get layer-specific KV cache if available
layer_kv_cache = kv_cache[layer_idx] if kv_cache is not None else None
@@ -1101,16 +899,13 @@ def forward(
attention_kwargs,
kv_cache=layer_kv_cache,
kv_cache_mode=kv_cache_mode,
- hidden_states_mask=hidden_states_mask,
)
- # 3. Output norm & projection
- # _sp_plan will gather hidden_states via proj_out hook
+ # 4. Output norm & projection
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
- # 4. Unpatchify: [B, H'*W', C*p*p] -> [B, C, H, W]
- p = self.patch_size
+ # 5. Unpatchify: [B, H'*W', C*p*p] -> [B, C, H, W]
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
diff --git a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py
index 97cba18c234..375f7e7b80d 100644
--- a/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py
+++ b/vllm_omni/diffusion/models/glm_image/pipeline_glm_image.py
@@ -301,10 +301,7 @@ def __init__(
# Load transformer (DiT)
logger.info("Loading GlmImageTransformer2DModel (DiT)...")
- self.transformer = GlmImageTransformer2DModel(
- od_config=od_config,
- quant_config=od_config.quantization_config,
- )
+ self.transformer = GlmImageTransformer2DModel(od_config=od_config)
# Weight sources for DiT loading
self.weights_sources = [
@@ -715,14 +712,6 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
if img is not None:
preprocessed_images = [img]
- # Priority: prompt dict (from ar2diffusion) > sampling_params
- # ar2diffusion returns adjusted height/width that matches prior_token_ids
- if not isinstance(first_prompt, str):
- ar_height = first_prompt.get("height")
- ar_width = first_prompt.get("width")
- else:
- ar_height = ar_width = None
-
img_height = req.sampling_params.height
img_width = req.sampling_params.width
@@ -730,19 +719,12 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
# Treat that as t2i warmup to avoid requiring i2i-only KV-cache inputs.
is_image_edit = (preprocessed_images is not None) and (not is_dummy_warmup)
- # Use prompt dict dimensions (from ar2diffusion) as priority, then sampling_params
- height = (
- ar_height or req.sampling_params.height or img_height or self.default_sample_size * self.vae_scale_factor
- )
- width = ar_width or req.sampling_params.width or img_width or self.default_sample_size * self.vae_scale_factor
+ # Use image dimensions as default if available
+ height = req.sampling_params.height or img_height or self.default_sample_size * self.vae_scale_factor
+ width = req.sampling_params.width or img_width or self.default_sample_size * self.vae_scale_factor
num_inference_steps = req.sampling_params.num_inference_steps or 50
guidance_scale = req.sampling_params.guidance_scale or 1.5
- # Ensure dimensions are multiples of vae_scale_factor * patch_size
- multiple_of = self.vae_scale_factor * self._patch_size
- height = height // multiple_of * multiple_of
- width = width // multiple_of * multiple_of
-
self.check_inputs(prompt=prompt, height=height, width=width, prompt_embeds=prompt_embeds)
batch_size = 1
@@ -771,20 +753,6 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
prior_token_id = prior_token_id.to(device=self.device, dtype=torch.long)
if prior_token_id.dim() == 1:
prior_token_id = prior_token_id.unsqueeze(0)
-
- # Validate that prior_token_id seq_len matches dimensions
- prior_seq_len = prior_token_id.shape[1]
- expected_seq_len = (height // self.vae_scale_factor // self._patch_size) * (
- width // self.vae_scale_factor // self._patch_size
- )
- if prior_seq_len != expected_seq_len:
- raise ValueError(
- f"prior_token_ids seq_len ({prior_seq_len}) doesn't match dimensions "
- f"({height}x{width}, expected seq_len={expected_seq_len}). "
- f"This indicates a mismatch between AR output and Diffusion input. "
- f"Please ensure ar2diffusion returns correct height/width."
- )
-
prior_token_image_ids = None
if external_prior_image_ids is not None:
if isinstance(external_prior_image_ids, torch.Tensor):
diff --git a/vllm_omni/diffusion/models/helios/helios_transformer.py b/vllm_omni/diffusion/models/helios/helios_transformer.py
index 5e7934c3ba6..812da7db149 100644
--- a/vllm_omni/diffusion/models/helios/helios_transformer.py
+++ b/vllm_omni/diffusion/models/helios/helios_transformer.py
@@ -62,16 +62,10 @@ def apply_rotary_emb_helios(
"""
x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
- # Use stack+flatten instead of strided slice assignment for contiguous
- # memory layout and better performance on GPU/NPU (#2436, cf. PR #2393).
- rotated = torch.stack(
- (
- x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2],
- x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2],
- ),
- dim=-1,
- )
- return rotated.flatten(-2, -1).type_as(hidden_states)
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2]
+ out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2]
+ return out.type_as(hidden_states)
class DistributedRMSNorm(nn.Module):
@@ -582,7 +576,7 @@ class HeliosTransformer3DModel(nn.Module):
"""
_repeated_blocks = ["HeliosTransformerBlock"]
- _layerwise_offload_blocks_attrs = ["blocks"]
+ _layerwise_offload_blocks_attr = "blocks"
packed_modules_mapping = {
"to_qkv": ["to_q", "to_k", "to_v"],
}
diff --git a/vllm_omni/diffusion/models/hunyuan_image3/system_prompt.py b/vllm_omni/diffusion/models/hunyuan_image3/system_prompt.py
deleted file mode 100644
index 29494fad419..00000000000
--- a/vllm_omni/diffusion/models/hunyuan_image3/system_prompt.py
+++ /dev/null
@@ -1,215 +0,0 @@
-# ruff: noqa: E501
-# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-t2i_system_prompt_en_vanilla = """
-You are an advanced AI text-to-image generation system. Given a detailed text prompt, your task is to create a high-quality, visually compelling image that accurately represents the described scene, characters, or objects. Pay careful attention to style, color, lighting, perspective, and any specific instructions provided.
-"""
-
-# 775
-t2i_system_prompt_en_recaption = """
-You are a world-class image generation prompt expert. Your task is to rewrite a user's simple description into a **structured, objective, and detail-rich** professional-level prompt.
-
-The final output must be wrapped in `` tags.
-
-### **Universal Core Principles**
-
-When rewriting the prompt (inside the `` tags), you must adhere to the following principles:
-
-1. **Absolute Objectivity**: Describe only what is visually present. Avoid subjective words like "beautiful" or "sad". Convey aesthetic qualities through specific descriptions of color, light, shadow, and composition.
-2. **Physical and Logical Consistency**: All scene elements (e.g., gravity, light, shadows, reflections, spatial relationships, object proportions) must strictly adhere to real-world physics and common sense. For example, tennis players must be on opposite sides of the net; objects cannot float without a cause.
-3. **Structured Description**: Strictly follow a logical order: from general to specific, background to foreground, and primary to secondary elements. Use directional terms like "foreground," "mid-ground," "background," and "left side of the frame" to clearly define the spatial layout.
-4. **Use Present Tense**: Describe the scene from an observer's perspective using the present tense, such as "A man stands..." or "Light shines on..."
-5. **Use Rich and Specific Descriptive Language**: Use precise adjectives to describe the quantity, size, shape, color, and other attributes of objects, subjects, and text. Vague expressions are strictly prohibited.
-
-If the user specifies a style (e.g., oil painting, anime, UI design, text rendering), strictly adhere to that style. Otherwise, first infer a suitable style from the user's input. If there is no clear stylistic preference, default to an **ultra-realistic photographic style**. Then, generate the detailed rewritten prompt according to the **Style-Specific Creation Guide** below:
-
-### **Style-Specific Creation Guide**
-
-Based on the determined artistic style, apply the corresponding professional knowledge.
-
-**1. Photography and Realism Style**
-* Utilize professional photography terms (e.g., lighting, lens, composition) and meticulously detail material textures, physical attributes of subjects, and environmental details.
-
-**2. Illustration and Painting Style**
-* Clearly specify the artistic school (e.g., Japanese Cel Shading, Impasto Oil Painting) and focus on describing its unique medium characteristics, such as line quality, brushstroke texture, or paint properties.
-
-**3. Graphic/UI/APP Design Style**
-* Objectively describe the final product, clearly defining the layout, elements, and color palette. All text on the interface must be enclosed in double quotes `""` to specify its exact content (e.g., "Login"). Vague descriptions are strictly forbidden.
-
-**4. Typographic Art**
-* The text must be described as a complete physical object. The description must begin with the text itself. Use a straightforward front-on or top-down perspective to ensure the entire text is visible without cropping.
-
-### **Final Output Requirements**
-
-1. **Output the Final Prompt Only**: Do not show any thought process, Markdown formatting, or line breaks.
-2. **Adhere to the Input**: You must retain the core concepts, attributes, and any specified text from the user's input.
-3. **Style Reinforcement**: Mention the core style 3-5 times within the prompt and conclude with a style declaration sentence.
-4. **Avoid Self-Reference**: Describe the image content directly. Remove redundant phrases like "This image shows..." or "The scene depicts..."
-5. **The final output must be wrapped in `xxxx ` tags.**
-
-The user will now provide an input prompt. You will provide the expanded prompt.
-"""
-
-# 890
-t2i_system_prompt_en_think_recaption = """
-You will act as a top-tier Text-to-Image AI. Your core task is to deeply analyze the user's text input and transform it into a detailed, artistic, and fully user-intent-compliant image.
-
-Your workflow is divided into two phases:
-
-1. Thinking Phase (): In the tag, you need to conduct a structured thinking process, progressively breaking down and enriching the constituent elements of the image. This process must include, but is not limited to, the following dimensions:
-
-Subject: Clearly define the core character(s) or object(s) in the scene, including their appearance, posture, expression, and emotion.
-Composition: Set the camera angle and layout, such as close-up, long shot, bird's-eye view, golden ratio composition, etc.
-Environment/Background: Describe the scene where the subject is located, including the location, time of day, weather, and other elements in the background.
-Lighting: Define the type, direction, and quality of the light source, such as soft afternoon sunlight, cool tones of neon lights, dramatic Rembrandt lighting, etc., to create a specific atmosphere.
-Color Palette: Set the main color tone and color scheme of the image, such as vibrant and saturated, low-saturation Morandi colors, black and white, etc.
-Quality/Style: Determine the artistic style and technical details of the image. This includes user-specified styles (e.g., anime, oil painting) or the default realistic style, as well as camera parameters (e.g., focal length, aperture, depth of field).
-Details: Add minute elements that enhance the realism and narrative quality of the image, such as a character's accessories, the texture of a surface, dust particles in the air, etc.
-
-
-2. Recaption Phase (): In the tag, merge all the key details from the thinking process into a coherent, precise, and visually evocative final description. This description is the direct instruction for generating the image, so it must be clear, unambiguous, and organized in a way that is most suitable for an image generation engine to understand.
-
-Absolutely Objective: Describe only what is visually present. Avoid subjective words like "beautiful" or "sad." Convey aesthetic sense through concrete descriptions of colors, light, shadow, and composition.
-
-Physical and Logical Consistency: All scene elements (e.g., gravity, light and shadow, reflections, spatial relationships, object proportions) must strictly adhere to the physical laws of the real world and common sense. For example, in a tennis match, players must be on opposite sides of the net; objects cannot float without reason.
-
-Structured Description: Strictly follow a logical order: from whole to part, background to foreground, and primary to secondary. Use directional words like "foreground," "mid-ground," "background," "left side of the frame" to clearly define the spatial layout.
-
-Use Present Tense: Describe from an observer's perspective using the present tense, such as "a man stands," "light shines on..."
-Use Rich and Specific Descriptive Language: Use precise adjectives to describe the quantity, size, shape, color, and other attributes of objects/characters/text. Absolutely avoid any vague expressions.
-
-
-Output Format:
-Thinking process Refined image description Generate Image
-
-
-You must strictly adhere to the following rules:
-
-1. Faithful to Intent, Reasonable Expansion: You can creatively add details to the user's description to enhance the image's realism and artistic quality. However, all additions must be highly consistent with the user's core intent and never introduce irrelevant or conflicting elements.
-2. Style Handling: When the user does not specify a style, you must default to an "Ultra-realistic, Photorealistic" style. If the user explicitly specifies a style (e.g., anime, watercolor, oil painting, cyberpunk, etc.), both your thinking process and final description must strictly follow and reflect that specified style.
-3. Text Rendering: If specific text needs to appear in the image (such as words on a sign, a book title), you must enclose this text in English double quotes (""). Descriptive text must not use double quotes.
-4. Design-related Images: You need to specify all text and graphical elements that appear in the image and clearly describe their design details, including font, color, size, position, arrangement, visual effects, etc.
-"""
-
-t2i_system_prompts = {
- "en_vanilla": [t2i_system_prompt_en_vanilla],
- "en_recaption": [t2i_system_prompt_en_recaption],
- "en_think_recaption": [t2i_system_prompt_en_think_recaption],
-}
-
-
-unified_system_prompt_en = """You are an advanced multimodal model whose core mission is to analyze user intent and generate high-quality text and images.
-
-#### Four Core Capabilities
-1. **Text-to-Text (T2T):** Generate coherent text responses from text prompts.
-2. **Text-to-Image (T2I):** Generate high-quality images from text prompts.
-3. **Text & Image to Text (TI2T):** Generate accurate text responses based on a combination of images and text.
-4. **Text & Image to Image (TI2I):** Generate modified images based on a reference image and editing instructions.
-
----
-### Image Generation Protocol (for T2I & TI2I)
-You will operate in one of two modes, determined by the user's starting tag:
-#### ** Mode (Prompt Rewriting)**:
-* **Trigger:** Input begins with ``.
-* **Task:** Immediately rewrite the user's text into a structured, objective, and detail-rich professional-grade prompt.
-* **Output:** Output only the rewritten prompt within `` tags: `Rewritten professional-grade prompt `
-
-#### ** Mode (Think + Rewrite)**:
-* **Trigger:** Input begins with ``.
-* **Task:** First, conduct a structured analysis of the request within `` tags. Then, output the professional prompt, rewritten based on the analysis, within `` tags.
-* **Output:** Strictly adhere to the format: `Analysis process Rewritten prompt `
-
----
-### Execution Standards and Guidelines
-#### **`` Phase: Analysis Guidelines**
-**For T2I (New Image Generation):**
-Deconstruct the user's request into the following core visual components:
-* **Subject:** Key features of the main character/object, including appearance, pose, expression, and emotion.
-* **Composition:** Camera angle, lens type, and layout.
-* **Environment/Background:** The setting, time of day, weather, and background elements.
-* **Lighting:** Technical details such as light source type, direction, and quality.
-* **Color Palette:** The dominant hues and overall color scheme.
-* **Style/Quality:** The artistic style, clarity, depth of field, and other technical details.
-* **Text:** Identify any text to be rendered in the image, including its content, style, and position.
-* **Details:** Small elements that add narrative depth and realism.
-
-**For TI2I (Image Editing):**
-Adopt a task-diagnostic approach:
-1. **Diagnose Task:** Identify the edit type and analyze key requirements.
-2. **Prioritize Analysis:**
- * **Adding:** Analyze the new element's position and appearance, ensuring seamless integration with the original image's lighting, shadows, and style.
- * **Removing:** Identify the target for removal and determine how to logically fill the resulting space using surrounding textures and lighting.
- * **Modifying:** Analyze what to change and what it should become, while emphasizing which elements must remain unchanged.
- * **Style Transfer:** Deconstruct the target style into specific features (e.g., brushstrokes, color palette) and apply them to the original image.
- * **Text Editing:** Ensure correct content and format. Consider the text's visual style (e.g., font, color, material) and how it adapts to the surface's perspective, curvature, and lighting.
- * **Reference Editing:** Extract specific visual elements (e.g., appearance, posture, composition, lines, depth) from the reference image to generate an image that aligns with the text description while also incorporating the referenced content.
- * **Inferential Editing:** Identify vague requests (e.g., "make it more professional") and translate them into concrete visual descriptions.
-
-#### `` Phase: Professional-Grade Prompt Generation Rules
-**General Rewriting Principles (for T2I & TI2I):**
-1. **Structure & Logic:** Start with a global description. Use positional words (e.g., "foreground", "background") to define the layout.
-2. **Absolute Objectivity:** Avoid subjective terms. Convey aesthetics through precise descriptions of color, light, shadow, and materials.
-3. **Physical & Logical Consistency:** Ensure all descriptions adhere to the laws of physics and common sense.
-4. **Fidelity to User Intent:** Preserve the user's core concepts, subjects, and attributes. Text to be rendered in the image **must be enclosed in double quotes ("")**.
-5. **Camera & Resolution:** Translate camera parameters into descriptions of visual effects. Convert resolution information into natural language.
-
-**T2I-Specific Guidelines:**
-* **Style Adherence & Inference:** Strictly follow the specified style. If none is given, infer the most appropriate style and detail it using professional terminology.
-* **Style Detailing:**
- * **Photography/Realism:** Use professional photography terms to describe lighting, lens effects, and material textures.
- * **Painting/Illustration:** Specify the art movement or medium's characteristics.
- * **UI/Design:** Objectively describe the final product. Define layout, elements, and typography. Text content must be specific and unambiguous.
-
-**TI2I-Specific Guidelines:**
-* **Preserve Unchanged Elements:** Emphasize elements that **remain unchanged**. Unless explicitly instructed, never alter a character's identity/appearance, the core background, camera angle, or overall style.
-* **Clear Editing Instructions:**
- * **Replacement:** Use the logic "**replace B with A**," and provide a detailed description of A.
- * **Addition:** Clearly state what to add, where, and what it looks like.
-* **Unambiguous Referencing:** Avoid vague references (e.g., "that person"). Use specific descriptions of appearance.
-"""
-
-
-def get_system_prompt(sys_type, bot_task, system_prompt=None):
- # No system prompt, return None directly
- if sys_type == "None":
- return None
- # Use the unified English system prompt (combined T2I and TI2I guidelines)
- elif sys_type == "en_unified":
- return unified_system_prompt_en
- # Use predefined English system prompts: vanilla (basic), recaption, think_recaption
- elif sys_type in ["en_vanilla", "en_recaption", "en_think_recaption"]:
- return t2i_system_prompts[sys_type][0]
- # Dynamic mode: automatically select system prompt based on bot_task type
- elif sys_type == "dynamic":
- # Think task: use chain-of-thought recaption prompt
- if bot_task == "think":
- return t2i_system_prompts["en_think_recaption"][0]
- # Recaption task: use recaption prompt
- elif bot_task == "recaption":
- return t2i_system_prompts["en_recaption"][0]
- # Image generation task: use vanilla prompt
- elif bot_task == "image":
- return t2i_system_prompts["en_vanilla"][0].strip("\n")
- # Other tasks: use user-provided custom prompt
- else:
- return system_prompt
- # Custom mode: use the user-provided system_prompt parameter directly
- elif sys_type == "custom":
- return system_prompt
- # Unsupported type: raise NotImplementedError
- else:
- raise NotImplementedError(f"Unsupported system prompt type: {sys_type}")
-
-
-__all__ = ["get_system_prompt"]
diff --git a/vllm_omni/diffusion/models/hunyuan_image3/__init__.py b/vllm_omni/diffusion/models/hunyuan_image_3/__init__.py
similarity index 58%
rename from vllm_omni/diffusion/models/hunyuan_image3/__init__.py
rename to vllm_omni/diffusion/models/hunyuan_image_3/__init__.py
index 6612bd855ba..cbc6a8ad1f4 100644
--- a/vllm_omni/diffusion/models/hunyuan_image3/__init__.py
+++ b/vllm_omni/diffusion/models/hunyuan_image_3/__init__.py
@@ -2,12 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Hunyuan Image 3 diffusion model components."""
-from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe import HunyuanFusedMoE
-from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_image3_transformer import (
+from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import HunyuanFusedMoE
+from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_image_3_transformer import (
HunyuanImage3Model,
HunyuanImage3Text2ImagePipeline,
)
-from vllm_omni.diffusion.models.hunyuan_image3.pipeline_hunyuan_image3 import (
+from vllm_omni.diffusion.models.hunyuan_image_3.pipeline_hunyuan_image_3 import (
HunyuanImage3Pipeline,
)
diff --git a/vllm_omni/diffusion/models/hunyuan_image3/autoencoder.py b/vllm_omni/diffusion/models/hunyuan_image_3/autoencoder.py
similarity index 100%
rename from vllm_omni/diffusion/models/hunyuan_image3/autoencoder.py
rename to vllm_omni/diffusion/models/hunyuan_image_3/autoencoder.py
diff --git a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_fused_moe.py b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py
similarity index 100%
rename from vllm_omni/diffusion/models/hunyuan_image3/hunyuan_fused_moe.py
rename to vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_fused_moe.py
diff --git a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_tokenizer.py b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_tokenizer.py
similarity index 99%
rename from vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_tokenizer.py
rename to vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_tokenizer.py
index 4a29e9df93e..ce563f71159 100644
--- a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_tokenizer.py
+++ b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_tokenizer.py
@@ -13,7 +13,7 @@
from transformers import AutoTokenizer
from vllm.logger import init_logger
-from .hunyuan_image3_transformer import ImageInfo, JointImageInfo, default
+from .hunyuan_image_3_transformer import ImageInfo, JointImageInfo, default
logger = init_logger(__name__)
diff --git a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py
similarity index 99%
rename from vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py
rename to vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py
index 0f3c33389c5..bc81ca9c3ed 100644
--- a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py
+++ b/vllm_omni/diffusion/models/hunyuan_image_3/hunyuan_image_3_transformer.py
@@ -74,7 +74,7 @@
)
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.layers.rope import RotaryEmbedding
-from vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe import HunyuanFusedMoE
+from vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe import HunyuanFusedMoE
logger = logging.getLogger(__name__)
@@ -1484,7 +1484,7 @@ def __init__(
config.hidden_size,
config.num_experts,
bias=False,
- quant_config=quant_config,
+ quant_config=None,
prefix=f"{prefix}.gate",
)
if config.use_mixed_mlp_moe > 0:
@@ -1658,10 +1658,8 @@ def forward(
custom_pos_emb: tuple[torch.FloatTensor] | None = None,
**kwargs,
) -> torch.Tensor:
- bsz, q_len, hidden_size = hidden_states.size()
- hidden_states = hidden_states.reshape(-1, hidden_size)
+ bsz, q_len, _ = hidden_states.size()
qkv, _ = self.qkv_proj(hidden_states)
- qkv = qkv.reshape(bsz, q_len, -1)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
past_key_value: Cache | None = kwargs.get("past_key_value", None)
@@ -1686,8 +1684,7 @@ def forward(
else:
attn_output = self.attn(q, k, v)
# For o_proj
- # image_attn may return a non-contiguous tensor; reshape is safe here.
- attn_output = attn_output.reshape(q.shape[0], -1)
+ attn_output = attn_output.view(q.shape[0], -1)
output, _ = self.o_proj(attn_output)
output = output.reshape(bsz, q_len, -1)
return output, None, past_key_value
@@ -1725,7 +1722,7 @@ def __init__(
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
- quant_config=quant_config,
+ quant_config=None,
bias=attention_bias,
cache_config=None,
prefix=f"{prefix}.self_attn",
@@ -1935,7 +1932,7 @@ def __init__(self, config: HunyuanImage3Config, quant_config=None, prefix: str =
layer_idx=int(prefix.split(".")[-1]),
prefix=prefix,
),
- prefix=f"{prefix}.layers" if prefix else "layers",
+ prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -1950,7 +1947,7 @@ def _split_qkv_weight(self, qkv: torch.Tensor):
num_attention_heads = self.config.num_attention_heads
num_kv_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads)
num_key_value_groups = num_attention_heads // num_kv_heads
- hidden_size = qkv.shape[1]
+ hidden_size = self.config.hidden_size
if hasattr(self.config, "head_dim"):
attention_head_dim = self.config.head_dim
@@ -2003,15 +2000,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
split_params_mapping = [
(".gate_up_proj", ".gate_and_up_proj", 2, [(1, 1), (0, 1)], None),
(
- ".qkv_proj.weight",
- ".qkv_proj.weight",
- num_attention_heads + num_kv_heads * 2,
- [("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)],
- self._split_qkv_weight,
- ),
- (
- ".qkv_proj.weight_scale",
- ".qkv_proj.weight_scale",
+ ".qkv_proj",
+ ".qkv_proj",
num_attention_heads + num_kv_heads * 2,
[("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)],
self._split_qkv_weight,
@@ -2110,8 +2100,6 @@ def contains_unexpected_keyword(name, keywords):
continue
if "mlp.experts" in name:
continue
- if ".qkv_proj" in name and not name.endswith(weight_name):
- continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
diff --git a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py b/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py
similarity index 97%
rename from vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py
rename to vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py
index 84a7787ad11..ba24818dc93 100644
--- a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py
+++ b/vllm_omni/diffusion/models/hunyuan_image_3/pipeline_hunyuan_image_3.py
@@ -6,6 +6,7 @@
from collections.abc import Iterable
from typing import Any
+import numpy as np
import torch
import torch.nn as nn
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
@@ -14,7 +15,7 @@
from transformers.models.siglip2 import Siglip2VisionConfig, Siglip2VisionModel
from transformers.utils.generic import ModelOutput
from vllm.config.vllm import get_current_vllm_config
-from vllm.model_executor.models.utils import AutoWeightsLoader, WeightsMapper
+from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.transformers_utils.config import get_config
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
@@ -24,8 +25,8 @@
from vllm_omni.diffusion.request import OmniDiffusionRequest
from .autoencoder import AutoencoderKLConv3D
-from .hunyuan_image3_tokenizer import TokenizerWrapper
-from .hunyuan_image3_transformer import (
+from .hunyuan_image_3_tokenizer import TokenizerWrapper
+from .hunyuan_image_3_transformer import (
CausalMMOutputWithPast,
HunyuanImage3ImageProcessor,
HunyuanImage3Model,
@@ -40,7 +41,6 @@
build_batch_2d_rope,
real_batched_index_select,
)
-from .system_prompt import get_system_prompt
logger = logging.getLogger(__name__)
@@ -64,15 +64,6 @@ def to_device(data, device):
class HunyuanImage3Pipeline(HunyuanImage3PreTrainedModel, GenerationMixin, DiffusionPipelineProfilerMixin):
- hf_to_vllm_mapper = WeightsMapper(
- orig_to_new_prefix={
- "model.": "",
- },
- orig_to_new_substr={
- "mlp.gate.wg.": "mlp.gate.",
- "gate_and_up_proj.": "gate_up_proj.",
- },
- )
_PROFILER_TARGETS = [
"model.forward",
"model.layers[0].forward",
@@ -552,7 +543,7 @@ def prepare_model_inputs(
generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
# 3. apply chat template
- cfg_factor = {"gen_text": 1, "gen_image": 1 + int(guidance_scale > 1.0)}
+ cfg_factor = {"gen_text": 1, "gen_image": 2}
bot_task = kwargs.pop("bot_task", "auto")
# If `drop_think` enabled, always drop parts in the context.
drop_think = kwargs.get("drop_think", self.generation_config.drop_think)
@@ -1000,15 +991,10 @@ def forward(
width: int = 1024,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
+ system_prompt: str | None = None,
generator: torch.Generator | list[torch.Generator] | None = None,
**kwargs,
) -> DiffusionOutput:
- extra_args = getattr(getattr(req, "sampling_params", None), "extra_args", {}) or {}
- use_system_prompt = extra_args.get("use_system_prompt")
- system_prompt = extra_args.get("system_prompt")
- if use_system_prompt is not None:
- system_prompt = get_system_prompt(use_system_prompt, "image", system_prompt)
- system_prompt = system_prompt.strip() if system_prompt is not None else ""
prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt
generator = req.sampling_params.generator or generator
height = req.sampling_params.height or height
@@ -1017,7 +1003,8 @@ def forward(
if req.sampling_params.guidance_scale_provided:
guidance_scale = req.sampling_params.guidance_scale
if guidance_scale <= 1.0:
- logger.info("HunyuanImage3.0 runs without classifier-free guidance when guidance_scale <= 1.0.")
+ logger.warning("HunyuanImage3.0 does not support guidance_scale <= 1.0, will set it to 1.0 + epsilon.")
+ guidance_scale = 1.0 + np.finfo(float).eps
image_size = (height, width)
model_inputs = self.prepare_model_inputs(
prompt=prompt,
diff --git a/vllm_omni/diffusion/models/hunyuan_video/hunyuan_video_15_transformer.py b/vllm_omni/diffusion/models/hunyuan_video/hunyuan_video_15_transformer.py
index 6600b17d5cd..263e39e0189 100644
--- a/vllm_omni/diffusion/models/hunyuan_video/hunyuan_video_15_transformer.py
+++ b/vllm_omni/diffusion/models/hunyuan_video/hunyuan_video_15_transformer.py
@@ -539,7 +539,7 @@ class HunyuanVideo15Transformer3DModel(nn.Module):
"""
_repeated_blocks = ["HunyuanVideo15TransformerBlock"]
- _layerwise_offload_blocks_attrs = ["transformer_blocks"]
+ _layerwise_offload_blocks_attr = "transformer_blocks"
packed_modules_mapping = {
"to_qkv": ["to_q", "to_k", "to_v"],
"add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"],
diff --git a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py
index 6445bfee215..0b68676e8dc 100644
--- a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py
+++ b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py
@@ -24,9 +24,7 @@
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.hunyuan_video.hunyuan_video_15_transformer import HunyuanVideo15Transformer3DModel
-from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
from vllm_omni.diffusion.models.t5_encoder import T5EncoderModel
-from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.platforms import current_omni_platform
@@ -83,7 +81,7 @@ def post_process_func(video: torch.Tensor, output_type: str = "pil"):
return post_process_func
-class HunyuanVideo15Pipeline(nn.Module, CFGParallelMixin, ProgressBarMixin, DiffusionPipelineProfilerMixin):
+class HunyuanVideo15Pipeline(nn.Module, CFGParallelMixin):
def __init__(
self,
*,
@@ -175,10 +173,6 @@ def __init__(
self._num_timesteps = None
self._current_timestep = None
- self.setup_diffusion_pipeline_profiler(
- enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler
- )
-
@property
def guidance_scale(self):
return self._guidance_scale
@@ -451,63 +445,60 @@ def forward(
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
- with self.progress_bar(total=len(timesteps)) as pbar:
- for i, t in enumerate(timesteps):
- self._current_timestep = t
-
- latent_model_input = torch.cat([latents, cond_latents, mask], dim=1)
- timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
-
- timestep_r = None
- if self.use_meanflow:
- if i == len(timesteps) - 1:
- timestep_r = torch.tensor([0.0], device=device)
- else:
- timestep_r = timesteps[i + 1]
- timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype)
-
- positive_kwargs = {
+ for i, t in enumerate(timesteps):
+ self._current_timestep = t
+
+ latent_model_input = torch.cat([latents, cond_latents, mask], dim=1)
+ timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
+
+ timestep_r = None
+ if self.use_meanflow:
+ if i == len(timesteps) - 1:
+ timestep_r = torch.tensor([0.0], device=device)
+ else:
+ timestep_r = timesteps[i + 1]
+ timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype)
+
+ positive_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "timestep_r": timestep_r,
+ "encoder_hidden_states": prompt_embeds,
+ "encoder_attention_mask": prompt_embeds_mask,
+ "encoder_hidden_states_2": prompt_embeds_2,
+ "encoder_attention_mask_2": prompt_embeds_mask_2,
+ "image_embeds": image_embeds,
+ "return_dict": False,
+ }
+
+ negative_kwargs = None
+ if do_cfg and negative_prompt_embeds is not None:
+ negative_kwargs = {
"hidden_states": latent_model_input,
"timestep": timestep,
"timestep_r": timestep_r,
- "encoder_hidden_states": prompt_embeds,
- "encoder_attention_mask": prompt_embeds_mask,
- "encoder_hidden_states_2": prompt_embeds_2,
- "encoder_attention_mask_2": prompt_embeds_mask_2,
+ "encoder_hidden_states": negative_prompt_embeds,
+ "encoder_attention_mask": negative_prompt_embeds_mask,
+ "encoder_hidden_states_2": negative_prompt_embeds_2,
+ "encoder_attention_mask_2": negative_prompt_embeds_mask_2,
"image_embeds": image_embeds,
"return_dict": False,
}
- negative_kwargs = None
- if do_cfg and negative_prompt_embeds is not None:
- negative_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "timestep_r": timestep_r,
- "encoder_hidden_states": negative_prompt_embeds,
- "encoder_attention_mask": negative_prompt_embeds_mask,
- "encoder_hidden_states_2": negative_prompt_embeds_2,
- "encoder_attention_mask_2": negative_prompt_embeds_mask_2,
- "image_embeds": image_embeds,
- "return_dict": False,
- }
-
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_cfg and negative_kwargs is not None,
- true_cfg_scale=guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=req.sampling_params.cfg_normalize,
- )
-
- latents = self.scheduler_step_maybe_with_cfg(
- noise_pred,
- t,
- latents,
- do_true_cfg=do_cfg and negative_kwargs is not None,
- )
+ noise_pred = self.predict_noise_maybe_with_cfg(
+ do_true_cfg=do_cfg and negative_kwargs is not None,
+ true_cfg_scale=guidance_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ cfg_normalize=req.sampling_params.cfg_normalize,
+ )
- pbar.update()
+ latents = self.scheduler_step_maybe_with_cfg(
+ noise_pred,
+ t,
+ latents,
+ do_true_cfg=do_cfg and negative_kwargs is not None,
+ )
self._current_timestep = None
diff --git a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py
index c1acd1a895a..d68c43125c5 100644
--- a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py
+++ b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py
@@ -38,9 +38,7 @@
retrieve_latents,
)
from vllm_omni.diffusion.models.interface import SupportImageInput
-from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
from vllm_omni.diffusion.models.t5_encoder import T5EncoderModel
-from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.platforms import current_omni_platform
@@ -100,9 +98,7 @@ def pre_process_func(req: OmniDiffusionRequest) -> OmniDiffusionRequest:
return pre_process_func
-class HunyuanVideo15I2VPipeline(
- nn.Module, CFGParallelMixin, SupportImageInput, ProgressBarMixin, DiffusionPipelineProfilerMixin
-):
+class HunyuanVideo15I2VPipeline(nn.Module, CFGParallelMixin, SupportImageInput):
support_image_input = True
color_format = "RGB"
@@ -203,10 +199,6 @@ def __init__(
self._num_timesteps = None
self._current_timestep = None
- self.setup_diffusion_pipeline_profiler(
- enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler
- )
-
@property
def guidance_scale(self):
return self._guidance_scale
@@ -528,64 +520,61 @@ def forward(
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
- with self.progress_bar(total=len(timesteps)) as pbar:
- for i, t in enumerate(timesteps):
- self._current_timestep = t
-
- latent_model_input = torch.cat([latents, cond_latents, mask], dim=1)
- timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
-
- timestep_r = None
- if self.use_meanflow:
- if i == len(timesteps) - 1:
- timestep_r = torch.tensor([0.0], device=device)
- else:
- timestep_r = timesteps[i + 1]
- timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype)
-
- positive_kwargs = {
+ for i, t in enumerate(timesteps):
+ self._current_timestep = t
+
+ latent_model_input = torch.cat([latents, cond_latents, mask], dim=1)
+ timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
+
+ timestep_r = None
+ if self.use_meanflow:
+ if i == len(timesteps) - 1:
+ timestep_r = torch.tensor([0.0], device=device)
+ else:
+ timestep_r = timesteps[i + 1]
+ timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype)
+
+ positive_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "timestep_r": timestep_r,
+ "encoder_hidden_states": prompt_embeds,
+ "encoder_attention_mask": prompt_embeds_mask,
+ "encoder_hidden_states_2": prompt_embeds_2,
+ "encoder_attention_mask_2": prompt_embeds_mask_2,
+ "image_embeds": image_embeds,
+ "return_dict": False,
+ }
+
+ negative_kwargs = None
+ if do_cfg and negative_prompt_embeds is not None:
+ # For I2V CFG, negative still uses image embeds (only text is unconditional)
+ negative_kwargs = {
"hidden_states": latent_model_input,
"timestep": timestep,
"timestep_r": timestep_r,
- "encoder_hidden_states": prompt_embeds,
- "encoder_attention_mask": prompt_embeds_mask,
- "encoder_hidden_states_2": prompt_embeds_2,
- "encoder_attention_mask_2": prompt_embeds_mask_2,
+ "encoder_hidden_states": negative_prompt_embeds,
+ "encoder_attention_mask": negative_prompt_embeds_mask,
+ "encoder_hidden_states_2": negative_prompt_embeds_2,
+ "encoder_attention_mask_2": negative_prompt_embeds_mask_2,
"image_embeds": image_embeds,
"return_dict": False,
}
- negative_kwargs = None
- if do_cfg and negative_prompt_embeds is not None:
- # For I2V CFG, negative still uses image embeds (only text is unconditional)
- negative_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "timestep_r": timestep_r,
- "encoder_hidden_states": negative_prompt_embeds,
- "encoder_attention_mask": negative_prompt_embeds_mask,
- "encoder_hidden_states_2": negative_prompt_embeds_2,
- "encoder_attention_mask_2": negative_prompt_embeds_mask_2,
- "image_embeds": image_embeds,
- "return_dict": False,
- }
-
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_cfg and negative_kwargs is not None,
- true_cfg_scale=guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=req.sampling_params.cfg_normalize,
- )
-
- latents = self.scheduler_step_maybe_with_cfg(
- noise_pred,
- t,
- latents,
- do_true_cfg=do_cfg and negative_kwargs is not None,
- )
+ noise_pred = self.predict_noise_maybe_with_cfg(
+ do_true_cfg=do_cfg and negative_kwargs is not None,
+ true_cfg_scale=guidance_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ cfg_normalize=req.sampling_params.cfg_normalize,
+ )
- pbar.update()
+ latents = self.scheduler_step_maybe_with_cfg(
+ noise_pred,
+ t,
+ latents,
+ do_true_cfg=do_cfg and negative_kwargs is not None,
+ )
self._current_timestep = None
diff --git a/vllm_omni/diffusion/models/interface.py b/vllm_omni/diffusion/models/interface.py
index 00d54420dfe..ef906472bd0 100644
--- a/vllm_omni/diffusion/models/interface.py
+++ b/vllm_omni/diffusion/models/interface.py
@@ -58,27 +58,6 @@ def post_decode(self, state: DiffusionRequestState, **kwargs: Any) -> DiffusionO
"""Decode output after denoise loop."""
-@runtime_checkable
-class SupportsModuleOffload(Protocol):
- """Declares which submodules participate in CPU offload.
-
- All attribute names support dotted paths for nested submodules
- (e.g. ``"pipe.transformer"``).
-
- Attributes:
- _dit_modules: Denoising submodules (on GPU during diffusion).
- _encoder_modules: Encoder submodules (offloaded during diffusion).
- _vae_modules: VAE(s) (always on GPU).
- _resident_modules: Extra modules pinned on GPU during layerwise
- offloading. Optional, defaults to ``[]``.
- """
-
- _dit_modules: ClassVar[list[str]]
- _encoder_modules: ClassVar[list[str]]
- _vae_modules: ClassVar[list[str]]
- _resident_modules: ClassVar[list[str]] = []
-
-
def supports_step_execution(pipeline: object) -> bool:
"""Return whether `pipeline` implements :class:`SupportsStepExecution`."""
diff --git a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py
index 8f0ff446afd..8d8e523d60e 100644
--- a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py
+++ b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py
@@ -582,7 +582,6 @@ class LongCatImageTransformer2DModel(nn.Module):
"""
_repeated_blocks = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"]
- _layerwise_offload_blocks_attrs = ["transformer_blocks", "single_transformer_blocks"]
# Sequence Parallelism for LongCat (following diffusers' _cp_plan pattern)
_sp_plan = {
diff --git a/vllm_omni/diffusion/models/ltx2/__init__.py b/vllm_omni/diffusion/models/ltx2/__init__.py
index bf57d3f9e84..9f9d70f0106 100644
--- a/vllm_omni/diffusion/models/ltx2/__init__.py
+++ b/vllm_omni/diffusion/models/ltx2/__init__.py
@@ -4,18 +4,12 @@
from vllm_omni.diffusion.models.ltx2.ltx2_transformer import LTX2VideoTransformer3DModel
from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import (
LTX2Pipeline,
- LTX2T2VDMD2Pipeline,
LTX2TwoStagesPipeline,
create_transformer_from_config,
get_ltx2_post_process_func,
load_transformer_config,
)
-from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_3 import (
- LTX23ImageToVideoPipeline,
- LTX23Pipeline,
-)
from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import (
- LTX2I2VDMD2Pipeline,
LTX2ImageToVideoPipeline,
LTX2ImageToVideoTwoStagesPipeline,
)
@@ -23,14 +17,10 @@
__all__ = [
"LTX2Pipeline",
- "LTX2T2VDMD2Pipeline",
"LTX2ImageToVideoPipeline",
- "LTX2I2VDMD2Pipeline",
"LTX2LatentUpsamplePipeline",
"LTX2TwoStagesPipeline",
"LTX2ImageToVideoTwoStagesPipeline",
- "LTX23Pipeline",
- "LTX23ImageToVideoPipeline",
"get_ltx2_post_process_func",
"load_transformer_config",
"create_transformer_from_config",
diff --git a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py
index cd8eee9e0d2..a1bf7f7809c 100644
--- a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py
+++ b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py
@@ -41,7 +41,6 @@
from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.layer import Attention
-from vllm_omni.diffusion.distributed.hsdp_utils import is_transformer_block_module
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelInput, SequenceParallelOutput
from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available
@@ -436,10 +435,6 @@ def __call__(
sequence_length=sequence_length,
)
- # Compute gate logits from original hidden_states (before attention)
- if attn.to_gate_logits is not None:
- gate_logits = attn.to_gate_logits(hidden_states)
-
if is_self_attention:
encoder_hidden_states = hidden_states
@@ -475,14 +470,6 @@ def __call__(
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
- # LTX-2.3: per-head gated attention
- if attn.to_gate_logits is not None:
- hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
- # 2.0 * sigmoid so zero-init gates produce 1.0 (identity)
- gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
- hidden_states = hidden_states * gates.unsqueeze(-1)
- hidden_states = hidden_states.flatten(2, 3)
-
hidden_states = attn.to_out[0](hidden_states)
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
@@ -513,16 +500,11 @@ def __init__(
norm_eps: float = 1e-6,
norm_elementwise_affine: bool = True,
rope_type: str = "interleaved",
- apply_gated_attention: bool = False,
processor=None,
):
super().__init__()
- # LTX-2 uses "rms_norm_across_heads", LTX-2.3 uses "rms_norm" -- both
- # map to the same RMSNorm implementation applied across Q/K heads.
- if qk_norm not in ("rms_norm_across_heads", "rms_norm"):
- raise NotImplementedError(
- f"Only 'rms_norm_across_heads' and 'rms_norm' are supported for `qk_norm`, got {qk_norm!r}."
- )
+ if qk_norm != "rms_norm_across_heads":
+ raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
kv_heads = heads if kv_heads is None else kv_heads
@@ -580,34 +562,18 @@ def __init__(
self.heads = self.query_num_heads
tp_size = get_tensor_model_parallel_world_size()
- # At TP > 1 with rms_norm_across_heads, use TensorParallelRMSNorm
- # which all-reduces squared sums to match global RMS statistics.
- # At TP=1, use torch.nn.RMSNorm which is numerically identical to
- # the diffusers reference (verified via hook-based comparison).
- if tp_size > 1 and qk_norm == "rms_norm_across_heads":
- self.norm_q = TensorParallelRMSNorm(
- dim_head * self.query_num_heads,
- eps=norm_eps,
- elementwise_affine=norm_elementwise_affine,
- tp_size=tp_size,
- )
- self.norm_k = TensorParallelRMSNorm(
- dim_head * self.kv_num_heads,
- eps=norm_eps,
- elementwise_affine=norm_elementwise_affine,
- tp_size=tp_size,
- )
- else:
- self.norm_q = torch.nn.RMSNorm(
- dim_head * self.query_num_heads,
- eps=norm_eps,
- elementwise_affine=norm_elementwise_affine,
- )
- self.norm_k = torch.nn.RMSNorm(
- dim_head * self.kv_num_heads,
- eps=norm_eps,
- elementwise_affine=norm_elementwise_affine,
- )
+ self.norm_q = TensorParallelRMSNorm(
+ dim_head * self.query_num_heads,
+ eps=norm_eps,
+ elementwise_affine=norm_elementwise_affine,
+ tp_size=tp_size,
+ )
+ self.norm_k = TensorParallelRMSNorm(
+ dim_head * self.kv_num_heads,
+ eps=norm_eps,
+ elementwise_affine=norm_elementwise_affine,
+ tp_size=tp_size,
+ )
self.to_out = torch.nn.ModuleList(
[
@@ -629,12 +595,6 @@ def __init__(
causal=False,
)
- # LTX-2.3: per-head gated attention
- if apply_gated_attention:
- self.to_gate_logits = nn.Linear(query_dim, heads, bias=True)
- else:
- self.to_gate_logits = None
-
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
@@ -736,10 +696,6 @@ def __init__(
audio_num_attention_heads: int,
audio_attention_head_dim,
audio_cross_attention_dim: int,
- video_gated_attn: bool = False,
- video_cross_attn_adaln: bool = False,
- audio_gated_attn: bool = False,
- audio_cross_attn_adaln: bool = False,
qk_norm: str = "rms_norm_across_heads",
activation_fn: str = "gelu-approximate",
attention_bias: bool = True,
@@ -747,12 +703,8 @@ def __init__(
eps: float = 1e-6,
elementwise_affine: bool = False,
rope_type: str = "interleaved",
- perturbed_attn: bool = False,
):
super().__init__()
- self.video_cross_attn_adaln = video_cross_attn_adaln
- self.audio_cross_attn_adaln = audio_cross_attn_adaln
- self.perturbed_attn = perturbed_attn
# 1. Self-Attention (video and audio)
self.norm1 = _make_rms_norm(dim, eps=eps, elementwise_affine=elementwise_affine)
@@ -766,7 +718,6 @@ def __init__(
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
- apply_gated_attention=video_gated_attn,
)
self.audio_norm1 = _make_rms_norm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
@@ -780,7 +731,6 @@ def __init__(
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
- apply_gated_attention=audio_gated_attn,
)
# 2. Prompt Cross-Attention
@@ -795,7 +745,6 @@ def __init__(
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
- apply_gated_attention=video_gated_attn,
)
self.audio_norm2 = _make_rms_norm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
@@ -809,10 +758,10 @@ def __init__(
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
- apply_gated_attention=audio_gated_attn,
)
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
+ # Audio-to-Video (a2v) Attention --> Q: Video; K,V: Audio
self.audio_to_video_norm = _make_rms_norm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.audio_to_video_attn = LTX2Attention(
query_dim=dim,
@@ -824,9 +773,9 @@ def __init__(
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
- apply_gated_attention=video_gated_attn,
)
+ # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
self.video_to_audio_norm = _make_rms_norm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
self.video_to_audio_attn = LTX2Attention(
query_dim=audio_dim,
@@ -838,7 +787,6 @@ def __init__(
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
- apply_gated_attention=audio_gated_attn,
)
# 4. Feedforward layers
@@ -849,33 +797,14 @@ def __init__(
self.audio_ff = LTX2FeedForward(audio_dim, activation_fn=activation_fn)
# 5. Per-Layer Modulation Parameters
- # LTX-2.3 with cross_attn_adaln uses 9 params (extra 3 for cross-attn modulation);
- # LTX-2 uses 6.
- video_mod_param_num = 9 if self.video_cross_attn_adaln else 6
- audio_mod_param_num = 9 if self.audio_cross_attn_adaln else 6
- self.scale_shift_table = nn.Parameter(torch.randn(video_mod_param_num, dim) / dim**0.5)
- self.audio_scale_shift_table = nn.Parameter(torch.randn(audio_mod_param_num, audio_dim) / audio_dim**0.5)
-
- # Prompt cross-attn additional modulation params (LTX-2.3)
- self.cross_attn_adaln = video_cross_attn_adaln or audio_cross_attn_adaln
- if self.cross_attn_adaln:
- self.prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim))
- self.audio_prompt_scale_shift_table = nn.Parameter(torch.randn(2, audio_dim))
+ # Self-Attention / Feedforward AdaLayerNorm-Zero mod params
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+ self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5)
# Per-layer a2v, v2a Cross-Attention mod params
self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim))
self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim))
- @staticmethod
- def get_mod_params(
- scale_shift_table: torch.Tensor, temb: torch.Tensor, batch_size: int
- ) -> tuple[torch.Tensor, ...]:
- num_ada_params = scale_shift_table.shape[0]
- ada_values = scale_shift_table[None, None].to(temb.device) + temb.reshape(
- batch_size, temb.shape[1], num_ada_params, -1
- )
- return ada_values.unbind(dim=2)
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -888,170 +817,143 @@ def forward(
temb_ca_audio_scale_shift: torch.Tensor,
temb_ca_gate: torch.Tensor,
temb_ca_audio_gate: torch.Tensor,
- temb_prompt: torch.Tensor | None = None,
- temb_prompt_audio: torch.Tensor | None = None,
video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
ca_audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
encoder_attention_mask: torch.Tensor | None = None,
audio_encoder_attention_mask: torch.Tensor | None = None,
- self_attention_mask: torch.Tensor | None = None,
- audio_self_attention_mask: torch.Tensor | None = None,
a2v_cross_attention_mask: torch.Tensor | None = None,
v2a_cross_attention_mask: torch.Tensor | None = None,
- use_a2v_cross_attention: bool = True,
- use_v2a_cross_attention: bool = True,
- perturbation_mask: torch.Tensor | None = None,
- all_perturbed: bool | None = None,
) -> torch.Tensor:
batch_size = hidden_states.size(0)
# 1. Video and Audio Self-Attention
- # 1.1. Video Self-Attention
- video_ada_params = self.get_mod_params(self.scale_shift_table, temb, batch_size)
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = video_ada_params[:6]
- if self.video_cross_attn_adaln:
- shift_text_q, scale_text_q, gate_text_q = video_ada_params[6:9]
-
norm_hidden_states = self.norm1(hidden_states)
+
+ num_ada_params = self.scale_shift_table.shape[0]
+ ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
+ batch_size, temb.size(1), num_ada_params, -1
+ )
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
attn_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
query_rotary_emb=video_rotary_emb,
- attention_mask=self_attention_mask,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa
- # 1.2. Audio Self-Attention
- audio_ada_params = self.get_mod_params(self.audio_scale_shift_table, temb_audio, batch_size)
+ norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
+
+ num_audio_ada_params = self.audio_scale_shift_table.shape[0]
+ audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape(
+ batch_size, temb_audio.size(1), num_audio_ada_params, -1
+ )
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
- audio_ada_params[:6]
+ audio_ada_values.unbind(dim=2)
)
- if self.audio_cross_attn_adaln:
- audio_shift_text_q, audio_scale_text_q, audio_gate_text_q = audio_ada_params[6:9]
-
- norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
attn_audio_hidden_states = self.audio_attn1(
hidden_states=norm_audio_hidden_states,
encoder_hidden_states=None,
query_rotary_emb=audio_rotary_emb,
- attention_mask=audio_self_attention_mask,
)
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
- # 2. Video and Audio Cross-Attention with text embeddings (Q: Video/Audio; K,V: Text)
- # LTX-2.3: compute prompt modulation params for K/V
- if self.cross_attn_adaln and temb_prompt is not None:
- video_prompt_ada_params = self.get_mod_params(self.prompt_scale_shift_table, temb_prompt, batch_size)
- shift_text_kv, scale_text_kv = video_prompt_ada_params
-
- audio_prompt_ada_params = self.get_mod_params(
- self.audio_prompt_scale_shift_table, temb_prompt_audio, batch_size
- )
- audio_shift_text_kv, audio_scale_text_kv = audio_prompt_ada_params
-
- # 2.1. Video-Text Cross-Attention
+ # 2. Video and Audio Cross-Attention with the text embeddings
norm_hidden_states = self.norm2(hidden_states)
- if self.video_cross_attn_adaln:
- norm_hidden_states = norm_hidden_states * (1 + scale_text_q) + shift_text_q
- if self.cross_attn_adaln and temb_prompt is not None:
- encoder_hidden_states = encoder_hidden_states * (1 + scale_text_kv) + shift_text_kv
-
attn_hidden_states = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
query_rotary_emb=None,
attention_mask=encoder_attention_mask,
)
- if self.video_cross_attn_adaln:
- attn_hidden_states = attn_hidden_states * gate_text_q
hidden_states = hidden_states + attn_hidden_states
- # 2.2. Audio-Text Cross-Attention
norm_audio_hidden_states = self.audio_norm2(audio_hidden_states)
- if self.audio_cross_attn_adaln:
- norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_text_q) + audio_shift_text_q
- if self.cross_attn_adaln and temb_prompt is not None:
- audio_encoder_hidden_states = audio_encoder_hidden_states * (1 + audio_scale_text_kv) + audio_shift_text_kv
-
attn_audio_hidden_states = self.audio_attn2(
norm_audio_hidden_states,
encoder_hidden_states=audio_encoder_hidden_states,
query_rotary_emb=None,
attention_mask=audio_encoder_attention_mask,
)
- if self.audio_cross_attn_adaln:
- attn_audio_hidden_states = attn_audio_hidden_states * audio_gate_text_q
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
- if use_a2v_cross_attention or use_v2a_cross_attention:
- norm_hidden_states = self.audio_to_video_norm(hidden_states)
- norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
-
- # Combine global and per-layer cross attention modulation parameters
- # Video
- video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
- video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
-
- video_ca_ada_params = self.get_mod_params(video_per_layer_ca_scale_shift, temb_ca_scale_shift, batch_size)
- video_ca_gate_param = self.get_mod_params(video_per_layer_ca_gate, temb_ca_gate, batch_size)
+ norm_hidden_states = self.audio_to_video_norm(hidden_states)
+ norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
+
+ # Combine global and per-layer cross attention modulation parameters
+ # Video
+ video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
+ video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
+
+ video_ca_scale_shift_table = (
+ video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype)
+ + temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1)
+ ).unbind(dim=2)
+ video_ca_gate = (
+ video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype)
+ + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1)
+ ).unbind(dim=2)
+
+ video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table
+ a2v_gate = video_ca_gate[0].squeeze(2)
+
+ # Audio
+ audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
+ audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
+
+ audio_ca_scale_shift_table = (
+ audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype)
+ + temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1)
+ ).unbind(dim=2)
+ audio_ca_gate = (
+ audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype)
+ + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1)
+ ).unbind(dim=2)
+
+ audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table
+ v2a_gate = audio_ca_gate[0].squeeze(2)
+
+ # Audio-to-Video Cross Attention: Q: Video; K,V: Audio
+ mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(
+ 2
+ )
+ mod_norm_audio_hidden_states = norm_audio_hidden_states * (
+ 1 + audio_a2v_ca_scale.squeeze(2)
+ ) + audio_a2v_ca_shift.squeeze(2)
+
+ a2v_attn_hidden_states = self.audio_to_video_attn(
+ mod_norm_hidden_states,
+ encoder_hidden_states=mod_norm_audio_hidden_states,
+ query_rotary_emb=ca_video_rotary_emb,
+ key_rotary_emb=ca_audio_rotary_emb,
+ attention_mask=a2v_cross_attention_mask,
+ )
- video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_ada_params
- a2v_gate = video_ca_gate_param[0].squeeze(2)
+ hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
- # Audio
- audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
- audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
+ # Video-to-Audio Cross Attention: Q: Audio; K,V: Video
+ mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(
+ 2
+ )
+ mod_norm_audio_hidden_states = norm_audio_hidden_states * (
+ 1 + audio_v2a_ca_scale.squeeze(2)
+ ) + audio_v2a_ca_shift.squeeze(2)
+
+ v2a_attn_hidden_states = self.video_to_audio_attn(
+ mod_norm_audio_hidden_states,
+ encoder_hidden_states=mod_norm_hidden_states,
+ query_rotary_emb=ca_audio_rotary_emb,
+ key_rotary_emb=ca_video_rotary_emb,
+ attention_mask=v2a_cross_attention_mask,
+ )
- audio_ca_ada_params = self.get_mod_params(
- audio_per_layer_ca_scale_shift, temb_ca_audio_scale_shift, batch_size
- )
- audio_ca_gate_param = self.get_mod_params(audio_per_layer_ca_gate, temb_ca_audio_gate, batch_size)
-
- audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_ada_params
- v2a_gate = audio_ca_gate_param[0].squeeze(2)
-
- # 3.2. Audio-to-Video Cross Attention: Q: Video; K,V: Audio
- if use_a2v_cross_attention:
- mod_norm_hidden_states = norm_hidden_states * (
- 1 + video_a2v_ca_scale.squeeze(2)
- ) + video_a2v_ca_shift.squeeze(2)
- mod_norm_audio_hidden_states = norm_audio_hidden_states * (
- 1 + audio_a2v_ca_scale.squeeze(2)
- ) + audio_a2v_ca_shift.squeeze(2)
-
- a2v_attn_hidden_states = self.audio_to_video_attn(
- mod_norm_hidden_states,
- encoder_hidden_states=mod_norm_audio_hidden_states,
- query_rotary_emb=ca_video_rotary_emb,
- key_rotary_emb=ca_audio_rotary_emb,
- attention_mask=a2v_cross_attention_mask,
- )
- hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
-
- # 3.3. Video-to-Audio Cross Attention: Q: Audio; K,V: Video
- if use_v2a_cross_attention:
- mod_norm_hidden_states = norm_hidden_states * (
- 1 + video_v2a_ca_scale.squeeze(2)
- ) + video_v2a_ca_shift.squeeze(2)
- mod_norm_audio_hidden_states = norm_audio_hidden_states * (
- 1 + audio_v2a_ca_scale.squeeze(2)
- ) + audio_v2a_ca_shift.squeeze(2)
-
- v2a_attn_hidden_states = self.video_to_audio_attn(
- mod_norm_audio_hidden_states,
- encoder_hidden_states=mod_norm_hidden_states,
- query_rotary_emb=ca_audio_rotary_emb,
- key_rotary_emb=ca_video_rotary_emb,
- attention_mask=v2a_cross_attention_mask,
- )
- audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
+ audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
# 4. Feedforward
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp
@@ -1362,8 +1264,6 @@ class LTX2VideoTransformer3DModel(nn.Module):
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["LTX2VideoTransformerBlock"]
- _layerwise_offload_blocks_attrs = ["transformer_blocks"]
- _hsdp_shard_conditions = [is_transformer_block_module]
_sp_plan: dict[str, Any] | None = None
@staticmethod
@@ -1447,15 +1347,8 @@ def __init__(
timestep_scale_multiplier: int = 1000,
cross_attn_timestep_scale_multiplier: int = 1000,
rope_type: str = "interleaved",
- use_prompt_embeddings: bool = True,
- perturbed_attn: bool = False,
- gated_attn: bool = False,
- cross_attn_mod: bool = False,
- audio_gated_attn: bool = False,
- audio_cross_attn_mod: bool = False,
) -> None:
super().__init__()
- self.perturbed_attn = perturbed_attn
out_channels = out_channels or in_channels
audio_out_channels = audio_out_channels or audio_in_channels
@@ -1505,34 +1398,19 @@ def __init__(
self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim)
# 2. Prompt embeddings
- # LTX-2 (use_prompt_embeddings=True): caption projection in the transformer
- # LTX-2.3 (use_prompt_embeddings=False): caption projection in the connectors
- if use_prompt_embeddings:
- self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
- self.audio_caption_projection = PixArtAlphaTextProjection(
- in_features=caption_channels, hidden_size=audio_inner_dim
- )
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
+ self.audio_caption_projection = PixArtAlphaTextProjection(
+ in_features=caption_channels, hidden_size=audio_inner_dim
+ )
# 3. Timestep Modulation Params and Embedding
- # 3.1. Global Timestep Modulation Parameters
- # LTX-2.3 with cross_attn_mod uses 9 mod params (extra 3 for cross-attn); LTX-2 uses 6.
- video_num_mod_params = 9 if cross_attn_mod else 6
- audio_num_mod_params = 9 if audio_cross_attn_mod else 6
- self.time_embed = LTX2AdaLayerNormSingle(
- inner_dim, num_mod_params=video_num_mod_params, use_additional_conditions=False
- )
+ # 3.1. Global Timestep Modulation Parameters (except for cross-attention) and timestep + size embedding
+ # time_embed and audio_time_embed calculate both the timestep embedding and (global) modulation parameters
+ self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False)
self.audio_time_embed = LTX2AdaLayerNormSingle(
- audio_inner_dim, num_mod_params=audio_num_mod_params, use_additional_conditions=False
+ audio_inner_dim, num_mod_params=6, use_additional_conditions=False
)
- # 3.3. LTX-2.3: Prompt modulation from sigma
- self.prompt_modulation = cross_attn_mod or audio_cross_attn_mod
- if self.prompt_modulation:
- self.prompt_adaln = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=2, use_additional_conditions=False)
- self.audio_prompt_adaln = LTX2AdaLayerNormSingle(
- audio_inner_dim, num_mod_params=2, use_additional_conditions=False
- )
-
# 3.2. Global Cross Attention Modulation Parameters
# Used in the audio-to-video and video-to-audio cross attention layers as a global set of modulation params,
# which are then further modified by per-block modulaton params in each transformer block.
@@ -1635,10 +1513,6 @@ def __init__(
audio_num_attention_heads=audio_num_attention_heads,
audio_attention_head_dim=audio_attention_head_dim,
audio_cross_attention_dim=audio_cross_attention_dim,
- video_gated_attn=gated_attn,
- video_cross_attn_adaln=cross_attn_mod,
- audio_gated_attn=audio_gated_attn,
- audio_cross_attn_adaln=audio_cross_attn_mod,
qk_norm=qk_norm,
activation_fn=activation_fn,
attention_bias=attention_bias,
@@ -1646,7 +1520,6 @@ def __init__(
eps=norm_eps,
elementwise_affine=norm_elementwise_affine,
rope_type=rope_type,
- perturbed_attn=perturbed_attn,
)
for _ in range(num_layers)
]
@@ -1682,8 +1555,6 @@ def forward(
audio_encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
audio_timestep: torch.LongTensor | None = None,
- sigma: torch.Tensor | None = None,
- audio_sigma: torch.Tensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
audio_encoder_attention_mask: torch.Tensor | None = None,
num_frames: int | None = None,
@@ -1695,7 +1566,6 @@ def forward(
audio_coords: torch.Tensor | None = None,
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
- **kwargs, # Accept extra diffusers kwargs (isolate_modalities, perturbation_mask, etc.)
) -> torch.Tensor:
"""
Forward pass for LTX-2.0 audiovisual video transformer.
@@ -1838,55 +1708,54 @@ def forward(
)
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
- # 3.3. LTX-2.3: Compute prompt modulation from sigma
- audio_sigma = audio_sigma if audio_sigma is not None else sigma
- if self.prompt_modulation and sigma is not None:
- temb_prompt, _ = self.prompt_adaln(sigma.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype)
- temb_prompt_audio, _ = self.audio_prompt_adaln(
- audio_sigma.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype
- )
- temb_prompt = temb_prompt.view(batch_size, -1, temb_prompt.size(-1))
- temb_prompt_audio = temb_prompt_audio.view(batch_size, -1, temb_prompt_audio.size(-1))
- else:
- temb_prompt = temb_prompt_audio = None
-
# 4. Prepare prompt embeddings
- # LTX-2: caption projection is in the transformer
- # LTX-2.3: caption projection is in the connectors (already applied)
- if hasattr(self, "caption_projection"):
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
- if hasattr(self, "audio_caption_projection"):
- audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
- audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1))
+ audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
+ audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1))
# 5. Run transformer blocks
for block in self.transformer_blocks:
- block_kwargs = {
- "hidden_states": hidden_states,
- "audio_hidden_states": audio_hidden_states,
- "encoder_hidden_states": encoder_hidden_states,
- "audio_encoder_hidden_states": audio_encoder_hidden_states,
- "temb": temb,
- "temb_audio": temb_audio,
- "temb_ca_scale_shift": video_cross_attn_scale_shift,
- "temb_ca_audio_scale_shift": audio_cross_attn_scale_shift,
- "temb_ca_gate": video_cross_attn_a2v_gate,
- "temb_ca_audio_gate": audio_cross_attn_v2a_gate,
- "temb_prompt": temb_prompt,
- "temb_prompt_audio": temb_prompt_audio,
- "video_rotary_emb": video_rotary_emb,
- "audio_rotary_emb": audio_rotary_emb,
- "ca_video_rotary_emb": video_cross_attn_rotary_emb,
- "ca_audio_rotary_emb": audio_cross_attn_rotary_emb,
- "encoder_attention_mask": encoder_attention_mask,
- "audio_encoder_attention_mask": audio_encoder_attention_mask,
- }
if torch.is_grad_enabled() and self.gradient_checkpointing:
- hidden_states, audio_hidden_states = self._gradient_checkpointing_func(block, **block_kwargs)
+ hidden_states, audio_hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ audio_hidden_states,
+ encoder_hidden_states,
+ audio_encoder_hidden_states,
+ temb,
+ temb_audio,
+ video_cross_attn_scale_shift,
+ audio_cross_attn_scale_shift,
+ video_cross_attn_a2v_gate,
+ audio_cross_attn_v2a_gate,
+ video_rotary_emb,
+ audio_rotary_emb,
+ video_cross_attn_rotary_emb,
+ audio_cross_attn_rotary_emb,
+ encoder_attention_mask,
+ audio_encoder_attention_mask,
+ )
else:
- hidden_states, audio_hidden_states = block(**block_kwargs)
+ hidden_states, audio_hidden_states = block(
+ hidden_states,
+ audio_hidden_states,
+ encoder_hidden_states,
+ audio_encoder_hidden_states,
+ temb,
+ temb_audio,
+ video_cross_attn_scale_shift,
+ audio_cross_attn_scale_shift,
+ video_cross_attn_a2v_gate,
+ audio_cross_attn_v2a_gate,
+ video_rotary_emb,
+ audio_rotary_emb,
+ video_cross_attn_rotary_emb,
+ audio_cross_attn_rotary_emb,
+ encoder_attention_mask,
+ audio_encoder_attention_mask,
+ )
# 6. Output layers (including unpatchification)
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
@@ -1954,14 +1823,6 @@ def _maybe_shard_weight(weight: torch.Tensor, param: torch.Tensor) -> torch.Tens
weight_loader(param, loaded_weight, shard_id)
break
else:
- if name not in params_dict:
- logger.warning(
- "Skipping transformer weight %s -- not found in model "
- "parameters. This may indicate an incomplete "
- "implementation or checkpoint mismatch.",
- name,
- )
- continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", None)
if weight_loader is not None:
diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py
index 4f62d72c9b6..efc342e9327 100644
--- a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py
+++ b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py
@@ -9,7 +9,7 @@
import os
from collections.abc import Iterable
from contextlib import nullcontext
-from typing import Any, ClassVar
+from typing import Any
import numpy as np
import torch
@@ -28,13 +28,13 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.parallel_state import (
+ get_cfg_group,
+ get_classifier_free_guidance_rank,
get_classifier_free_guidance_world_size,
)
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
-from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin
-from vllm_omni.diffusion.models.interface import SupportsModuleOffload
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.lora.request import LoRARequest
@@ -122,31 +122,6 @@ def calculate_shift(
return mu
-class _VideoAudioScheduler:
- """Composite scheduler dispatching to video and audio schedulers."""
-
- def __init__(self, video_scheduler, audio_scheduler):
- self.video_scheduler = video_scheduler
- self.audio_scheduler = audio_scheduler
-
- def step(self, noise_pred, t, latents, return_dict=False, generator=None):
- video_out = self.video_scheduler.step(
- noise_pred[0],
- t[0],
- latents[0],
- return_dict=False,
- generator=generator,
- )[0]
- audio_out = self.audio_scheduler.step(
- noise_pred[1],
- t[1],
- latents[1],
- return_dict=False,
- generator=generator,
- )[0]
- return ((video_out, audio_out),)
-
-
class LTX2Pipeline(nn.Module, CFGParallelMixin, ProgressBarMixin):
def __init__(
self,
@@ -567,10 +542,6 @@ def _unpack_audio_latents(
latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2)
return latents
- @staticmethod
- def _unpad_audio_latents(latents: torch.Tensor, num_frames: int) -> torch.Tensor:
- return latents[:, :num_frames]
-
def prepare_latents(
self,
batch_size: int = 1,
@@ -626,49 +597,25 @@ def prepare_audio_latents(
noise_scale: float = 0.0,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
- generator: torch.Generator | list[torch.Generator] | None = None,
+ generator: torch.Generator | None = None,
latents: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, int, int]:
- original_latent_length = audio_latent_length
- padded_latent_length = original_latent_length
-
- latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
-
- sp_size = getattr(self.od_config.parallel_config, "sequence_parallel_size", 1)
- if sp_size > 1:
- padded_latent_length += (sp_size - (original_latent_length % sp_size)) % sp_size
-
+ ) -> tuple[torch.Tensor, int]:
if latents is not None:
if latents.ndim == 4:
# latents are of shape [B, C, L, M], need to be packed
latents = self._pack_audio_latents(latents)
if latents.ndim != 3:
raise ValueError(
- f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is "
- "[batch_size, num_seq, num_features] or [batch_size, num_channels, audio_length, mel_bins]."
+ f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." # noqa
)
latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std)
latents = self._create_noised_state(latents, noise_scale, generator)
+ return latents.to(device=device, dtype=dtype)
- if latents.shape[1] not in {original_latent_length, padded_latent_length}:
- raise ValueError(
- "Provided `audio_latents` has incompatible audio frame count "
- f"{latents.shape[1]}; expected {original_latent_length} or {padded_latent_length}."
- )
-
- if latents.shape[1] == original_latent_length and padded_latent_length > original_latent_length:
- padding = torch.zeros(
- latents.shape[0],
- padded_latent_length - original_latent_length,
- latents.shape[2],
- dtype=latents.dtype,
- device=latents.device,
- )
- latents = torch.cat([latents, padding], dim=1)
-
- return latents.to(device=device, dtype=dtype), original_latent_length, padded_latent_length
+ # TODO: confirm whether this logic is correct
+ latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
- shape = (batch_size, num_channels_latents, padded_latent_length, latent_mel_bins)
+ shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -678,7 +625,7 @@ def prepare_audio_latents(
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_audio_latents(latents)
- return latents, original_latent_length, padded_latent_length
+ return latents
@property
def guidance_scale(self):
@@ -708,44 +655,147 @@ def attention_kwargs(self):
def interrupt(self):
return self._interrupt
+ def _is_cfg_parallel_enabled(self, do_true_cfg: bool) -> bool:
+ return do_true_cfg and get_classifier_free_guidance_world_size() > 1
+
def _transformer_cache_context(self, context_name: str):
cache_context = getattr(self.transformer, "cache_context", None)
if callable(cache_context):
return cache_context(context_name)
return nullcontext()
- def predict_noise(self, **kwargs):
+ def _predict_noise_av(self, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
with self._transformer_cache_context("cond_uncond"):
noise_pred_video, noise_pred_audio = self.transformer(**kwargs)
+ return noise_pred_video, noise_pred_audio
+
+ def predict_noise_av_maybe_with_cfg(
+ self,
+ do_true_cfg: bool,
+ true_cfg_scale: float,
+ positive_kwargs: dict[str, Any],
+ negative_kwargs: dict[str, Any] | None,
+ guidance_rescale: float = 0.0,
+ cfg_normalize: bool = False,
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
+ if do_true_cfg:
+ cfg_parallel_ready = get_classifier_free_guidance_world_size() > 1
+
+ if cfg_parallel_ready:
+ cfg_group = get_cfg_group()
+ cfg_rank = get_classifier_free_guidance_rank()
+
+ if cfg_rank == 0:
+ noise_pred_video, noise_pred_audio = self._predict_noise_av(**positive_kwargs)
+ else:
+ noise_pred_video, noise_pred_audio = self._predict_noise_av(**negative_kwargs)
+
+ noise_pred_video = noise_pred_video.float()
+ noise_pred_audio = noise_pred_audio.float()
+
+ gathered_video = cfg_group.all_gather(noise_pred_video, separate_tensors=True)
+ gathered_audio = cfg_group.all_gather(noise_pred_audio, separate_tensors=True)
+
+ if cfg_rank == 0:
+ noise_pred_video_text = gathered_video[0]
+ noise_pred_video_uncond = gathered_video[1]
+ noise_pred_audio_text = gathered_audio[0]
+ noise_pred_audio_uncond = gathered_audio[1]
+
+ noise_pred_video = self.combine_cfg_noise(
+ noise_pred_video_text,
+ noise_pred_video_uncond,
+ true_cfg_scale,
+ cfg_normalize,
+ )
+ noise_pred_audio = self.combine_cfg_noise(
+ noise_pred_audio_text,
+ noise_pred_audio_uncond,
+ true_cfg_scale,
+ cfg_normalize,
+ )
+
+ if guidance_rescale > 0:
+ noise_pred_video = rescale_noise_cfg(
+ noise_pred_video,
+ noise_pred_video_text,
+ guidance_rescale=guidance_rescale,
+ )
+ noise_pred_audio = rescale_noise_cfg(
+ noise_pred_audio,
+ noise_pred_audio_text,
+ guidance_rescale=guidance_rescale,
+ )
+ return noise_pred_video, noise_pred_audio
+ return None, None
+
+ noise_pred_video_text, noise_pred_audio_text = self._predict_noise_av(**positive_kwargs)
+ noise_pred_video_uncond, noise_pred_audio_uncond = self._predict_noise_av(**negative_kwargs)
+
+ noise_pred_video_text = noise_pred_video_text.float()
+ noise_pred_audio_text = noise_pred_audio_text.float()
+ noise_pred_video_uncond = noise_pred_video_uncond.float()
+ noise_pred_audio_uncond = noise_pred_audio_uncond.float()
+
+ noise_pred_video = self.combine_cfg_noise(
+ noise_pred_video_text,
+ noise_pred_video_uncond,
+ true_cfg_scale,
+ cfg_normalize,
+ )
+ noise_pred_audio = self.combine_cfg_noise(
+ noise_pred_audio_text,
+ noise_pred_audio_uncond,
+ true_cfg_scale,
+ cfg_normalize,
+ )
+
+ if guidance_rescale > 0:
+ noise_pred_video = rescale_noise_cfg(
+ noise_pred_video,
+ noise_pred_video_text,
+ guidance_rescale=guidance_rescale,
+ )
+ noise_pred_audio = rescale_noise_cfg(
+ noise_pred_audio,
+ noise_pred_audio_text,
+ guidance_rescale=guidance_rescale,
+ )
+
+ return noise_pred_video, noise_pred_audio
+
+ noise_pred_video, noise_pred_audio = self._predict_noise_av(**positive_kwargs)
return noise_pred_video.float(), noise_pred_audio.float()
- def combine_cfg_noise(self, positive_noise_pred, negative_noise_pred, true_cfg_scale, cfg_normalize=False):
- """Per-element CFG combine with guidance_rescale support."""
- (video_pos, audio_pos) = positive_noise_pred
- (video_neg, audio_neg) = negative_noise_pred
- video_combined = super().combine_cfg_noise(video_pos, video_neg, true_cfg_scale, cfg_normalize)
- audio_combined = super().combine_cfg_noise(audio_pos, audio_neg, true_cfg_scale, cfg_normalize)
- if self._guidance_rescale and self._guidance_rescale > 0:
- video_combined = rescale_noise_cfg(video_combined, video_pos, guidance_rescale=self._guidance_rescale)
- audio_combined = rescale_noise_cfg(audio_combined, audio_pos, guidance_rescale=self._guidance_rescale)
- return (video_combined, audio_combined)
-
- def _synchronize_cfg_parallel_step_output(
+ def _scheduler_step_video_audio_maybe_with_cfg(
self,
- latents: tuple[torch.Tensor, torch.Tensor],
+ noise_pred_video: torch.Tensor | None,
+ noise_pred_audio: torch.Tensor | None,
+ t: torch.Tensor,
+ latents: torch.Tensor,
+ audio_latents: torch.Tensor,
+ audio_scheduler: FlowMatchEulerDiscreteScheduler,
do_true_cfg: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
- if not (do_true_cfg and get_classifier_free_guidance_world_size() > 1):
- return latents
-
- # Without this sync, CUDA async execution causes non-deterministic
- # numerical drift across denoising steps in CFG parallel mode,
- # producing different video outputs across runs.
- latents = tuple(tensor.contiguous() for tensor in latents)
- device = next((tensor.device for tensor in latents if tensor.is_cuda), None)
- if device is not None:
- torch.cuda.current_stream(device).synchronize()
- return latents
+ cfg_parallel_ready = self._is_cfg_parallel_enabled(do_true_cfg)
+
+ if cfg_parallel_ready:
+ cfg_group = get_cfg_group()
+ cfg_rank = get_classifier_free_guidance_rank()
+
+ if cfg_rank == 0:
+ latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0]
+ audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
+
+ latents = latents.contiguous()
+ audio_latents = audio_latents.contiguous()
+ cfg_group.broadcast(latents, src=0)
+ cfg_group.broadcast(audio_latents, src=0)
+ return latents, audio_latents
+
+ latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0]
+ audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
+ return latents, audio_latents
@torch.no_grad()
def forward(
@@ -778,8 +828,6 @@ def forward(
attention_kwargs: dict[str, Any] | None = None,
max_sequence_length: int | None = None,
) -> DiffusionOutput:
- # Extract prompt/negative_prompt from request.
- # Input format: req.prompts is a list of str or dict with "prompt"/"negative_prompt" keys.
prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt
if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts):
negative_prompt = None
@@ -821,7 +869,6 @@ def forward(
else req.sampling_params.extra_args.get("audio_latents", audio_latents)
)
- # Override with pre-computed embeddings if provided in request.
req_prompt_embeds = [_get_prompt_field(p, "prompt_embeds") for p in req.prompts]
if any(p is not None for p in req_prompt_embeds):
prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore[arg-type]
@@ -892,17 +939,20 @@ def forward(
max_sequence_length=max_sequence_length,
device=device,
)
- # Compute positive prompt connectors
+ cfg_parallel_ready = self._is_cfg_parallel_enabled(self.do_classifier_free_guidance)
+ if self.do_classifier_free_guidance and not cfg_parallel_ready:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
prompt_embeds, additive_attention_mask, additive_mask=True
)
- # Compute negative prompt connectors when CFG is enabled
negative_connector_prompt_embeds = None
negative_connector_audio_prompt_embeds = None
negative_connector_attention_mask = None
- if self.do_classifier_free_guidance:
+ if cfg_parallel_ready:
negative_additive_attention_mask = (
1 - negative_prompt_attention_mask.to(negative_prompt_embeds.dtype)
) * -1000000.0
@@ -977,7 +1027,20 @@ def forward(
num_channels_latents_audio = (
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
)
- audio_latents, original_audio_num_frames, padded_audio_num_frames = self.prepare_audio_latents(
+
+ # padding audio_latents if needed
+ sp_size = getattr(self.od_config.parallel_config, "sequence_parallel_size", 1)
+ if sp_size > 1:
+ pad_len = (sp_size - (audio_num_frames % sp_size)) % sp_size
+ if pad_len > 0:
+ if audio_latents is not None:
+ pad_shape = list(audio_latents.shape)
+ pad_shape[2] = pad_len
+ padding = torch.zeros(pad_shape, dtype=audio_latents.dtype, device=audio_latents.device)
+ audio_latents = torch.cat([audio_latents, padding], dim=2)
+ audio_num_frames += pad_len
+
+ audio_latents = self.prepare_audio_latents(
batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents_audio,
audio_latent_length=audio_num_frames,
@@ -998,7 +1061,6 @@ def forward(
self.scheduler.config.get("max_shift", 2.05),
)
audio_scheduler = copy.deepcopy(self.scheduler)
- video_audio_scheduler = _VideoAudioScheduler(self.scheduler, audio_scheduler)
_ = retrieve_timesteps(
audio_scheduler,
num_inference_steps,
@@ -1021,10 +1083,12 @@ def forward(
latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
)
audio_coords = self.transformer.audio_rope.prepare_audio_coords(
- audio_latents.shape[0], padded_audio_num_frames, audio_latents.device
+ audio_latents.shape[0], audio_num_frames, audio_latents.device
)
- # No coord duplication needed: mixin handles CFG via separate forward calls,
- # not batch=2. Each forward gets batch=1 coords directly.
+ # Duplicate the positional ids as well if using CFG
+ if self.do_classifier_free_guidance and not cfg_parallel_ready:
+ video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) # Repeat twice in batch dim
+ audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1))
with self.progress_bar(total=len(timesteps)) as pbar:
for i, t in enumerate(timesteps):
@@ -1033,60 +1097,119 @@ def forward(
self._current_timestep = t
- latent_model_input = latents.to(prompt_embeds.dtype)
- audio_latent_model_input = audio_latents.to(prompt_embeds.dtype)
- timestep = t.expand(latent_model_input.shape[0])
- do_true_cfg = self.do_classifier_free_guidance
-
- positive_kwargs = {
- "hidden_states": latent_model_input,
- "audio_hidden_states": audio_latent_model_input,
- "encoder_hidden_states": connector_prompt_embeds,
- "audio_encoder_hidden_states": connector_audio_prompt_embeds,
- "timestep": timestep,
- "encoder_attention_mask": connector_attention_mask,
- "audio_encoder_attention_mask": connector_attention_mask,
- "num_frames": latent_num_frames,
- "height": latent_height,
- "width": latent_width,
- "fps": frame_rate,
- "audio_num_frames": padded_audio_num_frames,
- "video_coords": video_coords,
- "audio_coords": audio_coords,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- }
- negative_kwargs = (
- {
- **positive_kwargs,
+ if cfg_parallel_ready:
+ latent_model_input = latents.to(prompt_embeds.dtype)
+ audio_latent_model_input = audio_latents.to(prompt_embeds.dtype)
+ timestep = t.expand(latent_model_input.shape[0])
+
+ positive_kwargs = {
+ "hidden_states": latent_model_input,
+ "audio_hidden_states": audio_latent_model_input,
+ "encoder_hidden_states": connector_prompt_embeds,
+ "audio_encoder_hidden_states": connector_audio_prompt_embeds,
+ "timestep": timestep,
+ "encoder_attention_mask": connector_attention_mask,
+ "audio_encoder_attention_mask": connector_attention_mask,
+ "num_frames": latent_num_frames,
+ "height": latent_height,
+ "width": latent_width,
+ "fps": frame_rate,
+ "audio_num_frames": audio_num_frames,
+ "video_coords": video_coords,
+ "audio_coords": audio_coords,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ }
+ negative_kwargs = {
+ "hidden_states": latent_model_input,
+ "audio_hidden_states": audio_latent_model_input,
"encoder_hidden_states": negative_connector_prompt_embeds,
"audio_encoder_hidden_states": negative_connector_audio_prompt_embeds,
+ "timestep": timestep,
"encoder_attention_mask": negative_connector_attention_mask,
"audio_encoder_attention_mask": negative_connector_attention_mask,
+ "num_frames": latent_num_frames,
+ "height": latent_height,
+ "width": latent_width,
+ "fps": frame_rate,
+ "audio_num_frames": audio_num_frames,
+ "video_coords": video_coords,
+ "audio_coords": audio_coords,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
}
- if do_true_cfg
- else None
- )
- noise_pred_video, noise_pred_audio = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale=guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=False,
- )
-
- latents, audio_latents = self.scheduler_step_maybe_with_cfg(
- (noise_pred_video, noise_pred_audio),
- (t, t),
- (latents, audio_latents),
- do_true_cfg=do_true_cfg,
- per_request_scheduler=video_audio_scheduler,
- )
- latents, audio_latents = self._synchronize_cfg_parallel_step_output(
- (latents, audio_latents),
- do_true_cfg=do_true_cfg,
- )
+ noise_pred_video, noise_pred_audio = self.predict_noise_av_maybe_with_cfg(
+ do_true_cfg=True,
+ true_cfg_scale=guidance_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ guidance_rescale=guidance_rescale,
+ cfg_normalize=False,
+ )
+
+ latents, audio_latents = self._scheduler_step_video_audio_maybe_with_cfg(
+ noise_pred_video,
+ noise_pred_audio,
+ t,
+ latents,
+ audio_latents,
+ audio_scheduler,
+ do_true_cfg=True,
+ )
+ else:
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
+ audio_latent_model_input = (
+ torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents
+ )
+ audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype)
+
+ timestep = t.expand(latent_model_input.shape[0])
+
+ with self._transformer_cache_context("cond_uncond"):
+ noise_pred_video, noise_pred_audio = self.transformer(
+ hidden_states=latent_model_input,
+ audio_hidden_states=audio_latent_model_input,
+ encoder_hidden_states=connector_prompt_embeds,
+ audio_encoder_hidden_states=connector_audio_prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=connector_attention_mask,
+ audio_encoder_attention_mask=connector_attention_mask,
+ num_frames=latent_num_frames,
+ height=latent_height,
+ width=latent_width,
+ fps=frame_rate,
+ audio_num_frames=audio_num_frames,
+ video_coords=video_coords,
+ audio_coords=audio_coords,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )
+ noise_pred_video = noise_pred_video.float()
+ noise_pred_audio = noise_pred_audio.float()
+
+ if self.do_classifier_free_guidance:
+ noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
+ noise_pred_video = noise_pred_video_uncond + guidance_scale * (
+ noise_pred_video_text - noise_pred_video_uncond
+ )
+
+ noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
+ noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (
+ noise_pred_audio_text - noise_pred_audio_uncond
+ )
+
+ if guidance_rescale > 0:
+ noise_pred_video = rescale_noise_cfg(
+ noise_pred_video, noise_pred_video_text, guidance_rescale=guidance_rescale
+ )
+ noise_pred_audio = rescale_noise_cfg(
+ noise_pred_audio, noise_pred_audio_text, guidance_rescale=guidance_rescale
+ )
+
+ latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0]
+ audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
pbar.update()
@@ -1102,15 +1225,10 @@ def forward(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
- audio_latents = self._unpad_audio_latents(audio_latents, original_audio_num_frames)
audio_latents = self._denormalize_audio_latents(
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
)
- audio_latents = self._unpack_audio_latents(
- audio_latents,
- original_audio_num_frames,
- num_mel_bins=latent_mel_bins,
- )
+ audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
if output_type == "latent":
video = latents
@@ -1153,13 +1271,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
return loader.load_weights(weights)
-class LTX2TwoStagesPipeline(nn.Module, SupportsModuleOffload):
+class LTX2TwoStagesPipeline(nn.Module):
"""LTX2TwoStagesPipeline is for two stages image to video generation"""
- _dit_modules: ClassVar[list[str]] = ["pipe.transformer"]
- _encoder_modules: ClassVar[list[str]] = ["pipe.text_encoder"]
- _vae_modules: ClassVar[list[str]] = ["pipe.vae", "pipe.audio_vae"]
-
def __init__(
self,
*,
@@ -1310,11 +1424,3 @@ def forward(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
-
-
-class LTX2T2VDMD2Pipeline(DMD2PipelineMixin, LTX2Pipeline):
- """LTX-2 T2V pipeline for FastGen DMD2-distilled models."""
-
- def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
- super().__init__(od_config=od_config, prefix=prefix)
- self.__init_dmd2__()
diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_3.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_3.py
deleted file mode 100644
index dd0a5717951..00000000000
--- a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_3.py
+++ /dev/null
@@ -1,998 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""
-Fully independent LTX-2.3 pipeline for vLLM-Omni.
-
-This pipeline does NOT inherit from LTX2Pipeline because:
-- LTX-2.3 uses a different text encoding strategy (flatten ALL 49 hidden states
- vs. LTX-2's _pack_text_embeds with per-layer normalization and pooling)
-- LTX-2.3 connectors expect the padding_side API (not additive_mask)
-- LTX-2.3 uses a BWE vocoder outputting 48kHz audio (not 16kHz)
-- LTX-2.3 transformer requires the sigma parameter for prompt modulation
-- CPU offloading is required for the 22B transformer (~44GB VRAM)
-"""
-
-from __future__ import annotations
-
-import copy
-import json
-import os
-from collections.abc import Iterable
-from contextlib import nullcontext
-from typing import Any
-
-import numpy as np
-import torch
-from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, FlowMatchEulerDiscreteScheduler
-from diffusers.pipelines.ltx2 import LTX2TextConnectors
-from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
-from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
-from diffusers.utils.torch_utils import randn_tensor
-from diffusers.video_processor import VideoProcessor
-from huggingface_hub import hf_hub_download
-from torch import nn
-from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
-from vllm.logger import init_logger
-from vllm.model_executor.models.utils import AutoWeightsLoader
-
-from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
-from vllm_omni.diffusion.distributed.utils import get_local_device
-from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
-from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
-from vllm_omni.diffusion.request import OmniDiffusionRequest
-
-from .pipeline_ltx2 import (
- _get_prompt_field,
- calculate_shift,
- create_transformer_from_config,
- load_transformer_config,
-)
-
-logger = init_logger(__name__)
-
-# Try to import LTX2VocoderWithBWE (diffusers >= 0.38.0)
-try:
- from diffusers.pipelines.ltx2.vocoder import LTX2VocoderWithBWE
-except ImportError:
- LTX2VocoderWithBWE = None
-
-
-def _detect_vocoder_output_sample_rate(model: str) -> int | None:
- """Detect the vocoder output sample rate from vocoder/config.json.
-
- This runs at factory time (engine process) so the rate is captured in
- the post-process closure and doesn't need cross-process communication.
-
- Returns:
- Output sample rate (e.g. 48000 for LTX-2.3 BWE vocoder) or None.
- """
- vocoder_config_path = os.path.join(model, "vocoder", "config.json")
- if not os.path.exists(vocoder_config_path):
- try:
- vocoder_config_path = hf_hub_download(model, "vocoder/config.json")
- except Exception:
- return None
- try:
- with open(vocoder_config_path) as f:
- cfg = json.load(f)
- return cfg.get("output_sampling_rate")
- except Exception:
- return None
-
-
-def get_ltx2_post_process_func(od_config: OmniDiffusionConfig):
- """Factory for the LTX-2.3 post-process function.
-
- Detects the vocoder output sample rate at factory time and captures it
- in the closure so that the audio_sample_rate flows through
- DiffusionEngine -> OmniRequestOutput -> serving_video.
- """
- output_sr = _detect_vocoder_output_sample_rate(od_config.model)
-
- def post_process_func(output: tuple[torch.Tensor, torch.Tensor] | torch.Tensor):
- if isinstance(output, tuple) and len(output) == 2:
- video, audio = output
- if isinstance(audio, torch.Tensor):
- audio = audio.detach().cpu()
- result: dict[str, Any] = {"video": video, "audio": audio}
- if output_sr is not None:
- result["audio_sample_rate"] = output_sr
- return result
- return output
-
- return post_process_func
-
-
-class LTX23Pipeline(nn.Module, ProgressBarMixin):
- """Fully independent LTX-2.3 pipeline.
-
- Key differences from LTX2Pipeline:
- - Text encoding: uses ALL 49 hidden states from Gemma-3-12B, flattened
- - Connectors: uses padding_side API (not additive_mask)
- - Vocoder: uses LTX2VocoderWithBWE (48kHz output)
- - Transformer: passes sigma for prompt_adaln
- - CPU offloading: text encoder, connectors, VAE, vocoder stay on CPU
- """
-
- def __init__(
- self,
- *,
- od_config: OmniDiffusionConfig,
- prefix: str = "",
- ):
- super().__init__()
- self.od_config = od_config
- self.device = get_local_device()
- dtype = getattr(od_config, "dtype", torch.bfloat16)
- model = od_config.model
- local_files_only = os.path.exists(model)
-
- # Weight sources for transformer (loaded via AutoWeightsLoader)
- self.weights_sources = [
- DiffusersPipelineLoader.ComponentSource(
- model_or_path=od_config.model,
- subfolder="transformer",
- revision=None,
- prefix="transformer.",
- fall_back_to_pt=True,
- ),
- ]
-
- # --- Tokenizer (lightweight, stays wherever) ---
- self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
-
- # --- Text encoder: load on CPU, move to GPU only during encoding ---
- with torch.device("cpu"):
- self.text_encoder = Gemma3ForConditionalGeneration.from_pretrained(
- model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only
- )
-
- # --- Connectors: CPU (LTX-2.3 connectors include caption projection) ---
- self.connectors = LTX2TextConnectors.from_pretrained(
- model, subfolder="connectors", torch_dtype=dtype, local_files_only=local_files_only
- )
-
- # --- VAE, Audio VAE: CPU ---
- self.vae = AutoencoderKLLTX2Video.from_pretrained(
- model, subfolder="vae", torch_dtype=dtype, local_files_only=local_files_only
- )
- self.audio_vae = AutoencoderKLLTX2Audio.from_pretrained(
- model, subfolder="audio_vae", torch_dtype=dtype, local_files_only=local_files_only
- )
-
- # --- Vocoder: prefer BWE vocoder (48kHz) for LTX-2.3 ---
- vocoder_cls = LTX2VocoderWithBWE or LTX2Vocoder
- try:
- self.vocoder = vocoder_cls.from_pretrained(
- model, subfolder="vocoder", torch_dtype=dtype, local_files_only=local_files_only
- )
- except (TypeError, OSError, ValueError):
- self.vocoder = LTX2Vocoder.from_pretrained(
- model, subfolder="vocoder", torch_dtype=dtype, local_files_only=local_files_only
- )
-
- # --- Transformer: created empty, weights loaded via AutoWeightsLoader ---
- transformer_config = load_transformer_config(model, "transformer", local_files_only)
- self.transformer = create_transformer_from_config(transformer_config)
-
- # --- Scheduler ---
- self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
- model, subfolder="scheduler", local_files_only=local_files_only
- )
-
- # --- Derived compression ratios ---
- self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio if self.vae is not None else 32
- self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio if self.vae is not None else 8
- self.audio_vae_mel_compression_ratio = self.audio_vae.mel_compression_ratio if self.audio_vae is not None else 4
- self.audio_vae_temporal_compression_ratio = (
- self.audio_vae.temporal_compression_ratio if self.audio_vae is not None else 4
- )
- self.transformer_spatial_patch_size = self.transformer.config.patch_size if self.transformer is not None else 1
- self.transformer_temporal_patch_size = (
- self.transformer.config.patch_size_t if self.transformer is not None else 1
- )
- self.audio_sampling_rate = self.audio_vae.config.sample_rate if self.audio_vae is not None else 16000
- self.audio_hop_length = self.audio_vae.config.mel_hop_length if self.audio_vae is not None else 160
-
- self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
-
- # Tokenizer max length
- tokenizer_max_length = 1024
- if self.tokenizer is not None:
- tokenizer_max_length = self.tokenizer.model_max_length
- if tokenizer_max_length is None or tokenizer_max_length > 100000:
- encoder_config = getattr(self.text_encoder, "config", None)
- config_max_len = getattr(encoder_config, "max_position_embeddings", None)
- if config_max_len is None:
- config_max_len = getattr(encoder_config, "max_seq_len", None)
- tokenizer_max_length = config_max_len or 1024
- self.tokenizer_max_length = int(tokenizer_max_length)
-
- # Pipeline state
- self._guidance_scale = None
- self._attention_kwargs = None
- self._interrupt = False
- self._num_timesteps = None
- self._current_timestep = None
-
- # ------------------------------------------------------------------
- # Text Encoding (LTX-2.3 specific)
- # ------------------------------------------------------------------
-
- def _get_gemma_prompt_embeds(
- self,
- prompt: str | list[str],
- num_videos_per_prompt: int = 1,
- max_sequence_length: int = 1024,
- device: torch.device | None = None,
- dtype: torch.dtype | None = None,
- ):
- """Encode prompts using Gemma-3-12B, returning ALL 49 hidden states flattened.
-
- LTX-2.3 differs from LTX-2 in text encoding:
- - LTX-2: uses _pack_text_embeds (layer selection + pooling)
- - LTX-2.3: stacks ALL 49 hidden states and flattens to [B, seq, 188160]
- The connectors unflatten, apply per_token_rms_norm, and project internally.
- """
- device = device or self.device
- dtype = dtype or self.text_encoder.dtype
-
- prompt = [prompt] if isinstance(prompt, str) else prompt
- batch_size = len(prompt)
-
- if self.tokenizer is not None:
- self.tokenizer.padding_side = "left"
- if self.tokenizer.pad_token is None:
- self.tokenizer.pad_token = self.tokenizer.eos_token
-
- prompt = [p.strip() for p in prompt]
- text_inputs = self.tokenizer(
- prompt,
- padding="max_length",
- max_length=max_sequence_length,
- truncation=True,
- add_special_tokens=True,
- return_tensors="pt",
- )
- text_input_ids = text_inputs.input_ids.to(device)
- prompt_attention_mask = text_inputs.attention_mask.to(device)
-
- # Move text encoder to GPU for encoding
- self.text_encoder.to(device)
- text_encoder_outputs = self.text_encoder(
- input_ids=text_input_ids,
- attention_mask=prompt_attention_mask,
- output_hidden_states=True,
- )
- # Move text encoder back to CPU immediately
- self.text_encoder.to("cpu")
- torch.cuda.empty_cache()
-
- hidden_states = text_encoder_outputs.hidden_states
-
- # LTX-2.3: Stack ALL 49 hidden states and flatten
- # [49 x (B, seq, 3840)] -> [B, seq, 3840, 49] -> [B, seq, 188160]
- prompt_embeds = torch.stack(hidden_states, dim=-1).flatten(2, 3).to(dtype=dtype)
-
- _, seq_len, _ = prompt_embeds.shape
- prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
- prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
-
- prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
- prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
-
- return prompt_embeds, prompt_attention_mask
-
- def encode_prompt(
- self,
- prompt: str | list[str],
- negative_prompt: str | list[str] | None = None,
- do_classifier_free_guidance: bool = True,
- num_videos_per_prompt: int = 1,
- prompt_embeds: torch.Tensor | None = None,
- negative_prompt_embeds: torch.Tensor | None = None,
- prompt_attention_mask: torch.Tensor | None = None,
- negative_prompt_attention_mask: torch.Tensor | None = None,
- max_sequence_length: int = 1024,
- device: torch.device | None = None,
- dtype: torch.dtype | None = None,
- ):
- device = device or self.device
-
- prompt = [prompt] if isinstance(prompt, str) else prompt
- if prompt is not None:
- batch_size = len(prompt)
- else:
- batch_size = prompt_embeds.shape[0]
-
- if prompt_embeds is None:
- prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
- prompt=prompt,
- num_videos_per_prompt=num_videos_per_prompt,
- max_sequence_length=max_sequence_length,
- device=device,
- dtype=dtype,
- )
-
- if do_classifier_free_guidance and negative_prompt_embeds is None:
- negative_prompt = negative_prompt or ""
- negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
-
- if prompt is not None and type(prompt) is not type(negative_prompt):
- raise TypeError(
- f"`negative_prompt` should be the same type as `prompt`, but got {type(negative_prompt)} !="
- f" {type(prompt)}."
- )
- if isinstance(negative_prompt, list) and batch_size != len(negative_prompt):
- raise ValueError(
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
- " the batch size of `prompt`."
- )
-
- negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
- prompt=negative_prompt,
- num_videos_per_prompt=num_videos_per_prompt,
- max_sequence_length=max_sequence_length,
- device=device,
- dtype=dtype,
- )
-
- return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
-
- # ------------------------------------------------------------------
- # Latent utilities (shared with LTX2Pipeline)
- # ------------------------------------------------------------------
-
- @staticmethod
- def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
- batch_size, num_channels, num_frames, height, width = latents.shape
- post_patch_num_frames = num_frames // patch_size_t
- post_patch_height = height // patch_size
- post_patch_width = width // patch_size
- latents = latents.reshape(
- batch_size,
- -1,
- post_patch_num_frames,
- patch_size_t,
- post_patch_height,
- patch_size,
- post_patch_width,
- patch_size,
- )
- latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
- return latents
-
- @staticmethod
- def _unpack_latents(
- latents: torch.Tensor,
- num_frames: int,
- height: int,
- width: int,
- patch_size: int = 1,
- patch_size_t: int = 1,
- ) -> torch.Tensor:
- batch_size = latents.size(0)
- latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
- latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
- return latents
-
- @staticmethod
- def _normalize_latents(
- latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
- ) -> torch.Tensor:
- latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
- latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
- latents = (latents - latents_mean) * scaling_factor / latents_std
- return latents
-
- @staticmethod
- def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
- latents_mean = latents_mean.to(latents.device, latents.dtype)
- latents_std = latents_std.to(latents.device, latents.dtype)
- return (latents - latents_mean) / latents_std
-
- @staticmethod
- def _denormalize_latents(
- latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
- ) -> torch.Tensor:
- latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
- latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
- latents = latents * latents_std / scaling_factor + latents_mean
- return latents
-
- @staticmethod
- def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
- latents_mean = latents_mean.to(latents.device, latents.dtype)
- latents_std = latents_std.to(latents.device, latents.dtype)
- return (latents * latents_std) + latents_mean
-
- @staticmethod
- def _pack_audio_latents(
- latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None
- ) -> torch.Tensor:
- if patch_size is not None and patch_size_t is not None:
- batch_size, num_channels, latent_length, latent_mel_bins = latents.shape
- post_patch_latent_length = latent_length / patch_size_t
- post_patch_mel_bins = latent_mel_bins / patch_size
- latents = latents.reshape(
- batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size
- )
- latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
- else:
- latents = latents.transpose(1, 2).flatten(2, 3)
- return latents
-
- @staticmethod
- def _unpack_audio_latents(
- latents: torch.Tensor,
- latent_length: int,
- num_mel_bins: int,
- patch_size: int | None = None,
- patch_size_t: int | None = None,
- ) -> torch.Tensor:
- if patch_size is not None and patch_size_t is not None:
- batch_size = latents.size(0)
- latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size)
- latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
- else:
- latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2)
- return latents
-
- @staticmethod
- def _unpad_audio_latents(latents: torch.Tensor, num_frames: int) -> torch.Tensor:
- return latents[:, :num_frames]
-
- # ------------------------------------------------------------------
- # Latent preparation
- # ------------------------------------------------------------------
-
- def prepare_latents(
- self,
- batch_size: int = 1,
- num_channels_latents: int = 128,
- height: int = 512,
- width: int = 768,
- num_frames: int = 121,
- noise_scale: float = 0.0,
- dtype: torch.dtype | None = None,
- device: torch.device | None = None,
- generator: torch.Generator | None = None,
- latents: torch.Tensor | None = None,
- ) -> torch.Tensor:
- if latents is not None:
- if latents.ndim == 5:
- latents = self._normalize_latents(
- latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
- )
- latents = self._pack_latents(
- latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
- )
- if latents.ndim != 3:
- raise ValueError(f"Provided `latents` has shape {latents.shape}, expected [batch, seq, features].")
- noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype)
- latents = noise_scale * noise + (1 - noise_scale) * latents
- return latents.to(device=device, dtype=dtype)
-
- height = height // self.vae_spatial_compression_ratio
- width = width // self.vae_spatial_compression_ratio
- num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
- shape = (batch_size, num_channels_latents, num_frames, height, width)
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
- latents = self._pack_latents(latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size)
- return latents
-
- def prepare_audio_latents(
- self,
- batch_size: int = 1,
- num_channels_latents: int = 8,
- audio_latent_length: int = 1,
- num_mel_bins: int = 64,
- noise_scale: float = 0.0,
- dtype: torch.dtype | None = None,
- device: torch.device | None = None,
- generator: torch.Generator | list[torch.Generator] | None = None,
- latents: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, int, int]:
- original_latent_length = audio_latent_length
- padded_latent_length = original_latent_length
- latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
-
- if latents is not None:
- if latents.ndim == 4:
- latents = self._pack_audio_latents(latents)
- if latents.ndim != 3:
- raise ValueError(f"Provided `latents` has shape {latents.shape}, expected [batch, seq, features].")
- latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std)
- noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype)
- latents = noise_scale * noise + (1 - noise_scale) * latents
- return latents.to(device=device, dtype=dtype), original_latent_length, padded_latent_length
-
- shape = (batch_size, num_channels_latents, padded_latent_length, latent_mel_bins)
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
- latents = self._pack_audio_latents(latents)
- return latents, original_latent_length, padded_latent_length
-
- # ------------------------------------------------------------------
- # Properties
- # ------------------------------------------------------------------
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale is not None and self._guidance_scale > 1.0
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
- @property
- def current_timestep(self):
- return self._current_timestep
-
- @property
- def interrupt(self):
- return self._interrupt
-
- # ------------------------------------------------------------------
- # Input validation
- # ------------------------------------------------------------------
-
- def check_inputs(
- self,
- prompt,
- height,
- width,
- prompt_embeds=None,
- negative_prompt_embeds=None,
- prompt_attention_mask=None,
- negative_prompt_attention_mask=None,
- ):
- if height % 32 != 0 or width % 32 != 0:
- raise ValueError(f"`height` and `width` must be divisible by 32 but are {height} and {width}.")
- if prompt is not None and prompt_embeds is not None:
- raise ValueError("Cannot forward both `prompt` and `prompt_embeds`.")
- elif prompt is None and prompt_embeds is None:
- raise ValueError("Provide either `prompt` or `prompt_embeds`.")
- elif prompt is not None and not isinstance(prompt, (str, list)):
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
-
- if prompt_embeds is not None and prompt_attention_mask is None:
- raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
-
- if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
- raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
-
- if prompt_embeds is not None and negative_prompt_embeds is not None:
- if prompt_embeds.shape != negative_prompt_embeds.shape:
- raise ValueError(
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
- f" {negative_prompt_embeds.shape}."
- )
- if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
- raise ValueError(
- "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when "
- "passed directly, but got: `prompt_attention_mask` "
- f"{prompt_attention_mask.shape} != `negative_prompt_attention_mask` "
- f"{negative_prompt_attention_mask.shape}."
- )
-
- # ------------------------------------------------------------------
- # Cache context
- # ------------------------------------------------------------------
-
- def _transformer_cache_context(self, context_name: str):
- cache_context = getattr(self.transformer, "cache_context", None)
- if callable(cache_context):
- return cache_context(context_name)
- return nullcontext()
-
- # ------------------------------------------------------------------
- # Forward pass
- # ------------------------------------------------------------------
-
- @torch.no_grad()
- def forward(
- self,
- req: OmniDiffusionRequest,
- prompt: str | list[str] | None = None,
- negative_prompt: str | list[str] | None = None,
- height: int | None = None,
- width: int | None = None,
- num_frames: int | None = None,
- frame_rate: float | None = None,
- num_inference_steps: int | None = None,
- sigmas: list[float] | None = None,
- timesteps: list[int] | None = None,
- guidance_scale: float = 4.0,
- noise_scale: float = 0.0,
- num_videos_per_prompt: int | None = 1,
- generator: torch.Generator | list[torch.Generator] | None = None,
- latents: torch.Tensor | None = None,
- audio_latents: torch.Tensor | None = None,
- prompt_embeds: torch.Tensor | None = None,
- negative_prompt_embeds: torch.Tensor | None = None,
- prompt_attention_mask: torch.Tensor | None = None,
- negative_prompt_attention_mask: torch.Tensor | None = None,
- decode_timestep: float | list[float] = 0.0,
- decode_noise_scale: float | list[float] | None = None,
- output_type: str = "np",
- return_dict: bool = True,
- attention_kwargs: dict[str, Any] | None = None,
- max_sequence_length: int | None = None,
- ) -> DiffusionOutput:
- # ---- Extract from request ----
- prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt
- if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts):
- negative_prompt = None
- elif req.prompts:
- negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts]
-
- height = req.sampling_params.height or height or 512
- width = req.sampling_params.width or width or 768
- num_frames = req.sampling_params.num_frames or num_frames or 121
- frame_rate = req.sampling_params.resolved_frame_rate or frame_rate or 24.0
- num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps or 40
- # Enforce minimum of 2 timesteps for flow matching scheduler
- if timesteps is None:
- num_inference_steps = max(int(num_inference_steps), 2)
- elif len(timesteps) < 2:
- raise ValueError("`timesteps` must contain at least 2 values for FlowMatchEulerDiscreteScheduler.")
- num_videos_per_prompt = (
- req.sampling_params.num_outputs_per_prompt
- if req.sampling_params.num_outputs_per_prompt > 0
- else num_videos_per_prompt or 1
- )
- max_sequence_length = (
- req.sampling_params.max_sequence_length or max_sequence_length or self.tokenizer_max_length
- )
-
- if req.sampling_params.guidance_scale_provided:
- guidance_scale = req.sampling_params.guidance_scale
-
- if generator is None:
- generator = req.sampling_params.generator
- if generator is None and req.sampling_params.seed is not None:
- generator = torch.Generator(device=self.device).manual_seed(req.sampling_params.seed)
-
- latents = req.sampling_params.latents if req.sampling_params.latents is not None else latents
- audio_latents = (
- req.sampling_params.audio_latents
- if req.sampling_params.audio_latents is not None
- else req.sampling_params.extra_args.get("audio_latents", audio_latents)
- )
-
- # Override with pre-computed embeddings if provided in request
- req_prompt_embeds = [_get_prompt_field(p, "prompt_embeds") for p in req.prompts]
- if any(p is not None for p in req_prompt_embeds):
- prompt_embeds = torch.stack(req_prompt_embeds)
-
- req_negative_prompt_embeds = [_get_prompt_field(p, "negative_prompt_embeds") for p in req.prompts]
- if any(p is not None for p in req_negative_prompt_embeds):
- negative_prompt_embeds = torch.stack(req_negative_prompt_embeds)
-
- req_prompt_attention_masks = [
- _get_prompt_field(p, "prompt_attention_mask") or _get_prompt_field(p, "attention_mask") for p in req.prompts
- ]
- if any(m is not None for m in req_prompt_attention_masks):
- prompt_attention_mask = torch.stack(req_prompt_attention_masks)
-
- req_negative_attention_masks = [
- _get_prompt_field(p, "negative_prompt_attention_mask") or _get_prompt_field(p, "negative_attention_mask")
- for p in req.prompts
- ]
- if any(m is not None for m in req_negative_attention_masks):
- negative_prompt_attention_mask = torch.stack(req_negative_attention_masks)
-
- if req.sampling_params.decode_timestep is not None:
- decode_timestep = req.sampling_params.decode_timestep
- if req.sampling_params.decode_noise_scale is not None:
- decode_noise_scale = req.sampling_params.decode_noise_scale
- if req.sampling_params.output_type is not None:
- output_type = req.sampling_params.output_type
-
- self.check_inputs(
- prompt=prompt,
- height=height,
- width=width,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- prompt_attention_mask=prompt_attention_mask,
- negative_prompt_attention_mask=negative_prompt_attention_mask,
- )
-
- self._guidance_scale = guidance_scale
- self._attention_kwargs = attention_kwargs
- self._interrupt = False
- self._current_timestep = None
-
- if prompt is not None and isinstance(prompt, str):
- batch_size = 1
- elif prompt is not None and isinstance(prompt, list):
- batch_size = len(prompt)
- else:
- batch_size = prompt_embeds.shape[0]
-
- device = self.device
-
- # ---- Encode prompts ----
- (
- prompt_embeds,
- prompt_attention_mask,
- negative_prompt_embeds,
- negative_prompt_attention_mask,
- ) = self.encode_prompt(
- prompt=prompt,
- negative_prompt=negative_prompt,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
- num_videos_per_prompt=num_videos_per_prompt,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- prompt_attention_mask=prompt_attention_mask,
- negative_prompt_attention_mask=negative_prompt_attention_mask,
- max_sequence_length=max_sequence_length,
- device=device,
- )
-
- # ---- Connectors (LTX-2.3: padding_side API) ----
- # Concatenate negative + positive embeddings BEFORE connector call,
- # matching diffusers which calls connectors once with batch=2.
- # This ensures batch-dependent operations produce identical results.
- if self.do_classifier_free_guidance:
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
- prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
-
- self.connectors.to(device)
- tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left")
- connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
- prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side
- )
- self.connectors.to("cpu")
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
-
- # ---- Prepare latents ----
- latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
- latent_height = height // self.vae_spatial_compression_ratio
- latent_width = width // self.vae_spatial_compression_ratio
- if latents is not None and latents.ndim == 5:
- _, _, latent_num_frames, latent_height, latent_width = latents.shape
-
- num_channels_latents = self.transformer.config.in_channels
- latents = self.prepare_latents(
- batch_size * num_videos_per_prompt,
- num_channels_latents,
- height,
- width,
- num_frames,
- noise_scale,
- torch.float32,
- device,
- generator,
- latents,
- )
-
- duration_s = num_frames / frame_rate
- audio_latents_per_second = (
- self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio)
- )
- audio_num_frames = round(duration_s * audio_latents_per_second)
- if audio_latents is not None and audio_latents.ndim == 4:
- _, _, audio_num_frames, _ = audio_latents.shape
-
- num_mel_bins = self.audio_vae.config.mel_bins if self.audio_vae is not None else 64
- latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
- num_channels_latents_audio = self.audio_vae.config.latent_channels if self.audio_vae is not None else 8
- audio_latents, original_audio_num_frames, padded_audio_num_frames = self.prepare_audio_latents(
- batch_size * num_videos_per_prompt,
- num_channels_latents=num_channels_latents_audio,
- audio_latent_length=audio_num_frames,
- num_mel_bins=num_mel_bins,
- noise_scale=noise_scale,
- dtype=torch.float32,
- device=device,
- generator=generator,
- latents=audio_latents,
- )
-
- # ---- Scheduler setup ----
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
- # Use max_image_seq_len (not actual video_sequence_length) for mu calculation,
- # matching diffusers' LTX2Pipeline which hardcodes this value.
- mu = calculate_shift(
- self.scheduler.config.get("max_image_seq_len", 4096),
- self.scheduler.config.get("base_image_seq_len", 1024),
- self.scheduler.config.get("max_image_seq_len", 4096),
- self.scheduler.config.get("base_shift", 0.95),
- self.scheduler.config.get("max_shift", 2.05),
- )
- audio_scheduler = copy.deepcopy(self.scheduler)
- _ = retrieve_timesteps(audio_scheduler, num_inference_steps, device, timesteps, sigmas=sigmas, mu=mu)
- timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler,
- num_inference_steps,
- device,
- timesteps,
- sigmas=sigmas,
- mu=mu,
- )
- self._num_timesteps = len(timesteps)
-
- # ---- RoPE coordinates ----
- video_coords = self.transformer.rope.prepare_video_coords(
- latents.shape[0],
- latent_num_frames,
- latent_height,
- latent_width,
- latents.device,
- fps=frame_rate,
- )
- audio_coords = self.transformer.audio_rope.prepare_audio_coords(
- audio_latents.shape[0],
- padded_audio_num_frames,
- audio_latents.device,
- )
-
- # ---- CFG: duplicate coords for batch=2 ----
- # Connector outputs are already batch=2 (neg+pos concatenated before connector call)
- if self.do_classifier_free_guidance:
- video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1))
- audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1))
-
- # ---- Denoising loop ----
- # Uses x0-space CFG (delta formulation) matching diffusers' LTX2Pipeline.
- # The velocity predictions are converted to x0, guidance is applied in x0
- # space, then converted back to velocity for the scheduler step.
- with self.progress_bar(total=len(timesteps)) as pbar:
- for i, t in enumerate(timesteps):
- if self.interrupt:
- continue
-
- self._current_timestep = t
-
- # Duplicate latents for CFG (uncond + cond)
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
- latent_model_input = latent_model_input.to(connector_prompt_embeds.dtype)
- audio_latent_model_input = (
- torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents
- )
- audio_latent_model_input = audio_latent_model_input.to(connector_prompt_embeds.dtype)
- ts = t.expand(latent_model_input.shape[0])
-
- with self._transformer_cache_context("cond_uncond"):
- noise_pred_video, noise_pred_audio = self.transformer(
- hidden_states=latent_model_input,
- audio_hidden_states=audio_latent_model_input,
- encoder_hidden_states=connector_prompt_embeds,
- audio_encoder_hidden_states=connector_audio_prompt_embeds,
- timestep=ts,
- sigma=ts, # LTX-2.3: sigma for prompt_adaln
- encoder_attention_mask=connector_attention_mask,
- audio_encoder_attention_mask=connector_attention_mask,
- num_frames=latent_num_frames,
- height=latent_height,
- width=latent_width,
- fps=frame_rate,
- audio_num_frames=padded_audio_num_frames,
- video_coords=video_coords,
- audio_coords=audio_coords,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )
-
- noise_pred_video = noise_pred_video.float()
- noise_pred_audio = noise_pred_audio.float()
-
- # CFG in x0-space (delta formulation matching diffusers)
- if self.do_classifier_free_guidance:
- noise_pred_video_uncond, noise_pred_video_cond = noise_pred_video.chunk(2)
- # Convert velocity to x0: x0 = sample - velocity * sigma
- x0_video_cond = latents - noise_pred_video_cond * self.scheduler.sigmas[i]
- x0_video_uncond = latents - noise_pred_video_uncond * self.scheduler.sigmas[i]
- video_cfg_delta = (guidance_scale - 1) * (x0_video_cond - x0_video_uncond)
- x0_video_guided = x0_video_cond + video_cfg_delta
-
- noise_pred_audio_uncond, noise_pred_audio_cond = noise_pred_audio.chunk(2)
- x0_audio_cond = audio_latents - noise_pred_audio_cond * audio_scheduler.sigmas[i]
- x0_audio_uncond = audio_latents - noise_pred_audio_uncond * audio_scheduler.sigmas[i]
- audio_cfg_delta = (guidance_scale - 1) * (x0_audio_cond - x0_audio_uncond)
- x0_audio_guided = x0_audio_cond + audio_cfg_delta
-
- # Convert x0 back to velocity: v = (sample - x0) / sigma
- noise_pred_video = (latents - x0_video_guided) / self.scheduler.sigmas[i]
- noise_pred_audio = (audio_latents - x0_audio_guided) / audio_scheduler.sigmas[i]
-
- latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0]
- audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
-
- pbar.update()
-
- # ---- Unpack and denormalize ----
- latents = self._unpack_latents(
- latents,
- latent_num_frames,
- latent_height,
- latent_width,
- self.transformer_spatial_patch_size,
- self.transformer_temporal_patch_size,
- )
- latents = self._denormalize_latents(
- latents,
- self.vae.latents_mean,
- self.vae.latents_std,
- self.vae.config.scaling_factor,
- )
-
- audio_latents = self._unpad_audio_latents(audio_latents, original_audio_num_frames)
- audio_latents = self._denormalize_audio_latents(
- audio_latents,
- self.audio_vae.latents_mean,
- self.audio_vae.latents_std,
- )
- audio_latents = self._unpack_audio_latents(
- audio_latents,
- original_audio_num_frames,
- num_mel_bins=latent_mel_bins,
- )
-
- # ---- Decode ----
- if output_type == "latent":
- video = latents
- audio = audio_latents
- else:
- latents = latents.to(connector_prompt_embeds.dtype)
-
- if not self.vae.config.timestep_conditioning:
- timestep_decode = None
- else:
- noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
- if not isinstance(decode_timestep, list):
- decode_timestep = [decode_timestep] * batch_size
- if decode_noise_scale is None:
- decode_noise_scale = decode_timestep
- elif not isinstance(decode_noise_scale, list):
- decode_noise_scale = [decode_noise_scale] * batch_size
- timestep_decode = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
- decode_noise_scale_t = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
- :, None, None, None, None
- ]
- latents = (1 - decode_noise_scale_t) * latents + decode_noise_scale_t * noise
-
- # Move VAE, audio_vae, vocoder to GPU for decoding
- self.vae.to(device)
- latents = latents.to(self.vae.dtype)
- video = self.vae.decode(latents, timestep_decode, return_dict=False)[0]
- video = self.video_processor.postprocess_video(video, output_type=output_type)
- self.vae.to("cpu")
-
- self.audio_vae.to(device)
- audio_latents = audio_latents.to(self.audio_vae.dtype)
- generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
- self.audio_vae.to("cpu")
-
- self.vocoder.to(device)
- audio = self.vocoder(generated_mel_spectrograms)
- self.vocoder.to("cpu")
- torch.cuda.empty_cache()
-
- return DiffusionOutput(output=(video, audio))
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- loader = AutoWeightsLoader(self)
- return loader.load_weights(weights)
-
-
-class LTX23ImageToVideoPipeline(nn.Module):
- """LTX-2.3 image-to-video pipeline placeholder."""
-
- def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
- super().__init__()
- raise NotImplementedError(
- "LTX23ImageToVideoPipeline is not yet implemented. "
- "Use LTX23Pipeline for single-stage text-to-video generation."
- )
diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_3_image2video.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_3_image2video.py
deleted file mode 100644
index d30ef546400..00000000000
--- a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_3_image2video.py
+++ /dev/null
@@ -1,18 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""Re-exports for LTX-2.3 I2V pipeline variants.
-
-The registry loads pipeline classes by (mod_folder, mod_relname, cls_name).
-This module exposes the I2V class names so the registry can find them.
-"""
-
-from .pipeline_ltx2_3 import (
- LTX23ImageToVideoPipeline,
- get_ltx2_post_process_func, # noqa: F401 - loaded by registry via getattr
-)
-
-__all__ = [
- "LTX23ImageToVideoPipeline",
- "get_ltx2_post_process_func",
-]
diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py
index 4cc65f74908..11091518b4e 100644
--- a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py
+++ b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py
@@ -6,7 +6,7 @@
import copy
import os
from collections.abc import Iterable
-from typing import Any, ClassVar
+from typing import Any
import numpy as np
import PIL.Image
@@ -14,7 +14,7 @@
import torch.nn as nn
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
-from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg, retrieve_timesteps
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
@@ -22,11 +22,10 @@
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
+from vllm_omni.diffusion.distributed.parallel_state import get_cfg_group, get_classifier_free_guidance_rank
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
-from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin
-from vllm_omni.diffusion.models.interface import SupportsModuleOffload
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.lora.request import LoRARequest
@@ -47,32 +46,6 @@ def get_ltx2_post_process_func(od_config: OmniDiffusionConfig):
return _get_ltx2_post_process_func(od_config)
-class _I2VVideoAudioScheduler:
- """Composite scheduler for I2V: uses _step_video_latents_i2v for video, standard step for audio."""
-
- def __init__(self, pipeline, audio_scheduler, latent_num_frames, latent_height, latent_width):
- self.video_scheduler = pipeline.scheduler
- self.audio_scheduler = audio_scheduler
- self._pipeline = pipeline
- self._latent_num_frames = latent_num_frames
- self._latent_height = latent_height
- self._latent_width = latent_width
-
- def step(self, noise_pred, t, latents, return_dict=False, generator=None):
- video_out = self._pipeline._step_video_latents_i2v(
- noise_pred[0],
- latents[0],
- t[0],
- self._latent_num_frames,
- self._latent_height,
- self._latent_width,
- )
- audio_out = self.audio_scheduler.step(noise_pred[1], t[1], latents[1], return_dict=False, generator=generator)[
- 0
- ]
- return ((video_out, audio_out),)
-
-
class LTX2ImageToVideoPipeline(LTX2Pipeline):
support_image_input = True
@@ -314,8 +287,6 @@ def forward(
attention_kwargs: dict[str, Any] | None = None,
max_sequence_length: int | None = None,
) -> DiffusionOutput:
- # Extract prompt/negative_prompt from request.
- # Input format: req.prompts is a list of str or dict with "prompt"/"negative_prompt" keys.
prompt = [p if isinstance(p, str) else (p.get("prompt") or "") for p in req.prompts] or prompt
if all(isinstance(p, str) or p.get("negative_prompt") is None for p in req.prompts):
negative_prompt = None
@@ -357,7 +328,6 @@ def forward(
else req.sampling_params.extra_args.get("audio_latents", audio_latents)
)
- # Override with pre-computed embeddings if provided in request.
req_prompt_embeds = [_get_prompt_field(p, "prompt_embeds") for p in req.prompts]
if any(p is not None for p in req_prompt_embeds):
prompt_embeds = torch.stack(req_prompt_embeds) # type: ignore[arg-type]
@@ -459,17 +429,20 @@ def forward(
max_sequence_length=max_sequence_length,
device=device,
)
- # Compute positive prompt connectors
+ cfg_parallel_ready = self._is_cfg_parallel_enabled(self.do_classifier_free_guidance)
+ if self.do_classifier_free_guidance and not cfg_parallel_ready:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
prompt_embeds, additive_attention_mask, additive_mask=True
)
- # Compute negative prompt connectors when CFG is enabled
negative_connector_prompt_embeds = None
negative_connector_audio_prompt_embeds = None
negative_connector_attention_mask = None
- if self.do_classifier_free_guidance:
+ if cfg_parallel_ready:
negative_additive_attention_mask = (
1 - negative_prompt_attention_mask.to(negative_prompt_embeds.dtype)
) * -1000000.0
@@ -527,6 +500,8 @@ def forward(
generator,
latents,
)
+ if self.do_classifier_free_guidance and not cfg_parallel_ready:
+ conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
duration_s = num_frames / frame_rate
audio_latents_per_second = (
@@ -554,7 +529,20 @@ def forward(
num_channels_latents_audio = (
self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8
)
- audio_latents, original_audio_num_frames, padded_audio_num_frames = self.prepare_audio_latents(
+
+ # padding audio_latents if needed
+ sp_size = getattr(self.od_config.parallel_config, "sequence_parallel_size", 1)
+ if sp_size > 1:
+ pad_len = (sp_size - (audio_num_frames % sp_size)) % sp_size
+ if pad_len > 0:
+ if audio_latents is not None:
+ pad_shape = list(audio_latents.shape)
+ pad_shape[2] = pad_len
+ padding = torch.zeros(pad_shape, dtype=audio_latents.dtype, device=audio_latents.device)
+ audio_latents = torch.cat([audio_latents, padding], dim=2)
+ audio_num_frames += pad_len
+
+ audio_latents = self.prepare_audio_latents(
batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents_audio,
audio_latent_length=audio_num_frames,
@@ -597,17 +585,12 @@ def forward(
latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
)
audio_coords = self.transformer.audio_rope.prepare_audio_coords(
- audio_latents.shape[0], padded_audio_num_frames, audio_latents.device
- )
-
- i2v_scheduler = _I2VVideoAudioScheduler(
- pipeline=self,
- audio_scheduler=audio_scheduler,
- latent_num_frames=latent_num_frames,
- latent_height=latent_height,
- latent_width=latent_width,
+ audio_latents.shape[0], audio_num_frames, audio_latents.device
)
- # No coord duplication needed: mixin handles CFG via separate forward calls.
+ # Duplicate the positional ids as well if using CFG
+ if self.do_classifier_free_guidance and not cfg_parallel_ready:
+ video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) # Repeat twice in batch dim
+ audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1))
with self.progress_bar(total=len(timesteps)) as pbar:
for i, t in enumerate(timesteps):
@@ -616,62 +599,140 @@ def forward(
self._current_timestep = t
- latent_model_input = latents.to(prompt_embeds.dtype)
- audio_latent_model_input = audio_latents.to(prompt_embeds.dtype)
- timestep = t.expand(latent_model_input.shape[0])
- video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
- do_true_cfg = self.do_classifier_free_guidance
-
- positive_kwargs = {
- "hidden_states": latent_model_input,
- "audio_hidden_states": audio_latent_model_input,
- "encoder_hidden_states": connector_prompt_embeds,
- "audio_encoder_hidden_states": connector_audio_prompt_embeds,
- "timestep": video_timestep,
- "audio_timestep": timestep,
- "encoder_attention_mask": connector_attention_mask,
- "audio_encoder_attention_mask": connector_attention_mask,
- "num_frames": latent_num_frames,
- "height": latent_height,
- "width": latent_width,
- "fps": frame_rate,
- "audio_num_frames": padded_audio_num_frames,
- "video_coords": video_coords,
- "audio_coords": audio_coords,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- }
- negative_kwargs = (
- {
- **positive_kwargs,
+ if cfg_parallel_ready:
+ latent_model_input = latents.to(prompt_embeds.dtype)
+ audio_latent_model_input = audio_latents.to(prompt_embeds.dtype)
+
+ timestep = t.expand(latent_model_input.shape[0])
+ video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
+
+ positive_kwargs = {
+ "hidden_states": latent_model_input,
+ "audio_hidden_states": audio_latent_model_input,
+ "encoder_hidden_states": connector_prompt_embeds,
+ "audio_encoder_hidden_states": connector_audio_prompt_embeds,
+ "timestep": video_timestep,
+ "audio_timestep": timestep,
+ "encoder_attention_mask": connector_attention_mask,
+ "audio_encoder_attention_mask": connector_attention_mask,
+ "num_frames": latent_num_frames,
+ "height": latent_height,
+ "width": latent_width,
+ "fps": frame_rate,
+ "audio_num_frames": audio_num_frames,
+ "video_coords": video_coords,
+ "audio_coords": audio_coords,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ }
+ negative_kwargs = {
+ "hidden_states": latent_model_input,
+ "audio_hidden_states": audio_latent_model_input,
"encoder_hidden_states": negative_connector_prompt_embeds,
"audio_encoder_hidden_states": negative_connector_audio_prompt_embeds,
+ "timestep": video_timestep,
+ "audio_timestep": timestep,
"encoder_attention_mask": negative_connector_attention_mask,
"audio_encoder_attention_mask": negative_connector_attention_mask,
+ "num_frames": latent_num_frames,
+ "height": latent_height,
+ "width": latent_width,
+ "fps": frame_rate,
+ "audio_num_frames": audio_num_frames,
+ "video_coords": video_coords,
+ "audio_coords": audio_coords,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
}
- if do_true_cfg
- else None
- )
- noise_pred_video, noise_pred_audio = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale=guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=False,
- )
+ noise_pred_video, noise_pred_audio = self.predict_noise_av_maybe_with_cfg(
+ do_true_cfg=True,
+ true_cfg_scale=guidance_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ guidance_rescale=guidance_rescale,
+ cfg_normalize=False,
+ )
+
+ if get_classifier_free_guidance_rank() == 0:
+ latents = self._step_video_latents_i2v(
+ noise_pred_video,
+ latents,
+ t,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ )
+ audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
- latents, audio_latents = self.scheduler_step_maybe_with_cfg(
- (noise_pred_video, noise_pred_audio),
- (t, t),
- (latents, audio_latents),
- do_true_cfg=do_true_cfg,
- per_request_scheduler=i2v_scheduler,
- )
- latents, audio_latents = self._synchronize_cfg_parallel_step_output(
- (latents, audio_latents),
- do_true_cfg=do_true_cfg,
- )
+ cfg_group = get_cfg_group()
+ latents = latents.contiguous()
+ audio_latents = audio_latents.contiguous()
+ cfg_group.broadcast(latents, src=0)
+ cfg_group.broadcast(audio_latents, src=0)
+ else:
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
+ audio_latent_model_input = (
+ torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents
+ )
+ audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype)
+
+ timestep = t.expand(latent_model_input.shape[0])
+ video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
+
+ with self._transformer_cache_context("cond_uncond"):
+ noise_pred_video, noise_pred_audio = self.transformer(
+ hidden_states=latent_model_input,
+ audio_hidden_states=audio_latent_model_input,
+ encoder_hidden_states=connector_prompt_embeds,
+ audio_encoder_hidden_states=connector_audio_prompt_embeds,
+ timestep=video_timestep,
+ audio_timestep=timestep,
+ encoder_attention_mask=connector_attention_mask,
+ audio_encoder_attention_mask=connector_attention_mask,
+ num_frames=latent_num_frames,
+ height=latent_height,
+ width=latent_width,
+ fps=frame_rate,
+ audio_num_frames=audio_num_frames,
+ video_coords=video_coords,
+ audio_coords=audio_coords,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )
+ noise_pred_video = noise_pred_video.float()
+ noise_pred_audio = noise_pred_audio.float()
+
+ if self.do_classifier_free_guidance:
+ noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
+ noise_pred_video = noise_pred_video_uncond + guidance_scale * (
+ noise_pred_video_text - noise_pred_video_uncond
+ )
+
+ noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
+ noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (
+ noise_pred_audio_text - noise_pred_audio_uncond
+ )
+
+ if guidance_rescale > 0:
+ noise_pred_video = rescale_noise_cfg(
+ noise_pred_video, noise_pred_video_text, guidance_rescale=guidance_rescale
+ )
+ noise_pred_audio = rescale_noise_cfg(
+ noise_pred_audio, noise_pred_audio_text, guidance_rescale=guidance_rescale
+ )
+
+ latents = self._step_video_latents_i2v(
+ noise_pred_video,
+ latents,
+ t,
+ latent_num_frames,
+ latent_height,
+ latent_width,
+ )
+
+ audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0]
pbar.update()
@@ -687,15 +748,10 @@ def forward(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
- audio_latents = self._unpad_audio_latents(audio_latents, original_audio_num_frames)
audio_latents = self._denormalize_audio_latents(
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
)
- audio_latents = self._unpack_audio_latents(
- audio_latents,
- original_audio_num_frames,
- num_mel_bins=latent_mel_bins,
- )
+ audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
if output_type == "latent":
video = latents
@@ -734,15 +790,11 @@ def forward(
return DiffusionOutput(output=(video, audio))
-class LTX2ImageToVideoTwoStagesPipeline(nn.Module, SupportsModuleOffload):
+class LTX2ImageToVideoTwoStagesPipeline(nn.Module):
"""LTXImageToVideoTwoStagesPipeline is for two stages image to video generation"""
support_image_input = True
- _dit_modules: ClassVar[list[str]] = ["pipe.transformer"]
- _encoder_modules: ClassVar[list[str]] = ["pipe.text_encoder"]
- _vae_modules: ClassVar[list[str]] = ["pipe.vae", "pipe.audio_vae"]
-
def __init__(
self,
*,
@@ -895,11 +947,3 @@ def forward(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
-
-
-class LTX2I2VDMD2Pipeline(DMD2PipelineMixin, LTX2ImageToVideoPipeline):
- """LTX-2 I2V pipeline for FastGen DMD2-distilled models."""
-
- def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
- super().__init__(od_config=od_config, prefix=prefix)
- self.__init_dmd2__()
diff --git a/vllm_omni/diffusion/models/magi_human/__init__.py b/vllm_omni/diffusion/models/magi_human/__init__.py
deleted file mode 100644
index 9881313609a..00000000000
--- a/vllm_omni/diffusion/models/magi_human/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
diff --git a/vllm_omni/diffusion/models/magi_human/magi_human_dit.py b/vllm_omni/diffusion/models/magi_human/magi_human_dit.py
deleted file mode 100644
index 491b1b3c40d..00000000000
--- a/vllm_omni/diffusion/models/magi_human/magi_human_dit.py
+++ /dev/null
@@ -1,1624 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright (c) 2026 SandAI. All Rights Reserved.
-# Ported from daVinci-MagiHuman inference/model/dit/dit_module.py
-# Adaptations: removed Ulysses context-parallelism, inlined Modality/VarlenHandler.
-
-from __future__ import annotations
-
-import importlib
-from collections.abc import Callable
-from dataclasses import dataclass, field
-from enum import Enum, IntEnum
-from typing import TYPE_CHECKING, Any, Literal
-
-import torch
-import torch.nn as nn
-from einops import rearrange, repeat
-from torch.nn import Parameter
-from torch.nn import functional as F
-from vllm.distributed import (
- get_tensor_model_parallel_world_size,
-)
-from vllm.model_executor.layers.linear import (
- ColumnParallelLinear,
- QKVParallelLinear,
- RowParallelLinear,
-)
-from vllm.vllm_flash_attn import flash_attn_varlen_func as _vllm_fa_varlen
-
-try:
- from magi_compiler.api import magi_register_custom_op
- from magi_compiler.config import CompileConfig
-except Exception:
-
- class CompileConfig: # type: ignore[no-redef]
- pass
-
- def magi_register_custom_op(*args, **kwargs): # type: ignore[no-redef]
- def decorator(func):
- return func
-
- return decorator
-
-
-def magi_compile(*args, **kwargs):
- """No-op stub — vllm-omni handles execution; magi compilation is skipped."""
-
- def decorator(cls_or_fn):
- return cls_or_fn
-
- return decorator
-
-
-# ---------------------------------------------------------------------------
-# Inlined from inference/common/sequence_schema.py
-# ---------------------------------------------------------------------------
-class Modality(IntEnum):
- VIDEO = 0
- AUDIO = 1
- TEXT = 2
-
-
-@dataclass
-class VarlenHandler:
- cu_seqlens_q: torch.Tensor
- cu_seqlens_k: torch.Tensor
- max_seqlen_q: int
- max_seqlen_k: int
-
-
-def _is_hopper_arch() -> bool:
- if not torch.cuda.is_available():
- return False
- return torch.cuda.get_device_capability()[0] == 9
-
-
-# ---------------------------------------------------------------------------
-# FFA handler for local / flex attention
-# ---------------------------------------------------------------------------
-@dataclass
-class FFAHandler:
- q_ranges: torch.Tensor
- k_ranges: torch.Tensor
- max_seqlen_q: int
- max_seqlen_k: int
- attn_type_map: torch.Tensor
- softmax_scale: float
-
-
-# ---------------------------------------------------------------------------
-# Activation helpers
-# ---------------------------------------------------------------------------
-class MLPActivationType(Enum):
- SWIGLU7 = "swiglu7"
- GELU7 = "gelu7"
-
-
-def swiglu7(x, alpha: float = 1.702, limit: float = 7.0, out_dtype: torch.dtype | None = None):
- out_dtype = x.dtype if out_dtype is None else out_dtype
- x = x.to(torch.float32)
- x_glu, x_linear = x[..., ::2], x[..., 1::2]
- x_glu = x_glu.clamp(min=None, max=limit)
- x_linear = x_linear.clamp(min=-limit, max=limit)
- out_glu = x_glu * torch.sigmoid(alpha * x_glu)
- return (out_glu * (x_linear + 1)).to(out_dtype)
-
-
-def gelu7(x, alpha: float = 1.702, limit: float = 7.0, out_dtype: torch.dtype | None = None):
- out_dtype = x.dtype if out_dtype is None else out_dtype
- x = x.to(torch.float32)
- x_glu = x.clamp(min=None, max=limit)
- out_glu = x_glu * torch.sigmoid(alpha * x_glu)
- return out_glu.to(out_dtype)
-
-
-def create_activation_func(activation_type: MLPActivationType) -> Callable:
- match activation_type:
- case MLPActivationType.SWIGLU7:
- return swiglu7
- case MLPActivationType.GELU7:
- return gelu7
- case _:
- raise ValueError(f"Unknown activation type: {activation_type}")
-
-
-# ---------------------------------------------------------------------------
-# Modality dispatcher (permutation helper)
-# ---------------------------------------------------------------------------
-class ModalityDispatcher:
- permuted_modality_mapping: torch.Tensor
- group_size: torch.Tensor
- group_size_cpu: list[int]
- num_modalities: int
-
- def __init__(self, modality_mapping: torch.Tensor, num_modalities: int):
- self.modality_mapping = modality_mapping
- self.num_modalities = num_modalities
- self.permuted_modality_mapping = self._precompute_permute_mapping(modality_mapping)
- self.group_size = torch.bincount(self.permuted_modality_mapping, minlength=num_modalities).to(torch.int32)
- self.group_size_cpu: list[int] = [int(x) for x in self.group_size.to("cpu").tolist()]
-
- def _precompute_permute_mapping(self, modality_mapping):
- self.permute_mapping = torch.argsort(modality_mapping)
- self.inv_permute_mapping = torch.argsort(self.permute_mapping)
- return modality_mapping[self.permute_mapping]
-
- def dispatch(self, x: torch.Tensor) -> list[torch.Tensor]:
- return list(torch.split(x, self.group_size_cpu, dim=0))
-
- def undispatch(self, *processed_groups: list[torch.Tensor]) -> torch.Tensor:
- return torch.cat(processed_groups, dim=0)
-
- @staticmethod
- def permute(x: torch.Tensor, permute_mapping: torch.Tensor) -> torch.Tensor:
- return x[permute_mapping]
-
- @staticmethod
- def inv_permute(x: torch.Tensor, inv_permute_mapping: torch.Tensor) -> torch.Tensor:
- return x[inv_permute_mapping]
-
-
-# ---------------------------------------------------------------------------
-# Positional / rotary embedding helpers
-# ---------------------------------------------------------------------------
-def freq_bands(
- num_bands: int, temperature: float = 10000.0, step: int = 2, device: torch.device | None = None
-) -> torch.Tensor:
- exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
- return 1.0 / (temperature**exp)
-
-
-def rotate_half(x, interleaved=False):
- if not interleaved:
- x1, x2 = x.chunk(2, dim=-1)
- return torch.cat((-x2, x1), dim=-1)
- else:
- x1, x2 = x[..., ::2], x[..., 1::2]
- return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
-
-
-def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
- ro_dim = cos.shape[-1] * 2
- assert ro_dim <= x.shape[-1]
- cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
- sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
- return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], dim=-1)
-
-
-# ---------------------------------------------------------------------------
-# Fourier positional embedding
-# ---------------------------------------------------------------------------
-class ElementWiseFourierEmbed(nn.Module):
- def __init__(
- self,
- dim: int,
- max_res: int = 224,
- temperature: float = 10000.0,
- in_pixels: bool = True,
- linear_bands: bool = False,
- learnable: bool = False,
- device: torch.device = torch.device("cpu"),
- dtype: torch.dtype = torch.float32,
- ):
- super().__init__()
- self.dim = dim
- self.in_pixels = in_pixels
- self.learnable = learnable
- self.temperature = temperature
- self.max_res = max_res
- self.linear_bands = linear_bands
- self.device = device
- self.dtype = dtype
- bands = self.get_default_bands()
- self.bands = nn.Parameter(bands, requires_grad=self.learnable)
-
- def forward(self, coords: torch.Tensor) -> torch.Tensor:
- coords_xyz = coords[:, :3]
- sizes = coords[:, 3:6]
- refs = coords[:, 6:9]
-
- scales = (refs - 1) / (sizes - 1)
- scales[(refs == 1) & (sizes == 1)] = 1
- assert not scales.isnan().any(), "scales has nan"
- assert not scales.isinf().any(), "scales has inf"
-
- centers = (sizes - 1) / 2
- centers[:, 0] = 0
- coords_xyz = coords_xyz - centers
-
- bands = self.bands.to(coords.device, coords.dtype)
- proj = coords_xyz.unsqueeze(-1) * scales.unsqueeze(-1) * bands
- sin_proj = proj.sin()
- cos_proj = proj.cos()
- return torch.cat((sin_proj, cos_proj), dim=1).flatten(1)
-
- def reset_parameters(self):
- self.bands.copy_(self.get_default_bands())
-
- def get_default_bands(self):
- if self.in_pixels:
- raise NotImplementedError("in_pixels are not implemented yet")
- return freq_bands(self.dim // 8, temperature=self.temperature, step=1, device=self.device).to(self.dtype)
-
-
-# ---------------------------------------------------------------------------
-# Multi-modality RMSNorm
-# ---------------------------------------------------------------------------
-class MultiModalityRMSNorm(nn.Module):
- __constants__ = ["dim", "eps", "num_modality"]
-
- def __init__(self, dim: int, eps: float = 1e-6, device: torch.device | None = None, num_modality: int = 1):
- super().__init__()
- self.dim = dim
- self.eps = eps
- self.num_modality = num_modality
- self.weight = nn.Parameter(torch.zeros(dim * num_modality, device=device, dtype=torch.float32))
- if num_modality > 1:
- self.forward = self.forward_multi_experts
- else:
- self.forward = self.forward_single_expert
- self.reset_parameters()
-
- def reset_parameters(self):
- nn.init.zeros_(self.weight)
-
- def rms(self, x: torch.Tensor) -> torch.Tensor:
- t = x.float()
- return t * torch.rsqrt(torch.mean(t**2, dim=-1, keepdim=True) + self.eps)
-
- def forward_multi_experts(self, x: torch.Tensor, modality_dispatcher: ModalityDispatcher) -> torch.Tensor:
- original_dtype = x.dtype
- t = self.rms(x)
- weight_chunked = self.weight.chunk(self.num_modality, dim=0)
- t_list = modality_dispatcher.dispatch(t)
- for i in range(self.num_modality):
- t_list[i] = t_list[i] * (weight_chunked[i] + 1)
- t = modality_dispatcher.undispatch(*t_list)
- return t.to(original_dtype)
-
- def forward_single_expert(
- self, x: torch.Tensor, modality_dispatcher: ModalityDispatcher | None = None
- ) -> torch.Tensor:
- t, original_dtype = x.float(), x.dtype
- t = t * torch.rsqrt(torch.mean(t**2, dim=-1, keepdim=True) + self.eps)
- return (t * (self.weight + 1)).to(original_dtype)
-
-
-# ---------------------------------------------------------------------------
-# Linear layers with bf16 compute and MoE dispatch
-# ---------------------------------------------------------------------------
-class _BF16ComputeLinear(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- input: torch.Tensor,
- weight: torch.Tensor,
- bias: torch.Tensor | None,
- output_dtype: torch.dtype | None,
- compute_dtype: torch.dtype = torch.bfloat16,
- ):
- input_cast = input.to(compute_dtype)
- weight_cast = weight.to(compute_dtype)
- output = torch.matmul(input_cast, weight_cast.t())
- if bias is not None:
- output = output + bias.to(compute_dtype)
- return output.to(output_dtype)
-
-
-class BaseLinear(nn.Module):
- __constants__ = ["in_features", "out_features", "num_layers", "num_experts"]
-
- def __init__(
- self, in_features, out_features, num_layers_for_initialization, num_experts, bias=True, device=None, dtype=None
- ):
- super().__init__()
- factory_kwargs = {"device": device, "dtype": torch.bfloat16}
- self.in_features = in_features
- self.out_features = out_features
- self.num_layers_for_initialization = num_layers_for_initialization
- self.num_experts = num_experts
- self.use_bias = bias
- self.weight = Parameter(torch.empty((out_features * num_experts, in_features), **factory_kwargs))
- if bias:
- self.bias = Parameter(torch.empty(out_features * num_experts, **factory_kwargs))
- else:
- self.register_parameter("bias", None)
-
- def forward(
- self,
- input: torch.Tensor,
- output_dtype: torch.dtype | None = None,
- modality_dispatcher: ModalityDispatcher | None = None,
- ) -> torch.Tensor:
- output_dtype = input.dtype if output_dtype is None else output_dtype
- return _BF16ComputeLinear.apply(input, self.weight, self.bias, output_dtype, torch.bfloat16)
-
-
-class NativeMoELinear(BaseLinear):
- def forward(
- self,
- input: torch.Tensor,
- output_dtype: torch.dtype | None = None,
- modality_dispatcher: ModalityDispatcher | None = None,
- ) -> torch.Tensor:
- output_dtype = input.dtype if output_dtype is None else output_dtype
- input_list = modality_dispatcher.dispatch(input) # type: ignore
- weight_chunked = self.weight.chunk(self.num_experts, dim=0)
- if self.bias is not None:
- bias_chunked = self.bias.chunk(self.num_experts, dim=0)
- for i in range(self.num_experts):
- input_list[i] = _BF16ComputeLinear.apply(
- input_list[i],
- weight_chunked[i],
- bias_chunked[i] if self.bias is not None else None,
- output_dtype,
- torch.bfloat16,
- )
- return modality_dispatcher.undispatch(*input_list) # type: ignore
-
-
-def create_linear(
- in_features, out_features, num_layers=1, num_experts=1, bias=True, device=None, dtype=None
-) -> BaseLinear | NativeMoELinear:
- if num_experts == 1:
- return BaseLinear(in_features, out_features, num_layers, num_experts, bias, device, dtype)
- else:
- return NativeMoELinear(in_features, out_features, num_layers, num_experts, bias, device, dtype)
-
-
-# ---------------------------------------------------------------------------
-# MoE TP parallel linear wrappers: per-expert vLLM parallel layers
-# ---------------------------------------------------------------------------
-class MoEQKVParallelLinear(nn.Module):
- """Per-expert QKVParallelLinear with modality dispatch.
-
- Wraps ``num_experts`` independent QKVParallelLinear instances.
- Forward: dispatch tokens by modality → per-expert QKV matmul (TP-sharded)
- → undispatch.
- """
-
- def __init__(
- self,
- hidden_size: int,
- head_size: int,
- total_num_heads: int,
- total_num_kv_heads: int,
- num_experts: int,
- bias: bool = False,
- ):
- super().__init__()
- self.num_experts = num_experts
- self.experts = nn.ModuleList(
- [
- QKVParallelLinear(
- hidden_size=hidden_size,
- head_size=head_size,
- total_num_heads=total_num_heads,
- total_num_kv_heads=total_num_kv_heads,
- bias=bias,
- return_bias=False,
- )
- for _ in range(num_experts)
- ]
- )
- # Expose per-rank head info from the first expert (all are identical).
- self.num_heads = self.experts[0].num_heads
- self.num_kv_heads = self.experts[0].num_kv_heads
- self.head_size = head_size
-
- def forward(
- self,
- x: torch.Tensor,
- modality_dispatcher: ModalityDispatcher,
- ) -> torch.Tensor:
- x_list = modality_dispatcher.dispatch(x)
- out_list: list[torch.Tensor] = []
- for i in range(self.num_experts):
- out = self.experts[i](x_list[i])
- out_list.append(out)
- return modality_dispatcher.undispatch(*out_list)
-
-
-class MoEColumnParallelLinear(nn.Module):
- """Per-expert ColumnParallelLinear with modality dispatch.
-
- Forward: dispatch → per-expert column-parallel matmul → undispatch.
- Output stays TP-local (no gather).
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int,
- num_experts: int,
- bias: bool = False,
- ):
- super().__init__()
- self.num_experts = num_experts
- self.experts = nn.ModuleList(
- [
- ColumnParallelLinear(
- input_size=input_size,
- output_size=output_size,
- bias=bias,
- gather_output=False,
- return_bias=False,
- )
- for _ in range(num_experts)
- ]
- )
-
- def forward(
- self,
- x: torch.Tensor,
- modality_dispatcher: ModalityDispatcher,
- ) -> torch.Tensor:
- x_list = modality_dispatcher.dispatch(x)
- out_list: list[torch.Tensor] = []
- for i in range(self.num_experts):
- out = self.experts[i](x_list[i])
- out_list.append(out)
- return modality_dispatcher.undispatch(*out_list)
-
-
-class MoERowParallelLinear(nn.Module):
- """Per-expert RowParallelLinear with modality dispatch.
-
- Forward: dispatch → per-expert row-parallel matmul (includes all-reduce)
- → undispatch.
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int,
- num_experts: int,
- bias: bool = False,
- ):
- super().__init__()
- self.num_experts = num_experts
- self.experts = nn.ModuleList(
- [
- RowParallelLinear(
- input_size=input_size,
- output_size=output_size,
- bias=bias,
- input_is_parallel=True,
- return_bias=False,
- )
- for _ in range(num_experts)
- ]
- )
-
- def forward(
- self,
- x: torch.Tensor,
- modality_dispatcher: ModalityDispatcher,
- ) -> torch.Tensor:
- x_list = modality_dispatcher.dispatch(x)
- out_list: list[torch.Tensor] = []
- for i in range(self.num_experts):
- out = self.experts[i](x_list[i])
- out_list.append(out)
- return modality_dispatcher.undispatch(*out_list)
-
-
-def validate_magi_human_tp_constraints(
- *,
- hidden_size: int,
- num_heads_q: int,
- num_heads_kv: int,
- tensor_parallel_size: int,
-) -> None:
- """Validate MagiHuman TP divisibility constraints.
-
- Both shared layers (num_modality == 1) and MoE layers (num_modality == 3)
- support TP via vLLM's parallel linear layers (QKVParallelLinear /
- ColumnParallelLinear / RowParallelLinear). MoE layers use per-expert
- parallel layers with modality dispatch.
-
- Supported tp_sizes given default config (hidden=5120, heads_q=40, kv=8): 1, 2, 4.
- """
- tp = tensor_parallel_size
- if tp <= 1:
- return
- errors: list[str] = []
- if num_heads_q % tp != 0:
- errors.append(f"num_heads_q ({num_heads_q}) must be divisible by tensor_parallel_size ({tp})")
- if num_heads_kv % tp != 0:
- errors.append(f"num_heads_kv ({num_heads_kv}) must be divisible by tensor_parallel_size ({tp})")
- # SWIGLU layers use intermediate = int(hidden * 8/3) // 4 * 4
- intermediate_swiglu = int(hidden_size * 4 * 2 / 3) // 4 * 4
- if intermediate_swiglu % tp != 0:
- errors.append(
- f"swiglu intermediate_size ({intermediate_swiglu}) must be divisible by "
- f"tensor_parallel_size ({tp}). Supported tp values: 1, 2, 4"
- )
- # GELU7 MoE layers use intermediate = hidden * 4
- intermediate_gelu = hidden_size * 4
- if intermediate_gelu % tp != 0:
- errors.append(f"gelu intermediate_size ({intermediate_gelu}) must be divisible by tensor_parallel_size ({tp})")
- if errors:
- raise ValueError("MagiHuman TP constraint violations:\n" + "\n".join(f" - {e}" for e in errors))
-
-
-# ---------------------------------------------------------------------------
-# Flash attention (no context-parallelism) — uses vllm's flash attention
-# ---------------------------------------------------------------------------
-
-HAS_MAGI_ATTENTION = importlib.util.find_spec("magi_attention") is not None
-
-
-def _fa_varlen_simple(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
-) -> torch.Tensor:
- had_batch = query.ndim == 4
- if had_batch:
- query = query.squeeze(0)
- key = key.squeeze(0)
- value = value.squeeze(0)
- seq_len = query.shape[0]
- cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=query.device)
- out = _vllm_fa_varlen(
- q=query,
- k=key,
- v=value,
- cu_seqlens_q=cu_seqlens,
- cu_seqlens_k=cu_seqlens,
- max_seqlen_q=seq_len,
- max_seqlen_k=seq_len,
- )
- if had_batch:
- out = out.unsqueeze(0)
- return out
-
-
-@magi_register_custom_op(name="infra::flash_attn_func", is_subgraph_boundary=True)
-def flash_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
- return _fa_varlen_simple(query, key, value)
-
-
-def _split_q_range_with_no_overlap(
- q_ranges: torch.Tensor, k_ranges: torch.Tensor
-) -> tuple[list[list[int]], list[list[list[int]]]]:
- range_boundary = torch.unique(q_ranges, sorted=True).tolist()
- candidates = [[start, end, []] for start, end in zip(range_boundary[:-1], range_boundary[1:])]
- q_ranges = q_ranges.tolist()
- k_ranges = k_ranges.tolist()
- for q_range, k_range in zip(q_ranges, k_ranges):
- q_start, q_end = q_range
- for q_range_cand in candidates:
- if q_start <= q_range_cand[0] and q_range_cand[1] <= q_end:
- q_range_cand[2].append(k_range)
- q_ranges_out = []
- k_ranges_out = []
- for q_range_cand in candidates:
- if len(q_range_cand[2]) > 0:
- q_ranges_out.append(q_range_cand[0:2])
- k_ranges_out.append(q_range_cand[2])
- return q_ranges_out, k_ranges_out
-
-
-def _flash_attn_with_correction(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- q_ranges: list[list[int]],
- k_range_list: list[list[list[int]]],
-):
- output = torch.zeros_like(query)
- output_lse = torch.zeros((query.shape[0], query.shape[1]), dtype=torch.float32, device=query.device)
-
- for q_range, k_ranges in zip(q_ranges, k_range_list):
- q_start, q_end = q_range
- q_chunk = query[q_start:q_end]
- q_len = q_chunk.shape[0]
-
- # Concatenate all k_ranges into a single key/value block, then run one
- # flash-attention call. This avoids the need to merge per-chunk LSEs.
- k_parts = [key[ks:ke] for ks, ke in k_ranges]
- v_parts = [value[ks:ke] for ks, ke in k_ranges]
- k_combined = torch.cat(k_parts, dim=0) if len(k_parts) > 1 else k_parts[0]
- v_combined = torch.cat(v_parts, dim=0) if len(v_parts) > 1 else v_parts[0]
- k_len = k_combined.shape[0]
-
- cu_q = torch.tensor([0, q_len], dtype=torch.int32, device=query.device)
- cu_k = torch.tensor([0, k_len], dtype=torch.int32, device=query.device)
- qo_out = _vllm_fa_varlen(
- q=q_chunk,
- k=k_combined,
- v=v_combined,
- cu_seqlens_q=cu_q,
- cu_seqlens_k=cu_k,
- max_seqlen_q=q_len,
- max_seqlen_k=k_len,
- )
- output[q_start:q_end] = qo_out
- return output, output_lse
-
-
-def _flex_flash_attn_func_infer_output_meta(
- query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, q_ranges: torch.Tensor, k_ranges: torch.Tensor
-) -> tuple[torch.Tensor, torch.Tensor]:
- output = torch.empty_like(query)
- output_lse = torch.empty((query.shape[0], query.shape[1]), dtype=torch.float32, device=query.device)
- return output, output_lse
-
-
-@magi_register_custom_op(
- name="infra::flex_flash_attn_func",
- mutates_args=(),
- infer_output_meta_fn=_flex_flash_attn_func_infer_output_meta,
- is_subgraph_boundary=True,
-)
-def flex_flash_attn_func(
- query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, q_ranges: torch.Tensor, k_ranges: torch.Tensor
-) -> tuple[torch.Tensor, torch.Tensor]:
- if HAS_MAGI_ATTENTION and _is_hopper_arch():
- from magi_attention.api import flex_flash_attn_func as magi_flex_flash_attn_func
-
- return magi_flex_flash_attn_func(query, key, value, q_ranges, k_ranges)
- else:
- q_ranges_split, k_range_list = _split_q_range_with_no_overlap(q_ranges, k_ranges)
- return _flash_attn_with_correction(query, key, value, q_ranges_split, k_range_list)
-
-
-def flash_attn_no_cp(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
- q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16)
- return flash_attn_func(q, k, v).squeeze(0)
-
-
-def flex_flash_attn_no_cp(
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- q_ranges: torch.Tensor,
- k_ranges: torch.Tensor,
-) -> torch.Tensor:
- q, k, v = q.to(torch.bfloat16).squeeze(0), k.to(torch.bfloat16).squeeze(0), v.to(torch.bfloat16).squeeze(0)
- out, _ = flex_flash_attn_func(q, k, v, q_ranges=q_ranges, k_ranges=k_ranges)
- return out
-
-
-# ---------------------------------------------------------------------------
-# Attention module (no context-parallelism)
-# ---------------------------------------------------------------------------
-@dataclass
-class AttentionConfig:
- hidden_size: int
- num_heads_q: int
- num_heads_kv: int
- head_dim: int
- params_dtype: torch.dtype
- checkpoint_qk_layernorm_rope: bool
- num_modality: int
- num_layers: int
- use_local_attn: bool = False
- enable_attn_gating: bool = False
-
-
-class Attention(torch.nn.Module):
- config: AttentionConfig
-
- def __init__(self, config: AttentionConfig):
- super().__init__()
- self.config = config
- self.pre_norm = MultiModalityRMSNorm(config.hidden_size, eps=1e-6, num_modality=config.num_modality)
- self.gating_size = config.num_heads_q if config.enable_attn_gating else 0
-
- # Both shared blocks (num_modality == 1) and MoE blocks (num_modality > 1)
- # use vLLM's parallel linear layers for TP support.
- # MoE blocks wrap per-expert parallel layers with modality dispatch.
- if config.num_modality == 1:
- # QKVParallelLinear handles GQA head-sharding for any tp_size.
- # The combined checkpoint weight [Q, K, V, G] is split during
- # load_weights: Q+K+V → linear_qkv, G → linear_gating.
- self.linear_qkv = QKVParallelLinear(
- hidden_size=config.hidden_size,
- head_size=config.head_dim,
- total_num_heads=config.num_heads_q,
- total_num_kv_heads=config.num_heads_kv,
- bias=False,
- return_bias=False,
- )
- self.linear_proj = RowParallelLinear(
- input_size=config.num_heads_q * config.head_dim,
- output_size=config.hidden_size,
- bias=False,
- input_is_parallel=True,
- return_bias=False,
- )
- if config.enable_attn_gating:
- self.linear_gating = ColumnParallelLinear(
- input_size=config.hidden_size,
- output_size=config.num_heads_q,
- bias=False,
- gather_output=False,
- return_bias=False,
- )
- else:
- self.linear_gating = None
- else:
- # MoE blocks: per-expert TP-sharded parallel layers.
- self.linear_qkv = MoEQKVParallelLinear(
- hidden_size=config.hidden_size,
- head_size=config.head_dim,
- total_num_heads=config.num_heads_q,
- total_num_kv_heads=config.num_heads_kv,
- num_experts=config.num_modality,
- bias=False,
- )
- self.linear_proj = MoERowParallelLinear(
- input_size=config.num_heads_q * config.head_dim,
- output_size=config.hidden_size,
- num_experts=config.num_modality,
- bias=False,
- )
- if config.enable_attn_gating:
- self.linear_gating = MoEColumnParallelLinear(
- input_size=config.hidden_size,
- output_size=config.num_heads_q,
- num_experts=config.num_modality,
- bias=False,
- )
- else:
- self.linear_gating = None
-
- self.q_norm = MultiModalityRMSNorm(config.head_dim, num_modality=config.num_modality)
- self.k_norm = MultiModalityRMSNorm(config.head_dim, num_modality=config.num_modality)
-
- # q_size / kv_size reflect the per-rank head count when tp > 1.
- # Both shared and MoE QKV layers expose .num_heads / .num_kv_heads.
- if config.num_modality == 1:
- self.q_size = self.linear_qkv.num_heads * config.head_dim
- self.kv_size = self.linear_qkv.num_kv_heads * config.head_dim
- self._local_heads_q = self.linear_qkv.num_heads
- self._local_heads_kv = self.linear_qkv.num_kv_heads
- else:
- self.q_size = self.linear_qkv.num_heads * config.head_dim
- self.kv_size = self.linear_qkv.num_kv_heads * config.head_dim
- self._local_heads_q = self.linear_qkv.num_heads
- self._local_heads_kv = self.linear_qkv.num_kv_heads
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- rope: torch.Tensor,
- permute_mapping: torch.Tensor,
- inv_permute_mapping: torch.Tensor,
- varlen_handler: VarlenHandler,
- local_attn_handler: FFAHandler | None,
- modality_dispatcher: ModalityDispatcher,
- ) -> torch.Tensor:
- hidden_states = self.pre_norm(hidden_states, modality_dispatcher=modality_dispatcher).to(torch.bfloat16)
-
- if self.config.num_modality == 1:
- # vLLM parallel layers with return_bias=False return a single tensor.
- qkv = self.linear_qkv(hidden_states).to(torch.float32)
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
- if self.linear_gating is not None:
- g = self.linear_gating(hidden_states).to(torch.float32)
- else:
- g = hidden_states.new_empty(hidden_states.shape[0], 0)
- else:
- # MoE TP path: per-expert QKV parallel layers.
- qkv = self.linear_qkv(hidden_states, modality_dispatcher=modality_dispatcher).to(torch.float32)
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
- if self.linear_gating is not None:
- g = self.linear_gating(hidden_states, modality_dispatcher=modality_dispatcher).to(torch.float32)
- else:
- g = hidden_states.new_empty(hidden_states.shape[0], 0)
-
- q = q.view(-1, self._local_heads_q, self.config.head_dim)
- k = k.view(-1, self._local_heads_kv, self.config.head_dim)
- v = v.view(-1, self._local_heads_kv, self.config.head_dim)
- g = g.view(k.shape[0], self._local_heads_q, -1)
-
- q = self.q_norm(q, modality_dispatcher=modality_dispatcher)
- k = self.k_norm(k, modality_dispatcher=modality_dispatcher)
-
- q = ModalityDispatcher.inv_permute(q, inv_permute_mapping).unsqueeze(0)
- k = ModalityDispatcher.inv_permute(k, inv_permute_mapping).unsqueeze(0)
- v = ModalityDispatcher.inv_permute(v, inv_permute_mapping).unsqueeze(0)
-
- sin_emb, cos_emb = rope.tensor_split(2, -1)
- q = apply_rotary_emb_torch(q, cos_emb, sin_emb)
- k = apply_rotary_emb_torch(k, cos_emb, sin_emb)
-
- if self.config.use_local_attn and local_attn_handler is not None:
- self_attn_out = flex_flash_attn_no_cp(q, k, v, local_attn_handler.q_ranges, local_attn_handler.k_ranges)
- else:
- self_attn_out = flash_attn_no_cp(q, k, v)
- self_attn_out = ModalityDispatcher.permute(self_attn_out, permute_mapping)
-
- if self.config.enable_attn_gating:
- self_attn_out = self_attn_out * torch.sigmoid(g)
-
- self_attn_out = self_attn_out.view(-1, self._local_heads_q * self.config.head_dim).to(torch.bfloat16)
- if self.config.num_modality == 1:
- return self.linear_proj(self_attn_out)
- return self.linear_proj(self_attn_out, modality_dispatcher=modality_dispatcher)
-
-
-# ---------------------------------------------------------------------------
-# MLP module
-# ---------------------------------------------------------------------------
-@dataclass
-class MLPConfig:
- hidden_size: int
- intermediate_size: int
- activation_type: MLPActivationType
- params_dtype: torch.dtype
- num_modality: int = 1
- num_layers: int = 1
- gated_act: bool = False
-
-
-class MLP(torch.nn.Module):
- config: MLPConfig
-
- def __init__(self, config: MLPConfig):
- super().__init__()
- num_experts = config.num_modality
- self.pre_norm = MultiModalityRMSNorm(config.hidden_size, num_modality=config.num_modality)
- intermediate_size_up = config.intermediate_size * 2 if config.gated_act else config.intermediate_size
-
- # Both shared blocks (num_experts == 1) and MoE blocks (num_experts > 1)
- # use vLLM's parallel linear layers for TP support.
- if num_experts == 1:
- # ColumnParallelLinear shards the output dim uniformly. For
- # SWIGLU7 the interleaved [up0, gate0, up1, gate1, ...] format
- # is preserved within each rank's contiguous slice, so swiglu7
- # (which uses x[..., ::2] / x[..., 1::2]) still works correctly.
- self.up_gate_proj = ColumnParallelLinear(
- input_size=config.hidden_size,
- output_size=intermediate_size_up,
- bias=False,
- gather_output=False,
- return_bias=False,
- )
- self.down_proj = RowParallelLinear(
- input_size=config.intermediate_size,
- output_size=config.hidden_size,
- bias=False,
- input_is_parallel=True,
- return_bias=False,
- )
- else:
- # MoE blocks: per-expert TP-sharded parallel layers.
- self.up_gate_proj = MoEColumnParallelLinear(
- input_size=config.hidden_size,
- output_size=intermediate_size_up,
- num_experts=num_experts,
- bias=False,
- )
- self.down_proj = MoERowParallelLinear(
- input_size=config.intermediate_size,
- output_size=config.hidden_size,
- num_experts=num_experts,
- bias=False,
- )
- self.activation_func = create_activation_func(config.activation_type)
-
- def forward(self, x: torch.Tensor, modality_dispatcher: ModalityDispatcher) -> torch.Tensor:
- x = self.pre_norm(x, modality_dispatcher=modality_dispatcher).to(torch.bfloat16)
- if isinstance(self.up_gate_proj, ColumnParallelLinear):
- x = self.up_gate_proj(x).to(torch.float32)
- x = self.activation_func(x).to(torch.bfloat16)
- return self.down_proj(x).to(torch.float32)
- # MoE TP path: per-expert column/row parallel layers.
- x = self.up_gate_proj(x, modality_dispatcher=modality_dispatcher).to(torch.float32)
- x = self.activation_func(x).to(torch.bfloat16)
- x = self.down_proj(x, modality_dispatcher=modality_dispatcher).to(torch.float32)
- return x
-
-
-# ---------------------------------------------------------------------------
-# Adapter (per-modality embedders + RoPE)
-# ---------------------------------------------------------------------------
-@dataclass
-class AdapterConfig:
- hidden_size: int
- num_attention_heads: int
- text_in_channels: int
- video_in_channels: int
- audio_in_channels: int
- params_dtype: torch.dtype
-
-
-class Adapter(torch.nn.Module):
- config: AdapterConfig
-
- def __init__(self, config: AdapterConfig):
- super().__init__()
- self.config = config
- self.video_embedder = nn.Linear(config.video_in_channels, config.hidden_size, bias=True, dtype=torch.float32)
- self.text_embedder = nn.Linear(config.text_in_channels, config.hidden_size, bias=True, dtype=torch.float32)
- self.audio_embedder = nn.Linear(config.audio_in_channels, config.hidden_size, bias=True, dtype=torch.float32)
- self.rope = ElementWiseFourierEmbed(
- config.hidden_size // config.num_attention_heads, in_pixels=False, learnable=False
- )
-
- def forward(self, x, coords_mapping, video_mask, audio_mask, text_mask):
- rope = self.rope(coords_mapping)
-
- text_input = x[text_mask, : self.config.text_in_channels]
- audio_input = x[audio_mask, : self.config.audio_in_channels]
- video_input = x[video_mask, : self.config.video_in_channels]
-
- text_out = self.text_embedder(text_input)
- audio_out = self.audio_embedder(audio_input)
- video_out = self.video_embedder(video_input)
-
- output_x = torch.zeros(x.shape[0], self.config.hidden_size, device=x.device, dtype=x.dtype)
- output_x[text_mask] = text_out
- output_x[audio_mask] = audio_out
- output_x[video_mask] = video_out
- return output_x, rope
-
-
-# ---------------------------------------------------------------------------
-# Transformer layer (no CP)
-# ---------------------------------------------------------------------------
-class TransFormerLayer(torch.nn.Module):
- def __init__(self, config: Any, layer_idx: int):
- super().__init__()
- num_modality = 3 if layer_idx in config.mm_layers else 1
- use_local_attn = layer_idx in config.local_attn_layers
- self.post_norm = layer_idx in config.post_norm_layers
- attention_config = AttentionConfig(
- hidden_size=config.hidden_size,
- num_heads_q=config.num_heads_q,
- num_heads_kv=config.num_heads_kv,
- head_dim=config.head_dim,
- params_dtype=config.params_dtype,
- checkpoint_qk_layernorm_rope=config.checkpoint_qk_layernorm_rope,
- num_modality=num_modality,
- num_layers=config.num_layers,
- use_local_attn=use_local_attn,
- enable_attn_gating=config.enable_attn_gating,
- )
- self.attention: Attention = Attention(attention_config)
-
- activation_type = MLPActivationType.GELU7 if layer_idx in config.gelu7_layers else MLPActivationType.SWIGLU7
- if activation_type == MLPActivationType.SWIGLU7:
- gated_act = True
- intermediate_size = int(config.hidden_size * 4 * 2 / 3) // 4 * 4
- else:
- gated_act = False
- intermediate_size = config.hidden_size * 4
- mlp_config = MLPConfig(
- hidden_size=config.hidden_size,
- intermediate_size=intermediate_size,
- activation_type=activation_type,
- params_dtype=config.params_dtype,
- num_modality=num_modality,
- num_layers=config.num_layers,
- gated_act=gated_act,
- )
- self.mlp: MLP = MLP(mlp_config)
- if self.post_norm:
- self.attn_post_norm = MultiModalityRMSNorm(config.hidden_size, num_modality=num_modality)
- self.mlp_post_norm = MultiModalityRMSNorm(config.hidden_size, num_modality=num_modality)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- rope: torch.Tensor,
- permute_mapping: torch.Tensor,
- inv_permute_mapping: torch.Tensor,
- varlen_handler: VarlenHandler,
- local_attn_handler: FFAHandler | None,
- modality_dispatcher: ModalityDispatcher,
- ) -> torch.Tensor:
- attn_out = self.attention(
- hidden_states,
- rope,
- permute_mapping,
- inv_permute_mapping,
- varlen_handler,
- local_attn_handler,
- modality_dispatcher,
- )
- if self.post_norm:
- attn_out = self.attn_post_norm(attn_out, modality_dispatcher=modality_dispatcher)
- hidden_states = hidden_states + attn_out
-
- mlp_out = self.mlp(hidden_states, modality_dispatcher)
- if self.post_norm:
- mlp_out = self.mlp_post_norm(mlp_out, modality_dispatcher=modality_dispatcher)
- hidden_states = hidden_states + mlp_out
- return hidden_states
-
-
-# ---------------------------------------------------------------------------
-# TransformerBlock with magi_compile
-# ---------------------------------------------------------------------------
-is_base_model = True
-
-
-def config_patch(compile_config: CompileConfig) -> CompileConfig:
- global is_base_model
- if is_base_model:
- is_base_model = False
- else:
- compile_config.offload_config.gpu_resident_weight_ratio = 0.0
- return compile_config
-
-
-@magi_compile(
- config_patch=config_patch, dynamic_arg_dims={"x": 0, "rope": 0, "permute_mapping": 0, "inv_permute_mapping": 0}
-)
-class TransformerBlock(torch.nn.Module):
- def __init__(self, model_config: Any):
- super().__init__()
- self.layers: list[TransFormerLayer] = nn.ModuleList()
- for layer_idx in range(model_config.num_layers):
- self.layers.append(TransFormerLayer(model_config, layer_idx))
-
- def forward(
- self,
- x: torch.Tensor,
- rope: torch.Tensor,
- permute_mapping: torch.Tensor,
- inv_permute_mapping: torch.Tensor,
- varlen_handler: VarlenHandler,
- local_attn_handler: FFAHandler | None,
- modality_dispatcher: ModalityDispatcher,
- ) -> torch.Tensor:
- for layer in self.layers:
- x = layer(
- x, rope, permute_mapping, inv_permute_mapping, varlen_handler, local_attn_handler, modality_dispatcher
- )
- return x
-
-
-# ---------------------------------------------------------------------------
-# Internal config for TransformerBlock / DiTModel construction
-# ---------------------------------------------------------------------------
-@dataclass
-class TransformerConfig:
- hidden_size: int
- video_in_channels: int
- audio_in_channels: int
- text_in_channels: int
- params_dtype: torch.dtype
- post_process_dtype: torch.dtype
-
-
-# ---------------------------------------------------------------------------
-# DiTModel (no context-parallelism)
-# ---------------------------------------------------------------------------
-class DiTModel(torch.nn.Module):
- config: TransformerConfig
- _layerwise_offload_blocks_attr = "blocks"
-
- @property
- def blocks(self) -> nn.ModuleList:
- return self.block.layers
-
- def __init__(self, model_config: Any):
- super().__init__()
- validate_magi_human_tp_constraints(
- hidden_size=model_config.hidden_size,
- num_heads_q=model_config.hidden_size // model_config.head_dim,
- num_heads_kv=model_config.num_query_groups,
- tensor_parallel_size=get_tensor_model_parallel_world_size(),
- )
- self.config = TransformerConfig(
- hidden_size=model_config.hidden_size,
- video_in_channels=model_config.video_in_channels,
- audio_in_channels=model_config.audio_in_channels,
- text_in_channels=model_config.text_in_channels,
- params_dtype=model_config.params_dtype,
- post_process_dtype=torch.float32,
- )
- adapter_config = AdapterConfig(
- hidden_size=model_config.hidden_size,
- num_attention_heads=model_config.num_heads_q,
- text_in_channels=model_config.text_in_channels,
- video_in_channels=model_config.video_in_channels,
- audio_in_channels=model_config.audio_in_channels,
- params_dtype=torch.float32,
- )
- self.adapter: Adapter = Adapter(adapter_config)
- self.block: TransformerBlock = TransformerBlock(model_config=model_config)
- self.final_norm_video = MultiModalityRMSNorm(self.config.hidden_size)
- self.final_norm_audio = MultiModalityRMSNorm(self.config.hidden_size)
- self.final_linear_video = nn.Linear(
- self.config.hidden_size, self.config.video_in_channels, bias=False, dtype=torch.float32
- )
- self.final_linear_audio = nn.Linear(
- self.config.hidden_size, self.config.audio_in_channels, bias=False, dtype=torch.float32
- )
-
- def forward(
- self,
- x: torch.Tensor,
- coords_mapping: torch.Tensor,
- modality_mapping: torch.Tensor,
- varlen_handler: VarlenHandler,
- local_attn_handler: FFAHandler | None,
- ):
- modality_dispatcher = ModalityDispatcher(modality_mapping, 3)
- permute_mapping = modality_dispatcher.permute_mapping
- inv_permute_mapping = modality_dispatcher.inv_permute_mapping
- video_mask = modality_mapping == Modality.VIDEO
- audio_mask = modality_mapping == Modality.AUDIO
- text_mask = modality_mapping == Modality.TEXT
-
- x, rope = self.adapter(x, coords_mapping, video_mask, audio_mask, text_mask)
-
- x = x.to(self.config.params_dtype)
- x = ModalityDispatcher.permute(x, permute_mapping)
-
- x = self.block(
- x,
- rope,
- permute_mapping=permute_mapping,
- inv_permute_mapping=inv_permute_mapping,
- varlen_handler=varlen_handler,
- local_attn_handler=local_attn_handler,
- modality_dispatcher=modality_dispatcher,
- )
-
- x = ModalityDispatcher.inv_permute(x, inv_permute_mapping)
-
- x_video = x[video_mask].to(self.final_norm_video.weight.dtype)
- x_video = self.final_norm_video(x_video)
- x_video = self.final_linear_video(x_video)
-
- x_audio = x[audio_mask].to(self.final_norm_audio.weight.dtype)
- x_audio = self.final_norm_audio(x_audio)
- x_audio = self.final_linear_audio(x_audio)
-
- x_out = torch.zeros(
- x.shape[0],
- max(self.config.video_in_channels, self.config.audio_in_channels),
- device=x.device,
- dtype=x.dtype,
- )
- x_out[video_mask, : self.config.video_in_channels] = x_video
- x_out[audio_mask, : self.config.audio_in_channels] = x_audio
-
- return x_out
-
-
-# ---------------------------------------------------------------------------
-# Public config dataclass for building DiTModel from JSON
-# ---------------------------------------------------------------------------
-@dataclass
-class MagiHumanDiTConfig:
- num_layers: int = 40
- hidden_size: int = 5120
- head_dim: int = 128
- num_query_groups: int = 8
- video_in_channels: int = 48 * 4
- audio_in_channels: int = 64
- text_in_channels: int = 3584
- checkpoint_qk_layernorm_rope: bool = False
- params_dtype: torch.dtype = torch.float32
- mm_layers: list = field(default_factory=lambda: [0, 1, 2, 3, 36, 37, 38, 39])
- local_attn_layers: list = field(default_factory=list)
- enable_attn_gating: bool = True
- gelu7_layers: list = field(default_factory=lambda: [0, 1, 2, 3])
- post_norm_layers: list = field(default_factory=list)
-
- def __post_init__(self):
- self.num_heads_q = self.hidden_size // self.head_dim
- self.num_heads_kv = self.num_query_groups
-
-
-if TYPE_CHECKING:
- from .pipeline_magi_human import EvalInput
-
-
-# ===========================================================================
-# Data proxy (ported from daVinci-MagiHuman inference/pipeline/data_proxy.py)
-# ===========================================================================
-def _unfold_3d(
- x: torch.Tensor,
- kernel_size: tuple[int, int, int],
- stride: tuple[int, int, int],
-) -> torch.Tensor:
- """Pure-PyTorch 3D unfold matching UnfoldAnd behavior.
-
- After N unfold ops the shape is (batch, C, oD, oH, oW, kD, kH, kW).
- UnfoldAnd permutes kernel dims next to channel before reshape so that the
- col_dim axis is ordered as (C, kD, kH, kW) -- matching F.unfold semantics.
- Without this permute, .view() interleaves spatial and kernel positions.
-
- Args:
- x: (N, C, D, H, W)
- kernel_size: (kD, kH, kW)
- stride: (sD, sH, sW)
- Returns:
- (N, C*kD*kH*kW, L) where L = product of output spatial dims.
- """
- ndim = len(kernel_size)
- for d in range(ndim):
- x = x.unfold(d + 2, kernel_size[d], stride[d])
- # x: (N, C, oD, oH, oW, kD, kH, kW)
- # Permute to (N, C, kD, kH, kW, oD, oH, oW) so that view groups correctly
- perm = [0, 1] + list(range(ndim + 2, 2 * ndim + 2)) + list(range(2, ndim + 2))
- x = x.permute(*perm).contiguous()
-
- batch_size = x.shape[0]
- col_dim = 1
- for i in range(1, ndim + 2):
- col_dim *= x.shape[i]
- spatial = 1
- for i in range(ndim + 2, 2 * ndim + 2):
- spatial *= x.shape[i]
- return x.view(batch_size, col_dim, spatial)
-
-
-def calc_local_qk_range(
- num_video_tokens,
- num_audio_and_txt_tokens,
- num_frames,
- frame_receptive_field,
-):
- token_per_frame = num_video_tokens // num_frames
- total_tokens = num_video_tokens + num_audio_and_txt_tokens
-
- q_range_list = []
- k_range_list = []
- for i in range(num_frames):
- q_range_list.append(torch.tensor([i * token_per_frame, (i + 1) * token_per_frame]))
- k_range_list.append(
- torch.tensor(
- [
- (i - frame_receptive_field) * token_per_frame,
- (i + frame_receptive_field + 1) * token_per_frame,
- ]
- )
- )
- local_q_range = torch.stack(q_range_list, dim=0)
- local_k_range = torch.stack(k_range_list, dim=0)
-
- local_k_range[local_k_range < 0] = 0
- local_k_range[local_k_range > num_video_tokens] = num_video_tokens
-
- video_q_range = torch.tensor([[0, num_video_tokens]])
- video_k_range = torch.tensor([[num_video_tokens, num_video_tokens + num_audio_and_txt_tokens]])
-
- at_q_ranges = torch.tensor([[num_video_tokens, total_tokens]])
- at_k_ranges = torch.tensor([[0, total_tokens]])
-
- q_ranges = (
- torch.cat([local_q_range, video_q_range, at_q_ranges], dim=0).to(torch.int32).to("cuda", non_blocking=True)
- )
- k_ranges = (
- torch.cat([local_k_range, video_k_range, at_k_ranges], dim=0).to(torch.int32).to("cuda", non_blocking=True)
- )
- return q_ranges, k_ranges
-
-
-def calc_local_attn_ffa_handler(
- num_video_tokens,
- num_audio_and_txt_tokens,
- num_frames,
- frame_receptive_field,
-):
- q_ranges, k_ranges = calc_local_qk_range(
- num_video_tokens,
- num_audio_and_txt_tokens,
- num_frames,
- frame_receptive_field,
- )
- total = num_video_tokens + num_audio_and_txt_tokens
- return FFAHandler(
- q_ranges=q_ranges,
- k_ranges=k_ranges,
- max_seqlen_q=total,
- max_seqlen_k=total,
- attn_type_map=torch.zeros([q_ranges.shape[0]], device="cuda", dtype=torch.int32),
- softmax_scale=None,
- )
-
-
-def get_coords(
- shape: list[int],
- ref_feat_shape: list[int],
- offset_thw: list[int] | None = None,
- device: torch.device = torch.device("cpu"),
- dtype: torch.dtype = torch.float32,
-):
- if offset_thw is None:
- offset_thw = [0, 0, 0]
- ori_t, ori_h, ori_w = shape
- ref_t, ref_h, ref_w = ref_feat_shape
-
- offset_t, offset_h, offset_w = offset_thw
- time_rng = torch.arange(ori_t, device=device, dtype=dtype) + offset_t
- height_rng = torch.arange(ori_h, device=device, dtype=dtype) + offset_h
- width_rng = torch.arange(ori_w, device=device, dtype=dtype) + offset_w
-
- time_grid, height_grid, width_grid = torch.meshgrid(
- time_rng,
- height_rng,
- width_rng,
- indexing="ij",
- )
- coords_flat = torch.stack([time_grid, height_grid, width_grid], dim=-1).reshape(-1, 3)
-
- meta = torch.tensor(
- [ori_t, ori_h, ori_w, ref_t, ref_h, ref_w],
- device=device,
- dtype=dtype,
- )
- meta_expanded = meta.expand(coords_flat.size(0), -1)
- return torch.cat([coords_flat, meta_expanded], dim=-1)
-
-
-@dataclass
-class SingleData:
- video_x_t: torch.Tensor
- audio_x_t: torch.Tensor
- audio_feat_len: int
- txt_feat: torch.Tensor
- txt_feat_len: int
- t: int
- h: int
- w: int
- patch_size: int
- t_patch_size: int
- spatial_rope_interpolation: Literal["inter", "extra"]
- ref_audio_offset: int
- text_offset: int
- coords_style: Literal["v1", "v2"] = "v1"
-
- def __post_init__(self):
- self.video_token_num = self.video_x_t.shape[0]
- self.audio_x_t = self.audio_x_t[: self.audio_feat_len]
- self.txt_feat = self.txt_feat[: self.txt_feat_len]
- self.video_channel = self.video_x_t.shape[-1]
- self.audio_channel = self.audio_x_t.shape[-1]
- self.txt_channel = self.txt_feat.shape[-1]
-
- @property
- def device(self):
- return self.video_x_t.device
-
- @property
- def default_dtype(self):
- return self.video_x_t.dtype
-
- @property
- def total_token_num(self):
- return self.video_token_num + self.audio_feat_len + self.txt_feat_len
-
- @property
- def token_sequence(self):
- tensors = [self.video_x_t, self.audio_x_t, self.txt_feat]
- max_channel = max(t.shape[-1] for t in tensors)
- padded = [F.pad(t, (0, max_channel - t.shape[-1])) for t in tensors]
- return torch.cat(padded, dim=0)
-
- @property
- def modality_mapping(self):
- v_map = torch.full((self.video_token_num,), Modality.VIDEO, dtype=torch.int64, device=self.device)
- a_map = torch.full((self.audio_feat_len,), Modality.AUDIO, dtype=torch.int64, device=self.device)
- t_map = torch.full((self.txt_feat_len,), Modality.TEXT, dtype=torch.int64, device=self.device)
- return torch.cat([v_map, a_map, t_map], dim=0)
-
- def default_coords(self, shape, ref_feat_shape, offset_thw=None):
- if offset_thw is None:
- offset_thw = [0, 0, 0]
- return get_coords(
- shape=shape,
- ref_feat_shape=ref_feat_shape,
- offset_thw=offset_thw,
- device=self.device,
- dtype=self.default_dtype,
- )
-
- @property
- def coords_mapping(self):
- if self.spatial_rope_interpolation == "inter":
- video_ref_feat_shape = (self.t // self.t_patch_size, 32, 32)
- else:
- video_ref_feat_shape = (
- self.t // self.t_patch_size,
- self.h // self.patch_size,
- self.w // self.patch_size,
- )
-
- video_coords = self.default_coords(
- shape=(
- self.t // self.t_patch_size,
- self.h // self.patch_size,
- self.w // self.patch_size,
- ),
- ref_feat_shape=video_ref_feat_shape,
- )
-
- if self.coords_style == "v1":
- audio_coords = self.default_coords(
- shape=(self.audio_feat_len, 1, 1),
- ref_feat_shape=(self.t // self.t_patch_size, 1, 1),
- )
- text_coords = self.default_coords(
- shape=(self.txt_feat_len, 1, 1),
- ref_feat_shape=(2, 1, 1),
- offset_thw=[self.text_offset, 0, 0],
- )
- elif self.coords_style == "v2":
- magic_audio_ref_t = (self.audio_feat_len - 1) // 4 + 1
- audio_coords = self.default_coords(
- shape=(self.audio_feat_len, 1, 1),
- ref_feat_shape=(magic_audio_ref_t // self.t_patch_size, 1, 1),
- )
- text_coords = self.default_coords(
- shape=(self.txt_feat_len, 1, 1),
- ref_feat_shape=(1, 1, 1),
- offset_thw=[-self.txt_feat_len, 0, 0],
- )
- else:
- raise ValueError(f"Unknown coords_style: {self.coords_style}")
-
- return torch.cat([video_coords, audio_coords, text_coords], dim=0)
-
- def depack_token_sequence(self, token_sequence):
- video_x_t = token_sequence[: self.video_token_num, : self.video_channel]
- video_x_t = rearrange(
- video_x_t,
- "(T H W) (pT pH pW C) -> C (T pT) (H pH) (W pW)",
- H=self.h // self.patch_size,
- W=self.w // self.patch_size,
- pT=self.t_patch_size,
- pH=self.patch_size,
- pW=self.patch_size,
- ).contiguous()
- audio_x_t = token_sequence[
- self.video_token_num : self.video_token_num + self.audio_feat_len,
- : self.audio_channel,
- ]
- return video_x_t, audio_x_t
-
-
-@dataclass
-class SimplePackedData:
- items: list[SingleData]
-
- @property
- def token_sequence(self):
- return torch.cat([item.token_sequence for item in self.items], dim=0)
-
- @property
- def modality_mapping(self):
- return torch.cat([item.modality_mapping for item in self.items], dim=0)
-
- @property
- def coords_mapping(self):
- return torch.cat([item.coords_mapping for item in self.items], dim=0)
-
- @property
- def total_token_num(self):
- return sum(item.total_token_num for item in self.items)
-
- def __getitem__(self, index):
- return self.items[index]
-
- @property
- def cu_seqlen(self):
- cu = torch.cumsum(
- torch.tensor([item.total_token_num for item in self.items]),
- dim=0,
- )
- return F.pad(cu, (1, 0))
-
- @property
- def max_seqlen(self):
- return torch.tensor(max(item.total_token_num for item in self.items))
-
- def depack_token_sequence(self, token_sequence):
- video_list, audio_list = [], []
- parts = torch.split(
- token_sequence,
- [item.total_token_num for item in self.items],
- dim=0,
- )
- for item, part in zip(self.items, parts):
- v, a = item.depack_token_sequence(part)
- video_list.append(v)
- audio_list.append(a)
- return torch.stack(video_list, dim=0), torch.stack(audio_list, dim=0)
-
-
-class MagiDataProxy:
- def __init__(
- self,
- patch_size: int = 2,
- t_patch_size: int = 1,
- frame_receptive_field: int = 11,
- spatial_rope_interpolation: str = "extra",
- ref_audio_offset: int = 1000,
- text_offset: int = 0,
- coords_style: str = "v2",
- ):
- self.patch_size = patch_size
- self.t_patch_size = t_patch_size
- self.frame_receptive_field = frame_receptive_field
- self.spatial_rope_interpolation = spatial_rope_interpolation
- self.ref_audio_offset = ref_audio_offset
- self.text_offset = text_offset
- self.coords_style = coords_style
- self._kernel = (t_patch_size, patch_size, patch_size)
- self._stride = (t_patch_size, patch_size, patch_size)
- self._saved_data: dict[str, Any] = {}
-
- def saved_for_output(self, **kwargs):
- self._saved_data.update(kwargs)
-
- def get_saved_data(self, key: str):
- return self._saved_data[key]
-
- def img2tokens(self, x_t: torch.Tensor):
- x_t_unfolded = _unfold_3d(x_t, self._kernel, self._stride)
- return rearrange(
- x_t_unfolded,
- "N col_dim num_tokens -> N num_tokens col_dim",
- ).contiguous()
-
- def process_input(self, transported_data: EvalInput):
- batch_size, _, t, h, w = transported_data.x_t.shape
- x_t = self.img2tokens(transported_data.x_t)
- audio_x_t = transported_data.audio_x_t.contiguous()
- text_in = transported_data.txt_feat.contiguous()
-
- simple_packed_data = SimplePackedData(items=[])
- for i in range(batch_size):
- single_data = SingleData(
- video_x_t=x_t[i],
- audio_x_t=audio_x_t[i],
- audio_feat_len=transported_data.audio_feat_len[i],
- txt_feat=text_in[i],
- txt_feat_len=transported_data.txt_feat_len[i],
- t=t,
- h=h,
- w=w,
- patch_size=self.patch_size,
- t_patch_size=self.t_patch_size,
- spatial_rope_interpolation=self.spatial_rope_interpolation,
- ref_audio_offset=self.ref_audio_offset,
- text_offset=self.text_offset,
- coords_style=self.coords_style,
- )
- simple_packed_data.items.append(single_data)
-
- if self.frame_receptive_field != -1:
- assert batch_size == 1, "local attention only supports batch size 1"
- local_attn_handler = calc_local_attn_ffa_handler(
- num_video_tokens=simple_packed_data[0].video_token_num,
- num_audio_and_txt_tokens=(simple_packed_data[0].audio_feat_len + simple_packed_data[0].txt_feat_len),
- num_frames=t,
- frame_receptive_field=self.frame_receptive_field,
- )
- if isinstance(local_attn_handler.max_seqlen_k, torch.Tensor):
- local_attn_handler.max_seqlen_k = local_attn_handler.max_seqlen_k.item()
- if isinstance(local_attn_handler.max_seqlen_q, torch.Tensor):
- local_attn_handler.max_seqlen_q = local_attn_handler.max_seqlen_q.item()
- else:
- local_attn_handler = None
-
- varlen_handler = VarlenHandler(
- cu_seqlens_q=simple_packed_data.cu_seqlen.to(torch.int32).cuda(),
- cu_seqlens_k=simple_packed_data.cu_seqlen.to(torch.int32).cuda(),
- max_seqlen_q=simple_packed_data.max_seqlen.to(torch.int32).cuda(),
- max_seqlen_k=simple_packed_data.max_seqlen.to(torch.int32).cuda(),
- )
-
- self.saved_for_output(simple_packed_data=simple_packed_data)
-
- x = simple_packed_data.token_sequence
- coords_mapping = simple_packed_data.coords_mapping
- modality_mapping = simple_packed_data.modality_mapping
- return (x, coords_mapping, modality_mapping, varlen_handler, local_attn_handler)
-
- def process_output(self, x: torch.Tensor):
- simple_packed_data: SimplePackedData = self.get_saved_data("simple_packed_data")
- return simple_packed_data.depack_token_sequence(x)
diff --git a/vllm_omni/diffusion/models/magi_human/pipeline_magi_human.py b/vllm_omni/diffusion/models/magi_human/pipeline_magi_human.py
deleted file mode 100644
index c1abdf91f04..00000000000
--- a/vllm_omni/diffusion/models/magi_human/pipeline_magi_human.py
+++ /dev/null
@@ -1,2269 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright (c) 2026 SandAI. All Rights Reserved.
-# Ported from daVinci-MagiHuman inference/pipeline/video_generate.py
-# Adapted for vllm-omni: single-GPU, diffusers VAE, configurable dit_subfolder.
-
-from __future__ import annotations
-
-import json
-import logging
-import math
-import os
-from collections.abc import Iterable
-from dataclasses import dataclass
-from pathlib import Path
-from typing import Any, Literal
-
-import numpy as np
-import torch
-import torch.nn as nn
-import whisper
-from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.schedulers.scheduling_utils import (
- KarrasDiffusionSchedulers,
- SchedulerMixin,
- SchedulerOutput,
-)
-from diffusers.utils import deprecate, load_image
-from diffusers.utils.torch_utils import randn_tensor
-from diffusers.video_processor import VideoProcessor
-from einops import rearrange
-from PIL import Image
-from safetensors.torch import load_file
-from torch.nn import functional as F
-from torch.nn.utils import weight_norm
-from transformers import AutoTokenizer
-from transformers.models.t5gemma import T5GemmaEncoderModel
-from vllm.distributed import (
- get_tensor_model_parallel_world_size,
-)
-from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-
-from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
-from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_wan import (
- DistributedAutoencoderKLWan,
-)
-from vllm_omni.diffusion.model_loader.diffusers_loader import (
- DiffusersPipelineLoader,
-)
-from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
-from vllm_omni.diffusion.models.t5_encoder.t5_gemma_encoder import T5GemmaEncoderModelTP
-from vllm_omni.diffusion.models.utils import _load_json
-from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import (
- DiffusionPipelineProfilerMixin,
-)
-from vllm_omni.diffusion.request import OmniDiffusionRequest
-
-from .magi_human_dit import (
- DiTModel,
- FFAHandler,
- MagiHumanDiTConfig,
- Modality,
- VarlenHandler,
-)
-
-logger = logging.getLogger(__name__)
-
-
-# ===========================================================================
-# Scheduler (ported from daVinci-MagiHuman inference/pipeline/scheduler_unipc.py)
-# ===========================================================================
-class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
- _compatibles = [e.name for e in KarrasDiffusionSchedulers]
- order = 1
-
- @register_to_config
- def __init__(
- self,
- num_train_timesteps: int = 1000,
- solver_order: int = 2,
- prediction_type: str = "flow_prediction",
- shift: float = 1.0,
- use_dynamic_shifting=False,
- thresholding: bool = False,
- dynamic_thresholding_ratio: float = 0.995,
- sample_max_value: float = 1.0,
- predict_x0: bool = True,
- solver_type: str = "bh2",
- lower_order_final: bool = True,
- disable_corrector: list[int] = [],
- solver_p: SchedulerMixin = None,
- timestep_spacing: str = "linspace",
- steps_offset: int = 0,
- final_sigmas_type: str | None = "zero",
- ):
- if solver_type not in ["bh1", "bh2"]:
- if solver_type in ["midpoint", "heun", "logrho"]:
- self.register_to_config(solver_type="bh2")
- else:
- raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
-
- self.predict_x0 = predict_x0
- self.num_inference_steps = None
- alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
- sigmas = 1.0 - alphas
- sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
-
- if not use_dynamic_shifting:
- sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
-
- self.sigmas = sigmas
- self.timesteps = sigmas * num_train_timesteps
-
- self.model_outputs = [None] * solver_order
- self.timestep_list = [None] * solver_order
- self.lower_order_nums = 0
- self.disable_corrector = disable_corrector
- self.solver_p = solver_p
- self.last_sample = None
- self._step_index: int | None = None
- self._begin_index: int | None = None
-
- self.sigmas = self.sigmas.to("cpu")
- self.sigma_min = self.sigmas[-1].item()
- self.sigma_max = self.sigmas[0].item()
-
- @property
- def step_index(self):
- return self._step_index
-
- @property
- def begin_index(self):
- return self._begin_index
-
- def set_begin_index(self, begin_index: int = 0):
- self._begin_index = begin_index
-
- def set_timesteps(
- self,
- num_inference_steps: int | None = None,
- device: str | torch.device = None,
- sigmas: list[float] | None = None,
- mu: float | None | None = None,
- shift: float | None | None = None,
- ):
- if self.config.use_dynamic_shifting and mu is None:
- raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
-
- if sigmas is None:
- sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1]
-
- if self.config.use_dynamic_shifting:
- sigmas = self.time_shift(mu, 1.0, sigmas)
- else:
- if shift is None:
- shift = self.config.shift
- sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
-
- if self.config.final_sigmas_type == "sigma_min":
- sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
- elif self.config.final_sigmas_type == "zero":
- sigma_last = 0
- else:
- raise ValueError(
- f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
- )
-
- timesteps = sigmas * self.config.num_train_timesteps
- sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
-
- self.sigmas = torch.from_numpy(sigmas)
- self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
-
- self.num_inference_steps = len(timesteps)
-
- self.model_outputs = [None] * self.config.solver_order
- self.lower_order_nums = 0
- self.last_sample = None
- if self.solver_p:
- self.solver_p.set_timesteps(self.num_inference_steps, device=device)
-
- self._step_index = None
- self._begin_index = None
- self.sigmas = self.sigmas.to("cpu")
-
- def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
- dtype = sample.dtype
- batch_size, channels, *remaining_dims = sample.shape
-
- if dtype not in (torch.float32, torch.float64):
- sample = sample.float()
-
- sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
- abs_sample = sample.abs()
- s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
- s = torch.clamp(s, min=1, max=self.config.sample_max_value)
- s = s.unsqueeze(1)
- sample = torch.clamp(sample, -s, s) / s
- sample = sample.reshape(batch_size, channels, *remaining_dims)
- return sample.to(dtype)
-
- def _sigma_to_t(self, sigma):
- return sigma * self.config.num_train_timesteps
-
- def _sigma_to_alpha_sigma_t(self, sigma):
- return 1 - sigma, sigma
-
- def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
- return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
-
- def convert_model_output(
- self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs
- ) -> torch.Tensor:
- timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
- if sample is None:
- if len(args) > 1:
- sample = args[1]
- else:
- raise ValueError("missing `sample` as a required keyword argument")
- if timestep is not None:
- deprecate(
- "timesteps",
- "1.0.0",
- "Passing `timesteps` is deprecated and has no effect as model output "
- "conversion is now handled via an internal counter `self.step_index`",
- )
-
- sigma = self.sigmas[self.step_index]
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
-
- if self.predict_x0:
- if self.config.prediction_type == "flow_prediction":
- sigma_t = self.sigmas[self.step_index]
- x0_pred = sample - sigma_t * model_output
- else:
- raise ValueError(
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
- " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
- )
- if self.config.thresholding:
- x0_pred = self._threshold_sample(x0_pred)
- return x0_pred
- else:
- if self.config.prediction_type == "flow_prediction":
- sigma_t = self.sigmas[self.step_index]
- epsilon = sample - (1 - sigma_t) * model_output
- else:
- raise ValueError(
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
- " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
- )
- if self.config.thresholding:
- sigma_t = self.sigmas[self.step_index]
- x0_pred = sample - sigma_t * model_output
- x0_pred = self._threshold_sample(x0_pred)
- epsilon = model_output + x0_pred
- return epsilon
-
- def multistep_uni_p_bh_update(
- self,
- model_output: torch.Tensor,
- *args,
- sample: torch.Tensor | None = None,
- order: int | None = None,
- **kwargs,
- ) -> torch.Tensor:
- prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
- if sample is None:
- if len(args) > 1:
- sample = args[1]
- else:
- raise ValueError(" missing `sample` as a required keyword argument")
- if order is None:
- if len(args) > 2:
- order = args[2]
- else:
- raise ValueError(" missing `order` as a required keyword argument")
- if prev_timestep is not None:
- deprecate("prev_timestep", "1.0.0", "Passing `prev_timestep` is deprecated and has no effect.")
-
- model_output_list = self.model_outputs
- s0 = self.timestep_list[-1]
- m0 = model_output_list[-1]
- x = sample
-
- if self.solver_p:
- return self.solver_p.step(model_output, s0, x).prev_sample
-
- sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
- alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
-
- lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
- lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
- h = lambda_t - lambda_s0
- device = sample.device
-
- rks = []
- D1s: list[Any] | None = []
- for i in range(1, order):
- si = self.step_index - i
- mi = model_output_list[-(i + 1)]
- alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
- lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
- rk = (lambda_si - lambda_s0) / h
- rks.append(rk)
- D1s.append((mi - m0) / rk)
-
- rks.append(1.0)
- rks = torch.tensor(rks, device=device)
-
- R = []
- b = []
- hh = -h if self.predict_x0 else h
- h_phi_1 = torch.expm1(hh)
- h_phi_k = h_phi_1 / hh - 1
- factorial_i = 1
-
- if self.config.solver_type == "bh1":
- B_h = hh
- elif self.config.solver_type == "bh2":
- B_h = torch.expm1(hh)
- else:
- raise NotImplementedError()
-
- for i in range(1, order + 1):
- R.append(torch.pow(rks, i - 1))
- b.append(h_phi_k * factorial_i / B_h)
- factorial_i *= i + 1
- h_phi_k = h_phi_k / hh - 1 / factorial_i
-
- R = torch.stack(R)
- b = torch.tensor(b, device=device)
-
- if len(D1s) > 0:
- D1s = torch.stack(D1s, dim=1)
- if order == 2:
- rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
- else:
- rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
- else:
- D1s = None
-
- if self.predict_x0:
- x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
- pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) if D1s is not None else 0
- x_t = x_t_ - alpha_t * B_h * pred_res
- else:
- x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
- pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) if D1s is not None else 0
- x_t = x_t_ - sigma_t * B_h * pred_res
-
- return x_t.to(x.dtype)
-
- def multistep_uni_c_bh_update(
- self,
- this_model_output: torch.Tensor,
- *args,
- last_sample: torch.Tensor = None,
- this_sample: torch.Tensor = None,
- order: int | None = None,
- **kwargs,
- ) -> torch.Tensor:
- this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
- if last_sample is None:
- if len(args) > 1:
- last_sample = args[1]
- else:
- raise ValueError(" missing`last_sample` as a required keyword argument")
- if this_sample is None:
- if len(args) > 2:
- this_sample = args[2]
- else:
- raise ValueError(" missing`this_sample` as a required keyword argument")
- if order is None:
- if len(args) > 3:
- order = args[3]
- else:
- raise ValueError(" missing`order` as a required keyword argument")
- if this_timestep is not None:
- deprecate("this_timestep", "1.0.0", "Passing `this_timestep` is deprecated and has no effect.")
-
- model_output_list = self.model_outputs
- m0 = model_output_list[-1]
- x = last_sample
- x_t = this_sample
- model_t = this_model_output
-
- sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
- alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
-
- lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
- lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
- h = lambda_t - lambda_s0
- device = this_sample.device
-
- rks = []
- D1s: list[Any] | None = []
- for i in range(1, order):
- si = self.step_index - (i + 1)
- mi = model_output_list[-(i + 1)]
- alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
- lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
- rk = (lambda_si - lambda_s0) / h
- rks.append(rk)
- D1s.append((mi - m0) / rk)
-
- rks.append(1.0)
- rks = torch.tensor(rks, device=device)
-
- R = []
- b = []
- hh = -h if self.predict_x0 else h
- h_phi_1 = torch.expm1(hh)
- h_phi_k = h_phi_1 / hh - 1
- factorial_i = 1
-
- if self.config.solver_type == "bh1":
- B_h = hh
- elif self.config.solver_type == "bh2":
- B_h = torch.expm1(hh)
- else:
- raise NotImplementedError()
-
- for i in range(1, order + 1):
- R.append(torch.pow(rks, i - 1))
- b.append(h_phi_k * factorial_i / B_h)
- factorial_i *= i + 1
- h_phi_k = h_phi_k / hh - 1 / factorial_i
-
- R = torch.stack(R)
- b = torch.tensor(b, device=device)
-
- if len(D1s) > 0:
- D1s = torch.stack(D1s, dim=1)
- else:
- D1s = None
-
- if order == 1:
- rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
- else:
- rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
-
- if self.predict_x0:
- x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
- corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) if D1s is not None else 0
- D1_t = model_t - m0
- x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
- else:
- x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
- corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) if D1s is not None else 0
- D1_t = model_t - m0
- x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
- return x_t.to(x.dtype)
-
- def index_for_timestep(self, timestep, schedule_timesteps=None):
- if schedule_timesteps is None:
- schedule_timesteps = self.timesteps
- indices = (schedule_timesteps == timestep).nonzero()
- pos = 1 if len(indices) > 1 else 0
- return indices[pos].item()
-
- def _init_step_index(self, timestep):
- if self.begin_index is None:
- if isinstance(timestep, torch.Tensor):
- timestep = timestep.to(self.timesteps.device)
- self._step_index = self.index_for_timestep(timestep)
- else:
- self._step_index = self._begin_index
-
- def step(
- self,
- model_output: torch.Tensor,
- timestep: int | torch.Tensor,
- sample: torch.Tensor,
- return_dict: bool = True,
- generator=None,
- ) -> SchedulerOutput | tuple:
- if self.num_inference_steps is None:
- raise ValueError(
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
- )
-
- if self.step_index is None:
- self._init_step_index(timestep)
-
- use_corrector = (
- self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
- )
-
- model_output_convert = self.convert_model_output(model_output, sample=sample)
- if use_corrector:
- sample = self.multistep_uni_c_bh_update(
- this_model_output=model_output_convert,
- last_sample=self.last_sample,
- this_sample=sample,
- order=self.this_order,
- )
-
- for i in range(self.config.solver_order - 1):
- self.model_outputs[i] = self.model_outputs[i + 1]
- self.timestep_list[i] = self.timestep_list[i + 1]
-
- self.model_outputs[-1] = model_output_convert
- self.timestep_list[-1] = timestep
-
- if self.config.lower_order_final:
- this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index)
- else:
- this_order = self.config.solver_order
-
- self.this_order = min(this_order, self.lower_order_nums + 1)
- assert self.this_order > 0
-
- self.last_sample = sample
- prev_sample = self.multistep_uni_p_bh_update(model_output=model_output, sample=sample, order=self.this_order)
-
- if self.lower_order_nums < self.config.solver_order:
- self.lower_order_nums += 1
-
- self._step_index += 1
-
- if not return_dict:
- return (prev_sample,)
- return SchedulerOutput(prev_sample=prev_sample)
-
- def step_ddim(
- self,
- velocity: torch.FloatTensor,
- t: int,
- curr_state: torch.FloatTensor,
- prev_state: torch.FloatTensor | None = None,
- generator: torch.Generator | None = None,
- ):
- device = curr_state.device
- curr_t = self.sigmas[t]
- prev_t = self.sigmas[t + 1]
- variance_noise = randn_tensor(curr_state.shape, generator=generator, device=device, dtype=curr_state.dtype)
- cur_clean_ = curr_state - curr_t * velocity
- return prev_t * variance_noise + (1 - prev_t) * cur_clean_
-
- def step_sde(
- self,
- velocity: torch.FloatTensor,
- t: int,
- curr_state: torch.FloatTensor,
- noise_theta: float = 1.0,
- prev_state: torch.FloatTensor | None = None,
- generator: torch.Generator | None = None,
- ):
- device = curr_state.device
- curr_t = self.sigmas[t]
- prev_t = self.sigmas[t + 1]
- cos = torch.cos(torch.tensor(noise_theta) * torch.pi / 2).to(device)
- sin = torch.sin(torch.tensor(noise_theta) * torch.pi / 2).to(device)
- prev_sample_mean = (1 - prev_t + prev_t * cos) * (curr_state - curr_t * velocity) + prev_t * cos * velocity
- std_dev_t = prev_t * sin
- std_dev_t = torch.ones((1, 1)).to(curr_state) * std_dev_t
- if prev_state is None:
- variance_noise = randn_tensor(curr_state.shape, generator=generator, device=device, dtype=curr_state.dtype)
- prev_state = prev_sample_mean + std_dev_t * variance_noise
- else:
- prev_state = prev_sample_mean + (prev_state - prev_sample_mean.detach())
- return prev_state
-
- def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
- return sample
-
- def add_noise(
- self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor
- ) -> torch.Tensor:
- sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
- if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
- schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
- timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
- else:
- schedule_timesteps = self.timesteps.to(original_samples.device)
- timesteps = timesteps.to(original_samples.device)
-
- if self.begin_index is None:
- step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
- elif self.step_index is not None:
- step_indices = [self.step_index] * timesteps.shape[0]
- else:
- step_indices = [self.begin_index] * timesteps.shape[0]
-
- sigma = sigmas[step_indices].flatten()
- while len(sigma.shape) < len(original_samples.shape):
- sigma = sigma.unsqueeze(-1)
-
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
- return alpha_t * original_samples + sigma_t * noise
-
- def __len__(self):
- return self.config.num_train_timesteps
-
-
-# ===========================================================================
-# Audio VAE (ported from daVinci-MagiHuman inference/model/sa_audio/)
-# ===========================================================================
-def _snake_beta(x, alpha, beta):
- return x + (1.0 / (beta + 1e-9)) * torch.pow(torch.sin(x * alpha), 2)
-
-
-class _SnakeBeta(nn.Module):
- def __init__(self, in_features: int, alpha: float = 1.0, alpha_trainable: bool = True, alpha_logscale: bool = True):
- super().__init__()
- self.alpha_logscale = alpha_logscale
- if self.alpha_logscale:
- self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
- self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
- else:
- self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
- self.beta = nn.Parameter(torch.ones(in_features) * alpha)
- self.alpha.requires_grad = alpha_trainable
- self.beta.requires_grad = alpha_trainable
-
- def forward(self, x):
- alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
- beta = self.beta.unsqueeze(0).unsqueeze(-1)
- if self.alpha_logscale:
- alpha = torch.exp(alpha)
- beta = torch.exp(beta)
- return _snake_beta(x, alpha, beta)
-
-
-def _vae_sample(mean, scale):
- stdev = F.softplus(scale) + 1e-4
- var = stdev * stdev
- logvar = torch.log(var)
- latents = torch.randn_like(mean) * stdev + mean
- kl = (mean * mean + var - logvar - 1).sum(1).mean()
- return latents, kl
-
-
-class _VAEBottleneck(nn.Module):
- def encode(self, x, return_info=False, **kwargs):
- info = {}
- mean, scale = x.chunk(2, dim=1)
- x, kl = _vae_sample(mean, scale)
- info["kl"] = kl
- return (x, info) if return_info else x
-
- def decode(self, x):
- return x
-
-
-def _WNConv1d(*args, **kwargs):
- return weight_norm(nn.Conv1d(*args, **kwargs))
-
-
-def _WNConvTranspose1d(*args, **kwargs):
- return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
-
-
-def _checkpoint(function, *args, **kwargs):
- kwargs.setdefault("use_reentrant", False)
- return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
-
-
-def _get_activation(activation: Literal["elu", "snake", "none"], antialias: bool = False, channels=None) -> nn.Module:
- if antialias:
- raise NotImplementedError("antialias activation not supported")
- if activation == "elu":
- return nn.ELU()
- if activation == "snake":
- return _SnakeBeta(channels)
- if activation == "none":
- return nn.Identity()
- raise ValueError(f"Unknown activation {activation}")
-
-
-class _ResidualUnit(nn.Module):
- def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
- super().__init__()
- padding = (dilation * (7 - 1)) // 2
- self.layers = nn.Sequential(
- _get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
- _WNConv1d(in_channels, out_channels, kernel_size=7, dilation=dilation, padding=padding),
- _get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
- _WNConv1d(out_channels, out_channels, kernel_size=1),
- )
-
- def forward(self, x):
- return (_checkpoint(self.layers, x) if self.training else self.layers(x)) + x
-
-
-class _EncoderBlock(nn.Module):
- def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
- super().__init__()
- self.layers = nn.Sequential(
- _ResidualUnit(in_channels, in_channels, 1, use_snake=use_snake),
- _ResidualUnit(in_channels, in_channels, 3, use_snake=use_snake),
- _ResidualUnit(in_channels, in_channels, 9, use_snake=use_snake),
- _get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
- _WNConv1d(in_channels, out_channels, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)),
- )
-
- def forward(self, x):
- return self.layers(x)
-
-
-class _DecoderBlock(nn.Module):
- def __init__(
- self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False
- ):
- super().__init__()
- if use_nearest_upsample:
- upsample_layer = nn.Sequential(
- nn.Upsample(scale_factor=stride, mode="nearest"),
- _WNConv1d(in_channels, out_channels, kernel_size=2 * stride, stride=1, bias=False, padding="same"),
- )
- else:
- upsample_layer = _WNConvTranspose1d(
- in_channels, out_channels, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)
- )
- self.layers = nn.Sequential(
- _get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
- upsample_layer,
- _ResidualUnit(out_channels, out_channels, 1, use_snake=use_snake),
- _ResidualUnit(out_channels, out_channels, 3, use_snake=use_snake),
- _ResidualUnit(out_channels, out_channels, 9, use_snake=use_snake),
- )
-
- def forward(self, x):
- return self.layers(x)
-
-
-class _OobleckEncoder(nn.Module):
- def __init__(
- self,
- in_channels=2,
- channels=128,
- latent_dim=32,
- c_mults=[1, 2, 4, 8],
- strides=[2, 4, 8, 8],
- use_snake=False,
- antialias_activation=False,
- ):
- super().__init__()
- c_mults = [1] + c_mults
- depth = len(c_mults)
- layers = [_WNConv1d(in_channels, c_mults[0] * channels, kernel_size=7, padding=3)]
- for i in range(depth - 1):
- layers.append(
- _EncoderBlock(c_mults[i] * channels, c_mults[i + 1] * channels, strides[i], use_snake=use_snake)
- )
- layers.extend(
- [
- _get_activation(
- "snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels
- ),
- _WNConv1d(c_mults[-1] * channels, latent_dim, kernel_size=3, padding=1),
- ]
- )
- self.layers = nn.Sequential(*layers)
-
- def forward(self, x):
- return self.layers(x)
-
-
-class _OobleckDecoder(nn.Module):
- def __init__(
- self,
- out_channels=2,
- channels=128,
- latent_dim=32,
- c_mults=[1, 2, 4, 8],
- strides=[2, 4, 8, 8],
- use_snake=False,
- antialias_activation=False,
- use_nearest_upsample=False,
- final_tanh=True,
- ):
- super().__init__()
- c_mults = [1] + c_mults
- depth = len(c_mults)
- layers = [_WNConv1d(latent_dim, c_mults[-1] * channels, kernel_size=7, padding=3)]
- for i in range(depth - 1, 0, -1):
- layers.append(
- _DecoderBlock(
- c_mults[i] * channels,
- c_mults[i - 1] * channels,
- strides[i - 1],
- use_snake=use_snake,
- antialias_activation=antialias_activation,
- use_nearest_upsample=use_nearest_upsample,
- )
- )
- layers.extend(
- [
- _get_activation(
- "snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels
- ),
- _WNConv1d(c_mults[0] * channels, out_channels, kernel_size=7, padding=3, bias=False),
- nn.Tanh() if final_tanh else nn.Identity(),
- ]
- )
- self.layers = nn.Sequential(*layers)
-
- def forward(self, x):
- return self.layers(x)
-
-
-class _AudioAutoencoder(nn.Module):
- def __init__(
- self,
- encoder,
- decoder,
- latent_dim,
- downsampling_ratio,
- sample_rate,
- io_channels=2,
- bottleneck=None,
- in_channels=None,
- out_channels=None,
- soft_clip=False,
- ):
- super().__init__()
- self.downsampling_ratio = downsampling_ratio
- self.sample_rate = sample_rate
- self.latent_dim = latent_dim
- self.io_channels = io_channels
- self.in_channels = in_channels if in_channels is not None else io_channels
- self.out_channels = out_channels if out_channels is not None else io_channels
- self.bottleneck = bottleneck
- self.encoder = encoder
- self.decoder = decoder
- self.soft_clip = soft_clip
-
- def encode(self, audio, skip_bottleneck=False, return_info=False, **kwargs):
- info = {}
- latents = self.encoder(audio)
- info["pre_bottleneck_latents"] = latents
- if self.bottleneck is not None and not skip_bottleneck:
- latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
- info.update(bottleneck_info)
- return (latents, info) if return_info else latents
-
- def decode(self, latents, skip_bottleneck=False, **kwargs):
- if self.bottleneck is not None and not skip_bottleneck:
- latents = self.bottleneck.decode(latents)
- decoded = self.decoder(latents, **kwargs)
- if self.soft_clip:
- decoded = torch.tanh(decoded)
- return decoded
-
-
-def _create_encoder_from_config(cfg: dict[str, Any]):
- assert cfg.get("type") == "oobleck", f"Only 'oobleck' encoder supported, got: {cfg.get('type')}"
- enc = _OobleckEncoder(**cfg["config"])
- if not cfg.get("requires_grad", True):
- for p in enc.parameters():
- p.requires_grad = False
- return enc
-
-
-def _create_decoder_from_config(cfg: dict[str, Any]):
- assert cfg.get("type") == "oobleck", f"Only 'oobleck' decoder supported, got: {cfg.get('type')}"
- dec = _OobleckDecoder(**cfg["config"])
- if not cfg.get("requires_grad", True):
- for p in dec.parameters():
- p.requires_grad = False
- return dec
-
-
-def _create_bottleneck_from_config(cfg: dict[str, Any]):
- assert cfg.get("type") == "vae", f"Only 'vae' bottleneck supported, got: {cfg.get('type')}"
- bn = _VAEBottleneck()
- if not cfg.get("requires_grad", True):
- for p in bn.parameters():
- p.requires_grad = False
- return bn
-
-
-def _create_autoencoder_from_config(config: dict[str, Any]):
- ae_config = config["model"]
- if ae_config.get("pretransform") is not None:
- raise NotImplementedError("Nested pretransform not supported")
- encoder = _create_encoder_from_config(ae_config["encoder"])
- decoder = _create_decoder_from_config(ae_config["decoder"])
- bottleneck_cfg = ae_config.get("bottleneck")
- bottleneck = _create_bottleneck_from_config(bottleneck_cfg) if bottleneck_cfg else None
- return _AudioAutoencoder(
- encoder=encoder,
- decoder=decoder,
- latent_dim=ae_config["latent_dim"],
- downsampling_ratio=ae_config["downsampling_ratio"],
- sample_rate=config["sample_rate"],
- io_channels=ae_config["io_channels"],
- bottleneck=bottleneck,
- in_channels=ae_config.get("in_channels"),
- out_channels=ae_config.get("out_channels"),
- soft_clip=ae_config["decoder"].get("soft_clip", False),
- )
-
-
-class SAAudioFeatureExtractor:
- def __init__(self, device, model_path):
- self.device = device
- self.vae_model, self.sample_rate = self._load_vae(model_path)
- self.resampler = None
-
- def _load_vae(self, model_path):
- if not (isinstance(model_path, str) and Path(model_path).is_dir()):
- raise ValueError("model_path must be a local directory")
-
- model_config_path = os.path.join(model_path, "model_config.json")
- with open(model_config_path) as f:
- full_config = json.load(f)
-
- vae_config = full_config["model"]["pretransform"]["config"]
- sample_rate = full_config["sample_rate"]
-
- autoencoder_config = {
- "model_type": "autoencoder",
- "sample_rate": sample_rate,
- "model": vae_config,
- }
- vae_model = _create_autoencoder_from_config(autoencoder_config)
-
- weights_path = Path(model_path) / "model.safetensors"
- if not weights_path.exists():
- raise FileNotFoundError(f"Weight file does not exist: {weights_path}")
-
- full_state_dict = load_file(weights_path, device=str(self.device))
- vae_state_dict = {}
- for key, value in full_state_dict.items():
- if key.startswith("pretransform.model."):
- vae_state_dict[key[len("pretransform.model.") :]] = value
-
- model_keys = set(vae_model.state_dict().keys())
- vae_keys = set(vae_state_dict.keys())
- missing = model_keys - vae_keys
- extra = vae_keys - model_keys
- if missing:
- logger.warning("Audio VAE missing keys (%d): %s", len(missing), list(missing)[:5])
- if extra:
- logger.warning("Audio VAE unexpected keys (%d): %s", len(extra), list(extra)[:5])
-
- vae_model.load_state_dict(vae_state_dict)
- vae_model.to(self.device)
- return vae_model, sample_rate
-
- def decode(self, latents):
- with torch.no_grad():
- return self.vae_model.decode(latents)
-
- def encode(self, waveform):
- with torch.no_grad():
- return self.vae_model.encode(waveform)
-
-
-# ===========================================================================
-# Audio utilities (ported from daVinci-MagiHuman inference/pipeline/video_process.py)
-# ===========================================================================
-_SAMPLE_RATE = 51200
-_AUDIO_CHUNK_DURATION = 29
-_OVERLAP_RATIO = 0.5
-
-
-def _merge_overlapping_vae_features(audio_feats: list[torch.Tensor], overlap_ratio: float = 0.5) -> torch.Tensor | None:
- if not audio_feats:
- return None
- if len(audio_feats) == 1:
- return audio_feats[0]
-
- batch_size, total_frames, feature_dim = audio_feats[0].shape
- overlap_frames = int(total_frames * overlap_ratio)
- step_frames = total_frames - overlap_frames
- final_length = (len(audio_feats) - 1) * step_frames + total_frames
- output_feat = torch.zeros(
- batch_size, final_length, feature_dim, device=audio_feats[0].device, dtype=audio_feats[0].dtype
- )
-
- for block_idx, current_feat in enumerate(audio_feats):
- output_start = block_idx * step_frames
- if block_idx == 0:
- output_feat[:, output_start : output_start + total_frames, :] = current_feat
- continue
-
- non_overlap_start = output_start + overlap_frames
- non_overlap_end = output_start + total_frames
- output_feat[:, non_overlap_start:non_overlap_end, :] = current_feat[:, overlap_frames:, :]
-
- for frame_idx in range(overlap_frames):
- output_pos = output_start + frame_idx
- prev_weight = (overlap_frames - frame_idx) / overlap_frames
- curr_weight = frame_idx / overlap_frames
- output_feat[:, output_pos, :] = (
- prev_weight * output_feat[:, output_pos, :] + curr_weight * current_feat[:, frame_idx, :]
- )
- return output_feat
-
-
-def load_audio_and_encode(audio_vae, audio_path: str, seconds: int | None = None) -> torch.Tensor:
- """Load audio from file and encode to latent space using the Stable Audio VAE."""
- audio_full = whisper.load_audio(audio_path, sr=_SAMPLE_RATE)
- if seconds is not None:
- audio_full = audio_full[: min(int(seconds * _SAMPLE_RATE), audio_full.shape[0])]
- total_samples = audio_full.shape[0]
-
- window_size = int(_AUDIO_CHUNK_DURATION * _SAMPLE_RATE)
- step_size = int(window_size * (1 - _OVERLAP_RATIO))
- if total_samples <= window_size:
- audio = torch.from_numpy(audio_full).cuda()
- audio = audio.unsqueeze(0).expand(2, -1)
- return audio_vae.vae_model.encode(audio)
-
- encoded_chunks = []
- latent_to_audio_ratio = None
- for offset_start in range(0, total_samples, step_size):
- offset_end = min(offset_start + window_size, total_samples)
- chunk = whisper.pad_or_trim(audio_full[offset_start:offset_end], length=window_size)
- chunk_tensor = torch.from_numpy(chunk).cuda().unsqueeze(0).expand(2, -1)
- encoded_chunk = audio_vae.vae_model.encode(chunk_tensor)
-
- if latent_to_audio_ratio is None:
- latent_to_audio_ratio = encoded_chunk.shape[-1] / window_size
-
- encoded_chunks.append(encoded_chunk.permute(0, 2, 1))
- if offset_end >= total_samples:
- break
-
- final_feat = _merge_overlapping_vae_features(encoded_chunks, overlap_ratio=_OVERLAP_RATIO).permute(0, 2, 1)
- final_target_len = math.ceil(total_samples * latent_to_audio_ratio)
- return final_feat[:, :, :final_target_len]
-
-
-# ===========================================================================
-# Data proxy (ported from daVinci-MagiHuman inference/pipeline/data_proxy.py)
-# ===========================================================================
-def _unfold_3d(x: torch.Tensor, kernel_size: tuple[int, int, int], stride: tuple[int, int, int]) -> torch.Tensor:
- """Pure-PyTorch 3D unfold matching UnfoldAnd behavior.
-
- After N unfold ops the shape is (batch, C, oD, oH, oW, kD, kH, kW).
- UnfoldAnd permutes kernel dims next to channel before reshape so that the
- col_dim axis is ordered as (C, kD, kH, kW) -- matching F.unfold semantics.
- Without this permute, .view() interleaves spatial and kernel positions.
-
- Args:
- x: (N, C, D, H, W)
- kernel_size: (kD, kH, kW)
- stride: (sD, sH, sW)
- Returns:
- (N, C*kD*kH*kW, L) where L = product of output spatial dims.
- """
- ndim = len(kernel_size)
- for d in range(ndim):
- x = x.unfold(d + 2, kernel_size[d], stride[d])
- perm = [0, 1] + list(range(ndim + 2, 2 * ndim + 2)) + list(range(2, ndim + 2))
- x = x.permute(*perm).contiguous()
-
- batch_size = x.shape[0]
- col_dim = 1
- for i in range(1, ndim + 2):
- col_dim *= x.shape[i]
- spatial = 1
- for i in range(ndim + 2, 2 * ndim + 2):
- spatial *= x.shape[i]
- return x.view(batch_size, col_dim, spatial)
-
-
-def _calc_local_qk_range(num_video_tokens, num_audio_and_txt_tokens, num_frames, frame_receptive_field):
- token_per_frame = num_video_tokens // num_frames
- total_tokens = num_video_tokens + num_audio_and_txt_tokens
-
- q_range_list = []
- k_range_list = []
- for i in range(num_frames):
- q_range_list.append(torch.tensor([i * token_per_frame, (i + 1) * token_per_frame]))
- k_range_list.append(
- torch.tensor(
- [
- (i - frame_receptive_field) * token_per_frame,
- (i + frame_receptive_field + 1) * token_per_frame,
- ]
- )
- )
- local_q_range = torch.stack(q_range_list, dim=0)
- local_k_range = torch.stack(k_range_list, dim=0)
-
- local_k_range[local_k_range < 0] = 0
- local_k_range[local_k_range > num_video_tokens] = num_video_tokens
-
- video_q_range = torch.tensor([[0, num_video_tokens]])
- video_k_range = torch.tensor([[num_video_tokens, num_video_tokens + num_audio_and_txt_tokens]])
-
- at_q_ranges = torch.tensor([[num_video_tokens, total_tokens]])
- at_k_ranges = torch.tensor([[0, total_tokens]])
-
- q_ranges = (
- torch.cat([local_q_range, video_q_range, at_q_ranges], dim=0).to(torch.int32).to("cuda", non_blocking=True)
- )
- k_ranges = (
- torch.cat([local_k_range, video_k_range, at_k_ranges], dim=0).to(torch.int32).to("cuda", non_blocking=True)
- )
- return q_ranges, k_ranges
-
-
-def _calc_local_attn_ffa_handler(num_video_tokens, num_audio_and_txt_tokens, num_frames, frame_receptive_field):
- q_ranges, k_ranges = _calc_local_qk_range(
- num_video_tokens, num_audio_and_txt_tokens, num_frames, frame_receptive_field
- )
- total = num_video_tokens + num_audio_and_txt_tokens
- return FFAHandler(
- q_ranges=q_ranges,
- k_ranges=k_ranges,
- max_seqlen_q=total,
- max_seqlen_k=total,
- attn_type_map=torch.zeros([q_ranges.shape[0]], device="cuda", dtype=torch.int32),
- softmax_scale=None,
- )
-
-
-def _get_coords(
- shape: list[int],
- ref_feat_shape: list[int],
- offset_thw: list[int] | None = None,
- device: torch.device = torch.device("cpu"),
- dtype: torch.dtype = torch.float32,
-):
- if offset_thw is None:
- offset_thw = [0, 0, 0]
- ori_t, ori_h, ori_w = shape
- ref_t, ref_h, ref_w = ref_feat_shape
-
- offset_t, offset_h, offset_w = offset_thw
- time_rng = torch.arange(ori_t, device=device, dtype=dtype) + offset_t
- height_rng = torch.arange(ori_h, device=device, dtype=dtype) + offset_h
- width_rng = torch.arange(ori_w, device=device, dtype=dtype) + offset_w
-
- time_grid, height_grid, width_grid = torch.meshgrid(time_rng, height_rng, width_rng, indexing="ij")
- coords_flat = torch.stack([time_grid, height_grid, width_grid], dim=-1).reshape(-1, 3)
-
- meta = torch.tensor([ori_t, ori_h, ori_w, ref_t, ref_h, ref_w], device=device, dtype=dtype)
- meta_expanded = meta.expand(coords_flat.size(0), -1)
- return torch.cat([coords_flat, meta_expanded], dim=-1)
-
-
-@dataclass
-class _SingleData:
- video_x_t: torch.Tensor
- audio_x_t: torch.Tensor
- audio_feat_len: int
- txt_feat: torch.Tensor
- txt_feat_len: int
- t: int
- h: int
- w: int
- patch_size: int
- t_patch_size: int
- spatial_rope_interpolation: Literal["inter", "extra"]
- ref_audio_offset: int
- text_offset: int
- coords_style: Literal["v1", "v2"] = "v1"
-
- def __post_init__(self):
- self.video_token_num = self.video_x_t.shape[0]
- self.audio_x_t = self.audio_x_t[: self.audio_feat_len]
- self.txt_feat = self.txt_feat[: self.txt_feat_len]
- self.video_channel = self.video_x_t.shape[-1]
- self.audio_channel = self.audio_x_t.shape[-1]
- self.txt_channel = self.txt_feat.shape[-1]
-
- @property
- def device(self):
- return self.video_x_t.device
-
- @property
- def default_dtype(self):
- return self.video_x_t.dtype
-
- @property
- def total_token_num(self):
- return self.video_token_num + self.audio_feat_len + self.txt_feat_len
-
- @property
- def token_sequence(self):
- tensors = [self.video_x_t, self.audio_x_t, self.txt_feat]
- max_channel = max(t.shape[-1] for t in tensors)
- padded = [F.pad(t, (0, max_channel - t.shape[-1])) for t in tensors]
- return torch.cat(padded, dim=0)
-
- @property
- def modality_mapping(self):
- v_map = torch.full((self.video_token_num,), Modality.VIDEO, dtype=torch.int64, device=self.device)
- a_map = torch.full((self.audio_feat_len,), Modality.AUDIO, dtype=torch.int64, device=self.device)
- t_map = torch.full((self.txt_feat_len,), Modality.TEXT, dtype=torch.int64, device=self.device)
- return torch.cat([v_map, a_map, t_map], dim=0)
-
- def _default_coords(self, shape, ref_feat_shape, offset_thw=None):
- if offset_thw is None:
- offset_thw = [0, 0, 0]
- return _get_coords(
- shape=shape,
- ref_feat_shape=ref_feat_shape,
- offset_thw=offset_thw,
- device=self.device,
- dtype=self.default_dtype,
- )
-
- @property
- def coords_mapping(self):
- if self.spatial_rope_interpolation == "inter":
- video_ref_feat_shape = (self.t // self.t_patch_size, 32, 32)
- else:
- video_ref_feat_shape = (self.t // self.t_patch_size, self.h // self.patch_size, self.w // self.patch_size)
-
- video_coords = self._default_coords(
- shape=(self.t // self.t_patch_size, self.h // self.patch_size, self.w // self.patch_size),
- ref_feat_shape=video_ref_feat_shape,
- )
-
- if self.coords_style == "v1":
- audio_coords = self._default_coords(
- shape=(self.audio_feat_len, 1, 1),
- ref_feat_shape=(self.t // self.t_patch_size, 1, 1),
- )
- text_coords = self._default_coords(
- shape=(self.txt_feat_len, 1, 1),
- ref_feat_shape=(2, 1, 1),
- offset_thw=[self.text_offset, 0, 0],
- )
- elif self.coords_style == "v2":
- magic_audio_ref_t = (self.audio_feat_len - 1) // 4 + 1
- audio_coords = self._default_coords(
- shape=(self.audio_feat_len, 1, 1),
- ref_feat_shape=(magic_audio_ref_t // self.t_patch_size, 1, 1),
- )
- text_coords = self._default_coords(
- shape=(self.txt_feat_len, 1, 1),
- ref_feat_shape=(1, 1, 1),
- offset_thw=[-self.txt_feat_len, 0, 0],
- )
- else:
- raise ValueError(f"Unknown coords_style: {self.coords_style}")
-
- return torch.cat([video_coords, audio_coords, text_coords], dim=0)
-
- def depack_token_sequence(self, token_sequence):
- video_x_t = token_sequence[: self.video_token_num, : self.video_channel]
- video_x_t = rearrange(
- video_x_t,
- "(T H W) (pT pH pW C) -> C (T pT) (H pH) (W pW)",
- H=self.h // self.patch_size,
- W=self.w // self.patch_size,
- pT=self.t_patch_size,
- pH=self.patch_size,
- pW=self.patch_size,
- ).contiguous()
- audio_x_t = token_sequence[
- self.video_token_num : self.video_token_num + self.audio_feat_len, : self.audio_channel
- ]
- return video_x_t, audio_x_t
-
-
-@dataclass
-class _SimplePackedData:
- items: list[_SingleData]
-
- @property
- def token_sequence(self):
- return torch.cat([item.token_sequence for item in self.items], dim=0)
-
- @property
- def modality_mapping(self):
- return torch.cat([item.modality_mapping for item in self.items], dim=0)
-
- @property
- def coords_mapping(self):
- return torch.cat([item.coords_mapping for item in self.items], dim=0)
-
- @property
- def total_token_num(self):
- return sum(item.total_token_num for item in self.items)
-
- def __getitem__(self, index):
- return self.items[index]
-
- @property
- def cu_seqlen(self):
- cu = torch.cumsum(torch.tensor([item.total_token_num for item in self.items]), dim=0)
- return F.pad(cu, (1, 0))
-
- @property
- def max_seqlen(self):
- return torch.tensor(max(item.total_token_num for item in self.items))
-
- def depack_token_sequence(self, token_sequence):
- video_list, audio_list = [], []
- parts = torch.split(token_sequence, [item.total_token_num for item in self.items], dim=0)
- for item, part in zip(self.items, parts):
- v, a = item.depack_token_sequence(part)
- video_list.append(v)
- audio_list.append(a)
- return torch.stack(video_list, dim=0), torch.stack(audio_list, dim=0)
-
-
-class MagiDataProxy:
- def __init__(
- self,
- patch_size: int = 2,
- t_patch_size: int = 1,
- frame_receptive_field: int = 11,
- spatial_rope_interpolation: str = "extra",
- ref_audio_offset: int = 1000,
- text_offset: int = 0,
- coords_style: str = "v2",
- ):
- self.patch_size = patch_size
- self.t_patch_size = t_patch_size
- self.frame_receptive_field = frame_receptive_field
- self.spatial_rope_interpolation = spatial_rope_interpolation
- self.ref_audio_offset = ref_audio_offset
- self.text_offset = text_offset
- self.coords_style = coords_style
- self._kernel = (t_patch_size, patch_size, patch_size)
- self._stride = (t_patch_size, patch_size, patch_size)
- self._saved_data: dict[str, Any] = {}
-
- def saved_for_output(self, **kwargs):
- self._saved_data.update(kwargs)
-
- def get_saved_data(self, key: str):
- return self._saved_data[key]
-
- def img2tokens(self, x_t: torch.Tensor):
- x_t_unfolded = _unfold_3d(x_t, self._kernel, self._stride)
- return rearrange(x_t_unfolded, "N col_dim num_tokens -> N num_tokens col_dim").contiguous()
-
- def process_input(self, transported_data: EvalInput):
- batch_size, _, t, h, w = transported_data.x_t.shape
- x_t = self.img2tokens(transported_data.x_t)
- audio_x_t = transported_data.audio_x_t.contiguous()
- text_in = transported_data.txt_feat.contiguous()
-
- simple_packed_data = _SimplePackedData(items=[])
- for i in range(batch_size):
- single_data = _SingleData(
- video_x_t=x_t[i],
- audio_x_t=audio_x_t[i],
- audio_feat_len=transported_data.audio_feat_len[i],
- txt_feat=text_in[i],
- txt_feat_len=transported_data.txt_feat_len[i],
- t=t,
- h=h,
- w=w,
- patch_size=self.patch_size,
- t_patch_size=self.t_patch_size,
- spatial_rope_interpolation=self.spatial_rope_interpolation,
- ref_audio_offset=self.ref_audio_offset,
- text_offset=self.text_offset,
- coords_style=self.coords_style,
- )
- simple_packed_data.items.append(single_data)
-
- if self.frame_receptive_field != -1:
- assert batch_size == 1, "local attention only supports batch size 1"
- local_attn_handler = _calc_local_attn_ffa_handler(
- num_video_tokens=simple_packed_data[0].video_token_num,
- num_audio_and_txt_tokens=simple_packed_data[0].audio_feat_len + simple_packed_data[0].txt_feat_len,
- num_frames=t,
- frame_receptive_field=self.frame_receptive_field,
- )
- if isinstance(local_attn_handler.max_seqlen_k, torch.Tensor):
- local_attn_handler.max_seqlen_k = local_attn_handler.max_seqlen_k.item()
- if isinstance(local_attn_handler.max_seqlen_q, torch.Tensor):
- local_attn_handler.max_seqlen_q = local_attn_handler.max_seqlen_q.item()
- else:
- local_attn_handler = None
-
- varlen_handler = VarlenHandler(
- cu_seqlens_q=simple_packed_data.cu_seqlen.to(torch.int32).cuda(),
- cu_seqlens_k=simple_packed_data.cu_seqlen.to(torch.int32).cuda(),
- max_seqlen_q=simple_packed_data.max_seqlen.to(torch.int32).cuda(),
- max_seqlen_k=simple_packed_data.max_seqlen.to(torch.int32).cuda(),
- )
-
- self.saved_for_output(simple_packed_data=simple_packed_data)
-
- x = simple_packed_data.token_sequence
- coords_mapping = simple_packed_data.coords_mapping
- modality_mapping = simple_packed_data.modality_mapping
- return (x, coords_mapping, modality_mapping, varlen_handler, local_attn_handler)
-
- def process_output(self, x: torch.Tensor):
- simple_packed_data: _SimplePackedData = self.get_saved_data("simple_packed_data")
- return simple_packed_data.depack_token_sequence(x)
-
-
-# ===========================================================================
-# Pipeline helpers
-# ===========================================================================
-@dataclass
-class EvalInput:
- x_t: torch.Tensor
- audio_x_t: torch.Tensor
- audio_feat_len: torch.Tensor | list[int]
- txt_feat: torch.Tensor
- txt_feat_len: torch.Tensor | list[int]
-
-
-class _T5GemmaEncoder:
- def __init__(self, model_path: str, device: str, weight_dtype: torch.dtype, subfolder: str | None = None):
- from vllm.distributed import get_tensor_model_parallel_world_size
-
- self.device = device
- hf_kwargs: dict[str, Any] = {}
- if subfolder is not None:
- hf_kwargs["subfolder"] = subfolder
- self.tokenizer = AutoTokenizer.from_pretrained(model_path, **hf_kwargs)
-
- tp_size = get_tensor_model_parallel_world_size()
- if tp_size > 1:
- from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig
-
- config = T5GemmaConfig.from_pretrained(model_path, **hf_kwargs)
- # The config we need is the encoder config
- config_encoder = config.encoder
- # Propagate some outer config values
- config_encoder.vocab_size = config.vocab_size
- config_encoder.rms_norm_eps = getattr(config, "rms_norm_eps", config_encoder.rms_norm_eps)
- self.model = T5GemmaEncoderModelTP(config_encoder).to(device).to(weight_dtype)
- self.is_tp = True
- else:
- self.model = T5GemmaEncoderModel.from_pretrained(
- model_path, is_encoder_decoder=False, dtype=weight_dtype, **hf_kwargs
- ).to(device)
- self.is_tp = False
-
- @torch.inference_mode()
- def encode(self, prompt: str) -> torch.Tensor:
- inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
- outputs = self.model(**inputs)
-
- if self.is_tp:
- # T5GemmaEncoderModelTP just returns the hidden states tensor
- return outputs.half()
- else:
- # HF model returns BaseModelOutput
- return outputs["last_hidden_state"].half()
-
-
-def _pad_or_trim(tensor: torch.Tensor, target_size: int, dim: int, pad_value: float = 0.0) -> tuple[torch.Tensor, int]:
- current_size = tensor.size(dim)
- if current_size < target_size:
- padding_amount = target_size - current_size
- padding_tuple = [0] * (2 * tensor.dim())
- padding_dim_index = tensor.dim() - 1 - dim
- padding_tuple[2 * padding_dim_index + 1] = padding_amount
- return F.pad(tensor, tuple(padding_tuple), "constant", pad_value), current_size
- slicing = [slice(None)] * tensor.dim()
- slicing[dim] = slice(0, target_size)
- return tensor[tuple(slicing)], target_size
-
-
-def _get_padded_t5_gemma_embedding(
- prompt: str,
- encoder: _T5GemmaEncoder,
- target_length: int,
-) -> tuple[torch.Tensor, int]:
- txt_feat = encoder.encode(prompt)
- txt_feat, original_len = _pad_or_trim(txt_feat, target_size=target_length, dim=1)
- return txt_feat.to(torch.float32), original_len
-
-
-def _resizecrop(img: Image.Image, target_height: int, target_width: int) -> Image.Image:
- """Centre-crop resize keeping aspect ratio then letterbox to target."""
- pil_image = img.convert("RGB")
- original_width, original_height = pil_image.size
- scale_x = target_width / original_width
- scale_y = target_height / original_height
- scale = max(scale_x, scale_y)
- new_width = int(round(original_width * scale))
- new_height = int(round(original_height * scale))
- resized_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
- left = (new_width - target_width) // 2
- top = (new_height - target_height) // 2
- return resized_image.crop((left, top, left + target_width, top + target_height))
-
-
-class ZeroSNRDDPMDiscretization:
- """ZeroSNR DDPM sigma schedule, ported from daVinci-MagiHuman.
- Used to compute sigma values for SR noise injection.
- """
-
- def __init__(
- self,
- linear_start: float = 0.00085,
- linear_end: float = 0.0120,
- num_timesteps: int = 1000,
- shift_scale: float = 1.0,
- keep_start: bool = False,
- post_shift: bool = False,
- ):
- from functools import partial
-
- if keep_start and not post_shift:
- linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start)
- self.num_timesteps = num_timesteps
- betas = torch.linspace(linear_start**0.5, linear_end**0.5, num_timesteps, dtype=torch.float64) ** 2
- alphas = 1.0 - betas.cpu().numpy()
- self.alphas_cumprod = np.cumprod(alphas, axis=0)
- self.to_torch = partial(torch.tensor, dtype=torch.float32)
- if not post_shift:
- self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1 - shift_scale) * self.alphas_cumprod)
- self.post_shift = post_shift
- self.shift_scale = shift_scale
-
- def __call__(
- self,
- n: int,
- do_append_zero: bool = True,
- device: str = "cpu",
- flip: bool = False,
- return_idx: bool = False,
- ):
- from functools import partial
-
- if n < self.num_timesteps:
- timesteps = np.linspace(self.num_timesteps - 1, 0, n, endpoint=False).astype(int)[::-1]
- alphas_cumprod = self.alphas_cumprod[timesteps]
- elif n == self.num_timesteps:
- alphas_cumprod = self.alphas_cumprod
- else:
- raise ValueError(f"n={n} > num_timesteps={self.num_timesteps}")
-
- to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
- alphas_cumprod = to_torch(alphas_cumprod)
- alphas_cumprod_sqrt = alphas_cumprod.sqrt()
- alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone()
- alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone()
- alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T
- alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T)
-
- if self.post_shift:
- alphas_cumprod_sqrt = (
- alphas_cumprod_sqrt**2 / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2)
- ) ** 0.5
-
- sigmas = torch.flip(alphas_cumprod_sqrt, (0,))
- sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) if do_append_zero else sigmas
- if return_idx:
- return sigmas if not flip else torch.flip(sigmas, (0,)), timesteps
- return sigmas if not flip else torch.flip(sigmas, (0,))
-
-
-def _schedule_latent_step(
- *,
- video_scheduler: FlowUniPCMultistepScheduler,
- audio_scheduler: FlowUniPCMultistepScheduler,
- latent_video: torch.Tensor,
- latent_audio: torch.Tensor,
- t,
- idx: int,
- steps,
- v_cfg_video: torch.Tensor,
- v_cfg_audio: torch.Tensor,
- is_a2v: bool,
- cfg_number: int,
- using_sde_flag: bool,
- use_sr_model: bool = False,
-):
- # Fast DDIM path for cfg_number==1, only used during the BR stage
- if cfg_number == 1 and not use_sr_model:
- latent_video = video_scheduler.step_ddim(v_cfg_video, idx, latent_video)
- latent_audio = audio_scheduler.step_ddim(v_cfg_audio, idx, latent_audio)
- return latent_video, latent_audio
-
- if using_sde_flag:
- if use_sr_model:
- # SR stage with SDE: only update video, keep audio unchanged
- latent_video = video_scheduler.step(v_cfg_video, t, latent_video, return_dict=False)[0]
- return latent_video, latent_audio
- if idx < int(len(steps) * (3 / 4)):
- noise_theta = 1.0 if (idx + 1) % 2 == 0 else 0.0
- else:
- noise_theta = 1.0 if idx % 3 == 0 else 0.0
- latent_video = video_scheduler.step_sde(v_cfg_video, idx, latent_video, noise_theta=noise_theta)
- if not is_a2v:
- latent_audio = audio_scheduler.step_sde(v_cfg_audio, idx, latent_audio, noise_theta=noise_theta)
- return latent_video, latent_audio
-
- latent_video = video_scheduler.step(v_cfg_video, t, latent_video, return_dict=False)[0]
- # Do not update audio latent during the SR stage
- if not is_a2v and not use_sr_model:
- latent_audio = audio_scheduler.step(v_cfg_audio, t, latent_audio, return_dict=False)[0]
- return latent_video, latent_audio
-
-
-_NEGATIVE_PROMPT = (
- "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, "
- "overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, "
- "poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, "
- "still picture, messy background, three legs, many people in the background, walking backwards"
- ", low quality, worst quality, poor quality, noise, background noise, hiss, hum, buzz, crackle, static, "
- "compression artifacts, MP3 artifacts, digital clipping, distortion, muffled, muddy, unclear, echo, "
- "reverb, room echo, over-reverberated, hollow sound, distant, washed out, harsh, shrill, piercing, "
- "grating, tinny, thin sound, boomy, bass-heavy, flat EQ, over-compressed, abrupt cut, jarring transition, "
- "sudden silence, looping artifact, music, instrumental, sirens, alarms, crowd noise, unrelated sound "
- "effects, chaotic, disorganized, messy, cheap sound"
- ", emotionless, flat delivery, deadpan, lifeless, apathetic, robotic, mechanical, monotone, flat "
- "intonation, undynamic, boring, reading from a script, AI voice, synthetic, text-to-speech, TTS, "
- "insincere, fake emotion, exaggerated, overly dramatic, melodramatic, cheesy, cringey, hesitant, "
- "unconfident, tired, weak voice, stuttering, stammering, mumbling, slurred speech, mispronounced, "
- "bad articulation, lisp, vocal fry, creaky voice, mouth clicks, lip smacks, wet mouth sounds, heavy "
- "breathing, audible inhales, plosives, p-pops, coughing, clearing throat, sneezing, speaking too fast, "
- "rushed, speaking too slow, dragged out, unnatural pauses, awkward silence, choppy, disjointed, multiple "
- "speakers, two voices, background talking, out of tune, off-key, autotune artifacts"
-)
-
-
-# ===========================================================================
-# Pre/post process funcs (registered in registry)
-# ===========================================================================
-def get_magi_human_pre_process_func(*args, **kwargs):
- def pre_process(request: OmniDiffusionRequest):
- return request
-
- return pre_process
-
-
-def get_magi_human_post_process_func(*args, **kwargs):
- def post_process(output):
- if isinstance(output, tuple) and len(output) == 2:
- video, audio = output
- return {
- "video": video,
- "audio": audio,
- "audio_sample_rate": 44100,
- "fps": 25,
- }
- return output
-
- return post_process
-
-
-# ===========================================================================
-# HF Hub / local path helpers
-# ===========================================================================
-
-
-def _resolve_subdir(
- model_path: str,
- subfolder: str,
- local_files_only: bool = True,
- required_files: list[str] | None = None,
-) -> str:
- """Resolve a model subfolder to a local directory path.
-
- For HF Hub repos, downloads all ``required_files`` (default: ``["config.json"]``)
- into the HF cache and returns the parent directory.
- """
- if local_files_only:
- return os.path.join(model_path, subfolder)
- from huggingface_hub import hf_hub_download
-
- files = required_files or ["config.json"]
- last_cached: str | None = None
- for fname in files:
- last_cached = hf_hub_download(repo_id=model_path, filename=f"{subfolder}/{fname}")
- return os.path.dirname(last_cached)
-
-
-# ===========================================================================
-# Main Pipeline
-# ===========================================================================
-class MagiHumanPipeline(nn.Module, ProgressBarMixin, DiffusionPipelineProfilerMixin):
- def __init__(self, od_config: OmniDiffusionConfig, **kwargs):
- super().__init__()
- model_path = od_config.model
- local_files_only = os.path.exists(model_path)
- device = f"cuda:{torch.cuda.current_device()}"
- self.device_str = device
- self.dtype = od_config.dtype or torch.bfloat16
-
- model_index = _load_json(model_path, "model_index.json", local_files_only)
- eval_cfg = model_index
- dp_cfg = model_index.get("data_proxy", {})
-
- dit_subfolder = "transformer"
-
- dit_json = _load_json(model_path, f"{dit_subfolder}/config.json", local_files_only)
- dit_model_config = MagiHumanDiTConfig(**dit_json)
-
- self.dit = DiTModel(dit_model_config)
- self.dit.eval()
-
- self.vae = DistributedAutoencoderKLWan.from_pretrained(model_path, subfolder="vae")
- self.vae.to(device)
- self.vae.eval()
- vae_cfg = _load_json(model_path, "vae/config.json", local_files_only)
- self.vae_latent_mean = torch.tensor(vae_cfg["latents_mean"], dtype=torch.float32)
- self.vae_latent_std = torch.tensor(vae_cfg["latents_std"], dtype=torch.float32)
-
- self.audio_vae = SAAudioFeatureExtractor(
- device=device,
- model_path=_resolve_subdir(
- model_path,
- "audio_vae",
- local_files_only,
- required_files=["config.json", "model_config.json", "model.safetensors"],
- ),
- )
-
- logger.info("Loading T5Gemma text encoder from %s (subfolder=text_encoder)", model_path)
- if local_files_only:
- txt_enc_path = os.path.join(model_path, "text_encoder")
- txt_enc_subfolder = None
- else:
- txt_enc_path = model_path
- txt_enc_subfolder = "text_encoder"
- self.text_encoder = _T5GemmaEncoder(
- model_path=txt_enc_path,
- device=device,
- weight_dtype=self.dtype,
- subfolder=txt_enc_subfolder,
- )
-
- self.data_proxy = MagiDataProxy(
- patch_size=dp_cfg.get("patch_size", 2),
- t_patch_size=dp_cfg.get("t_patch_size", 1),
- frame_receptive_field=dp_cfg.get("frame_receptive_field", 11),
- spatial_rope_interpolation=dp_cfg.get("spatial_rope_interpolation", "extra"),
- ref_audio_offset=dp_cfg.get("ref_audio_offset", 1000),
- text_offset=dp_cfg.get("text_offset", 0),
- coords_style=dp_cfg.get("coords_style", "v2"),
- )
- # SR DataProxy forces v1 coordinate style (consistent with the original)
- self.sr_data_proxy = MagiDataProxy(
- patch_size=dp_cfg.get("patch_size", 2),
- t_patch_size=dp_cfg.get("t_patch_size", 1),
- frame_receptive_field=dp_cfg.get("frame_receptive_field", 11),
- spatial_rope_interpolation=dp_cfg.get("spatial_rope_interpolation", "extra"),
- ref_audio_offset=dp_cfg.get("ref_audio_offset", 1000),
- text_offset=dp_cfg.get("text_offset", 0),
- coords_style="v1",
- )
-
- self.fps = eval_cfg.get("fps", 25)
- self.num_inference_steps_default = eval_cfg.get("num_inference_steps", 32)
- self.video_txt_guidance_scale = eval_cfg.get("video_txt_guidance_scale", 5.0)
- self.audio_txt_guidance_scale = eval_cfg.get("audio_txt_guidance_scale", 5.0)
- self.shift = eval_cfg.get("shift", 5.0)
- self.cfg_number = eval_cfg.get("cfg_number", 2)
- self.use_cfg_trick = eval_cfg.get("use_cfg_trick", True)
- self.cfg_trick_start_frame = eval_cfg.get("cfg_trick_start_frame", 13)
- self.cfg_trick_value = eval_cfg.get("cfg_trick_value", 2.0)
- self.using_sde_flag = eval_cfg.get("using_sde_flag", False)
- self.t5_gemma_target_length = eval_cfg.get("t5_gemma_target_length", 640)
- self.vae_stride = eval_cfg.get("vae_stride", [4, 16, 16])
- self.z_dim = eval_cfg.get("z_dim", 48)
- self.patch_size = eval_cfg.get("patch_size", [1, 2, 2])
- # SR-specific hyperparameters
- self.sr_num_inference_steps_default = eval_cfg.get("sr_num_inference_steps", 5)
- self.sr_cfg_number = eval_cfg.get("sr_cfg_number", 2)
- self.sr_video_txt_guidance_scale = eval_cfg.get("sr_video_txt_guidance_scale", 3.5)
- self.noise_value = eval_cfg.get("noise_value", 220)
- self.sr_audio_noise_scale = eval_cfg.get("sr_audio_noise_scale", 0.7)
- # ZeroSNR sigma schedule for SR noise injection (flip=True, high to low)
- self.zerosnr_sigmas = ZeroSNRDDPMDiscretization()(1000, do_append_zero=False, flip=True)
-
- self.context_null, self.original_context_null_len = _get_padded_t5_gemma_embedding(
- _NEGATIVE_PROMPT,
- self.text_encoder,
- self.t5_gemma_target_length,
- )
- self.video_processor = VideoProcessor(vae_scale_factor=16)
-
- # SR DiT model (loaded from the sr/ subdirectory)
- sr_dit_subfolder = "sr"
- sr_dit_json = _load_json(model_path, f"{sr_dit_subfolder}/config.json", local_files_only)
- sr_dit_model_config = MagiHumanDiTConfig(**sr_dit_json)
- self.sr_dit = DiTModel(sr_dit_model_config)
- self.sr_dit.eval()
-
- self.weights_sources = [
- DiffusersPipelineLoader.ComponentSource(
- model_or_path=model_path,
- subfolder=dit_subfolder,
- revision=None,
- prefix="dit.",
- fall_back_to_pt=True,
- ),
- DiffusersPipelineLoader.ComponentSource(
- model_or_path=model_path,
- subfolder=sr_dit_subfolder,
- revision=None,
- prefix="sr_dit.",
- fall_back_to_pt=True,
- ),
- ]
- if getattr(self.text_encoder, "is_tp", False):
- self.weights_sources.append(
- DiffusersPipelineLoader.ComponentSource(
- model_or_path=model_path,
- subfolder="text_encoder",
- revision=None,
- prefix="text_encoder.",
- fall_back_to_pt=True,
- ),
- )
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- # Weight loading for MagiHuman DiT with TP support.
- #
- # The checkpoint stores weights with these naming patterns:
- # - attention.linear_qkv.weight: fused [Q, K, V, G] for shared layers,
- # or stacked per-expert [expert0_Q|K|V|G, expert1_..., expert2_...] for MoE.
- # - attention.linear_proj.weight: single for shared, stacked per-expert for MoE.
- # - mlp.up_gate_proj.weight / mlp.down_proj.weight: similarly stacked for MoE.
- #
- # The model now uses per-expert vLLM parallel layers for MoE blocks:
- # attention.linear_qkv.experts.{i}.weight (QKVParallelLinear per expert)
- # attention.linear_gating.experts.{i}.weight (ColumnParallelLinear per expert)
- # attention.linear_proj.experts.{i}.weight (RowParallelLinear per expert)
- # mlp.up_gate_proj.experts.{i}.weight (ColumnParallelLinear per expert)
- # mlp.down_proj.experts.{i}.weight (RowParallelLinear per expert)
- #
- # Shared layers keep the same naming (no .experts.).
- params_dict = dict(self.named_parameters())
- modules_dict = dict(self.named_modules())
- loaded_params: set[str] = set()
-
- for name, loaded_weight in weights:
- # ── Text Encoder weights ──
- if name.startswith("text_encoder."):
- if getattr(self.text_encoder, "is_tp", False):
- # Strip "text_encoder." prefix for the T5Gemma TP model
- # The T5GemmaEncoderModelTP load_weights handles the "encoder." prefix itself
- sub_name = name[len("text_encoder.") :]
- loaded_params.update(
- f"text_encoder.{k}" for k in self.text_encoder.model.load_weights([(sub_name, loaded_weight)])
- )
- else:
- loaded_params.add(name)
- continue
-
- # ── Shared attention QKV + Gating split ──
- # Checkpoint: attention.linear_qkv.weight = [Q, K, V, G] fused.
- # Model: attention.linear_qkv.weight (QKVParallelLinear) + attention.linear_gating.weight.
- if "attention.linear_qkv.weight" in name:
- gating_name = name.replace("attention.linear_qkv.weight", "attention.linear_gating.weight")
- # Check if this is a shared layer (direct param exists, no .experts.)
- if name in params_dict and gating_name in params_dict:
- qkv_param = params_dict[name]
- gating_param = params_dict[gating_name]
-
- mod_path = name[: -len(".weight")]
- qkv_mod = modules_dict.get(mod_path)
- if qkv_mod is not None and hasattr(qkv_mod, "total_num_heads"):
- total_heads_q = qkv_mod.total_num_heads
- total_heads_kv = qkv_mod.total_num_kv_heads
- head_dim = qkv_mod.head_size
- else:
- head_dim = 128
- tp_size = get_tensor_model_parallel_world_size()
- total_heads_q = gating_param.data.shape[0] * tp_size
- total_heads_kv = (loaded_weight.shape[0] - total_heads_q * head_dim - total_heads_q) // (
- 2 * head_dim
- )
-
- q_size = total_heads_q * head_dim
- kv_size = total_heads_kv * head_dim
-
- q_w = loaded_weight[:q_size]
- k_w = loaded_weight[q_size : q_size + kv_size]
- v_w = loaded_weight[q_size + kv_size : q_size + 2 * kv_size]
- g_w = loaded_weight[q_size + 2 * kv_size :]
-
- qkv_loader = getattr(qkv_param, "weight_loader", default_weight_loader)
- qkv_loader(qkv_param, q_w, "q")
- qkv_loader(qkv_param, k_w, "k")
- qkv_loader(qkv_param, v_w, "v")
-
- gating_loader = getattr(gating_param, "weight_loader", default_weight_loader)
- gating_loader(gating_param, g_w)
-
- loaded_params.add(name)
- loaded_params.add(gating_name)
- continue
-
- # ── MoE attention QKV + Gating split ──
- # Checkpoint: attention.linear_qkv.weight = stacked [expert0_QKVG, expert1_QKVG, ...].
- # Model: attention.linear_qkv.experts.{i}.weight (QKVParallelLinear per expert)
- # + attention.linear_gating.experts.{i}.weight (ColumnParallelLinear per expert).
- expert0_name = name.replace("attention.linear_qkv.weight", "attention.linear_qkv.experts.0.weight")
- if expert0_name in params_dict:
- # Determine num_experts by checking which expert indices exist.
- moe_qkv_mod_path = name[: -len(".weight")]
- moe_qkv_mod = modules_dict.get(moe_qkv_mod_path)
- num_experts = moe_qkv_mod.num_experts if moe_qkv_mod is not None else 3
-
- # Get head info from the first expert's QKVParallelLinear.
- expert0_mod_path = name.replace("attention.linear_qkv.weight", "attention.linear_qkv.experts.0")
- expert0_mod = modules_dict.get(expert0_mod_path)
- if expert0_mod is not None and hasattr(expert0_mod, "total_num_heads"):
- total_heads_q = expert0_mod.total_num_heads
- total_heads_kv = expert0_mod.total_num_kv_heads
- head_dim = expert0_mod.head_size
- else:
- head_dim = 128
- # Infer from checkpoint weight shape.
- # We'll get exact sizes from model config below.
- total_heads_q = 40 # fallback for default config
- total_heads_kv = 8
-
- q_size = total_heads_q * head_dim
- kv_size = total_heads_kv * head_dim
- # Check if gating is present.
- gating_expert0_name = name.replace(
- "attention.linear_qkv.weight", "attention.linear_gating.experts.0.weight"
- )
- has_gating = gating_expert0_name in params_dict
-
- # Split stacked checkpoint weight into per-expert chunks.
- expert_weights = loaded_weight.chunk(num_experts, dim=0)
-
- for i in range(num_experts):
- expert_w = expert_weights[i]
- # Each expert chunk: [Q, K, V, G (optional)].
- q_w = expert_w[:q_size]
- k_w = expert_w[q_size : q_size + kv_size]
- v_w = expert_w[q_size + kv_size : q_size + 2 * kv_size]
-
- expert_param_name = name.replace(
- "attention.linear_qkv.weight",
- f"attention.linear_qkv.experts.{i}.weight",
- )
- expert_param = params_dict[expert_param_name]
- expert_loader = getattr(expert_param, "weight_loader", default_weight_loader)
- expert_loader(expert_param, q_w, "q")
- expert_loader(expert_param, k_w, "k")
- expert_loader(expert_param, v_w, "v")
- loaded_params.add(expert_param_name)
-
- if has_gating:
- g_w = expert_w[q_size + 2 * kv_size :]
- gating_param_name = name.replace(
- "attention.linear_qkv.weight",
- f"attention.linear_gating.experts.{i}.weight",
- )
- gating_param = params_dict[gating_param_name]
- gating_loader = getattr(gating_param, "weight_loader", default_weight_loader)
- gating_loader(gating_param, g_w)
- loaded_params.add(gating_param_name)
- continue
-
- # ── MoE stacked weight splitting for proj / MLP layers ──
- # Checkpoint: x.y.weight (stacked [expert0, expert1, ...]).
- # Model: x.y.experts.{i}.weight.
- if name not in params_dict:
- # Check if this is a stacked MoE weight by looking for .experts.0.
- base, _, suffix = name.rpartition(".")
- expert0_name = f"{base}.experts.0.{suffix}" if base else None
- if expert0_name and expert0_name in params_dict:
- # Determine num_experts.
- moe_mod = modules_dict.get(base)
- num_experts = getattr(moe_mod, "num_experts", 3) if moe_mod is not None else 3
-
- # Split stacked weight into per-expert chunks.
- expert_weights = loaded_weight.chunk(num_experts, dim=0)
- for i in range(num_experts):
- expert_param_name = f"{base}.experts.{i}.{suffix}"
- if expert_param_name not in params_dict:
- continue
- expert_param = params_dict[expert_param_name]
- expert_loader = getattr(expert_param, "weight_loader", default_weight_loader)
- expert_loader(expert_param, expert_weights[i])
- loaded_params.add(expert_param_name)
- continue
- # Truly unknown weight — skip.
- continue
-
- # ── Standard weight loading (shared layers + non-MoE params) ──
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
-
- if getattr(self.text_encoder, "is_tp", False):
- self.context_null, self.original_context_null_len = _get_padded_t5_gemma_embedding(
- _NEGATIVE_PROMPT,
- self.text_encoder,
- self.t5_gemma_target_length,
- )
-
- return loaded_params
-
- def _dit_forward(self, eval_input: EvalInput) -> tuple[torch.Tensor, torch.Tensor]:
- packed = self.data_proxy.process_input(eval_input)
- noise_pred = self.dit(*packed)
- return self.data_proxy.process_output(noise_pred)
-
- def _sr_dit_forward(self, eval_input: EvalInput) -> tuple[torch.Tensor, torch.Tensor]:
- """SR stage uses sr_data_proxy (coords_style=v1) and sr_dit model."""
- packed = self.sr_data_proxy.process_input(eval_input)
- noise_pred = self.sr_dit(*packed)
- return self.sr_data_proxy.process_output(noise_pred)
-
- @torch.inference_mode()
- def _evaluate_with_latent(
- self,
- context: torch.Tensor,
- original_context_len: int,
- latent_image: torch.Tensor | None,
- latent_video: torch.Tensor,
- latent_audio: torch.Tensor,
- num_inference_steps: int,
- is_a2v: bool = False,
- use_sr_model: bool = False,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- # Select cfg_number and guidance_scale based on BR/SR stage
- cfg_number = self.sr_cfg_number if use_sr_model else self.cfg_number
- video_guidance = self.sr_video_txt_guidance_scale if use_sr_model else self.video_txt_guidance_scale
- forward_fn = self._sr_dit_forward if use_sr_model else self._dit_forward
-
- video_scheduler = FlowUniPCMultistepScheduler()
- audio_scheduler = FlowUniPCMultistepScheduler()
- video_scheduler.set_timesteps(num_inference_steps, device=self.device_str, shift=self.shift)
- audio_scheduler.set_timesteps(num_inference_steps, device=self.device_str, shift=self.shift)
- timesteps = video_scheduler.timesteps
-
- latent_length = latent_video.shape[2]
- cfg_trick_guidance = (
- torch.tensor(video_guidance, device=self.device_str).expand(1, 1, latent_length, 1, 1).clone()
- )
- if self.use_cfg_trick:
- cfg_trick_guidance[:, :, : self.cfg_trick_start_frame] = min(self.cfg_trick_value, video_guidance)
-
- with self.progress_bar(total=len(timesteps)) as pbar:
- for idx, t in enumerate(timesteps):
- if latent_image is not None:
- latent_video[:, :, :1] = latent_image[:, :, :1]
-
- # Reduce guidance when t<=500 during BR stage (original behavior)
- cur_video_guidance = video_guidance if (use_sr_model or t > 500) else 2.0
-
- eval_input_cond = EvalInput(
- x_t=latent_video,
- audio_x_t=latent_audio,
- audio_feat_len=[latent_audio.shape[1]],
- txt_feat=context,
- txt_feat_len=[original_context_len],
- )
-
- v_cond_video, v_cond_audio = forward_fn(eval_input_cond)
-
- if cfg_number == 1:
- v_cfg_video = v_cond_video
- v_cfg_audio = v_cond_audio
- elif cfg_number == 2:
- eval_input_uncond = EvalInput(
- x_t=latent_video,
- audio_x_t=latent_audio,
- audio_feat_len=[latent_audio.shape[1]],
- txt_feat=self.context_null,
- txt_feat_len=[self.original_context_null_len],
- )
- v_uncond_video, v_uncond_audio = forward_fn(eval_input_uncond)
- v_cfg_video = v_uncond_video + cur_video_guidance * (v_cond_video - v_uncond_video)
- v_cfg_audio = v_uncond_audio + self.audio_txt_guidance_scale * (v_cond_audio - v_uncond_audio)
- else:
- raise ValueError(f"Invalid cfg_number: {cfg_number}")
-
- latent_video, latent_audio = _schedule_latent_step(
- video_scheduler=video_scheduler,
- audio_scheduler=audio_scheduler,
- latent_video=latent_video,
- latent_audio=latent_audio,
- t=t,
- idx=idx,
- steps=timesteps,
- v_cfg_video=v_cfg_video,
- v_cfg_audio=v_cfg_audio,
- is_a2v=is_a2v,
- cfg_number=cfg_number,
- using_sde_flag=self.using_sde_flag,
- use_sr_model=use_sr_model,
- )
-
- pbar.update()
-
- if latent_image is not None:
- latent_video[:, :, :1] = latent_image[:, :, :1]
- return latent_video, latent_audio
-
- def _encode_image(self, image: Image.Image, height: int, width: int) -> torch.Tensor:
- image = load_image(image)
- image = _resizecrop(image, height, width)
- image = self.video_processor.preprocess(image, height=height, width=width)
- image = image.to(device=self.device_str, dtype=self.dtype).unsqueeze(2)
- vae_out = self.vae.encode(image)
- if hasattr(vae_out, "latent_dist"):
- return vae_out.latent_dist.mode().to(torch.float32)
- return vae_out.to(torch.float32)
-
- def _decode_video(self, latent: torch.Tensor) -> list[np.ndarray]:
- mean = self.vae_latent_mean.to(latent.device, dtype=latent.dtype).view(1, -1, 1, 1, 1)
- std = self.vae_latent_std.to(latent.device, dtype=latent.dtype).view(1, -1, 1, 1, 1)
- latent = latent * std + mean
-
- videos = self.vae.decode(latent.to(self.dtype))
- if hasattr(videos, "sample"):
- videos = videos.sample
- videos.mul_(0.5).add_(0.5).clamp_(0, 1)
- videos = [v.float().cpu().permute(1, 2, 3, 0) * 255 for v in videos]
- return [v.numpy().astype(np.uint8) for v in videos]
-
- def _decode_audio(self, latent_audio: torch.Tensor) -> np.ndarray:
- latent_audio = latent_audio.squeeze(0).to(self.dtype)
- audio_output = self.audio_vae.decode(latent_audio.T)
- audio_np = audio_output.squeeze(0).T.float().cpu().numpy()
- target_len = int(audio_np.shape[0] * 441 / 512)
- from scipy.signal import resample
-
- return resample(audio_np, target_len)
-
- @torch.inference_mode()
- def forward(
- self,
- req: OmniDiffusionRequest,
- prompt: str | None = None,
- height: int = 256,
- width: int = 448,
- num_inference_steps: int | None = None,
- seconds: int = 10,
- seed: int | None = None,
- image_path: str | None = None,
- audio_path: str | None = None,
- **kwargs,
- ) -> DiffusionOutput:
- if len(req.prompts) >= 1:
- p = req.prompts[0]
- prompt = p if isinstance(p, str) else p.get("prompt", prompt)
- if not isinstance(p, str):
- image_path = p.get("image_path", image_path)
- audio_path = p.get("audio_path", audio_path)
- if prompt is None:
- raise ValueError("prompt is required")
-
- height = req.sampling_params.height or height
- width = req.sampling_params.width or width
- seed = req.sampling_params.seed if req.sampling_params.seed is not None else seed
- num_steps = req.sampling_params.num_inference_steps or num_inference_steps or self.num_inference_steps_default
- sr_height: int | None = None
- sr_width: int | None = None
- sr_num_steps: int | None = None
- if hasattr(req.sampling_params, "extra_args") and req.sampling_params.extra_args:
- seconds = req.sampling_params.extra_args.get("seconds", seconds)
- audio_path = req.sampling_params.extra_args.get("audio_path", audio_path)
- image_path = req.sampling_params.extra_args.get("image_path", image_path)
- sr_height = req.sampling_params.extra_args.get("sr_height", None)
- sr_width = req.sampling_params.extra_args.get("sr_width", None)
- sr_num_steps = req.sampling_params.extra_args.get("sr_num_inference_steps", None)
-
- device = self.device_str
-
- br_latent_height = height // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1]
- br_latent_width = width // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2]
- br_height = br_latent_height * self.vae_stride[1]
- br_width = br_latent_width * self.vae_stride[2]
-
- if seed is not None:
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
-
- if audio_path is not None:
- latent_audio = load_audio_and_encode(self.audio_vae, audio_path, seconds)
- latent_audio = latent_audio.permute(0, 2, 1)
- num_frames = latent_audio.shape[1]
- is_a2v = True
- else:
- num_frames = seconds * self.fps + 1
- latent_audio = torch.randn(1, num_frames, 64, dtype=torch.float32, device=device)
- is_a2v = False
-
- latent_length = (num_frames - 1) // 4 + 1
- latent_video = torch.randn(
- 1,
- self.z_dim,
- latent_length,
- br_latent_height,
- br_latent_width,
- dtype=torch.float32,
- device=device,
- )
-
- context, original_context_len = _get_padded_t5_gemma_embedding(
- prompt,
- self.text_encoder,
- self.t5_gemma_target_length,
- )
-
- if image_path is not None:
- br_image = self._encode_image(load_image(image_path), br_height, br_width)
- else:
- br_image = None
-
- # ── BR stage ─────────────────────────────────────────────────────────
- br_latent_video, br_latent_audio = self._evaluate_with_latent(
- context,
- original_context_len,
- br_image,
- latent_video.clone(),
- latent_audio.clone(),
- num_steps,
- is_a2v,
- use_sr_model=False,
- )
-
- # ── SR stage (optional, triggered when sr_height/sr_width are provided) ──
- if sr_height is not None and sr_width is not None:
- sr_latent_height = sr_height // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1]
- sr_latent_width = sr_width // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2]
- sr_height = sr_latent_height * self.vae_stride[1]
- sr_width = sr_latent_width * self.vae_stride[2]
-
- # Image condition (at SR resolution)
- if image_path is not None:
- sr_image = self._encode_image(load_image(image_path), sr_height, sr_width)
- else:
- sr_image = None
-
- # Trilinear interpolation of BR latent to SR resolution
- sr_latent_video = torch.nn.functional.interpolate(
- br_latent_video,
- size=(latent_length, sr_latent_height, sr_latent_width),
- mode="trilinear",
- align_corners=True,
- )
-
- # Noise injection: sigma-weighted blend (noise_value indexes the ZeroSNR sigma schedule)
- if self.noise_value != 0:
- noise = torch.randn_like(sr_latent_video)
- sigma = self.zerosnr_sigmas.to(sr_latent_video.device)[self.noise_value]
- sr_latent_video = sr_latent_video * sigma + noise * (1 - sigma**2) ** 0.5
-
- # Audio: blend with noise (noised version used during SR inference; final audio keeps BR result)
- sr_latent_audio = torch.randn_like(br_latent_audio) * self.sr_audio_noise_scale + br_latent_audio * (
- 1 - self.sr_audio_noise_scale
- )
-
- torch.cuda.empty_cache()
- sr_steps = sr_num_steps or self.sr_num_inference_steps_default
- final_latent_video, _ = self._evaluate_with_latent(
- context,
- original_context_len,
- sr_image,
- sr_latent_video.clone(),
- sr_latent_audio.clone(),
- sr_steps,
- is_a2v,
- use_sr_model=True,
- )
- # SR stage does not update audio; keep the BR result
- final_latent_video = final_latent_video
- final_latent_audio = br_latent_audio
- else:
- final_latent_video = br_latent_video
- final_latent_audio = br_latent_audio
-
- torch.cuda.empty_cache()
- videos_np = self._decode_video(final_latent_video)
- torch.cuda.empty_cache()
- audio_np = self._decode_audio(final_latent_audio)
-
- return DiffusionOutput(output=(videos_np, audio_np))
diff --git a/vllm_omni/diffusion/models/mammoth_moda2/rope_real.py b/vllm_omni/diffusion/models/mammoth_moda2/rope_real.py
index 64cc4324869..d16181a6913 100644
--- a/vllm_omni/diffusion/models/mammoth_moda2/rope_real.py
+++ b/vllm_omni/diffusion/models/mammoth_moda2/rope_real.py
@@ -18,8 +18,6 @@
from einops import repeat
from torch import nn
-from vllm_omni.platforms import current_omni_platform
-
def apply_real_rotary_emb(x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor:
"""
@@ -121,7 +119,7 @@ def get_freqs_real(
axes_dim: tuple[int, int, int], axes_lens: tuple[int, int, int], theta: int
) -> list[tuple[torch.Tensor, torch.Tensor]]:
freqs_real = []
- freqs_dtype = torch.float64 if current_omni_platform.supports_float64() else torch.float32
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
cos_emb, sin_emb = get_1d_rotary_pos_embed_real(d, e, theta=theta, freqs_dtype=freqs_dtype)
freqs_real.append((cos_emb, sin_emb))
diff --git a/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py
index d2b3eb81e31..ded3079265e 100644
--- a/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py
+++ b/vllm_omni/diffusion/models/nextstep_1_1/modeling_nextstep.py
@@ -114,8 +114,6 @@ def from_json(cls, path: str) -> NextStepConfig:
class NextStepModel(nn.Module):
- _layerwise_offload_blocks_attrs = ["layers"]
-
def __init__(self, config: NextStepConfig):
super().__init__()
self.config = config
diff --git a/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py b/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py
index 3f03563a1cc..b626ca1d85b 100644
--- a/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py
+++ b/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py
@@ -5,8 +5,6 @@
import torch
import torch.nn as nn
-import torch.nn.functional as F
-import vllm._custom_ops as ops
from diffusers.models.activations import get_activation
from diffusers.models.embeddings import Timesteps, get_1d_rotary_pos_embed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
@@ -18,114 +16,13 @@
QKVParallelLinear,
RowParallelLinear,
)
-from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm_omni.diffusion.attention.layer import Attention
-from vllm_omni.platforms import current_omni_platform
logger = logging.getLogger(__name__)
-def _patch_cutlass_padded_fp8():
- """Monkey-patch vllm._custom_ops.cutlass_scaled_mm to pad tensors whose
- dimensions are not multiples of 16, so the CUTLASS FP8 kernel is used.
-
- OmniGen2 has hidden_size=2520 (2520 % 16 == 8). Without this patch,
- vLLM's cutlass_scaled_mm falls back to a Triton scaled_mm kernel for
- every FP8 linear layer (QKV, attn output, gate_up_proj, down_proj),
- which is dramatically slower than the native CUTLASS FP8 tensor-core
- path on H100/H200 GPUs.
-
- Weight tensors (b) are constant across forward passes, so padded
- versions are computed once and cached by data_ptr to avoid repeated
- allocation and column-major conversion overhead.
- """
- _orig_cutlass_scaled_mm = ops.cutlass_scaled_mm
- # Cache: data_ptr → (padded_b, padded_bias, padded_scale_b, pad_k, pad_n, orig_n)
- _weight_cache: dict[int, tuple] = {}
-
- def _padded_cutlass_scaled_mm(
- a: torch.Tensor,
- b: torch.Tensor,
- scale_a: torch.Tensor,
- scale_b: torch.Tensor,
- out_dtype: torch.dtype,
- bias: torch.Tensor | None = None,
- ) -> torch.Tensor:
- if b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0:
- return _orig_cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
-
- # Reshape to 2D (mirrors the original function)
- target_shape = (*a.shape[:-1], b.shape[1])
- a = a.view(-1, a.shape[-1])
- orig_n = b.shape[1]
-
- # Cache the padded weight — it's a model parameter that never changes.
- key = b.data_ptr()
- if key not in _weight_cache:
- pad_k = (16 - b.shape[0] % 16) % 16
- pad_n = (16 - orig_n % 16) % 16
- b_pad = b
- if pad_k > 0:
- b_pad = F.pad(b_pad, (0, 0, 0, pad_k))
- if pad_n > 0:
- b_pad = F.pad(b_pad, (0, pad_n))
- # CUTLASS requires b column-major (stride(0)==1).
- b_pad = b_pad.t().contiguous().t()
-
- bias_pad = None
- if bias is not None and pad_n > 0:
- bias_pad = F.pad(bias, (0, pad_n))
-
- scale_b_pad = scale_b
- if scale_b.numel() > 1 and pad_n > 0:
- scale_b_pad = F.pad(
- scale_b.view(-1, scale_b.shape[-1]),
- (0, pad_n),
- value=1.0,
- )
-
- _weight_cache[key] = (
- b_pad,
- bias_pad,
- scale_b_pad,
- pad_k,
- pad_n,
- orig_n,
- )
-
- b_pad, bias_pad, scale_b_pad, pad_k, pad_n, orig_n = _weight_cache[key]
-
- # Pad activations on K dimension (cheap — activations are small).
- if pad_k > 0:
- a = F.pad(a, (0, pad_k)).contiguous()
-
- out = torch.empty((a.shape[0], b_pad.shape[1]), dtype=out_dtype, device=a.device)
- torch.ops._C.cutlass_scaled_mm(
- out,
- a,
- b_pad,
- scale_a,
- scale_b_pad,
- bias_pad if bias is not None else None,
- )
-
- if pad_n > 0:
- out = out[:, :orig_n]
-
- return out.view(*target_shape)
-
- ops.cutlass_scaled_mm = _padded_cutlass_scaled_mm
- logger.info(
- "Patched vllm._custom_ops.cutlass_scaled_mm with CUTLASS-padded FP8 "
- "variant (avoids slow Triton fallback for non-%%16 dimensions)"
- )
-
-
-_patch_cutlass_padded_fp8()
-
-
class OmniGen2Attention(nn.Module):
def __init__(
self,
@@ -133,8 +30,6 @@ def __init__(
num_heads: int,
num_kv_heads: int,
eps: float = 1e-5,
- quant_config: QuantizationConfig | None = None,
- prefix: str = "",
):
super().__init__()
self.dim = dim
@@ -150,26 +45,12 @@ def __init__(
total_num_kv_heads=num_kv_heads,
disable_tp=True,
bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.to_qkv",
)
self.norm_q = RMSNorm(self.head_dim, eps=eps)
self.norm_k = RMSNorm(self.head_dim, eps=eps)
- self.to_out = nn.ModuleList(
- [
- RowParallelLinear(
- dim,
- dim,
- bias=False,
- input_is_parallel=False,
- quant_config=quant_config,
- return_bias=False,
- prefix=f"{prefix}.to_out.0",
- )
- ]
- )
+ self.to_out = nn.ModuleList([nn.Linear(dim, dim, bias=False)])
self.attn = Attention(
num_heads=num_heads,
head_size=self.head_dim,
@@ -196,9 +77,6 @@ def forward(
"""
batch_size = hidden_states.shape[0]
- # Contiguous layout for FP8 quantized linear GEMMs (matches FLUX DiT).
- hidden_states = hidden_states.contiguous()
-
# Get Query-Key-Value Pair
qkv, _ = self.to_qkv(hidden_states)
@@ -242,7 +120,7 @@ def forward(
hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(dtype)
- hidden_states = self.to_out[0](hidden_states.contiguous())
+ hidden_states = self.to_out[0](hidden_states)
return hidden_states
@@ -354,7 +232,6 @@ def __init__(
embedding_dim: int,
norm_eps: float,
norm_elementwise_affine: bool,
- **kwargs,
):
super().__init__()
self.silu = nn.SiLU()
@@ -447,8 +324,6 @@ def __init__(
inner_dim: int,
multiple_of: int | None = 256,
ffn_dim_multiplier: float | None = None,
- quant_config: QuantizationConfig | None = None,
- prefix: str = "",
):
super().__init__()
@@ -462,8 +337,6 @@ def __init__(
[inner_dim, inner_dim],
bias=False,
return_bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.gate_up_proj",
)
self.act_fn = get_act_and_mul_fn("silu")
self.down_proj = RowParallelLinear(
@@ -472,8 +345,6 @@ def __init__(
bias=False,
input_is_parallel=True,
return_bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.down_proj",
)
def forward(self, x):
@@ -540,7 +411,7 @@ def get_freqs_cis(
axes_dim: tuple[int, int, int], axes_lens: tuple[int, int, int], theta: int
) -> list[torch.Tensor]:
freqs_cis = []
- freqs_dtype = torch.float64 if current_omni_platform.supports_float64() else torch.float32
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
freqs_cis.append(emb)
@@ -719,8 +590,6 @@ def __init__(
ffn_dim_multiplier: float,
norm_eps: float,
modulation: bool = True,
- quant_config: QuantizationConfig | None = None,
- prefix: str = "",
) -> None:
"""Initialize the transformer block."""
super().__init__()
@@ -732,8 +601,6 @@ def __init__(
num_heads=num_attention_heads,
num_kv_heads=num_kv_heads,
eps=1e-5,
- quant_config=quant_config,
- prefix=f"{prefix}.attn",
)
# Initialize feed-forward network
@@ -742,19 +609,11 @@ def __init__(
inner_dim=4 * dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
- quant_config=quant_config,
- prefix=f"{prefix}.feed_forward",
)
# Initialize normalization layers
if modulation:
- self.norm1 = LuminaRMSNormZero(
- embedding_dim=dim,
- norm_eps=norm_eps,
- norm_elementwise_affine=True,
- quant_config=quant_config,
- prefix=f"{prefix}.norm1",
- )
+ self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True)
else:
self.norm1 = RMSNorm(dim, eps=norm_eps)
@@ -853,7 +712,6 @@ def __init__(
axes_lens: tuple[int, int, int] = (1024, 1664, 1664),
text_feat_dim: int = 2048,
timestep_scale: float = 1000.0,
- quant_config: QuantizationConfig | None = None,
) -> None:
"""Initialize the OmniGen2 transformer model."""
super().__init__()
@@ -911,10 +769,8 @@ def __init__(
ffn_dim_multiplier,
norm_eps,
modulation=True,
- quant_config=quant_config,
- prefix=f"noise_refiner.{i}",
)
- for i in range(num_refiner_layers)
+ for _ in range(num_refiner_layers)
]
)
@@ -928,10 +784,8 @@ def __init__(
ffn_dim_multiplier,
norm_eps,
modulation=True,
- quant_config=quant_config,
- prefix=f"ref_image_refiner.{i}",
)
- for i in range(num_refiner_layers)
+ for _ in range(num_refiner_layers)
]
)
@@ -945,10 +799,8 @@ def __init__(
ffn_dim_multiplier,
norm_eps,
modulation=False,
- quant_config=quant_config,
- prefix=f"context_refiner.{i}",
)
- for i in range(num_refiner_layers)
+ for _ in range(num_refiner_layers)
]
)
@@ -963,10 +815,8 @@ def __init__(
ffn_dim_multiplier,
norm_eps,
modulation=True,
- quant_config=quant_config,
- prefix=f"layers.{i}",
)
- for i in range(num_layers)
+ for _ in range(num_layers)
]
)
@@ -996,25 +846,11 @@ def img_patch_embed_and_refine(
temb,
):
batch_size = len(hidden_states)
- has_ref_tokens = any(ref_img_len > 0 for ref_lens in l_effective_ref_img_len for ref_img_len in ref_lens)
max_combined_img_len = max(
[img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)]
)
hidden_states = self.x_embedder(hidden_states)
- if not has_ref_tokens:
- # FP8 kernels do not support zero-token GEMM on ref_image_patch_embedder; skip that path only.
- # Still run noise_refiner and return the same combined layout as the no-ref case below
- # (batch, max_combined_img_len, hidden) — not raw noise tokens alone.
- for layer in self.noise_refiner:
- hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
- combined_img_hidden_states = hidden_states.new_zeros(
- batch_size, max_combined_img_len, self.config.hidden_size
- )
- for i, img_len in enumerate(l_effective_img_len):
- combined_img_hidden_states[i, :img_len] = hidden_states[i, :img_len]
- return combined_img_hidden_states
-
ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
for i in range(batch_size):
diff --git a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py
index 46d634bfdc0..2d370aea19c 100644
--- a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py
+++ b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py
@@ -8,7 +8,7 @@
import warnings
from collections.abc import Iterable
from dataclasses import dataclass
-from typing import Any, ClassVar
+from typing import Any
import numpy as np
import PIL.Image
@@ -29,10 +29,8 @@
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
-from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
-from vllm_omni.diffusion.models.interface import SupportsModuleOffload
from vllm_omni.diffusion.models.omnigen2.omnigen2_transformer import (
OmniGen2RotaryPosEmbed,
OmniGen2Transformer2DModel,
@@ -621,7 +619,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class OmniGen2Pipeline(CFGParallelMixin, nn.Module, SupportsModuleOffload):
+class OmniGen2Pipeline(nn.Module):
"""
Pipeline for text-to-image generation using OmniGen2.
@@ -635,10 +633,6 @@ class OmniGen2Pipeline(CFGParallelMixin, nn.Module, SupportsModuleOffload):
od_config (OmniDiffusionConfig): The OmniDiffusion configuration.
"""
- _dit_modules: ClassVar[list[str]] = ["transformer"]
- _encoder_modules: ClassVar[list[str]] = ["mllm"]
- _vae_modules: ClassVar[list[str]] = ["vae"]
-
def __init__(
self,
*,
@@ -681,10 +675,7 @@ def __init__(
)
transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, OmniGen2Transformer2DModel)
- self.transformer = OmniGen2Transformer2DModel(
- **transformer_kwargs,
- quant_config=od_config.quantization_config,
- )
+ self.transformer = OmniGen2Transformer2DModel(**transformer_kwargs)
self.mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model, subfolder="mllm", local_files_only=local_files_only
).to(self.device)
@@ -1180,14 +1171,7 @@ def processing(
self._num_timesteps = len(timesteps)
for i, t in enumerate(timesteps):
- text_guidance_scale = (
- self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
- )
- image_guidance_scale = (
- self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
- )
-
- positive_kwargs = dict(
+ model_pred = self.predict(
t=t,
latents=latents,
prompt_embeds=prompt_embeds,
@@ -1195,18 +1179,15 @@ def processing(
prompt_attention_mask=prompt_attention_mask,
ref_image_hidden_states=ref_latents,
)
- uncond_kwargs = dict(
- t=t,
- latents=latents,
- prompt_embeds=negative_prompt_embeds,
- freqs_cis=freqs_cis,
- prompt_attention_mask=negative_prompt_attention_mask,
- ref_image_hidden_states=None,
+ text_guidance_scale = (
+ self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
+ )
+ image_guidance_scale = (
+ self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0
)
if text_guidance_scale > 1.0 and image_guidance_scale > 1.0:
- # 3-branch CFG: pos + ref_neg + uncond
- ref_neg_kwargs = dict(
+ model_pred_ref = self.predict(
t=t,
latents=latents,
prompt_embeds=negative_prompt_embeds,
@@ -1214,24 +1195,31 @@ def processing(
prompt_attention_mask=negative_prompt_attention_mask,
ref_image_hidden_states=ref_latents,
)
- model_pred = self.predict_noise_with_multi_branch_cfg(
- do_true_cfg=True,
- true_cfg_scale={
- "text": text_guidance_scale,
- "image": image_guidance_scale,
- },
- branches_kwargs=[positive_kwargs, ref_neg_kwargs, uncond_kwargs],
+
+ model_pred_uncond = self.predict(
+ t=t,
+ latents=latents,
+ prompt_embeds=negative_prompt_embeds,
+ freqs_cis=freqs_cis,
+ prompt_attention_mask=negative_prompt_attention_mask,
+ ref_image_hidden_states=None,
+ )
+
+ model_pred = (
+ model_pred_uncond
+ + image_guidance_scale * (model_pred_ref - model_pred_uncond)
+ + text_guidance_scale * (model_pred - model_pred_ref)
)
elif text_guidance_scale > 1.0:
- # 2-branch CFG: pos + uncond
- model_pred = self.predict_noise_with_multi_branch_cfg(
- do_true_cfg=True,
- true_cfg_scale=text_guidance_scale,
- branches_kwargs=[positive_kwargs, uncond_kwargs],
+ model_pred_uncond = self.predict(
+ t=t,
+ latents=latents,
+ prompt_embeds=negative_prompt_embeds,
+ freqs_cis=freqs_cis,
+ prompt_attention_mask=negative_prompt_attention_mask,
+ ref_image_hidden_states=None,
)
- else:
- # No CFG
- model_pred = self.predict_noise(**positive_kwargs)
+ model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond)
latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0]
@@ -1261,6 +1249,8 @@ def predict(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ batch_size, num_channels_latents, height, width = latents.shape
+
optional_kwargs = {}
if "ref_image_hidden_states" in set(inspect.signature(self.transformer.forward).parameters.keys()):
optional_kwargs["ref_image_hidden_states"] = ref_image_hidden_states
@@ -1275,21 +1265,6 @@ def predict(
)
return model_pred
- def predict_noise(self, **kwargs):
- """Override CFGParallelMixin.predict_noise to use self.predict."""
- return self.predict(**kwargs)
-
- def combine_multi_branch_cfg_noise(self, predictions, true_cfg_scale, cfg_normalize=False):
- """Override: 3-branch dual scale or 2-branch standard CFG."""
- if len(predictions) == 3:
- text_scale = true_cfg_scale["text"]
- image_scale = true_cfg_scale["image"]
- pos, ref, uncond = predictions[0], predictions[1], predictions[2]
- return uncond + image_scale * (ref - uncond) + text_scale * (pos - ref)
- # 2-branch: standard CFG
- pos, neg = predictions[0], predictions[1]
- return neg + true_cfg_scale * (pos - neg)
-
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
diff --git a/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py b/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py
index c330e91de8d..568e2f51640 100644
--- a/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py
+++ b/vllm_omni/diffusion/models/omnivoice/pipeline_omnivoice.py
@@ -16,7 +16,6 @@
from collections.abc import Iterable
from typing import ClassVar
-import numpy as np
import torch
from tokenizers import Tokenizer as HFTokenizer
from torch import nn
@@ -31,13 +30,6 @@
from vllm_omni.model_executor.models.omnivoice.omnivoice_decoder import OmniVoiceDecoder
from vllm_omni.model_executor.models.omnivoice.omnivoice_generator import OmniVoiceGenerator
-try:
- from transformers import HiggsAudioV2TokenizerModel
-except ImportError:
- HiggsAudioV2TokenizerModel = None
-
-import torchaudio
-
logger = init_logger(__name__)
@@ -87,17 +79,6 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
tokenizer_path = os.path.join(self.model_path, "tokenizer.json")
self.tokenizer = HFTokenizer.from_file(tokenizer_path)
- # Audio tokenizer for voice cloning (requires transformers>=5.3)
- if HiggsAudioV2TokenizerModel is not None:
- audio_tokenizer_path = os.path.join(self.model_path, "audio_tokenizer")
- self.audio_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained(
- audio_tokenizer_path, device_map=self.device
- ).eval()
- logger.info("HiggsAudioV2 tokenizer loaded for voice cloning on %s", self.device)
- else:
- self.audio_tokenizer = None
- logger.warning("Voice cloning disabled (requires transformers>=5.3.0).")
-
# Duration estimator
self.duration_estimator = RuleDurationEstimator()
@@ -110,46 +91,20 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
self.class_temperature = self.config.class_temperature
self.sample_rate = self.config.sample_rate
- def _encode_ref_audio(self, audio_signal: torch.Tensor, sr: int) -> torch.Tensor:
- """Encode reference audio to 8-codebook tokens for voice cloning."""
- if self.audio_tokenizer is None:
- raise RuntimeError("Audio tokenizer not available for voice cloning")
- if audio_signal.dim() == 1:
- audio_signal = audio_signal.unsqueeze(0)
- # Resample to tokenizer's expected sample rate
- target_sr = self.audio_tokenizer.config.sample_rate
- if sr != target_sr:
- audio_signal = torchaudio.functional.resample(audio_signal, sr, target_sr)
- # Ensure mono [B, 1, samples]
- if audio_signal.dim() == 2:
- audio_signal = audio_signal.unsqueeze(1)
- with torch.inference_mode():
- tokens = self.audio_tokenizer.encode(
- audio_signal.to(self.audio_tokenizer.device), return_dict=False
- ) # [B, 8, T_ref]
- tokens = tokens.squeeze(0) # [8, T_ref]
- return tokens
-
@torch.inference_mode()
def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
- """Generate speech audio from text, optionally with voice cloning.
+ """Generate speech audio from text.
+
+ Args:
+ req: Diffusion request containing text prompt(s).
- Accepts either a plain text prompt or a structured dict:
- {"text": "...", "ref_audio": (samples, sr), "ref_text": "...",
- "lang": "...", "instruct": "..."}
+ Returns:
+ DiffusionOutput with audio tensor in .output
"""
+ # Extract text from request
prompt = req.prompts[0] if req.prompts else ""
- ref_audio = None
- ref_text = None
- lang = "None"
- instruct = "None"
-
if isinstance(prompt, dict):
text = prompt.get("input", prompt.get("text", str(prompt)))
- ref_audio = prompt.get("ref_audio")
- ref_text = prompt.get("ref_text")
- lang = prompt.get("lang") or "None"
- instruct = prompt.get("instruct") or "None"
else:
text = str(prompt)
@@ -164,37 +119,17 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
target_len = self.duration_estimator.estimate_duration(text, "Nice to meet you.", 25)
target_len = max(1, int(target_len))
- # Build text prompt with control tokens
- style = f"<|denoise|><|lang_start|>{lang}<|lang_end|><|instruct_start|>{instruct}<|instruct_end|>"
- if ref_text:
- full_text = f"{ref_text} {text}"
- else:
- full_text = text
- full_prompt = f"{style}<|text_start|>{full_text}<|text_end|>"
+ # Tokenize with control tokens
+ style = "<|denoise|><|lang_start|>None<|lang_end|><|instruct_start|>None<|instruct_end|>"
+ full_prompt = f"{style}<|text_start|>{text}<|text_end|>"
encoding = self.tokenizer.encode(full_prompt)
text_tokens = torch.tensor(encoding.ids, dtype=torch.long, device=device)
text_len = text_tokens.shape[0]
- # Encode reference audio tokens if provided
- ref_audio_tokens = None
- if ref_audio is not None:
- if self.audio_tokenizer is None:
- raise RuntimeError(
- "Voice cloning requires transformers>=5.3.0. Try: uv pip install 'transformers>=5.3.0'"
- )
- audio_signal, sr = ref_audio
- if isinstance(audio_signal, np.ndarray):
- audio_signal = torch.from_numpy(audio_signal).float()
- ref_audio_tokens = self._encode_ref_audio(audio_signal, int(sr)).to(device)
-
# Build conditional + unconditional batches [2, 8, max_len]
text_ids = text_tokens.unsqueeze(0).repeat(num_cb, 1)
target_ids = torch.full((num_cb, target_len), mask_id, dtype=torch.long, device=device)
-
- if ref_audio_tokens is not None:
- cond_ids = torch.cat([text_ids, ref_audio_tokens, target_ids], dim=1)
- else:
- cond_ids = torch.cat([text_ids, target_ids], dim=1)
+ cond_ids = torch.cat([text_ids, target_ids], dim=1)
cond_len = cond_ids.shape[1]
uncond_ids = target_ids.clone()
diff --git a/vllm_omni/diffusion/models/ovis_image/ovis_image_transformer.py b/vllm_omni/diffusion/models/ovis_image/ovis_image_transformer.py
index 0e98729c3d6..bd2a3b48346 100644
--- a/vllm_omni/diffusion/models/ovis_image/ovis_image_transformer.py
+++ b/vllm_omni/diffusion/models/ovis_image/ovis_image_transformer.py
@@ -366,7 +366,6 @@ class OvisImageTransformer2DModel(nn.Module):
"""
_repeated_blocks = ["OvisImageTransformerBlock", "OvisImageSingleTransformerBlock"]
- _layerwise_offload_blocks_attrs = ["transformer_blocks", "single_transformer_blocks"]
def __init__(
self,
diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py
index 9ef0cacd5a0..5056b5342ea 100644
--- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py
+++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py
@@ -34,12 +34,6 @@
)
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
-from vllm_omni.diffusion.utils.prompt_utils import (
- validate_prompt_sequence_lengths,
-)
-from vllm_omni.diffusion.utils.size_utils import (
- normalize_min_aligned_size,
-)
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
if TYPE_CHECKING:
@@ -366,10 +360,8 @@ def check_inputs(
"that was used to generate `negative_prompt_embeds`."
)
- if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
- raise ValueError(
- f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
- )
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
@@ -383,8 +375,6 @@ def _get_qwen_prompt_embeds(
self,
prompt: str | list[str] = None,
dtype: torch.dtype | None = None,
- max_sequence_length: int | None = None,
- prompt_name: str = "prompt",
):
dtype = dtype or self.text_encoder.dtype
@@ -395,27 +385,12 @@ def _get_qwen_prompt_embeds(
txt = [template.format(e) for e in prompt]
txt_tokens = self.tokenizer(
txt,
+ max_length=self.tokenizer_max_length + drop_idx,
padding=True,
- truncation=False,
+ truncation=True,
return_tensors="pt",
).to(self.device)
- # Validate only the user prompt contribution. The Qwen template also
- # adds a fixed suffix after the user text, so subtracting only
- # prompt_template_encode_start_idx would overcount near-limit prompts.
- template_tokens = self.tokenizer(
- [template.format("")],
- padding=True,
- truncation=False,
- return_tensors="pt",
- ).to(self.device)
- validate_prompt_sequence_lengths(
- txt_tokens.attention_mask,
- max_sequence_length=max_sequence_length or self.tokenizer_max_length,
- supported_max_sequence_length=self.tokenizer_max_length,
- prompt_name=prompt_name,
- baseline_attention_mask=template_tokens.attention_mask,
- error_context="after applying the Qwen prompt template",
- )
+ # print(f"attention mask: {txt_tokens.attention_mask}")
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
@@ -444,7 +419,6 @@ def encode_prompt(
prompt_embeds: torch.Tensor | None = None,
prompt_embeds_mask: torch.Tensor | None = None,
max_sequence_length: int = 1024,
- prompt_name: str = "prompt",
):
r"""
@@ -462,11 +436,7 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
if prompt_embeds is None:
- prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
- prompt,
- max_sequence_length=max_sequence_length,
- prompt_name=prompt_name,
- )
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt)
prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
@@ -659,7 +629,6 @@ def _prepare_generation_context(
prompt_embeds_mask=negative_prompt_embeds_mask,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
- prompt_name="negative_prompt",
)
else:
negative_prompt_embeds = None
@@ -731,7 +700,7 @@ def prepare_encode(
num_images_per_prompt=sampling.num_outputs_per_prompt if sampling.num_outputs_per_prompt > 0 else 1,
generator=sampling.generator,
true_cfg_scale=sampling.true_cfg_scale or 4.0,
- max_sequence_length=sampling.max_sequence_length or self.tokenizer_max_length,
+ max_sequence_length=sampling.max_sequence_length or 512,
attention_kwargs=kwargs.get("attention_kwargs"),
)
@@ -962,14 +931,13 @@ def forward(
output_type: str | None = "pil",
attention_kwargs: dict[str, Any] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
- max_sequence_length: int = 1024,
+ max_sequence_length: int = 512,
) -> DiffusionOutput:
extracted_prompt, negative_prompt = self._extract_prompts(req.prompts)
prompt = extracted_prompt or prompt
height = req.sampling_params.height or self.default_sample_size * self.vae_scale_factor
width = req.sampling_params.width or self.default_sample_size * self.vae_scale_factor
- height, width = normalize_min_aligned_size(height, width, self.vae_scale_factor * 2)
num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps
sigmas = req.sampling_params.sigmas or sigmas
max_sequence_length = req.sampling_params.max_sequence_length or max_sequence_length
diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py
index cef7fe473a8..3d0cd2a6d4d 100644
--- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py
+++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py
@@ -37,12 +37,6 @@
)
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
-from vllm_omni.diffusion.utils.prompt_utils import (
- validate_prompt_sequence_lengths,
-)
-from vllm_omni.diffusion.utils.size_utils import (
- normalize_min_aligned_size,
-)
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.inputs.data import OmniTextPrompt
from vllm_omni.model_executor.model_loader.weight_utils import (
@@ -103,7 +97,9 @@ def pre_process_func(
width = request.sampling_params.width or calculated_width
# Ensure dimensions are multiples of vae_scale_factor * 2
- height, width = normalize_min_aligned_size(height, width, vae_scale_factor * 2)
+ multiple_of = vae_scale_factor * 2
+ height = height // multiple_of * multiple_of
+ width = width // multiple_of * multiple_of
# Store calculated dimensions in request
prompt["additional_information"]["calculated_height"] = calculated_height
@@ -326,10 +322,8 @@ def check_inputs(
"that was used to generate `negative_prompt_embeds`."
)
- if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
- raise ValueError(
- f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
- )
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
@@ -389,8 +383,6 @@ def _get_qwen_prompt_embeds(
prompt: str | list[str] = None,
image: PIL.Image.Image | torch.Tensor | None = None,
dtype: torch.dtype | None = None,
- max_sequence_length: int | None = None,
- prompt_name: str = "prompt",
):
"""Get prompt embeddings with image support for editing."""
dtype = dtype or self.text_encoder.dtype
@@ -400,33 +392,6 @@ def _get_qwen_prompt_embeds(
template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
- txt_tokens = self.tokenizer(
- txt,
- padding=True,
- truncation=False,
- return_tensors="pt",
- ).to(self.device)
- # The edit template contains fixed multimodal scaffolding around the
- # instruction. Validate against the empty-template baseline so image
- # placeholder text does not consume the user's text budget.
- template_tokens = self.tokenizer(
- [template.format("")],
- padding=True,
- truncation=False,
- return_tensors="pt",
- ).to(self.device)
- # Qwen-Image-Edit expands image placeholders into many vision tokens
- # inside the processor. `max_sequence_length` is meant to constrain the
- # prompt text length, so validate on the text template before image
- # token expansion.
- validate_prompt_sequence_lengths(
- txt_tokens.attention_mask,
- max_sequence_length=max_sequence_length or self.tokenizer_max_length,
- supported_max_sequence_length=self.tokenizer_max_length,
- prompt_name=prompt_name,
- baseline_attention_mask=template_tokens.attention_mask,
- error_context="after applying the Qwen prompt template",
- )
# Use processor to handle both text and image inputs
model_inputs = self.processor(
@@ -468,7 +433,6 @@ def encode_prompt(
prompt_embeds: torch.Tensor | None = None,
prompt_embeds_mask: torch.Tensor | None = None,
max_sequence_length: int = 1024,
- prompt_name: str = "prompt",
):
r"""
@@ -488,12 +452,7 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
if prompt_embeds is None:
- prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
- prompt,
- image,
- max_sequence_length=max_sequence_length,
- prompt_name=prompt_name,
- )
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -664,7 +623,7 @@ def forward(
output_type: str | None = "pil",
attention_kwargs: dict[str, Any] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
- max_sequence_length: int = 1024,
+ max_sequence_length: int = 512,
) -> DiffusionOutput:
"""Forward pass for image editing."""
# TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "")
@@ -702,7 +661,9 @@ def forward(
height = height or calculated_height
width = width or calculated_width
- height, width = normalize_min_aligned_size(height, width, self.vae_scale_factor * 2)
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
image = self.image_processor.resize(image, calculated_height, calculated_width)
@@ -779,7 +740,6 @@ def forward(
prompt_embeds_mask=negative_prompt_embeds_mask,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
- prompt_name="negative_prompt",
)
num_channels_latents = self.transformer.in_channels // 4
diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py
index 2e25d0fe6b2..cb5a36579f9 100644
--- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py
+++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py
@@ -25,7 +25,6 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
-from vllm_omni.diffusion.model_metadata import QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.qwen_image.cfg_parallel import (
QwenImageCFGParallelMixin,
@@ -41,12 +40,6 @@
)
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
-from vllm_omni.diffusion.utils.prompt_utils import (
- validate_prompt_sequence_lengths,
-)
-from vllm_omni.diffusion.utils.size_utils import (
- normalize_min_aligned_size,
-)
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.inputs.data import OmniTextPrompt
from vllm_omni.model_executor.model_loader.weight_utils import (
@@ -57,12 +50,6 @@
CONDITION_IMAGE_SIZE = 384 * 384
VAE_IMAGE_SIZE = 1024 * 1024
-# Keep this in sync with the practical conditioning-token budget for
-# Qwen-Image-Edit-2511. Empirically, 4 images stays within the supported range
-# while 5 images overflows the prompt/conditioning path and fails downstream.
-# Re-export the shared metadata value locally so this pipeline keeps a nearby,
-# descriptive constant for validation and tests without becoming the source of truth.
-MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES = QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES
def get_qwen_image_edit_plus_pre_process_func(
@@ -100,11 +87,6 @@ def pre_process_func(
if not isinstance(raw_image, list):
raw_image = [raw_image]
- if len(raw_image) > MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES:
- raise ValueError(
- f"Received {len(raw_image)} input images. "
- f"At most {MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES} images are supported by this model."
- )
image = [
PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image | np.ndarray | torch.Tensor, im)
for im in raw_image
@@ -117,7 +99,9 @@ def pre_process_func(
width = request.sampling_params.width or calculated_width
# Ensure dimensions are multiples of vae_scale_factor * 2
- height, width = normalize_min_aligned_size(height, width, vae_scale_factor * 2)
+ multiple_of = vae_scale_factor * 2
+ height = height // multiple_of * multiple_of
+ width = width // multiple_of * multiple_of
# Store calculated dimensions in request
prompt["additional_information"]["calculated_height"] = calculated_height
@@ -298,10 +282,8 @@ def check_inputs(
"that was used to generate `negative_prompt_embeds`."
)
- if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
- raise ValueError(
- f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
- )
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
@@ -316,8 +298,6 @@ def _get_qwen_prompt_embeds(
prompt: str | list[str],
image: list[torch.Tensor] | torch.Tensor | None = None,
dtype: torch.dtype | None = None,
- max_sequence_length: int | None = None,
- prompt_name: str = "prompt",
):
"""Get prompt embeddings with support for multiple images."""
dtype = dtype or self.text_encoder.dtype
@@ -338,32 +318,6 @@ def _get_qwen_prompt_embeds(
template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(base_img_prompt + e) for e in prompt]
- txt_tokens = self.tokenizer(
- txt,
- padding=True,
- truncation=False,
- return_tensors="pt",
- ).to(self.device)
- # Multi-image edit prepends "Picture N" placeholders before the user
- # instruction. Subtract the placeholder-aware baseline so attached
- # images do not reduce the remaining prompt budget.
- template_tokens = self.tokenizer(
- [template.format(base_img_prompt)],
- padding=True,
- truncation=False,
- return_tensors="pt",
- ).to(self.device)
- # The processor expands image placeholders into many vision tokens.
- # `max_sequence_length` should guard the prompt text length before that
- # multimodal expansion happens.
- validate_prompt_sequence_lengths(
- txt_tokens.attention_mask,
- max_sequence_length=max_sequence_length or self.tokenizer_max_length,
- supported_max_sequence_length=self.tokenizer_max_length,
- prompt_name=prompt_name,
- baseline_attention_mask=template_tokens.attention_mask,
- error_context="after applying the Qwen prompt template",
- )
# Use processor to handle both text and image inputs
model_inputs = self.processor(
@@ -405,7 +359,6 @@ def encode_prompt(
prompt_embeds: torch.Tensor | None = None,
prompt_embeds_mask: torch.Tensor | None = None,
max_sequence_length: int = 1024,
- prompt_name: str = "prompt",
):
r"""
@@ -425,12 +378,7 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
if prompt_embeds is None:
- prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
- prompt,
- image,
- max_sequence_length=max_sequence_length,
- prompt_name=prompt_name,
- )
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -608,7 +556,7 @@ def forward(
output_type: str | None = "pil",
attention_kwargs: dict[str, Any] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
- max_sequence_length: int = 1024,
+ max_sequence_length: int = 512,
) -> DiffusionOutput:
"""Forward pass for image editing with support for multiple images."""
# TODO: In online mode, sometimes it receives [{"negative_prompt": None}, {...}], so cannot use .get("...", "")
@@ -656,7 +604,9 @@ def forward(
height = height or calculated_height
width = width or calculated_width
- height, width = normalize_min_aligned_size(height, width, self.vae_scale_factor * 2)
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
condition_images = []
vae_images = []
@@ -743,7 +693,6 @@ def forward(
prompt_embeds_mask=negative_prompt_embeds_mask,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
- prompt_name="negative_prompt",
)
num_channels_latents = self.transformer.in_channels // 4
diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py
index 905ef5b4243..f1d28f06857 100644
--- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py
+++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py
@@ -36,12 +36,6 @@
)
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
-from vllm_omni.diffusion.utils.prompt_utils import (
- validate_prompt_sequence_lengths,
-)
-from vllm_omni.diffusion.utils.size_utils import (
- normalize_min_aligned_size,
-)
from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.inputs.data import OmniTextPrompt
from vllm_omni.model_executor.model_loader.weight_utils import (
@@ -115,7 +109,9 @@ def pre_process_func(
height = calculated_height
width = calculated_width
- height, width = normalize_min_aligned_size(height, width, vae_scale_factor * 2)
+ multiple_of = vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
# Store calculated dimensions in request
prompt["additional_information"]["calculated_height"] = calculated_height
@@ -343,10 +339,8 @@ def check_inputs(
"generate `negative_prompt_embeds`."
)
- if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
- raise ValueError(
- f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
- )
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
@@ -361,8 +355,6 @@ def _get_qwen_prompt_embeds(
prompt: str | list[str] | None = None,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
- max_sequence_length: int | None = None,
- prompt_name: str = "prompt",
):
device = device or self.device
dtype = dtype or self.text_encoder.dtype
@@ -375,26 +367,8 @@ def _get_qwen_prompt_embeds(
txt_tokens = self.tokenizer(
txt,
padding=True,
- truncation=False,
return_tensors="pt",
).to(device)
- # The layered template also appends fixed non-user tokens after the
- # editable text, so use the empty-template tokenized baseline instead of
- # counting everything after prompt_template_encode_start_idx.
- template_tokens = self.tokenizer(
- [template.format("")],
- padding=True,
- truncation=False,
- return_tensors="pt",
- ).to(device)
- validate_prompt_sequence_lengths(
- txt_tokens.attention_mask,
- max_sequence_length=max_sequence_length or self.tokenizer_max_length,
- supported_max_sequence_length=self.tokenizer_max_length,
- prompt_name=prompt_name,
- baseline_attention_mask=template_tokens.attention_mask,
- error_context="after applying the Qwen prompt template",
- )
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
@@ -424,7 +398,6 @@ def encode_prompt(
prompt_embeds: torch.Tensor | None = None,
prompt_embeds_mask: torch.Tensor | None = None,
max_sequence_length: int = 1024,
- prompt_name: str = "prompt",
):
r"""
@@ -445,11 +418,7 @@ def encode_prompt(
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
if prompt_embeds is None:
- prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
- prompt,
- max_sequence_length=max_sequence_length,
- prompt_name=prompt_name,
- )
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt)
prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
@@ -633,7 +602,7 @@ def forward(
negative_prompt_embeds_mask: torch.Tensor | None = None,
output_type: str | None = "pil",
attention_kwargs: dict[str, Any] | None = None,
- max_sequence_length: int = 1024,
+ max_sequence_length: int = 512,
resolution: int = 640,
cfg_normalize: bool = False,
use_en_prompt: bool = False,
@@ -696,7 +665,9 @@ def forward(
height = calculated_height
width = calculated_width
- height, width = normalize_min_aligned_size(height, width, self.vae_scale_factor * 2)
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
image = self.image_processor.resize(image, calculated_height, calculated_width)
@@ -766,7 +737,6 @@ def forward(
device=self.device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
- prompt_name="negative_prompt",
)
# 4. Prepare latent variables
diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py
index 88a66d7f6b0..c2115670697 100644
--- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py
+++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py
@@ -169,15 +169,12 @@ def __init__(
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
- # Time embedding MLP is kept full precision (quant_config=None) —
- # small layers that feed per-block modulation; precision-sensitive
- # (see #2728).
self.timestep_embedder.linear_1 = ReplicatedLinear(
256,
embedding_dim,
bias=True,
return_bias=False,
- quant_config=None,
+ quant_config=quant_config,
prefix="timestep_embedder.linear_1",
)
self.timestep_embedder.linear_2 = ReplicatedLinear(
@@ -185,7 +182,7 @@ def __init__(
embedding_dim,
bias=True,
return_bias=False,
- quant_config=None,
+ quant_config=quant_config,
prefix="timestep_embedder.linear_2",
)
self.use_additional_t_cond = use_additional_t_cond
@@ -704,10 +701,7 @@ def __init__(
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
- # Image processing modules.
- # Modulation linear is kept full precision (quant_config=None) — it
- # produces shift/scale/gate values that are precision-sensitive
- # (see #2728).
+ # Image processing modules
self.img_mod = nn.Sequential(
nn.SiLU(),
ReplicatedLinear(
@@ -715,7 +709,7 @@ def __init__(
6 * dim,
bias=True,
return_bias=False,
- quant_config=None,
+ quant_config=quant_config,
prefix="img_mod.1",
),
)
@@ -731,7 +725,7 @@ def __init__(
self.img_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps)
self.img_mlp = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config, prefix="img_mlp")
- # Text processing modules.
+ # Text processing modules
self.txt_mod = nn.Sequential(
nn.SiLU(),
ReplicatedLinear(
@@ -739,7 +733,7 @@ def __init__(
6 * dim,
bias=True,
return_bias=False,
- quant_config=None,
+ quant_config=quant_config,
prefix="txt_mod.1",
),
)
@@ -750,9 +744,9 @@ def __init__(
self.zero_cond_t = zero_cond_t
- def _modulate(self, mod_params, index=None):
+ def _modulate(self, x, mod_params, index=None):
"""Apply modulation to input tensor"""
- # shift: b d, scale: b d, gate: b d
+ # x: b l d, shift: b d, scale: b d, gate: b d
shift, scale, gate = mod_params.chunk(3, dim=-1)
if index is not None:
@@ -784,7 +778,7 @@ def _modulate(self, mod_params, index=None):
scale_result = scale.unsqueeze(1)
gate_result = gate.unsqueeze(1)
- return scale_result, shift_result, gate_result
+ return x * (1 + scale_result) + shift_result, gate_result
def forward(
self,
@@ -810,12 +804,10 @@ def forward(
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
# Process image stream - norm1 + modulation
- img_scale1, img_shift1, img_gate1 = self._modulate(img_mod1, modulate_index)
- img_modulated = self.img_norm1(hidden_states, img_scale1, img_shift1)
+ img_modulated, img_gate1 = self.img_norm1(hidden_states, img_mod1, modulate_index)
# Process text stream - norm1 + modulation
- txt_scale1, txt_shift1, txt_gate1 = self._modulate(txt_mod1)
- txt_modulated = self.txt_norm1(encoder_hidden_states, txt_scale1, txt_shift1)
+ txt_modulated, txt_gate1 = self.txt_norm1(encoder_hidden_states, txt_mod1)
# Use QwenAttnProcessor2_0 for joint attention computation
# This directly implements the DoubleStreamLayerMegatron logic:
@@ -840,16 +832,13 @@ def forward(
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
# Process image stream - norm2 + MLP
- img_scale2, img_shift2, img_gate2 = self._modulate(img_mod2, modulate_index)
- img_modulated2 = self.img_norm2(hidden_states, img_scale2, img_shift2)
+ img_modulated2, img_gate2 = self.img_norm2(hidden_states, img_mod2, modulate_index)
img_mlp_output = self.img_mlp(img_modulated2)
hidden_states = hidden_states + img_gate2 * img_mlp_output
# Process text stream - norm2 + MLP
- txt_scale2, txt_shift2, txt_gate2 = self._modulate(txt_mod2)
- txt_modulated2 = self.txt_norm2(encoder_hidden_states, txt_scale2, txt_shift2)
-
+ txt_modulated2, txt_gate2 = self.txt_norm2(encoder_hidden_states, txt_mod2)
txt_mlp_output = self.txt_mlp(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
@@ -893,7 +882,7 @@ class QwenImageTransformer2DModel(CachedTransformer):
# -- typically a transformer layer
# used for torch compile optimizations
_repeated_blocks = ["QwenImageTransformerBlock"]
- _layerwise_offload_blocks_attrs = ["transformer_blocks"]
+ _layerwise_offload_blocks_attr = "transformer_blocks"
packed_modules_mapping = {
"to_qkv": ["to_q", "to_k", "to_v"],
"add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"],
@@ -969,14 +958,12 @@ def __init__(
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
- # Entry projections (image/text) are kept full precision —
- # small sensitive layers at the network boundary (see #2728).
self.img_in = ReplicatedLinear(
in_channels,
self.inner_dim,
bias=True,
return_bias=False,
- quant_config=None,
+ quant_config=quant_config,
prefix="img_in",
)
self.txt_in = ReplicatedLinear(
@@ -984,7 +971,7 @@ def __init__(
self.inner_dim,
bias=True,
return_bias=False,
- quant_config=None,
+ quant_config=quant_config,
prefix="txt_in",
)
@@ -1001,16 +988,13 @@ def __init__(
]
)
- # Final modulation and output projection are kept full precision —
- # they produce the output latent and are precision-sensitive
- # (see #2728).
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.norm_out.linear = ReplicatedLinear(
self.inner_dim,
2 * self.inner_dim,
bias=True,
return_bias=False,
- quant_config=None,
+ quant_config=quant_config,
prefix="norm_out.linear",
)
self.proj_out = ReplicatedLinear(
@@ -1018,7 +1002,7 @@ def __init__(
patch_size * patch_size * self.out_channels,
bias=True,
return_bias=False,
- quant_config=None,
+ quant_config=quant_config,
prefix="proj_out",
)
diff --git a/vllm_omni/diffusion/models/schedulers/__init__.py b/vllm_omni/diffusion/models/schedulers/__init__.py
index e683ed27203..6f8df78ebf0 100644
--- a/vllm_omni/diffusion/models/schedulers/__init__.py
+++ b/vllm_omni/diffusion/models/schedulers/__init__.py
@@ -1,12 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from vllm_omni.diffusion.models.schedulers.scheduling_dmd2_euler import DMD2EulerScheduler
from vllm_omni.diffusion.models.schedulers.scheduling_flow_unipc_multistep import (
FlowUniPCMultistepScheduler,
)
__all__ = [
- "DMD2EulerScheduler",
"FlowUniPCMultistepScheduler",
]
diff --git a/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py b/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py
deleted file mode 100644
index 01447a41d77..00000000000
--- a/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py
+++ /dev/null
@@ -1,23 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from __future__ import annotations
-
-import torch
-from diffusers import FlowMatchEulerDiscreteScheduler
-
-
-class DMD2EulerScheduler(FlowMatchEulerDiscreteScheduler):
- """Euler scheduler that always uses the fixed DMD2 training timestep schedule."""
-
- def __init__(self, *args, dmd2_timesteps: list[int], **kwargs):
- super().__init__(*args, **kwargs)
- self._dmd2_timesteps = dmd2_timesteps
-
- def set_timesteps(
- self,
- num_inference_steps: int | None = None,
- device: str | torch.device | None = None,
- **kwargs,
- ) -> None:
- super().set_timesteps(timesteps=self._dmd2_timesteps, device=device)
diff --git a/vllm_omni/diffusion/models/sd3/sd3_transformer.py b/vllm_omni/diffusion/models/sd3/sd3_transformer.py
index 89f06157758..308bd35a133 100644
--- a/vllm_omni/diffusion/models/sd3/sd3_transformer.py
+++ b/vllm_omni/diffusion/models/sd3/sd3_transformer.py
@@ -387,7 +387,6 @@ class SD3Transformer2DModel(nn.Module):
"""
_repeated_blocks = ["SD3TransformerBlock"]
- _layerwise_offload_blocks_attrs = ["transformer_blocks"]
def __init__(
self,
diff --git a/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py
index a3d4dc517f7..22d56ac1fd1 100644
--- a/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py
+++ b/vllm_omni/diffusion/models/stable_audio/stable_audio_transformer.py
@@ -17,7 +17,6 @@
from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.diffusion.data import OmniDiffusionConfig
-from vllm_omni.diffusion.distributed.hsdp_utils import is_transformer_block_module
logger = init_logger(__name__)
@@ -376,9 +375,6 @@ class StableAudioDiTModel(nn.Module):
- Output: [B, out_channels, L]
"""
- _repeated_blocks = ["StableAudioDiTBlock"]
- _hsdp_shard_conditions = [is_transformer_block_module]
-
def __init__(
self,
od_config: OmniDiffusionConfig | None = None,
diff --git a/vllm_omni/diffusion/models/t5_encoder/t5_encoder.py b/vllm_omni/diffusion/models/t5_encoder/t5_encoder.py
index 0b81a71f1f3..7b0e842d055 100644
--- a/vllm_omni/diffusion/models/t5_encoder/t5_encoder.py
+++ b/vllm_omni/diffusion/models/t5_encoder/t5_encoder.py
@@ -328,7 +328,6 @@ class T5EncoderModel(nn.Module):
def __init__(self, config: T5Config, prefix: str = ""):
super().__init__()
self.config = config
- self.prefix = prefix
self.shared = VocabParallelEmbedding(config.vocab_size, config.d_model)
self.encoder = T5Stack(config, self.shared, prefix=f"{prefix}.encoder")
@@ -360,62 +359,29 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
("wi", "wi_1", 1),
]
- model_prefix = self.prefix
-
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
original_name = name
-
- if model_prefix and name.startswith(model_prefix + "."):
- name = name[len(model_prefix) + 1 :]
-
lookup_name = name
- matched = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if f".{weight_name}." not in name:
continue
- lookup_name = name.replace(f".{weight_name}.", f".{param_name}.", 1)
- if lookup_name in params_dict:
- param = params_dict[lookup_name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- matched = True
- break
-
- if not matched:
- if name in params_dict:
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, loaded_weight)
- matched = True
-
- is_embed = "encoder.embed_tokens" in lookup_name
- is_shared = lookup_name.startswith("shared.") or ".shared." in lookup_name
- target_name = None
-
- if is_embed or is_shared:
- if is_embed:
- target_name = lookup_name.replace("encoder.embed_tokens", "shared")
- else:
- target_name = lookup_name.replace("shared.", "encoder.embed_tokens.", 1)
-
- if not matched and target_name in params_dict:
- weight_loader = getattr(params_dict[target_name], "weight_loader", default_weight_loader)
- weight_loader(params_dict[target_name], loaded_weight)
- loaded_params.add(target_name)
- matched = True
-
- if not matched:
- continue
-
- if target_name is not None and target_name in params_dict and target_name not in loaded_params:
- if target_name != lookup_name:
- weight_loader = getattr(params_dict[target_name], "weight_loader", default_weight_loader)
- weight_loader(params_dict[target_name], loaded_weight)
- loaded_params.add(target_name)
+ lookup_name = name.replace(weight_name, param_name)
+ if lookup_name not in params_dict:
+ continue
+ param = params_dict[lookup_name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ if name not in params_dict:
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
loaded_params.add(original_name)
loaded_params.add(lookup_name)
diff --git a/vllm_omni/diffusion/models/t5_encoder/t5_gemma_encoder.py b/vllm_omni/diffusion/models/t5_encoder/t5_gemma_encoder.py
deleted file mode 100644
index eca4267fa20..00000000000
--- a/vllm_omni/diffusion/models/t5_encoder/t5_gemma_encoder.py
+++ /dev/null
@@ -1,309 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from __future__ import annotations
-
-from collections.abc import Iterable
-
-import torch
-import torch.nn as nn
-from transformers import PretrainedConfig
-from vllm.config import VllmConfig
-from vllm.distributed import get_tensor_model_parallel_world_size
-from vllm.model_executor.layers.activation import get_act_fn
-from vllm.model_executor.layers.linear import (
- MergedColumnParallelLinear,
- QKVParallelLinear,
- RowParallelLinear,
-)
-from vllm.model_executor.layers.rotary_embedding import get_rope
-from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
-from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-
-
-class T5GemmaRMSNorm(nn.Module):
- def __init__(self, hidden_size: int, eps: float = 1e-6):
- super().__init__()
- # Normal RMSNorm but T5Gemma requires (1 + weight)
- self.weight = nn.Parameter(torch.zeros(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return (hidden_states * (1.0 + self.weight.float())).to(input_dtype)
-
-
-class T5GemmaMLP(nn.Module):
- def __init__(
- self,
- hidden_size: int,
- intermediate_size: int,
- hidden_act: str,
- ) -> None:
- super().__init__()
- self.gate_up_proj = MergedColumnParallelLinear(
- input_size=hidden_size,
- output_sizes=[intermediate_size, intermediate_size],
- bias=False,
- gather_output=False,
- )
- self.down_proj = RowParallelLinear(
- input_size=intermediate_size,
- output_size=hidden_size,
- bias=False,
- input_is_parallel=True,
- )
- self.act_fn = get_act_fn(hidden_act)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- gate_up, _ = self.gate_up_proj(x)
- gate, up = gate_up.chunk(2, dim=-1)
- x = self.act_fn(gate) * up
- x, _ = self.down_proj(x)
- return x
-
-
-class T5GemmaAttention(nn.Module):
- def __init__(
- self,
- hidden_size: int,
- num_heads: int,
- num_kv_heads: int,
- head_dim: int,
- max_position_embeddings: int,
- rope_theta: float,
- cache_config: VllmConfig | None = None,
- quant_config: dict | None = None,
- ) -> None:
- super().__init__()
- self.hidden_size = hidden_size
- tp_size = get_tensor_model_parallel_world_size()
- self.total_num_heads = num_heads
- assert self.total_num_heads % tp_size == 0
- self.num_heads = self.total_num_heads // tp_size
- self.total_num_kv_heads = num_kv_heads
- if self.total_num_kv_heads >= tp_size:
- assert self.total_num_kv_heads % tp_size == 0
- else:
- assert tp_size % self.total_num_kv_heads == 0
- self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
- self.head_dim = head_dim
- self.q_size = self.num_heads * self.head_dim
- self.kv_size = self.num_kv_heads * self.head_dim
-
- self.qkv_proj = QKVParallelLinear(
- hidden_size=hidden_size,
- head_size=self.head_dim,
- total_num_heads=self.total_num_heads,
- total_num_kv_heads=self.total_num_kv_heads,
- bias=False,
- )
- self.o_proj = RowParallelLinear(
- input_size=self.total_num_heads * self.head_dim,
- output_size=hidden_size,
- bias=False,
- input_is_parallel=True,
- )
-
- self.rotary_emb = get_rope(
- self.head_dim,
- max_position=max_position_embeddings,
- is_neox_style=True,
- rope_parameters={"base": rope_theta},
- )
-
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- ) -> torch.Tensor:
- qkv, _ = self.qkv_proj(hidden_states)
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
-
- q, k = self.rotary_emb(positions, q, k)
-
- # Scale Q appropriately. T5Gemma uses query_pre_attn_scalar=256 => 256**-0.5 = 1/16
- # The standard scaling is head_dim**-0.5. For T5Gemma, head_dim=256.
- # So we don't need to manually scale if F.scaled_dot_product_attention scales by head_dim.
- # But we must reshape.
- batch_size, seq_len, _ = hidden_states.shape
- q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
- k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
- v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
-
- # GQA repeat KV
- if self.num_kv_heads != self.num_heads:
- num_repeat = self.num_heads // self.num_kv_heads
- k = k.repeat_interleave(num_repeat, dim=1)
- v = v.repeat_interleave(num_repeat, dim=1)
-
- attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, dropout_p=0.0)
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.view(batch_size, seq_len, self.q_size)
-
- output, _ = self.o_proj(attn_output)
- return output
-
-
-class T5GemmaEncoderLayer(nn.Module):
- def __init__(self, config: PretrainedConfig) -> None:
- super().__init__()
- self.self_attn = T5GemmaAttention(
- hidden_size=config.hidden_size,
- num_heads=config.num_attention_heads,
- num_kv_heads=config.num_key_value_heads,
- head_dim=config.head_dim,
- max_position_embeddings=config.max_position_embeddings,
- rope_theta=config.rope_theta,
- )
- self.mlp = T5GemmaMLP(
- hidden_size=config.hidden_size,
- intermediate_size=config.intermediate_size,
- hidden_act=config.hidden_activation,
- )
- self.pre_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.pre_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor | None,
- ) -> torch.Tensor:
- # Self Attention
- residual = hidden_states
- hidden_states = self.pre_self_attn_layernorm(hidden_states)
- hidden_states = self.self_attn(
- positions=positions,
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- )
- hidden_states = self.post_self_attn_layernorm(hidden_states)
- hidden_states = residual + hidden_states
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.pre_feedforward_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = self.post_feedforward_layernorm(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
-
-
-class T5GemmaEncoderModelTP(nn.Module):
- def __init__(self, config: PretrainedConfig) -> None:
- super().__init__()
- self.config = config
- self.vocab_size = config.vocab_size
-
- self.embed_tokens = VocabParallelEmbedding(
- config.vocab_size,
- config.hidden_size,
- )
-
- self.layers = nn.ModuleList([T5GemmaEncoderLayer(config) for _ in range(config.num_hidden_layers)])
- self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- @property
- def dtype(self) -> torch.dtype:
- return next(self.parameters()).dtype
-
- @property
- def device(self) -> torch.device:
- return next(self.parameters()).device
-
- def forward(
- self,
- input_ids: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- ) -> torch.Tensor:
- hidden_states = self.embed_tokens(input_ids)
-
- # Scaling inputs
- normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype, device=hidden_states.device)
- hidden_states = hidden_states * normalizer
-
- # Simple position ids for RoPE
- batch_size, seq_len = input_ids.shape
- positions = torch.arange(seq_len, device=input_ids.device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)
-
- # Build attention mask: (batch, seq) -> (batch, 1, 1, seq)
- # Assuming typical bidirectional causal mask handling in HF: T5Gemma uses non-causal encoder.
- if attention_mask is not None:
- # HuggingFace expects boolean mask for scaled_dot_product_attention
- # or additive mask (0 and -inf). Let's use boolean matching FA patterns.
- # SDPA expects attention_mask to be boolean (True = keep, False = masking)
- bool_mask = attention_mask.to(torch.bool)
- extended_mask = bool_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, S)
- else:
- extended_mask = None
-
- for idx, layer in enumerate(self.layers):
- # T5Gemma has layer_types switching between "sliding_attention" and "full_attention"
- # However, for text encoder inference, the sequences are typically < max sequence length
- # and local sliding window only affects very long contexts. For simplicity we use full.
- hidden_states = layer(
- positions=positions,
- hidden_states=hidden_states,
- attention_mask=extended_mask,
- )
-
- hidden_states = self.norm(hidden_states)
- return hidden_states
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- stacked_params_mapping = [
- # (param_name, shard_name, shard_id)
- ("qkv_proj", "q_proj", "q"),
- ("qkv_proj", "k_proj", "k"),
- ("qkv_proj", "v_proj", "v"),
- ("gate_up_proj", "gate_proj", 0),
- ("gate_up_proj", "up_proj", 1),
- ]
-
- params_dict = dict(self.named_parameters())
- loaded_params: set[str] = set()
-
- for name, loaded_weight in weights:
- # HF checkpoint keys may carry a "model." prefix (e.g.
- # "model.encoder.layers.0..."). Strip it so the rest of the
- # logic only needs to handle the "encoder.*" namespace.
- if name.startswith("model."):
- name = name[len("model.") :]
-
- if not name.startswith("encoder."):
- continue
-
- # Strip "encoder." prefix as this model only wraps the encoder
- name = name[len("encoder.") :]
-
- # Map self_attn to self_attn and correct normalization names
- # HF: layers.0.pre_self_attn_layernorm.weight -> Ours: layers.0.pre_self_attn_layernorm.weight
-
- lookup_name = name
- for param_name, weight_name, shard_id in stacked_params_mapping:
- if f".{weight_name}." not in name:
- continue
- lookup_name = name.replace(f".{weight_name}.", f".{param_name}.")
- if lookup_name not in params_dict:
- continue
- param = params_dict[lookup_name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- if name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, loaded_weight)
-
- loaded_params.add("encoder." + name)
- loaded_params.add("encoder." + lookup_name)
-
- return loaded_params
diff --git a/vllm_omni/diffusion/models/utils.py b/vllm_omni/diffusion/models/utils.py
deleted file mode 100644
index ba0d8dda20c..00000000000
--- a/vllm_omni/diffusion/models/utils.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from __future__ import annotations
-
-import json
-import os
-
-
-def _load_json(model_path: str, filename: str, local_files_only: bool = True) -> dict:
- """Load a JSON config file from a local path or HuggingFace Hub repo."""
- if local_files_only:
- path = os.path.join(model_path, *filename.split("/"))
- with open(path) as f:
- return json.load(f)
- else:
- from huggingface_hub import hf_hub_download
-
- cached = hf_hub_download(repo_id=model_path, filename=filename)
- with open(cached) as f:
- return json.load(f)
diff --git a/vllm_omni/diffusion/models/wan2_2/__init__.py b/vllm_omni/diffusion/models/wan2_2/__init__.py
index ecbfdd47219..d418001d952 100644
--- a/vllm_omni/diffusion/models/wan2_2/__init__.py
+++ b/vllm_omni/diffusion/models/wan2_2/__init__.py
@@ -1,10 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from .patch_diffusers import patch_wan_rms_norm
from .pipeline_wan2_2 import (
Wan22Pipeline,
- WanT2VDMD2Pipeline,
create_transformer_from_config,
get_wan22_post_process_func,
get_wan22_pre_process_func,
@@ -13,7 +11,6 @@
)
from .pipeline_wan2_2_i2v import (
Wan22I2VPipeline,
- WanI2VDMD2Pipeline,
get_wan22_i2v_post_process_func,
get_wan22_i2v_pre_process_func,
)
@@ -31,7 +28,6 @@
from .wan2_2_vace_transformer import VaceWanTransformerBlock, WanVACETransformer3DModel
__all__ = [
- "WanT2VDMD2Pipeline",
"Wan22Pipeline",
"get_wan22_post_process_func",
"get_wan22_pre_process_func",
@@ -39,7 +35,6 @@
"load_transformer_config",
"create_transformer_from_config",
"Wan22I2VPipeline",
- "WanI2VDMD2Pipeline",
"get_wan22_i2v_post_process_func",
"get_wan22_i2v_pre_process_func",
"Wan22TI2VPipeline",
@@ -52,5 +47,3 @@
"VaceWanTransformerBlock",
"WanVACETransformer3DModel",
]
-
-patch_wan_rms_norm()
diff --git a/vllm_omni/diffusion/models/wan2_2/patch_diffusers.py b/vllm_omni/diffusion/models/wan2_2/patch_diffusers.py
deleted file mode 100644
index 9b590fc5e0a..00000000000
--- a/vllm_omni/diffusion/models/wan2_2/patch_diffusers.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import sys
-
-from vllm_omni.diffusion.layers.norm import RMSNormVAE
-
-
-def patch_wan_rms_norm():
- """Patch diffusers Wan RMSNorm implementation with RMSNormVAE."""
-
- for module_name, module in sys.modules.items():
- if hasattr(module, "WanRMS_norm"):
- setattr(module, "WanRMS_norm", RMSNormVAE)
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
index 34045a8aaa1..ea4b90a9bbf 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
@@ -22,12 +22,9 @@
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
-from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero
from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler
-from vllm_omni.diffusion.models.wan2_2.scheduling_wan_euler import WanEulerScheduler
from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel
-from vllm_omni.diffusion.postprocess import interpolate_video_tensor
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.inputs.data import OmniTextPrompt
@@ -35,46 +32,6 @@
logger = logging.getLogger(__name__)
DEBUG_PERF = False
-WAN_SAMPLE_SOLVER_CHOICES = {"unipc", "euler"}
-
-
-def build_wan_scheduler(sample_solver: str, flow_shift: float) -> Any:
- if sample_solver == "unipc":
- return FlowUniPCMultistepScheduler(
- num_train_timesteps=1000,
- shift=flow_shift,
- prediction_type="flow_prediction",
- )
- if sample_solver == "euler":
- return WanEulerScheduler(
- num_train_timesteps=1000,
- shift=flow_shift,
- )
-
- raise ValueError(
- f"Unsupported Wan sample_solver: {sample_solver}. Expected one of: {sorted(WAN_SAMPLE_SOLVER_CHOICES)}"
- )
-
-
-def resolve_wan_sample_solver(req: OmniDiffusionRequest, default: str = "unipc") -> str:
- extra_args = getattr(req.sampling_params, "extra_args", {}) or {}
- raw = extra_args.get("sample_solver", default)
- sample_solver = str(raw).strip().lower()
- if sample_solver not in WAN_SAMPLE_SOLVER_CHOICES:
- raise ValueError(f"Invalid sample_solver={raw!r}. Expected one of: {sorted(WAN_SAMPLE_SOLVER_CHOICES)}")
- return sample_solver
-
-
-def resolve_wan_flow_shift(req: OmniDiffusionRequest, od_config: OmniDiffusionConfig) -> float:
- extra_args = getattr(req.sampling_params, "extra_args", {}) or {}
- raw_flow_shift = extra_args.get("flow_shift")
- if raw_flow_shift is None:
- raw_flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0
-
- try:
- return float(raw_flow_shift)
- except (TypeError, ValueError) as exc:
- raise ValueError(f"Invalid flow_shift={raw_flow_shift!r}. flow_shift must be a float.") from exc
def retrieve_latents(
@@ -164,23 +121,10 @@ def get_wan22_post_process_func(
def post_process_func(
video: torch.Tensor,
output_type: str = "np",
- sampling_params=None,
):
if output_type == "latent":
return video
- custom_output = {}
- if sampling_params is not None and getattr(sampling_params, "enable_frame_interpolation", False):
- video, multiplier = interpolate_video_tensor(
- video,
- exp=sampling_params.frame_interpolation_exp,
- scale=sampling_params.frame_interpolation_scale,
- model_path=sampling_params.frame_interpolation_model_path,
- )
- custom_output["video_fps_multiplier"] = multiplier
- return {
- "video": video_processor.postprocess_video(video, output_type=output_type),
- "custom_output": custom_output,
- }
+ return video_processor.postprocess_video(video, output_type=output_type)
return post_process_func
@@ -190,6 +134,9 @@ def get_wan22_pre_process_func(
):
"""Pre-process function for Wan2.2: optionally load and resize input image for I2V mode."""
import numpy as np
+ from diffusers.video_processor import VideoProcessor
+
+ video_processor = VideoProcessor(vae_scale_factor=8)
def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest:
for i, prompt in enumerate(request.prompts):
@@ -233,6 +180,10 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest:
)
prompt["multi_modal_data"]["image"] = image # type: ignore # key existence already checked above
+ # Preprocess for VAE
+ prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess(
+ image, height=request.sampling_params.height, width=request.sampling_params.width
+ )
request.prompts[i] = prompt
return request
@@ -357,9 +308,13 @@ def __init__(
else:
raise RuntimeError("No transformer loaded")
- self._sample_solver = "unipc"
- self._flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0
- self.scheduler = build_wan_scheduler(self._sample_solver, self._flow_shift)
+ # Initialize UniPC scheduler
+ flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p
+ self.scheduler = FlowUniPCMultistepScheduler(
+ num_train_timesteps=1000,
+ shift=flow_shift,
+ prediction_type="flow_prediction",
+ )
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
@@ -393,102 +348,6 @@ def num_timesteps(self):
def current_timestep(self):
return self._current_timestep
- def diffuse(
- self,
- latents: torch.Tensor,
- timesteps: torch.Tensor,
- prompt_embeds: torch.Tensor,
- negative_prompt_embeds: torch.Tensor | None,
- guidance_low: float,
- guidance_high: float,
- boundary_timestep: float | None,
- dtype: torch.dtype,
- attention_kwargs: dict[str, Any],
- latent_condition: torch.Tensor | None = None,
- first_frame_mask: torch.Tensor | None = None,
- ) -> torch.Tensor:
- with self.progress_bar(total=len(timesteps)) as pbar:
- for t in timesteps:
- self._current_timestep = t
-
- # Select model based on timestep and boundary_ratio
- # High noise stage (t >= boundary_timestep): use transformer
- # Low noise stage (t < boundary_timestep): use transformer_2
- if boundary_timestep is not None and t < boundary_timestep:
- # Low noise stage - always use guidance_high for this stage
- current_guidance_scale = guidance_high
- if self.transformer_2 is not None:
- current_model = self.transformer_2
- elif self.transformer is not None:
- # Fallback to transformer if transformer_2 not loaded
- current_model = self.transformer
- else:
- raise RuntimeError("No transformer available for low-noise stage")
- else:
- # High noise stage - always use guidance_low for this stage
- current_guidance_scale = guidance_low
- if self.transformer is not None:
- current_model = self.transformer
- elif self.transformer_2 is not None:
- # Fallback to transformer_2 if transformer not loaded
- current_model = self.transformer_2
- else:
- raise RuntimeError("No transformer available for high-noise stage")
-
- if self.expand_timesteps and latent_condition is not None:
- # I2V mode: blend condition with latents using mask
- latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents
- latent_model_input = latent_model_input.to(dtype)
-
- # Expand timesteps per patch - use floor division to match patch embedding
- patch_size = self.transformer_config.patch_size
- patch_height = latents.shape[3] // patch_size[1]
- patch_width = latents.shape[4] // patch_size[2]
-
- # Create mask at patch resolution (same as hidden states sequence length)
- patch_mask = first_frame_mask[:, :, :, :: patch_size[1], :: patch_size[2]]
- patch_mask = patch_mask[:, :, :, :patch_height, :patch_width] # Ensure correct dimensions
- temp_ts = (patch_mask[0][0] * t).flatten()
- timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
- else:
- # T2V mode: standard forward
- latent_model_input = latents.to(dtype)
- timestep = t.expand(latents.shape[0])
-
- do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None
- positive_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": current_model,
- }
- if do_true_cfg:
- negative_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": negative_prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": current_model,
- }
- else:
- negative_kwargs = None
-
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale=current_guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=False,
- )
-
- latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
- pbar.update()
-
- return latents
-
def forward(
self,
req: OmniDiffusionRequest,
@@ -620,13 +479,6 @@ def forward(
current_omni_platform.synchronize()
_t_text_enc_ms = (time.perf_counter() - _t_text_enc_start) * 1000
- sample_solver = resolve_wan_sample_solver(req, default=self._sample_solver)
- flow_shift = resolve_wan_flow_shift(req, self.od_config)
- if sample_solver != self._sample_solver or abs(flow_shift - self._flow_shift) > 1e-6:
- self.scheduler = build_wan_scheduler(sample_solver, flow_shift)
- self._sample_solver = sample_solver
- self._flow_shift = flow_shift
-
# Timesteps
self.scheduler.set_timesteps(num_steps, device=device)
timesteps = self.scheduler.timesteps
@@ -736,19 +588,90 @@ def forward(
if DEBUG_PERF:
_t_denoise_start = time.perf_counter()
- latents = self.diffuse(
- latents=latents,
- timesteps=timesteps,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- guidance_low=guidance_low,
- guidance_high=guidance_high,
- boundary_timestep=boundary_timestep,
- dtype=dtype,
- attention_kwargs=attention_kwargs,
- latent_condition=latent_condition,
- first_frame_mask=first_frame_mask,
- )
+ with self.progress_bar(total=len(timesteps)) as pbar:
+ for t in timesteps:
+ self._current_timestep = t
+
+ # Select model based on timestep and boundary_ratio
+ # High noise stage (t >= boundary_timestep): use transformer
+ # Low noise stage (t < boundary_timestep): use transformer_2
+ if boundary_timestep is not None and t < boundary_timestep:
+ # Low noise stage - always use guidance_high for this stage
+ current_guidance_scale = guidance_high
+ if self.transformer_2 is not None:
+ current_model = self.transformer_2
+ elif self.transformer is not None:
+ # Fallback to transformer if transformer_2 not loaded
+ current_model = self.transformer
+ else:
+ raise RuntimeError("No transformer available for low-noise stage")
+ else:
+ # High noise stage - always use guidance_low for this stage
+ current_guidance_scale = guidance_low
+ if self.transformer is not None:
+ current_model = self.transformer
+ elif self.transformer_2 is not None:
+ # Fallback to transformer_2 if transformer not loaded
+ current_model = self.transformer_2
+ else:
+ raise RuntimeError("No transformer available for high-noise stage")
+
+ if self.expand_timesteps and latent_condition is not None:
+ # I2V mode: blend condition with latents using mask
+ latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents
+ latent_model_input = latent_model_input.to(dtype)
+
+ # Expand timesteps per patch - use floor division to match patch embedding
+ patch_size = self.transformer_config.patch_size
+ num_latent_frames = latents.shape[2]
+ patch_height = latents.shape[3] // patch_size[1]
+ patch_width = latents.shape[4] // patch_size[2]
+
+ # Create mask at patch resolution (same as hidden states sequence length)
+ patch_mask = first_frame_mask[:, :, :, :: patch_size[1], :: patch_size[2]]
+ patch_mask = patch_mask[:, :, :, :patch_height, :patch_width] # Ensure correct dimensions
+ temp_ts = (patch_mask[0][0] * t).flatten()
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ # T2V mode: standard forward
+ latent_model_input = latents.to(dtype)
+ timestep = t.expand(latents.shape[0])
+
+ do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None
+ # Prepare kwargs for positive and negative predictions
+ positive_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": current_model,
+ }
+ if do_true_cfg:
+ negative_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": negative_prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": current_model,
+ }
+ else:
+ negative_kwargs = None
+
+ # Predict noise with automatic CFG parallel handling
+ noise_pred = self.predict_noise_maybe_with_cfg(
+ do_true_cfg=do_true_cfg,
+ true_cfg_scale=current_guidance_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ cfg_normalize=False,
+ )
+
+ # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync
+ latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
+
+ pbar.update()
# Wan2.2 is prone to out of memory errors when predicting large videos
# so we empty the cache here to avoid OOM before vae decoding.
@@ -985,16 +908,3 @@ def check_inputs(
if boundary_ratio is None and guidance_scale_2 is not None:
raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.")
-
-
-# ---------------------------------------------------------------------------
-# DMD2-distilled variant
-# ---------------------------------------------------------------------------
-
-
-class WanT2VDMD2Pipeline(DMD2PipelineMixin, Wan22Pipeline):
- """Wan 2.x T2V pipeline for FastGen DMD2-distilled models."""
-
- def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
- super().__init__(od_config=od_config, prefix=prefix)
- self.__init_dmd2__()
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
index 373350c70eb..1e8a94eb3c1 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
@@ -12,7 +12,6 @@
import numpy as np
import PIL.Image
import torch
-import torchvision.transforms.functional as TF
from diffusers.utils.torch_utils import randn_tensor
from torch import nn
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
@@ -23,19 +22,14 @@
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
-from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero
-from vllm_omni.diffusion.models.utils import _load_json
+from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
- build_wan_scheduler,
create_transformer_from_config,
load_transformer_config,
- resolve_wan_flow_shift,
- resolve_wan_sample_solver,
retrieve_latents,
)
-from vllm_omni.diffusion.postprocess import interpolate_video_tensor
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.inputs.data import OmniTextPrompt
@@ -45,6 +39,29 @@
DEBUG_PERF = False
+def _load_model_index(model: str, local_files_only: bool) -> dict:
+ """Load model_index.json from local path or HF Hub."""
+ if local_files_only:
+ model_index_path = os.path.join(model, "model_index.json")
+ if os.path.exists(model_index_path):
+ import json
+
+ with open(model_index_path) as f:
+ return json.load(f)
+ else:
+ try:
+ import json
+
+ from huggingface_hub import hf_hub_download
+
+ model_index_path = hf_hub_download(model, "model_index.json")
+ with open(model_index_path) as f:
+ return json.load(f)
+ except Exception:
+ pass
+ return {}
+
+
def get_wan22_i2v_post_process_func(
od_config: OmniDiffusionConfig,
):
@@ -55,23 +72,10 @@ def get_wan22_i2v_post_process_func(
def post_process_func(
video: torch.Tensor,
output_type: str = "np",
- sampling_params=None,
):
if output_type == "latent":
return video
- custom_output = {}
- if sampling_params is not None and getattr(sampling_params, "enable_frame_interpolation", False):
- video, multiplier = interpolate_video_tensor(
- video,
- exp=sampling_params.frame_interpolation_exp,
- scale=sampling_params.frame_interpolation_scale,
- model_path=sampling_params.frame_interpolation_model_path,
- )
- custom_output["video_fps_multiplier"] = multiplier
- return {
- "video": video_processor.postprocess_video(video, output_type=output_type),
- "custom_output": custom_output,
- }
+ return video_processor.postprocess_video(video, output_type=output_type)
return post_process_func
@@ -80,6 +84,9 @@ def get_wan22_i2v_pre_process_func(
od_config: OmniDiffusionConfig,
):
"""Pre-process function for I2V: load and resize input image."""
+ from diffusers.video_processor import VideoProcessor
+
+ video_processor = VideoProcessor(vae_scale_factor=8)
def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest:
for i, prompt in enumerate(request.prompts):
@@ -125,6 +132,10 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest:
)
prompt["multi_modal_data"]["image"] = image # type: ignore # key existence already checked above
+ # Preprocess for VAE
+ prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess(
+ image, height=request.sampling_params.height, width=request.sampling_params.width
+ )
request.prompts[i] = prompt
return request
@@ -168,10 +179,7 @@ def __init__(
]
# Load model_index.json to detect available components
- try:
- model_index = _load_json(model, "model_index.json", local_files_only)
- except Exception:
- model_index = {}
+ model_index = _load_model_index(model, local_files_only)
# Check if this is a two-stage model (MoE with transformer_2)
self.has_transformer_2 = "transformer_2" in model_index
@@ -209,7 +217,7 @@ def __init__(
# VAE
self.vae = DistributedAutoencoderKLWan.from_pretrained(
- model, subfolder="vae", torch_dtype=dtype, local_files_only=local_files_only
+ model, subfolder="vae", torch_dtype=torch.float32, local_files_only=local_files_only
).to(self.device)
# Transformers (weights loaded via load_weights)
@@ -222,9 +230,13 @@ def __init__(
else:
self.transformer_2 = None
- self._sample_solver = "unipc"
- self._flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0
- self.scheduler = build_wan_scheduler(self._sample_solver, self._flow_shift)
+ # Initialize UniPC scheduler
+ flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p
+ self.scheduler = FlowUniPCMultistepScheduler(
+ num_train_timesteps=1000,
+ shift=flow_shift,
+ prediction_type="flow_prediction",
+ )
# VAE scale factors
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if hasattr(self.vae, "config") else 4
@@ -260,82 +272,6 @@ def num_timesteps(self):
def current_timestep(self):
return self._current_timestep
- def diffuse(
- self,
- latents: torch.Tensor,
- timesteps: torch.Tensor,
- prompt_embeds: torch.Tensor,
- negative_prompt_embeds: torch.Tensor | None,
- image_embeds: torch.Tensor | None,
- guidance_low: float,
- guidance_high: float,
- boundary_timestep: float | None,
- dtype: torch.dtype,
- attention_kwargs: dict[str, Any],
- condition: torch.Tensor,
- first_frame_mask: torch.Tensor,
- ) -> torch.Tensor:
- with self.progress_bar(total=len(timesteps)) as pbar:
- for t in timesteps:
- self._current_timestep = t
-
- # Select model and guidance scale based on timestep
- current_model = self.transformer
- current_guidance_scale = guidance_low
- if boundary_timestep is not None and t < boundary_timestep and self.transformer_2 is not None:
- current_model = self.transformer_2
- current_guidance_scale = guidance_high
-
- # Prepare latent input
- if self.expand_timesteps:
- # TI2V-5B style: blend condition with latents using mask
- latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
- latent_model_input = latent_model_input.to(dtype)
-
- # Expand timesteps for each patch
- temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
- timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
- else:
- # Wan2.1 style: concatenate condition with latents
- latent_model_input = torch.cat([latents, condition], dim=1).to(dtype)
- timestep = t.expand(latents.shape[0])
-
- do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None
- positive_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": prompt_embeds,
- "encoder_hidden_states_image": image_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": current_model,
- }
- if do_true_cfg:
- negative_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": negative_prompt_embeds,
- "encoder_hidden_states_image": image_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": current_model,
- }
- else:
- negative_kwargs = None
-
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale=current_guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=False,
- )
-
- latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
- pbar.update()
-
- return latents
-
def encode_image(
self,
image: PIL.Image.Image | list[PIL.Image.Image],
@@ -504,13 +440,6 @@ def forward(
current_omni_platform.synchronize()
_t_img_enc_ms = (time.perf_counter() - _t_img_enc_start) * 1000
- sample_solver = resolve_wan_sample_solver(req, default=self._sample_solver)
- flow_shift = resolve_wan_flow_shift(req, self.od_config)
- if sample_solver != self._sample_solver or abs(flow_shift - self._flow_shift) > 1e-6:
- self.scheduler = build_wan_scheduler(sample_solver, flow_shift)
- self._sample_solver = sample_solver
- self._flow_shift = flow_shift
-
# Timesteps
self.scheduler.set_timesteps(num_steps, device=device)
timesteps = self.scheduler.timesteps
@@ -530,7 +459,6 @@ def forward(
video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
if isinstance(image, PIL.Image.Image):
- image = TF.to_tensor(image).to(device)
image_tensor = video_processor.preprocess(image, height=height, width=width)
else:
image_tensor = image
@@ -539,7 +467,6 @@ def forward(
# Handle last_image if provided
if last_image is not None:
if isinstance(last_image, PIL.Image.Image):
- image = TF.to_tensor(last_image).to(device)
last_image_tensor = video_processor.preprocess(last_image, height=height, width=width)
else:
last_image_tensor = last_image
@@ -570,20 +497,68 @@ def forward(
if DEBUG_PERF:
_t_denoise_start = time.perf_counter()
- latents = self.diffuse(
- latents=latents,
- timesteps=timesteps,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- image_embeds=image_embeds,
- guidance_low=guidance_low,
- guidance_high=guidance_high,
- boundary_timestep=boundary_timestep,
- dtype=dtype,
- attention_kwargs=attention_kwargs,
- condition=condition,
- first_frame_mask=first_frame_mask,
- )
+ with self.progress_bar(total=len(timesteps)) as pbar:
+ for t in timesteps:
+ self._current_timestep = t
+
+ # Select model and guidance scale based on timestep
+ current_model = self.transformer
+ current_guidance_scale = guidance_low
+ if boundary_timestep is not None and t < boundary_timestep and self.transformer_2 is not None:
+ current_model = self.transformer_2
+ current_guidance_scale = guidance_high
+
+ # Prepare latent input
+ if self.expand_timesteps:
+ # TI2V-5B style: blend condition with latents using mask
+ latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
+ latent_model_input = latent_model_input.to(dtype)
+
+ # Expand timesteps for each patch
+ temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ # Wan2.1 style: concatenate condition with latents
+ latent_model_input = torch.cat([latents, condition], dim=1).to(dtype)
+ timestep = t.expand(latents.shape[0])
+
+ do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None
+ # Prepare kwargs for positive and negative predictions
+ positive_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": prompt_embeds,
+ "encoder_hidden_states_image": image_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": current_model,
+ }
+ if do_true_cfg:
+ negative_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": negative_prompt_embeds,
+ "encoder_hidden_states_image": image_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": current_model,
+ }
+ else:
+ negative_kwargs = None
+
+ # Predict noise with automatic CFG parallel handling
+ noise_pred = self.predict_noise_maybe_with_cfg(
+ do_true_cfg=do_true_cfg,
+ true_cfg_scale=current_guidance_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ cfg_normalize=False,
+ )
+
+ # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync
+ latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
+
+ pbar.update()
# Wan2.2 is prone to out of memory errors when predicting large videos
# so we empty the cache here to avoid OOM before vae decoding.
@@ -814,14 +789,12 @@ def prepare_latents(
return latents, latent_condition, first_frame_mask
# Wan2.1 style: create mask and concatenate with condition
- mask_lat_size = torch.ones(
- batch_size, 1, num_frames, latent_height, latent_width, device=latent_condition.device
- )
+ mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
if last_image is None:
- mask_lat_size[:, :, 1:] = 0
+ mask_lat_size[:, :, list(range(1, num_frames))] = 0
else:
- mask_lat_size[:, :, 1 : num_frames - 1] = 0
+ mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
@@ -878,16 +851,3 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights using AutoWeightsLoader for vLLM integration."""
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
-
-
-# ---------------------------------------------------------------------------
-# DMD2-distilled variant
-# ---------------------------------------------------------------------------
-
-
-class WanI2VDMD2Pipeline(DMD2PipelineMixin, Wan22I2VPipeline):
- """Wan 2.x I2V pipeline for FastGen DMD2-distilled models."""
-
- def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
- super().__init__(od_config=od_config, prefix=prefix)
- self.__init_dmd2__()
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
index 04d5c4c9571..f116834cf28 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
@@ -24,27 +24,24 @@
import numpy as np
import PIL.Image
import torch
+from diffusers import AutoencoderKLWan
from diffusers.utils.torch_utils import randn_tensor
from torch import nn
from transformers import AutoTokenizer, UMT5EncoderModel
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
-from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_wan import OmniAutoencoderKLWan
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
+from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
- build_wan_scheduler,
create_transformer_from_config,
load_transformer_config,
- resolve_wan_flow_shift,
- resolve_wan_sample_solver,
retrieve_latents,
)
-from vllm_omni.diffusion.postprocess import interpolate_video_tensor
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.inputs.data import OmniTextPrompt
from vllm_omni.platforms import current_omni_platform
@@ -62,23 +59,10 @@ def get_wan22_ti2v_post_process_func(
def post_process_func(
video: torch.Tensor,
output_type: str = "np",
- sampling_params=None,
):
if output_type == "latent":
return video
- custom_output = {}
- if sampling_params is not None and getattr(sampling_params, "enable_frame_interpolation", False):
- video, multiplier = interpolate_video_tensor(
- video,
- exp=sampling_params.frame_interpolation_exp,
- scale=sampling_params.frame_interpolation_scale,
- model_path=sampling_params.frame_interpolation_model_path,
- )
- custom_output["video_fps_multiplier"] = multiplier
- return {
- "video": video_processor.postprocess_video(video, output_type=output_type),
- "custom_output": custom_output,
- }
+ return video_processor.postprocess_video(video, output_type=output_type)
return post_process_func
@@ -87,6 +71,9 @@ def get_wan22_ti2v_pre_process_func(
od_config: OmniDiffusionConfig,
):
"""Pre-process function for TI2V: optionally load and resize input image."""
+ from diffusers.video_processor import VideoProcessor
+
+ video_processor = VideoProcessor(vae_scale_factor=8)
def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest:
for i, prompt in enumerate(request.prompts):
@@ -132,6 +119,10 @@ def pre_process_func(request: OmniDiffusionRequest) -> OmniDiffusionRequest:
)
prompt["multi_modal_data"]["image"] = image # type: ignore # key existence already checked above
+ # Preprocess for VAE
+ prompt["additional_information"]["preprocessed_image"] = video_processor.preprocess(
+ image, height=request.sampling_params.height, width=request.sampling_params.width
+ )
request.prompts[i] = prompt
return request
@@ -183,8 +174,8 @@ def __init__(
).to(self.device)
# VAE
- self.vae = OmniAutoencoderKLWan.from_pretrained(
- model, subfolder="vae", torch_dtype=dtype, local_files_only=local_files_only
+ self.vae = AutoencoderKLWan.from_pretrained(
+ model, subfolder="vae", torch_dtype=torch.float32, local_files_only=local_files_only
).to(self.device)
# Single transformer (TI2V uses dense 5B model, not MoE)
@@ -192,9 +183,13 @@ def __init__(
transformer_config = load_transformer_config(model, "transformer", local_files_only)
self.transformer = create_transformer_from_config(transformer_config)
- self._sample_solver = "unipc"
- self._flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0
- self.scheduler = build_wan_scheduler(self._sample_solver, self._flow_shift)
+ # Initialize UniPC scheduler
+ flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p
+ self.scheduler = FlowUniPCMultistepScheduler(
+ num_train_timesteps=1000,
+ shift=flow_shift,
+ prediction_type="flow_prediction",
+ )
# VAE scale factors
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if hasattr(self.vae, "config") else 4
@@ -223,77 +218,6 @@ def num_timesteps(self):
def current_timestep(self):
return self._current_timestep
- def diffuse(
- self,
- latents: torch.Tensor,
- timesteps: torch.Tensor,
- prompt_embeds: torch.Tensor,
- negative_prompt_embeds: torch.Tensor | None,
- guidance_scale: float,
- dtype: torch.dtype,
- attention_kwargs: dict[str, Any],
- num_latent_frames: int,
- latent_height: int,
- latent_width: int,
- latent_condition: torch.Tensor | None = None,
- first_frame_mask: torch.Tensor | None = None,
- ) -> torch.Tensor:
- with self.progress_bar(total=len(timesteps)) as pbar:
- for t in timesteps:
- self._current_timestep = t
-
- # Prepare latent input
- if latent_condition is not None:
- # I2V mode: blend condition with latents using mask
- latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents
- latent_model_input = latent_model_input.to(dtype)
-
- # Expand timesteps for each patch (TI2V style)
- temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
- timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
- else:
- # T2V mode: use latents directly
- latent_model_input = latents.to(dtype)
-
- # Expand timesteps for TI2V model architecture
- mask = torch.ones(1, 1, num_latent_frames, latent_height, latent_width, device=latents.device)
- temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
- timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
-
- do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None
- positive_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": self.transformer,
- }
- if do_true_cfg:
- negative_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": negative_prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": self.transformer,
- }
- else:
- negative_kwargs = None
-
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale=guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=False,
- )
-
- latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
- pbar.update()
-
- return latents
-
def forward(
self,
req: OmniDiffusionRequest,
@@ -399,13 +323,6 @@ def forward(
batch_size = prompt_embeds.shape[0]
- sample_solver = resolve_wan_sample_solver(req, default=self._sample_solver)
- flow_shift = resolve_wan_flow_shift(req, self.od_config)
- if sample_solver != self._sample_solver or abs(flow_shift - self._flow_shift) > 1e-6:
- self.scheduler = build_wan_scheduler(sample_solver, flow_shift)
- self._sample_solver = sample_solver
- self._flow_shift = flow_shift
-
# Timesteps
self.scheduler.set_timesteps(num_steps, device=device)
timesteps = self.scheduler.timesteps
@@ -463,20 +380,64 @@ def forward(
if attention_kwargs is None:
attention_kwargs = {}
- latents = self.diffuse(
- latents=latents,
- timesteps=timesteps,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- guidance_scale=guidance_scale,
- dtype=dtype,
- attention_kwargs=attention_kwargs,
- num_latent_frames=num_latent_frames,
- latent_height=latent_height,
- latent_width=latent_width,
- latent_condition=latent_condition,
- first_frame_mask=first_frame_mask,
- )
+ # Denoising loop
+ with self.progress_bar(total=len(timesteps)) as pbar:
+ for t in timesteps:
+ self._current_timestep = t
+
+ # Prepare latent input
+ if latent_condition is not None:
+ # I2V mode: blend condition with latents using mask
+ latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents
+ latent_model_input = latent_model_input.to(dtype)
+
+ # Expand timesteps for each patch (TI2V style)
+ temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ # T2V mode: use latents directly
+ latent_model_input = latents.to(dtype)
+
+ # Expand timesteps for TI2V model architecture
+ mask = torch.ones(1, 1, num_latent_frames, latent_height, latent_width, device=device)
+ temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+
+ do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None
+ # Prepare kwargs for positive and negative predictions
+ positive_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": self.transformer,
+ }
+ if do_true_cfg:
+ negative_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": negative_prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": self.transformer,
+ }
+ else:
+ negative_kwargs = None
+
+ # Predict noise with automatic CFG parallel handling
+ noise_pred = self.predict_noise_maybe_with_cfg(
+ do_true_cfg=do_true_cfg,
+ true_cfg_scale=guidance_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ cfg_normalize=False,
+ )
+
+ # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync
+ latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
+
+ pbar.update()
# Wan2.2 is prone to out of memory errors when predicting large videos
# so we empty the cache here to avoid OOM before vae decoding.
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py
index 0458f88597e..ea523363111 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py
@@ -176,62 +176,6 @@ def _create_transformer(self, config: dict) -> WanVACETransformer3DModel:
"""Build VACE transformer directly from config dict."""
return create_vace_transformer_from_config(config)
- def diffuse(
- self,
- latents: torch.Tensor,
- timesteps: torch.Tensor,
- prompt_embeds: torch.Tensor,
- negative_prompt_embeds: torch.Tensor | None,
- guidance_scale: float,
- dtype: torch.dtype,
- attention_kwargs: dict[str, object],
- vace_context: torch.Tensor | None,
- vace_context_scale: float,
- ) -> torch.Tensor:
- with self.progress_bar(total=len(timesteps)) as pbar:
- for t in timesteps:
- self._current_timestep = t
- latent_model_input = latents.to(dtype)
- timestep = t.expand(latents.shape[0])
-
- do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None
-
- positive_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "vace_context": vace_context,
- "vace_context_scale": vace_context_scale,
- "return_dict": False,
- }
- negative_kwargs = (
- {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": negative_prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "vace_context": vace_context,
- "vace_context_scale": vace_context_scale,
- "return_dict": False,
- }
- if do_true_cfg
- else None
- )
-
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale=guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=False,
- )
-
- latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
- pbar.update()
-
- return latents
-
def check_inputs(
self,
prompt,
@@ -625,17 +569,48 @@ def forward(
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
- latents = self.diffuse(
- latents=latents,
- timesteps=timesteps,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- guidance_scale=guidance_scale,
- dtype=dtype,
- attention_kwargs=attention_kwargs,
- vace_context=vace_context,
- vace_context_scale=vace_context_scale,
- )
+ # Denoising loop
+ with self.progress_bar(total=len(timesteps)) as pbar:
+ for t in timesteps:
+ self._current_timestep = t
+ latent_model_input = latents.to(dtype)
+ timestep = t.expand(latents.shape[0])
+
+ do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None
+
+ positive_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "vace_context": vace_context,
+ "vace_context_scale": vace_context_scale,
+ "return_dict": False,
+ }
+ negative_kwargs = (
+ {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": negative_prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "vace_context": vace_context,
+ "vace_context_scale": vace_context_scale,
+ "return_dict": False,
+ }
+ if do_true_cfg
+ else None
+ )
+
+ noise_pred = self.predict_noise_maybe_with_cfg(
+ do_true_cfg=do_true_cfg,
+ true_cfg_scale=guidance_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ cfg_normalize=False,
+ )
+
+ latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
+ pbar.update()
self._current_timestep = None
diff --git a/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py b/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py
deleted file mode 100644
index 25444044c2d..00000000000
--- a/vllm_omni/diffusion/models/wan2_2/scheduling_wan_euler.py
+++ /dev/null
@@ -1,147 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-from __future__ import annotations
-
-from dataclasses import dataclass
-from types import SimpleNamespace
-
-import numpy as np
-import torch
-
-
-@dataclass
-class WanEulerSchedulerOutput:
- prev_sample: torch.FloatTensor
-
-
-def _unsqueeze_to_ndim(in_tensor: torch.Tensor, target_ndim: int) -> torch.Tensor:
- if in_tensor.ndim >= target_ndim:
- return in_tensor
- return in_tensor[(...,) + (None,) * (target_ndim - in_tensor.ndim)]
-
-
-def _get_timesteps(num_steps: int, max_steps: int = 1000) -> np.ndarray:
- # Keep num_steps + 1 points so Euler update can always access sigma_next.
- return np.linspace(max_steps, 0, num_steps + 1, dtype=np.float32)
-
-
-def _timestep_shift(timesteps: torch.Tensor, shift: float = 1.0) -> torch.Tensor:
- return shift * timesteps / (1 + (shift - 1) * timesteps)
-
-
-class WanEulerScheduler:
- order = 1
-
- def __init__(
- self,
- num_train_timesteps: int = 1000,
- shift: float = 1.0,
- device: torch.device | str = "cpu",
- ) -> None:
- self.num_train_timesteps = int(num_train_timesteps)
- self._shift = float(shift)
- self.device = device
- self.config = SimpleNamespace(num_train_timesteps=self.num_train_timesteps)
- self.init_noise_sigma = 1.0
-
- self._step_index: int | None = None
- self._begin_index: int | None = None
-
- self.timesteps = torch.empty(0, dtype=torch.float32)
- self.sigmas = torch.empty(0, dtype=torch.float32)
- self.timesteps_ori = torch.empty(0, dtype=torch.float32)
-
- self.set_timesteps(num_inference_steps=self.num_train_timesteps, device=self.device)
-
- @property
- def step_index(self) -> int | None:
- return self._step_index
-
- @property
- def begin_index(self) -> int | None:
- return self._begin_index
-
- def set_begin_index(self, begin_index: int = 0) -> None:
- self._begin_index = int(begin_index)
-
- def index_for_timestep(self, timestep: torch.Tensor) -> int:
- indices = (self.timesteps == timestep).nonzero()
- if len(indices) > 0:
- pos = 1 if len(indices) > 1 else 0
- return int(indices[pos].item())
- # Fallback for tiny float drift
- return int(torch.argmin(torch.abs(self.timesteps - timestep)).item())
-
- def _init_step_index(self, timestep: float | torch.Tensor) -> None:
- if self.begin_index is None:
- if isinstance(timestep, torch.Tensor):
- timestep_t = timestep.to(self.timesteps.device, dtype=self.timesteps.dtype)
- else:
- timestep_t = torch.tensor(timestep, device=self.timesteps.device, dtype=self.timesteps.dtype)
- self._step_index = self.index_for_timestep(timestep_t)
- else:
- self._step_index = self._begin_index
-
- def set_shift(self, shift: float = 1.0) -> None:
- # Compute shifted sigma schedule on [0, 1].
- sigmas_full = self.timesteps_ori / float(self.num_train_timesteps)
- sigmas_full = _timestep_shift(sigmas_full, shift=float(shift))
- self.sigmas = sigmas_full
- # Public timesteps are the first N points; next point is consumed as sigma_next.
- self.timesteps = self.sigmas[:-1] * self.num_train_timesteps
- self._shift = float(shift)
-
- def set_timesteps(
- self,
- num_inference_steps: int,
- device: torch.device | str | int | None = None,
- **kwargs, # noqa: ARG002 - kept for scheduler API compatibility
- ) -> None:
- timesteps = _get_timesteps(
- num_steps=int(num_inference_steps),
- max_steps=self.num_train_timesteps,
- )
- self.timesteps_ori = torch.from_numpy(timesteps).to(
- dtype=torch.float32,
- device=device or self.device,
- )
- self.set_shift(self._shift)
- self._step_index = None
- self._begin_index = None
-
- def scale_model_input(self, sample: torch.Tensor, timestep: int | None = None) -> torch.Tensor: # noqa: ARG002
- return sample
-
- def step(
- self,
- model_output: torch.FloatTensor,
- timestep: float | torch.FloatTensor,
- sample: torch.FloatTensor,
- return_dict: bool = True,
- **kwargs, # noqa: ARG002 - kept for scheduler API compatibility
- ) -> WanEulerSchedulerOutput | tuple[torch.FloatTensor]:
- if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
- raise ValueError(
- "Passing integer indices as timesteps is not supported. Use one value from scheduler.timesteps instead."
- )
-
- if self.step_index is None:
- self._init_step_index(timestep)
- assert self._step_index is not None
-
- sample_fp32 = sample.to(torch.float32)
- sigma = _unsqueeze_to_ndim(self.sigmas[self._step_index], sample_fp32.ndim).to(sample_fp32.device)
- sigma_next = _unsqueeze_to_ndim(self.sigmas[self._step_index + 1], sample_fp32.ndim).to(sample_fp32.device)
-
- prev_sample = sample_fp32 + (sigma_next - sigma) * model_output
- prev_sample = prev_sample.to(model_output.dtype)
-
- self._step_index += 1
-
- if not return_dict:
- return (prev_sample,)
- return WanEulerSchedulerOutput(prev_sample=prev_sample)
-
- def __len__(self) -> int:
- return self.num_train_timesteps
diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py
index 0edd2214282..30dd696f840 100644
--- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py
+++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py
@@ -11,6 +11,7 @@
from diffusers.models.attention import FeedForward
from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.normalization import FP32LayerNorm
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -28,14 +29,39 @@
SequenceParallelOutput,
)
from vllm_omni.diffusion.forward_context import get_forward_context
-from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNorm
-from vllm_omni.diffusion.layers.norm import LayerNorm, RMSNorm
-from vllm_omni.diffusion.layers.rope import RotaryEmbeddingWan
-from vllm_omni.platforms import current_omni_platform
logger = init_logger(__name__)
+def apply_rotary_emb_wan(
+ hidden_states: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+) -> torch.Tensor:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensors.
+
+ Args:
+ hidden_states: Input tensor of shape [B, S, H, D]
+ freqs_cos: Cosine frequencies
+ freqs_sin: Sine frequencies
+
+ Returns:
+ Tensor with rotary embeddings applied
+ """
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+ rotated = torch.stack(
+ (
+ x1 * cos - x2 * sin,
+ x1 * sin + x2 * cos,
+ ),
+ dim=-1,
+ )
+ return rotated.flatten(-2, -1).to(hidden_states.dtype)
+
+
class DistributedRMSNorm(nn.Module):
"""
RMSNorm that computes global RMS across tensor parallel ranks.
@@ -145,7 +171,7 @@ def __init__(
# Split dimensions for temporal, height, width
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
- freqs_dtype = torch.float64 if current_omni_platform.supports_float64() else torch.float32
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_cos = []
freqs_sin = []
@@ -208,9 +234,9 @@ class WanImageEmbedding(nn.Module):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len: int | None = None):
super().__init__()
- self.norm1 = LayerNorm(in_features)
+ self.norm1 = FP32LayerNorm(in_features)
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
- self.norm2 = LayerNorm(out_features)
+ self.norm2 = FP32LayerNorm(out_features)
if pos_embed_seq_len is not None:
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
else:
@@ -350,12 +376,8 @@ def __init__(
self.tp_inner_dim = self.num_heads * head_dim
# QK normalization using vLLM's RMSNorm
- if get_tensor_model_parallel_world_size() > 1:
- self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
- self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
- else:
- self.norm_q = RMSNorm(self.tp_inner_dim, eps=eps)
- self.norm_k = RMSNorm(self.tp_inner_dim, eps=eps)
+ self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
+ self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
self.to_out = RowParallelLinear(
self.inner_dim,
@@ -399,10 +421,9 @@ def forward(
# Apply rotary embeddings
if rotary_emb is not None:
- self.rotary_embedding = RotaryEmbeddingWan(is_neox_style=False, half_head_dim=True)
freqs_cos, freqs_sin = rotary_emb
- query = self.rotary_embedding(query, freqs_cos, freqs_sin)
- key = self.rotary_embedding(key, freqs_cos, freqs_sin)
+ query = apply_rotary_emb_wan(query, freqs_cos, freqs_sin)
+ key = apply_rotary_emb_wan(key, freqs_cos, freqs_sin)
# Create attention metadata if mask is provided
attn_metadata = None
@@ -475,12 +496,8 @@ def __init__(
self.tp_inner_dim = self.num_heads * head_dim
# QK normalization
- if get_tensor_model_parallel_world_size() > 1:
- self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
- self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
- else:
- self.norm_q = RMSNorm(self.tp_inner_dim, eps=eps)
- self.norm_k = RMSNorm(self.tp_inner_dim, eps=eps)
+ self.norm_q = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
+ self.norm_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
# Optional added KV projections for I2V (image embeddings)
self.added_kv_proj_dim = added_kv_proj_dim
@@ -499,10 +516,7 @@ def __init__(
gather_output=False,
return_bias=False,
)
- if get_tensor_model_parallel_world_size() > 1:
- self.norm_added_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
- else:
- self.norm_added_k = RMSNorm(self.tp_inner_dim, eps=eps)
+ self.norm_added_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps)
else:
self.add_k_proj = None
self.add_v_proj = None
@@ -605,7 +619,7 @@ def __init__(
head_dim = dim // num_heads
# 1. Self-attention
- self.norm1 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = WanSelfAttention(
dim=dim,
num_heads=num_heads,
@@ -621,11 +635,11 @@ def __init__(
eps=eps,
added_kv_proj_dim=added_kv_proj_dim,
)
- self.norm2 = LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
# 3. Feed-forward
self.ffn = WanFeedForward(dim=dim, inner_dim=ffn_dim, dim_out=dim)
- self.norm3 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
# Scale-shift table for modulation
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
@@ -641,7 +655,7 @@ def forward(
if temb.ndim == 4:
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
- self.scale_shift_table.unsqueeze(0) + temb
+ self.scale_shift_table.unsqueeze(0) + temb.float()
).chunk(6, dim=2)
shift_msa = shift_msa.squeeze(2)
scale_msa = scale_msa.squeeze(2)
@@ -652,23 +666,25 @@ def forward(
else:
# temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
- self.scale_shift_table + temb
+ self.scale_shift_table + temb.float()
).chunk(6, dim=1)
# 1. Self-attention
- norm_hidden_states = self.norm1(hidden_states, scale_msa, shift_msa).type_as(hidden_states)
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output = self.attn1(norm_hidden_states, rotary_emb, hidden_states_mask)
- hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states)
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
- norm_hidden_states = self.norm2(hidden_states).type_as(hidden_states)
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
- norm_hidden_states = self.norm3(hidden_states, c_scale_msa, c_shift_msa).type_as(hidden_states)
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
+ hidden_states
+ )
ff_output = self.ffn(norm_hidden_states)
- hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states)
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
return hidden_states
@@ -709,7 +725,7 @@ class WanTransformer3DModel(nn.Module):
"""
_repeated_blocks = ["WanTransformerBlock"]
- _layerwise_offload_blocks_attrs = ["blocks"]
+ _layerwise_offload_blocks_attr = "blocks"
packed_modules_mapping = {
"to_qkv": ["to_q", "to_k", "to_v"],
}
@@ -837,17 +853,13 @@ def __init__(
)
# 4. Output norm & projection
- self.norm_out = AdaLayerNorm(inner_dim, elementwise_affine=False, eps=eps)
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
# SP helper modules
self.timestep_proj_prepare = TimestepProjPrepare()
self.output_scale_shift_prepare = OutputScaleShiftPrepare(inner_dim)
- # ROPE helper
- self._cached_rope_emb = None
- self._cached_rope_resolution = None
-
@property
def dtype(self) -> torch.dtype:
"""Return the dtype of the model parameters."""
@@ -869,14 +881,7 @@ def forward(
post_patch_width = width // p_w
# Compute RoPE embeddings (sharded by _sp_plan via split_output=True)
- current_rope_resolution = (post_patch_num_frames, post_patch_height, post_patch_width)
- if self._cached_rope_resolution == current_rope_resolution and self._cached_rope_emb is not None:
- rotary_emb = self._cached_rope_emb
- else:
- freqs_cos, freqs_sin = self.rope(hidden_states)
- rotary_emb = (freqs_cos[..., 0::2].to(hidden_states.dtype), freqs_sin[..., 1::2].to(hidden_states.dtype))
- self._hidden_states_shape = hidden_states.shape
- self._cached_rope_emb = rotary_emb
+ rotary_emb = self.rope(hidden_states)
# Patch embedding and flatten to sequence
# (hidden_states is sharded at blocks.0 input by _sp_plan)
@@ -936,7 +941,7 @@ def forward(
shift = shift.unsqueeze(1)
scale = scale.unsqueeze(1)
- hidden_states = self.norm_out(hidden_states, scale, shift).type_as(hidden_states)
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
@@ -1014,14 +1019,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if ".to_out.0." in lookup_name:
lookup_name = lookup_name.replace(".to_out.0.", ".to_out.")
- # Compatibility: some Wan conversion pipelines still keep
- # block modulation keys as `blocks.N.modulation` instead of
- # `blocks.N.scale_shift_table`.
- if lookup_name.endswith(".modulation"):
- modulation_alias = lookup_name[: -len(".modulation")] + ".scale_shift_table"
- if modulation_alias in params_dict:
- lookup_name = modulation_alias
-
if lookup_name not in params_dict:
logger.warning(f"Skipping weight {original_name} -> {lookup_name}")
continue
diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py
index 5060f1904f2..4f4217dabfa 100644
--- a/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py
+++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py
@@ -123,10 +123,6 @@ def __init__(
]
)
- # ROPE helper
- self._cached_rope_emb = None
- self._cached_rope_resolution = None
-
def embed_vace_context(
self,
vace_context: torch.Tensor,
@@ -169,14 +165,7 @@ def forward(
post_patch_width = width // p_w
# Compute RoPE embeddings (sharded by _sp_plan via split_output=True)
- current_rope_resolution = (post_patch_num_frames, post_patch_height, post_patch_width)
- if self._cached_rope_resolution == current_rope_resolution and self._cached_rope_emb is not None:
- rotary_emb = self._cached_rope_emb
- else:
- freqs_cos, freqs_sin = self.rope(hidden_states)
- rotary_emb = (freqs_cos[..., 0::2].to(hidden_states.dtype), freqs_sin[..., 1::2].to(hidden_states.dtype))
- self._hidden_states_shape = hidden_states.shape
- self._cached_rope_emb = rotary_emb
+ rotary_emb = self.rope(hidden_states)
# Patch embedding and flatten to sequence
hidden_states = self.patch_embedding(hidden_states)
@@ -250,7 +239,7 @@ def forward(
shift = shift.unsqueeze(1)
scale = scale.unsqueeze(1)
- hidden_states = self.norm_out(hidden_states, scale, shift).type_as(hidden_states)
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py
index 5bea59a2098..b9aceed2e5c 100644
--- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py
+++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py
@@ -21,10 +21,9 @@
from collections.abc import Callable, Iterable
from typing import Any
-import PIL.Image
import torch
import torch.nn as nn
-from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
@@ -60,7 +59,7 @@ def get_post_process_func(
vae_config = json.load(f)
vae_scale_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) if "block_out_channels" in vae_config else 8
- image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2, do_convert_rgb=True)
+ image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2)
def post_process_func(
images: torch.Tensor,
@@ -84,20 +83,6 @@ def calculate_shift(
return mu
-# Copied from diffusers
-def retrieve_latents(
- encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
-):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
@@ -202,8 +187,6 @@ def __init__(
enable_diffusion_pipeline_profiler=self.od_config.enable_diffusion_pipeline_profiler
)
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_convert_rgb=True)
-
def encode_prompt(
self,
prompt: str | list[str],
@@ -299,45 +282,12 @@ def prepare_latents(
device,
generator,
latents=None,
- image=None,
- timestep=None,
):
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
- if image is not None:
- if latents is not None:
- return latents.to(device=device, dtype=dtype)
-
- image = image.to(device=device, dtype=dtype)
- if image.shape[1] != num_channels_latents:
- if isinstance(generator, list):
- image_latents = [
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
- for i in range(image.shape[0])
- ]
- image_latents = torch.cat(image_latents, dim=0)
- else:
- image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
-
- image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
- else:
- image_latents = image
-
- if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
- additional_image_per_prompt = batch_size // image_latents.shape[0]
- image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
- elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
- raise ValueError(
- f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
- )
-
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
- latents = self.scheduler.scale_noise(image_latents, timestep, noise)
- return latents
-
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
@@ -346,14 +296,6 @@ def prepare_latents(
latents = latents.to(device)
return latents
- def get_timesteps(self, num_inference_steps, strength, device):
- init_timestep = min(num_inference_steps * strength, num_inference_steps)
- t_start = int(max(num_inference_steps - init_timestep, 0))
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
- if hasattr(self.scheduler, "set_begin_index"):
- self.scheduler.set_begin_index(t_start * self.scheduler.order)
- return timesteps, num_inference_steps - t_start
-
@property
def guidance_scale(self):
return self._guidance_scale
@@ -378,8 +320,6 @@ def forward(
self,
req: OmniDiffusionRequest,
prompt: str | list[str] | None = None,
- image: PipelineImageInput = None,
- strength: float = 0.6,
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 50,
@@ -407,11 +347,6 @@ def forward(
prompt (`str` or `list[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
- image (`PipelineImageInput`, *optional*):
- The image to use for img2img generation. If provided, the pipeline
- will perform img2img instead of text-to-image.
- strength (`float`, *optional*, defaults to 0.6):
- Indicates extent to transform the reference `image`. Must be between 0 and 1.
height (`int`, *optional*, defaults to 1024):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 1024):
@@ -490,34 +425,6 @@ def forward(
elif req.prompts:
negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts]
- # Handle img2img: extract image from request
- if image is None and req.prompts:
- if len(req.prompts) > 1:
- logger.warning(
- "This model only supports a single prompt for img2img, not a batched request. "
- "Taking only the first image for now."
- )
- first_prompt = req.prompts[0]
- if not isinstance(first_prompt, str):
- raw_image = first_prompt.get("multi_modal_data", {}).get("image")
- if raw_image is not None:
- if isinstance(raw_image, list):
- image = [PIL.Image.open(im) if isinstance(im, str) else raw_image[0] for im in raw_image[:1]]
- else:
- image = PIL.Image.open(raw_image) if isinstance(raw_image, str) else raw_image
-
- # strength is currently only applicable for Z-Image I2I; other pipelines ignore this parameter
- strength = req.sampling_params.strength if req.sampling_params.strength is not None else strength
- if strength is not None and image is None:
- logger.warning(
- "strength parameter (%.2f) is only applicable for image-to-image (I2I) generation. "
- "It will be ignored for text-to-image (T2I) generation.",
- strength,
- )
- strength = None
- if image is not None and strength is not None and (strength < 0 or strength > 1):
- raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}")
-
height = req.sampling_params.height or height
width = req.sampling_params.width or width
num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps
@@ -584,71 +491,16 @@ def forward(
# 4. Prepare latent variables
num_channels_latents = self.transformer.in_channels
- # img2img mode: prepare latents from input image
- if image is not None:
- # Handle image list - take first image
- if isinstance(image, list):
- image = image[0]
-
- # Prepare image for VAE encoding using image_processor
- if not isinstance(image, torch.Tensor):
- init_image = self.image_processor.preprocess(image, height, width)
- image = init_image.to(dtype=torch.float32, device=device)
-
- # Initialize scheduler kwargs for img2img
- mu = calculate_shift(
- (height // self.vae_scale_factor // 2) * (width // self.vae_scale_factor // 2),
- self.scheduler.config.get("base_image_seq_len", 256),
- self.scheduler.config.get("max_image_seq_len", 4096),
- self.scheduler.config.get("base_shift", 0.5),
- self.scheduler.config.get("max_shift", 1.15),
- )
- self.scheduler.sigma_min = 0.0
- scheduler_kwargs = {"mu": mu}
-
- # First initialize timesteps in scheduler
- timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler,
- num_inference_steps,
- device,
- sigmas=sigmas,
- **scheduler_kwargs,
- )
-
- # Then adjust timesteps based on strength
- timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
-
- if num_inference_steps < 1:
- raise ValueError(
- f"After adjusting the num_inference_steps by strength parameter: "
- f"{strength}, the number of pipeline steps is {num_inference_steps} "
- f"which is < 1 and not appropriate for this pipeline."
- )
- latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
-
- latents = self.prepare_latents(
- batch_size * num_images_per_prompt,
- num_channels_latents,
- height,
- width,
- prompt_embeds[0].dtype,
- device,
- generator,
- latents,
- image,
- latent_timestep,
- )
- else:
- latents = self.prepare_latents(
- batch_size * num_images_per_prompt,
- num_channels_latents,
- height,
- width,
- torch.float32,
- device,
- generator,
- latents,
- )
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
# Repeat prompt_embeds for num_images_per_prompt
if num_images_per_prompt > 1:
@@ -657,28 +509,25 @@ def forward(
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
actual_batch_size = batch_size * num_images_per_prompt
+ image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
# 5. Prepare timesteps
- if image is None:
- image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
- mu = calculate_shift(
- image_seq_len,
- self.scheduler.config.get("base_image_seq_len", 256),
- self.scheduler.config.get("max_image_seq_len", 4096),
- self.scheduler.config.get("base_shift", 0.5),
- self.scheduler.config.get("max_shift", 1.15),
- )
- self.scheduler.sigma_min = 0.0
- scheduler_kwargs = {"mu": mu}
-
- timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler,
- num_inference_steps,
- device,
- sigmas=sigmas,
- **scheduler_kwargs,
- )
-
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ self.scheduler.sigma_min = 0.0
+ scheduler_kwargs = {"mu": mu}
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ **scheduler_kwargs,
+ )
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py
index c36ea746654..fd8b0e490f2 100644
--- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py
+++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py
@@ -214,14 +214,12 @@ def __init__(
super().__init__()
if mid_size is None:
mid_size = out_size
- # Time embedding MLP is kept full precision (quant_config=None) —
- # small layers that feed adaLN; precision-sensitive (see #2728).
self.mlp = nn.Sequential(
ReplicatedLinear(
frequency_embedding_size,
mid_size,
bias=True,
- quant_config=None,
+ quant_config=quant_config,
return_bias=False,
),
nn.SiLU(),
@@ -229,7 +227,7 @@ def __init__(
mid_size,
out_size,
bias=True,
- quant_config=None,
+ quant_config=quant_config,
return_bias=False,
),
)
@@ -428,16 +426,9 @@ def __init__(
self.modulation = modulation
if modulation:
- # Modulation linear is kept at full precision (quant_config=None)
- # — it produces scale/gate values that are precision-sensitive
- # (see #2728, mirrors OmniGen2 fix).
self.adaLN_modulation = nn.Sequential(
ReplicatedLinear(
- min(dim, ADALN_EMBED_DIM),
- 4 * dim,
- bias=True,
- quant_config=None,
- return_bias=False,
+ min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True, return_bias=False, quant_config=quant_config
),
)
@@ -494,24 +485,14 @@ class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, quant_config: "QuantizationConfig | None" = None):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- # Final output projection and its modulation are precision-sensitive
- # (produce the output latent); keep at full precision (see #2728).
self.linear = ReplicatedLinear(
- hidden_size,
- out_channels,
- bias=True,
- quant_config=None,
- return_bias=False,
+ hidden_size, out_channels, bias=True, quant_config=quant_config, return_bias=False
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
ReplicatedLinear(
- min(hidden_size, ADALN_EMBED_DIM),
- hidden_size,
- bias=True,
- quant_config=None,
- return_bias=False,
+ min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True, quant_config=quant_config, return_bias=False
),
)
@@ -598,7 +579,6 @@ class ZImageTransformer2DModel(CachedTransformer):
"""
_repeated_blocks = ["ZImageTransformerBlock"]
- _layerwise_offload_blocks_attrs = ["layers"]
@staticmethod
def _is_transformer_block(name: str, module) -> bool:
@@ -692,13 +672,11 @@ def __init__(
all_x_embedder = {}
all_final_layer = {}
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
- # x_embedder (patch embed) is a small precision-sensitive entry
- # layer; keep full precision (see #2728).
x_embedder = ReplicatedLinear(
f_patch_size * patch_size * patch_size * in_channels,
dim,
bias=True,
- quant_config=None,
+ quant_config=quant_config,
return_bias=False,
)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
@@ -741,17 +719,9 @@ def __init__(
]
)
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024, quant_config=quant_config)
- # Caption embedder maps text features -> hidden; keep full precision
- # (see #2728).
self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps),
- ReplicatedLinear(
- cap_feat_dim,
- dim,
- bias=True,
- quant_config=None,
- return_bias=False,
- ),
+ ReplicatedLinear(cap_feat_dim, dim, bias=True, return_bias=False, quant_config=quant_config),
)
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
diff --git a/vllm_omni/diffusion/offloader/layerwise_backend.py b/vllm_omni/diffusion/offloader/layerwise_backend.py
index 9979d01b103..20af5b5d828 100644
--- a/vllm_omni/diffusion/offloader/layerwise_backend.py
+++ b/vllm_omni/diffusion/offloader/layerwise_backend.py
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from __future__ import annotations
from itertools import chain
from typing import Any
@@ -298,20 +297,13 @@ def enable(self, pipeline: nn.Module) -> None:
for enc in modules.encoders:
enc.to(self.device)
- # Move VAE(s) to GPU if available
- for vae in modules.vaes:
+ # Move VAE to GPU if available
+ if modules.vae is not None:
try:
- vae.to(self.device, non_blocking=True)
+ modules.vae.to(self.device, non_blocking=True)
except Exception as exc:
logger.debug("Failed to move VAE to GPU: %s", exc)
- # Move resident modules to GPU (small modules needed every forward)
- for name, module in zip(modules.resident_names, modules.resident_modules):
- try:
- module.to(self.device)
- except Exception as exc:
- logger.debug("Failed to move resident module %s to GPU: %s", name, exc)
-
logger.info("Applying layer-wise offloading on %s", modules.dit_names)
# Apply block-wise offloading hook for each of the blocks in DiT model(s)
@@ -320,9 +312,10 @@ def enable(self, pipeline: nn.Module) -> None:
dit_name = modules.dit_names[i]
logger.info(f"Applying hooks on {dit_name} ({dit_module.__class__.__name__})")
- blocks_attr_names, blocks = LayerWiseOffloadBackend.get_blocks_from_dit(dit_module)
+ blocks_attr_name = LayerWiseOffloadBackend.get_blocks_attr_name(dit_module)
+ blocks = LayerWiseOffloadBackend.get_blocks_from_dit(dit_module)
- if not blocks:
+ if not blocks_attr_name or not blocks:
logger.warning(
"Target layers (blocks) not found. Skipping offloading on %s (%s)",
dit_name,
@@ -343,20 +336,11 @@ def enable(self, pipeline: nn.Module) -> None:
# Move non-block modules to GPU (they stay resident)
for name, m in dit_module.named_children():
- if name not in blocks_attr_names:
- m.to(self.device)
- logger.debug(f"Moved {name} to device {self.device}")
- else:
+ if name == blocks_attr_name:
logger.debug(f"Skipped blocks module {name}")
-
- # Move top-level params/buffers to GPU (dit_module's own, not sub-modules)
- for param in dit_module._parameters.values():
- if param is not None:
- param.data = param.data.to(self.device, non_blocking=True)
-
- for buffer in dit_module._buffers.values():
- if buffer is not None:
- buffer.data = buffer.data.to(self.device, non_blocking=True)
+ continue
+ m.to(self.device)
+ logger.debug(f"Moved {name} to device {self.device}")
# Pre-fetch the first layer by manually calling the hook function on the last layer;
# For subsequent requests, the first layer/block will be pre-fetched
@@ -411,84 +395,40 @@ def disable(self) -> None:
logger.info("Layer-wise offloading disabled")
@staticmethod
- def get_blocks_attr_names(model: nn.Module) -> list[str]:
- """Get block attribute names from model class."""
- attrs: list[str] = getattr(model.__class__, "_layerwise_offload_blocks_attrs", [])
-
- if not attrs:
- old_attr = getattr(model.__class__, "_layerwise_offload_blocks_attr", None)
- if old_attr is not None:
- logger.warning(
- "'_layerwise_offload_blocks_attr' is deprecated, "
- "please use '_layerwise_offload_blocks_attrs' instead. "
- "Example: _layerwise_offload_blocks_attrs = ['blocks']"
- )
- attrs = [old_attr] if isinstance(old_attr, str) else list(old_attr)
-
- return attrs
+ def get_blocks_attr_name(model: nn.Module) -> str | None:
+ """Retrieve blocks attribute name from provided DiT model"""
+ return getattr(model.__class__, "_layerwise_offload_blocks_attr", None)
@staticmethod
- def set_blocks_attr_names(model: nn.Module, names: list[str]) -> None:
- if not hasattr(model.__class__, "_layerwise_offload_blocks_attrs"):
- setattr(model.__class__, "_layerwise_offload_blocks_attrs", names)
+ def set_blocks_attr_name(model: nn.Module, name: str) -> None:
+ if not hasattr(model.__class__, "_layerwise_offload_blocks_attr"):
+ setattr(model.__class__, "_layerwise_offload_blocks_attr", name)
@staticmethod
- def get_blocks_from_dit(model: nn.Module) -> tuple[list[str], list[nn.Module]]:
+ def get_blocks_from_dit(model: nn.Module) -> list[nn.Module]:
"""
- Retrieve blocks and attribute names from provided DiT model. Blocks attribute names
- are found by `_layerwise_offload_blocks_attrs` set to DiT models. For example,
+ Retrieve a list of blocks from provided DiT model. Blocks attribute name
+ are found by `_layerwise_offload_blocks_attr` set to DiT models. For example,
```
class WanTransformer3DModel(nn.Module):
- _layerwise_offload_blocks_attrs = ["blocks"]
+ _layerwise_offload_blocks_attr = "blocks"
```
-
- Returns:
- Tuple of (blocks_attr_names, blocks)
"""
- blocks_attr_names = LayerWiseOffloadBackend.get_blocks_attr_names(model)
- if not blocks_attr_names:
+ blocks_attr_name = LayerWiseOffloadBackend.get_blocks_attr_name(model)
+ if blocks_attr_name is None:
logger.warning(
- f"No _layerwise_offload_blocks_attrs defined for {model.__class__.__name__}, "
+ f"No _layerwise_offload_blocks_attr defined for {model.__class__.__name__}, "
"skipping layerwise offloading"
)
- return [], []
-
- blocks = []
- for name in blocks_attr_names:
- attr = getattr(model, name, None)
- if attr is None:
- raise AttributeError(
- f"Attribute '{name}' declared in _layerwise_offload_blocks_attrs "
- f"does not exist on model {model.__class__.__name__}"
- )
- try:
- attr_iter = iter(attr)
- except TypeError:
- if isinstance(attr, nn.Module):
- logger.warning(
- "Attribute '%s' on %s is not iterable; treating it as one block.",
- name,
- model.__class__.__name__,
- )
- blocks.append(attr)
- continue
+ return []
- logger.warning(
- "Attribute '%s' on %s is not iterable (got %s); skipping it.",
- name,
- model.__class__.__name__,
- type(attr).__name__,
- )
- else:
- blocks.extend(attr_iter)
-
- if not blocks:
+ _blocks = getattr(model, blocks_attr_name, None)
+ if _blocks is None:
logger.warning(
- "No blocks found in %s for %s, skipping layerwise offloading",
- blocks_attr_names,
- model.__class__.__name__,
+ f"Blocks (layers) '{blocks_attr_name}' not found on {model.__class__.__name__}, "
+ "skipping layerwise offloading"
)
- return [], []
+ return []
- return blocks_attr_names, blocks
+ return list(_blocks)
diff --git a/vllm_omni/diffusion/offloader/module_collector.py b/vllm_omni/diffusion/offloader/module_collector.py
index dfd81e98b89..307ca53a880 100644
--- a/vllm_omni/diffusion/offloader/module_collector.py
+++ b/vllm_omni/diffusion/offloader/module_collector.py
@@ -1,14 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from dataclasses import dataclass, field
-from operator import attrgetter
+from dataclasses import dataclass
from torch import nn
from vllm.logger import init_logger
-from vllm_omni.diffusion.models.interface import SupportsModuleOffload
-
logger = init_logger(__name__)
@@ -18,81 +15,15 @@ class PipelineModules:
dit_names: list[str]
encoders: list[nn.Module]
encoder_names: list[str]
- vaes: list[nn.Module]
- resident_modules: list[nn.Module] = field(default_factory=list)
- resident_names: list[str] = field(default_factory=list)
+ vae: nn.Module | None = None
class ModuleDiscovery:
- """Discovers pipeline components for offloading.
-
- If the pipeline implements :class:`SupportsModuleOffload`,
- its ``_dit_modules``, ``_encoder_modules``, and ``_vae_modules``
- class variables are used directly. Otherwise, falls back to
- scanning well-known attribute names.
- """
-
- # Fallback attribute names for pipelines that do not implement
- # SupportsModuleOffload.
- _FALLBACK_DIT_ATTRS = [
- "transformer",
- "transformer_2",
- "dit",
- "sr_dit",
- "language_model",
- "transformer_blocks",
- "model",
- ]
- _FALLBACK_ENCODER_ATTRS = [
- "text_encoder",
- "text_encoder_2",
- "text_encoder_3",
- "image_encoder",
- ]
- _FALLBACK_VAE_ATTRS = [
- "vae",
- "audio_vae",
- ]
+ """Discovers pipeline components for offloading"""
- @staticmethod
- def _collect_modules(
- pipeline: nn.Module,
- attr_names: list[str],
- *,
- warn_missing: bool = False,
- ) -> tuple[list[nn.Module], list[str]]:
- """Resolve attribute names to (module, name) pairs, skipping missing.
-
- Supports dotted paths via :func:`operator.attrgetter`.
- Warns on missing attributes when *warn_missing* is True.
- """
- modules: list[nn.Module] = []
- names: list[str] = []
- seen: set[int] = set()
- for attr in attr_names:
- try:
- module = attrgetter(attr)(pipeline)
- except AttributeError:
- module = None
- if module is None:
- if warn_missing:
- logger.warning(
- "Pipeline declares '%s' as offloadable but the attribute does not exist or is None",
- attr,
- )
- continue
- if not isinstance(module, nn.Module):
- logger.warning(
- "Expected '%s' to be nn.Module, got %r",
- attr,
- type(module),
- )
- continue
- if id(module) not in seen:
- seen.add(id(module))
- modules.append(module)
- names.append(attr)
- return modules, names
+ DIT_ATTRS = ["transformer", "transformer_2", "dit", "language_model", "transformer_blocks"]
+ ENCODER_ATTRS = ["text_encoder", "text_encoder_2", "text_encoder_3", "image_encoder"]
+ VAE_ATTRS = ["vae"]
@staticmethod
def discover(pipeline: nn.Module) -> PipelineModules:
@@ -104,29 +35,46 @@ def discover(pipeline: nn.Module) -> PipelineModules:
Returns:
PipelineModules with lists of discovered modules and names
"""
- declared = isinstance(pipeline, SupportsModuleOffload)
- if declared:
- dit_attrs = pipeline._dit_modules
- enc_attrs = pipeline._encoder_modules
- vae_attrs = pipeline._vae_modules
- res_attrs = pipeline._resident_modules
- else:
- dit_attrs = ModuleDiscovery._FALLBACK_DIT_ATTRS
- enc_attrs = ModuleDiscovery._FALLBACK_ENCODER_ATTRS
- vae_attrs = ModuleDiscovery._FALLBACK_VAE_ATTRS
- res_attrs = []
-
- dit_modules, dit_names = ModuleDiscovery._collect_modules(pipeline, dit_attrs, warn_missing=declared)
- encoders, encoder_names = ModuleDiscovery._collect_modules(pipeline, enc_attrs, warn_missing=declared)
- vaes, _ = ModuleDiscovery._collect_modules(pipeline, vae_attrs, warn_missing=declared)
- residents, resident_names = ModuleDiscovery._collect_modules(pipeline, res_attrs, warn_missing=declared)
+ # Collect DiT/transformer modules
+ dit_modules: list[nn.Module] = []
+ dit_names: list[str] = []
+ for attr in ModuleDiscovery.DIT_ATTRS:
+ if not hasattr(pipeline, attr):
+ continue
+ module_obj = getattr(pipeline, attr)
+ if module_obj is None:
+ continue
+
+ if not isinstance(module_obj, nn.Module):
+ logger.warning(f"Expected {attr} to be nn.Module, got {type(module_obj)!r}")
+ continue
+
+ if module_obj in dit_modules:
+ continue
+
+ dit_modules.append(module_obj)
+ dit_names.append(attr)
+
+ # Collect all encoders
+ encoders: list[nn.Module] = []
+ encoder_names: list[str] = []
+ for attr in ModuleDiscovery.ENCODER_ATTRS:
+ if hasattr(pipeline, attr) and getattr(pipeline, attr) is not None:
+ encoders.append(getattr(pipeline, attr))
+ encoder_names.append(attr)
+
+ # Collect VAE
+ vae = None
+ for attr in ModuleDiscovery.VAE_ATTRS:
+ module = getattr(pipeline, attr, None)
+ if module is not None:
+ vae = module
+ break
return PipelineModules(
dits=dit_modules,
dit_names=dit_names,
encoders=encoders,
encoder_names=encoder_names,
- vaes=vaes,
- resident_modules=residents,
- resident_names=resident_names,
+ vae=vae,
)
diff --git a/vllm_omni/diffusion/offloader/sequential_backend.py b/vllm_omni/diffusion/offloader/sequential_backend.py
index 06454ad5c63..46f48e99c5d 100644
--- a/vllm_omni/diffusion/offloader/sequential_backend.py
+++ b/vllm_omni/diffusion/offloader/sequential_backend.py
@@ -210,10 +210,10 @@ def enable(self, pipeline: nn.Module) -> None:
for enc in modules.encoders:
enc.to(self.device)
- # Move VAE(s) to GPU if available
- for vae in modules.vaes:
+ # Move VAE to GPU if available
+ if modules.vae is not None:
try:
- vae.to(self.device, non_blocking=True)
+ modules.vae.to(self.device, non_blocking=True)
except Exception as exc:
logger.debug("Failed to move VAE to GPU: %s", exc)
diff --git a/vllm_omni/diffusion/postprocess/__init__.py b/vllm_omni/diffusion/postprocess/__init__.py
deleted file mode 100644
index e6fe5b2d220..00000000000
--- a/vllm_omni/diffusion/postprocess/__init__.py
+++ /dev/null
@@ -1,10 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Diffusion post-processing helpers."""
-
-from vllm_omni.diffusion.postprocess.rife_interpolator import (
- FrameInterpolator,
- interpolate_video_tensor,
-)
-
-__all__ = ["FrameInterpolator", "interpolate_video_tensor"]
diff --git a/vllm_omni/diffusion/postprocess/rife_interpolator.py b/vllm_omni/diffusion/postprocess/rife_interpolator.py
deleted file mode 100644
index 89297d0a446..00000000000
--- a/vllm_omni/diffusion/postprocess/rife_interpolator.py
+++ /dev/null
@@ -1,443 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""
-RIFE 4.22.lite frame interpolation for vLLM-Omni video generation.
-
-RIFE model code is vendored and adapted from:
- - https://github.com/hzwer/ECCV2022-RIFE (MIT License)
- - https://github.com/hzwer/Practical-RIFE (MIT License)
- Copyright (c) 2021 Zhewei Huang
-
-The FrameInterpolator wrapper and vLLM-Omni integration code are original work.
-"""
-
-from __future__ import annotations
-
-import os
-import threading
-from typing import Any
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from vllm.logger import init_logger
-
-logger = init_logger(__name__)
-
-_DEFAULT_RIFE_HF_REPO = "elfgum/RIFE-4.22.lite"
-_MODEL_CACHE: dict[tuple[str, str], Model] = {}
-_MODEL_CACHE_LOCK = threading.Lock()
-
-
-def warp(ten_input: torch.Tensor, ten_flow: torch.Tensor) -> torch.Tensor:
- """Warp input tensor by optical flow using grid_sample."""
- ten_horizontal = (
- torch.linspace(-1.0, 1.0, ten_flow.shape[3], device=ten_flow.device)
- .view(1, 1, 1, ten_flow.shape[3])
- .expand(ten_flow.shape[0], -1, ten_flow.shape[2], -1)
- )
- ten_vertical = (
- torch.linspace(-1.0, 1.0, ten_flow.shape[2], device=ten_flow.device)
- .view(1, 1, ten_flow.shape[2], 1)
- .expand(ten_flow.shape[0], -1, -1, ten_flow.shape[3])
- )
- ten_grid = torch.cat([ten_horizontal, ten_vertical], dim=1)
-
- ten_flow = torch.cat(
- [
- ten_flow[:, 0:1, :, :] / ((ten_input.shape[3] - 1.0) / 2.0),
- ten_flow[:, 1:2, :, :] / ((ten_input.shape[2] - 1.0) / 2.0),
- ],
- dim=1,
- )
- grid = (ten_grid + ten_flow).permute(0, 2, 3, 1)
- return F.grid_sample(
- input=ten_input,
- grid=grid,
- mode="bilinear",
- padding_mode="border",
- align_corners=True,
- )
-
-
-def _conv(
- in_planes: int,
- out_planes: int,
- kernel_size: int = 3,
- stride: int = 1,
- padding: int = 1,
- dilation: int = 1,
-) -> nn.Sequential:
- return nn.Sequential(
- nn.Conv2d(
- in_planes,
- out_planes,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=True,
- ),
- nn.LeakyReLU(0.2, True),
- )
-
-
-class ResConv(nn.Module):
- """Residual convolution block with learnable beta scaling."""
-
- def __init__(self, c: int, dilation: int = 1):
- super().__init__()
- self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1)
- self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True)
- self.relu = nn.LeakyReLU(0.2, True)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.relu(self.conv(x) * self.beta + x)
-
-
-class IFBlock(nn.Module):
- """Single-scale optical flow, mask, and feature block."""
-
- def __init__(self, in_planes: int, c: int = 64):
- super().__init__()
- self.conv0 = nn.Sequential(
- _conv(in_planes, c // 2, 3, 2, 1),
- _conv(c // 2, c, 3, 2, 1),
- )
- self.convblock = nn.Sequential(
- ResConv(c),
- ResConv(c),
- ResConv(c),
- ResConv(c),
- ResConv(c),
- ResConv(c),
- ResConv(c),
- ResConv(c),
- )
- self.lastconv = nn.Sequential(
- nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1),
- nn.PixelShuffle(2),
- )
-
- def forward(
- self,
- x: torch.Tensor,
- flow: torch.Tensor | None = None,
- scale: float = 1.0,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False)
- if flow is not None:
- flow = (
- F.interpolate(
- flow,
- scale_factor=1.0 / scale,
- mode="bilinear",
- align_corners=False,
- )
- * 1.0
- / scale
- )
- x = torch.cat((x, flow), 1)
- feat = self.conv0(x)
- feat = self.convblock(feat)
- tmp = self.lastconv(feat)
- tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False)
- flow = tmp[:, :4] * scale
- mask = tmp[:, 4:5]
- feat = tmp[:, 5:]
- return flow, mask, feat
-
-
-class Head(nn.Module):
- """Feature encoder producing four-channel features at full resolution."""
-
- def __init__(self):
- super().__init__()
- self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1)
- self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1)
- self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1)
- self.cnn3 = nn.ConvTranspose2d(16, 4, 4, 2, 1)
- self.relu = nn.LeakyReLU(0.2, True)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x0 = self.cnn0(x)
- x = self.relu(x0)
- x1 = self.cnn1(x)
- x = self.relu(x1)
- x2 = self.cnn2(x)
- x = self.relu(x2)
- x3 = self.cnn3(x)
- return x3
-
-
-class IFNet(nn.Module):
- """Four-scale IFNet optical flow network."""
-
- def __init__(self):
- super().__init__()
- self.block0 = IFBlock(7 + 8, c=192)
- self.block1 = IFBlock(8 + 4 + 8 + 8, c=128)
- self.block2 = IFBlock(8 + 4 + 8 + 8, c=64)
- self.block3 = IFBlock(8 + 4 + 8 + 8, c=32)
- self.encode = Head()
-
- def forward(
- self,
- x: torch.Tensor,
- timestep: float = 0.5,
- scale_list: list[float] | None = None,
- ) -> tuple[list[torch.Tensor], torch.Tensor, list[tuple[torch.Tensor, torch.Tensor] | torch.Tensor]]:
- if scale_list is None:
- scale_list = [8, 4, 2, 1]
-
- channel = x.shape[1] // 2
- img0 = x[:, :channel]
- img1 = x[:, channel:]
-
- if not torch.is_tensor(timestep):
- timestep = (x[:, :1].clone() * 0 + 1) * timestep
- else:
- timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3])
-
- f0 = self.encode(img0[:, :3])
- f1 = self.encode(img1[:, :3])
-
- flow_list: list[torch.Tensor] = []
- merged: list[tuple[torch.Tensor, torch.Tensor] | torch.Tensor] = []
- mask_list: list[torch.Tensor] = []
- warped_img0 = img0
- warped_img1 = img1
- flow = None
- mask = None
-
- for i, block in enumerate([self.block0, self.block1, self.block2, self.block3]):
- if flow is None:
- flow, mask, feat = block(
- torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1),
- None,
- scale=scale_list[i],
- )
- else:
- wf0 = warp(f0, flow[:, :2])
- wf1 = warp(f1, flow[:, 2:4])
- fd, m0, feat = block(
- torch.cat(
- (
- warped_img0[:, :3],
- warped_img1[:, :3],
- wf0,
- wf1,
- timestep,
- mask,
- feat,
- ),
- 1,
- ),
- flow,
- scale=scale_list[i],
- )
- mask = m0
- flow = flow + fd
-
- mask_list.append(mask)
- flow_list.append(flow)
- warped_img0 = warp(img0, flow[:, :2])
- warped_img1 = warp(img1, flow[:, 2:4])
- merged.append((warped_img0, warped_img1))
-
- mask = torch.sigmoid(mask)
- merged[3] = warped_img0 * mask + warped_img1 * (1 - mask)
- return flow_list, mask_list[3], merged
-
-
-class Model:
- """Wraps IFNet and exposes RIFE-compatible load/inference helpers."""
-
- def __init__(self):
- self.flownet = IFNet()
-
- def eval(self) -> Model:
- self.flownet.eval()
- return self
-
- def device(self) -> torch.device:
- return next(self.flownet.parameters()).device
-
- def load_model(self, path: str) -> None:
- flownet_path = os.path.join(path, "flownet.pkl")
- if not os.path.isfile(flownet_path):
- raise FileNotFoundError(
- f"RIFE weight file not found: {flownet_path}. Expected layout: /flownet.pkl"
- )
-
- state = torch.load(flownet_path, map_location="cpu", weights_only=False)
- state = {k.removeprefix("module."): v for k, v in state.items()}
- self.flownet.load_state_dict(state, strict=False)
- logger.info("Loaded RIFE weights from %s", flownet_path)
-
- def inference(
- self,
- img0: torch.Tensor,
- img1: torch.Tensor,
- scale: float = 1.0,
- timestep: float = 0.5,
- ) -> torch.Tensor:
- _n, _c, h, w = img0.shape
- ph = ((h - 1) // 32 + 1) * 32
- pw = ((w - 1) // 32 + 1) * 32
- pad = (0, pw - w, 0, ph - h)
- img0 = F.pad(img0, pad)
- img1 = F.pad(img1, pad)
-
- imgs = torch.cat((img0, img1), 1)
- scale_list = [8 / scale, 4 / scale, 2 / scale, 1 / scale]
- with torch.no_grad():
- _flow_list, _mask, merged = self.flownet(
- imgs,
- timestep=timestep,
- scale_list=scale_list,
- )
- return merged[3][:, :, :h, :w]
-
-
-def _resolve_rife_model_path(model_path: str | None) -> str:
- model_path = model_path or _DEFAULT_RIFE_HF_REPO
- if os.path.isdir(model_path):
- return model_path
- from vllm_omni.model_executor.model_loader.weight_utils import (
- download_weights_from_hf_specific,
- )
-
- return download_weights_from_hf_specific(
- model_path,
- cache_dir=None,
- allow_patterns=["flownet.pkl"],
- require_all=True,
- )
-
-
-def _select_torch_device() -> torch.device:
- try:
- from vllm_omni.platforms import current_omni_platform
-
- return current_omni_platform.get_torch_device()
- except Exception as exc:
- logger.warning("Failed to resolve current vLLM-Omni torch device: %s", exc)
-
- if torch.cuda.is_available():
- return torch.device("cuda")
- return torch.device("cpu")
-
-
-def _normalize_video_tensor_layout(video: torch.Tensor) -> tuple[torch.Tensor, Any]:
- if video.ndim == 5:
- if video.shape[1] in (3, 4):
- return video, lambda out: out
- if video.shape[2] in (3, 4):
- return video.permute(0, 2, 1, 3, 4), lambda out: out.permute(0, 2, 1, 3, 4)
- elif video.ndim == 4:
- if video.shape[0] in (3, 4):
- return video.unsqueeze(0), lambda out: out.squeeze(0)
- if video.shape[1] in (3, 4):
- return video.permute(1, 0, 2, 3).unsqueeze(0), lambda out: out.squeeze(0).permute(1, 0, 2, 3)
- raise ValueError(f"Unsupported video tensor shape for interpolation: {tuple(video.shape)}")
-
-
-def _normalize_video_tensor_range(video: torch.Tensor) -> tuple[torch.Tensor, Any]:
- original_dtype = video.dtype
- video = video.detach()
- if video.is_floating_point():
- video = video.to(torch.float32)
- if torch.amin(video) < 0.0 or torch.amax(video) > 1.0:
- return video.clamp(-1.0, 1.0) * 0.5 + 0.5, lambda out: (out * 2.0 - 1.0).to(original_dtype)
- return video.clamp(0.0, 1.0), lambda out: out.to(original_dtype)
- return video.to(torch.float32) / 255.0, lambda out: (out * 255.0).round().clamp(0, 255).to(original_dtype)
-
-
-class FrameInterpolator:
- """Lazy-loaded RIFE 4.22.lite frame interpolator."""
-
- def __init__(self, model_path: str | None = None):
- self._model_path = model_path
- self._resolved_path: str | None = None
-
- def _ensure_model_loaded(self, preferred_device: torch.device | None = None) -> Model:
- resolved_path = _resolve_rife_model_path(self._model_path)
- self._resolved_path = resolved_path
- device = preferred_device or _select_torch_device()
- cache_key = (resolved_path, str(device))
-
- with _MODEL_CACHE_LOCK:
- if cache_key in _MODEL_CACHE:
- return _MODEL_CACHE[cache_key]
-
- model = Model()
- model.load_model(resolved_path)
- model.eval()
- model.flownet = model.flownet.to(device)
- _MODEL_CACHE[cache_key] = model
- logger.info("RIFE model loaded on device: %s", device)
- return model
-
- def _make_inference(
- self,
- model: Model,
- img0: torch.Tensor,
- img1: torch.Tensor,
- n: int,
- scale: float,
- ) -> list[torch.Tensor]:
- if n == 1:
- return [model.inference(img0, img1, scale=scale)]
- mid = model.inference(img0, img1, scale=scale)
- return (
- self._make_inference(model, img0, mid, n // 2, scale)
- + [mid]
- + self._make_inference(model, mid, img1, n // 2, scale)
- )
-
- def interpolate_tensor(
- self,
- video: torch.Tensor,
- exp: int = 1,
- scale: float = 1.0,
- ) -> tuple[torch.Tensor, int]:
- if exp < 1:
- raise ValueError(f"frame interpolation exp must be >= 1, got {exp}")
- if scale <= 0:
- raise ValueError(f"frame interpolation scale must be > 0, got {scale}")
-
- video, restore_layout = _normalize_video_tensor_layout(video)
- if video.shape[2] < 2:
- return restore_layout(video), 1
-
- video, restore_range = _normalize_video_tensor_range(video)
- # A CPU tensor may be transport/offload state rather than an execution
- # choice, so only trust it when it is already on an accelerator.
- preferred_device = video.device
- if preferred_device.type == "cpu":
- preferred_device = _select_torch_device()
- model = self._ensure_model_loaded(preferred_device=preferred_device)
- video = video.to(model.device())
- intermediates_per_pair = 2**exp // 2
-
- result_frames: list[torch.Tensor] = []
- for idx in range(video.shape[2] - 1):
- img0 = video[:, :, idx, :, :]
- img1 = video[:, :, idx + 1, :, :]
- result_frames.append(img0)
- result_frames.extend(self._make_inference(model, img0, img1, intermediates_per_pair, scale))
- result_frames.append(video[:, :, -1, :, :])
- result = torch.stack(result_frames, dim=2)
- return restore_layout(restore_range(result)), 2**exp
-
-
-def interpolate_video_tensor(
- video: torch.Tensor,
- exp: int = 1,
- scale: float = 1.0,
- model_path: str | None = None,
-) -> tuple[torch.Tensor, int]:
- """Interpolate a video tensor and return the FPS multiplier."""
- interpolator = FrameInterpolator(model_path=model_path)
- return interpolator.interpolate_tensor(video, exp=exp, scale=scale)
diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py
index ebcd263c143..c1f48137e19 100644
--- a/vllm_omni/diffusion/registry.py
+++ b/vllm_omni/diffusion/registry.py
@@ -5,7 +5,6 @@
import torch.nn as nn
from vllm.logger import init_logger
-from vllm.model_executor.model_loader.utils import configure_quant_config
from vllm.model_executor.models.registry import _LazyRegisteredModel, _ModelRegistry
from vllm_omni.diffusion.data import OmniDiffusionConfig
@@ -14,7 +13,6 @@
from vllm_omni.diffusion.forward_context import get_forward_context
from vllm_omni.diffusion.hooks.sequence_parallel import apply_sequence_parallel
from vllm_omni.diffusion.utils.tf_utils import find_module_with_attr
-from vllm_omni.platforms import current_omni_platform
logger = init_logger(__name__)
@@ -85,26 +83,6 @@
"pipeline_ltx2_image2video",
"LTX2ImageToVideoTwoStagesPipeline",
),
- "LTX2T2VDMD2Pipeline": (
- "ltx2",
- "pipeline_ltx2",
- "LTX2T2VDMD2Pipeline",
- ),
- "LTX2I2VDMD2Pipeline": (
- "ltx2",
- "pipeline_ltx2_image2video",
- "LTX2I2VDMD2Pipeline",
- ),
- "LTX23Pipeline": (
- "ltx2",
- "pipeline_ltx2_3",
- "LTX23Pipeline",
- ),
- "LTX23ImageToVideoPipeline": (
- "ltx2",
- "pipeline_ltx2_3_image2video",
- "LTX23ImageToVideoPipeline",
- ),
"StableAudioPipeline": (
"stable_audio",
"pipeline_stable_audio",
@@ -115,16 +93,6 @@
"pipeline_wan2_2_i2v",
"Wan22I2VPipeline",
),
- "WanT2VDMD2Pipeline": (
- "wan2_2",
- "pipeline_wan2_2",
- "WanT2VDMD2Pipeline",
- ),
- "WanI2VDMD2Pipeline": (
- "wan2_2",
- "pipeline_wan2_2_i2v",
- "WanI2VDMD2Pipeline",
- ),
"LongCatImagePipeline": (
"longcat_image",
"pipeline_longcat_image",
@@ -151,8 +119,8 @@
"FluxKontextPipeline",
),
"HunyuanImage3ForCausalMM": (
- "hunyuan_image3",
- "pipeline_hunyuan_image3",
+ "hunyuan_image_3",
+ "pipeline_hunyuan_image_3",
"HunyuanImage3Pipeline",
),
"Flux2KleinPipeline": (
@@ -205,11 +173,6 @@
"pipeline_hunyuan_video_1_5_i2v",
"HunyuanVideo15I2VPipeline",
),
- "MagiHumanPipeline": (
- "magi_human",
- "pipeline_magi_human",
- "MagiHumanPipeline",
- ),
"OmniVoicePipeline": (
"omnivoice",
"pipeline_omnivoice",
@@ -220,11 +183,6 @@
"pipeline_omnivoice",
"OmniVoicePipeline",
),
- "DiffusersAdapterPipeline": (
- "diffusers_adapter",
- "pipeline_diffusers_adapter",
- "DiffusersAdapterPipeline",
- ),
}
@@ -244,22 +202,6 @@
}
-def _prepare_diffusion_quant_config(
- od_config: OmniDiffusionConfig,
- model_class: type[nn.Module],
-) -> None:
- """Prepare diffusion quant config using vLLM-style model bindings."""
- quant_config = od_config.quantization_config
- if quant_config is None:
- return
- if hasattr(quant_config, "maybe_update_config"):
- quant_config.maybe_update_config(od_config.model)
- diffusion_packed_modules_mapping = current_omni_platform.get_diffusion_packed_modules_mapping(model_class)
- if diffusion_packed_modules_mapping is not None:
- model_class.packed_modules_mapping = diffusion_packed_modules_mapping
- configure_quant_config(quant_config, model_class)
-
-
def initialize_model(
od_config: OmniDiffusionConfig,
) -> nn.Module:
@@ -282,7 +224,6 @@ def initialize_model(
"""
model_class = DiffusionModelRegistry._try_load_model_cls(od_config.model_class_name)
if model_class is not None:
- _prepare_diffusion_quant_config(od_config, model_class)
model = model_class(od_config=od_config)
vae_pp_size = od_config.parallel_config.vae_patch_parallel_size
@@ -411,14 +352,8 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) -
"LTX2TwoStagesPipeline": "get_ltx2_post_process_func",
"LTX2ImageToVideoPipeline": "get_ltx2_post_process_func",
"LTX2ImageToVideoTwoStagesPipeline": "get_ltx2_post_process_func",
- "LTX2T2VDMD2Pipeline": "get_ltx2_post_process_func",
- "LTX2I2VDMD2Pipeline": "get_ltx2_post_process_func",
- "LTX23Pipeline": "get_ltx2_post_process_func",
- "LTX23ImageToVideoPipeline": "get_ltx2_post_process_func",
"StableAudioPipeline": "get_stable_audio_post_process_func",
"WanImageToVideoPipeline": "get_wan22_i2v_post_process_func",
- "WanT2VDMD2Pipeline": "get_wan22_post_process_func",
- "WanI2VDMD2Pipeline": "get_wan22_i2v_post_process_func",
"LongCatImagePipeline": "get_longcat_image_post_process_func",
"BagelPipeline": "get_bagel_post_process_func",
"LongCatImageEditPipeline": "get_longcat_image_post_process_func",
@@ -433,9 +368,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) -
"Flux2Pipeline": "get_flux2_post_process_func",
"HunyuanVideo15Pipeline": "get_hunyuan_video_15_post_process_func",
"HunyuanVideo15ImageToVideoPipeline": "get_hunyuan_video_15_i2v_post_process_func",
- "MagiHumanPipeline": "get_magi_human_post_process_func",
"OmniVoicePipeline": "get_omnivoice_post_process_func",
- "DreamIDOmniPipeline": "get_dreamid_omni_post_process_func",
}
_DIFFUSION_PRE_PROCESS_FUNCS = {
@@ -450,13 +383,10 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) -
"WanPipeline": "get_wan22_pre_process_func",
"WanVACEPipeline": "get_wan22_vace_pre_process_func",
"WanImageToVideoPipeline": "get_wan22_i2v_pre_process_func",
- "WanT2VDMD2Pipeline": "get_wan22_pre_process_func",
- "WanI2VDMD2Pipeline": "get_wan22_i2v_pre_process_func",
"OmniGen2Pipeline": "get_omnigen2_pre_process_func",
"HeliosPipeline": "get_helios_pre_process_func",
"HeliosPyramidPipeline": "get_helios_pre_process_func",
"HunyuanVideo15ImageToVideoPipeline": "get_hunyuan_video_15_i2v_pre_process_func",
- "MagiHumanPipeline": "get_magi_human_pre_process_func",
}
diff --git a/vllm_omni/diffusion/request.py b/vllm_omni/diffusion/request.py
index 4d4328d2513..1d6d64905ae 100644
--- a/vllm_omni/diffusion/request.py
+++ b/vllm_omni/diffusion/request.py
@@ -26,8 +26,6 @@ class OmniDiffusionRequest:
sampling_params: OmniDiffusionSamplingParams
request_ids: list[str] = field(default_factory=list)
- request_id: str | None = None
- kv_sender_info: dict | None = None
def __post_init__(self):
"""Initialize dependent fields after dataclass initialization."""
diff --git a/vllm_omni/diffusion/stage_diffusion_client.py b/vllm_omni/diffusion/stage_diffusion_client.py
index ff2d9f1e891..77db2b1b97c 100644
--- a/vllm_omni/diffusion/stage_diffusion_client.py
+++ b/vllm_omni/diffusion/stage_diffusion_client.py
@@ -8,20 +8,15 @@
from __future__ import annotations
import asyncio
-import multiprocessing.connection
import time
import uuid
-import weakref
from dataclasses import fields, is_dataclass
-from threading import Thread
from typing import TYPE_CHECKING, Any
import zmq
from vllm.logger import init_logger
-from vllm.v1.engine.exceptions import EngineDeadError
from vllm_omni.diffusion.stage_diffusion_proc import (
- StageDiffusionProc,
complete_diffusion_handshake,
spawn_diffusion_proc,
)
@@ -39,24 +34,6 @@
logger = init_logger(__name__)
-def create_diffusion_client(
- model: str,
- od_config: OmniDiffusionConfig,
- metadata: StageMetadata,
- stage_init_timeout: int,
- batch_size: int = 1,
- use_inline: bool = False,
-) -> Any:
- """Factory to create either an inline or out-of-process diffusion client."""
- if use_inline:
- from vllm_omni.diffusion.inline_stage_diffusion_client import InlineStageDiffusionClient
-
- return InlineStageDiffusionClient(model, od_config, metadata, batch_size=batch_size)
- return StageDiffusionClient(
- model, od_config, metadata, stage_init_timeout=stage_init_timeout, batch_size=batch_size
- )
-
-
class StageDiffusionClient:
"""Communicates with StageDiffusionProc via ZMQ for use inside the Orchestrator.
@@ -73,43 +50,7 @@ def __init__(
model: str,
od_config: OmniDiffusionConfig,
metadata: StageMetadata,
- stage_init_timeout: int,
batch_size: int = 1,
- ) -> None:
- # Spawn StageDiffusionProc subprocess and wait for READY.
- proc, handshake_address, request_address, response_address = spawn_diffusion_proc(model, od_config)
- complete_diffusion_handshake(proc, handshake_address, stage_init_timeout)
- self._initialize_client(metadata, request_address, response_address, proc=proc, batch_size=batch_size)
-
- @classmethod
- def from_addresses(
- cls,
- metadata: StageMetadata,
- request_address: str,
- response_address: str,
- *,
- proc: Any = None,
- batch_size: int = 1,
- ) -> StageDiffusionClient:
- """Create a client for an already-running diffusion subprocess."""
- client = cls.__new__(cls)
- client._initialize_client(
- metadata,
- request_address,
- response_address,
- proc=proc,
- batch_size=batch_size,
- )
- return client
-
- def _initialize_client(
- self,
- metadata: StageMetadata,
- request_address: str,
- response_address: str,
- *,
- proc: Any,
- batch_size: int,
) -> None:
self.stage_id = metadata.stage_id
self.final_output = metadata.final_output
@@ -117,9 +58,13 @@ def _initialize_client(
self.default_sampling_params = metadata.default_sampling_params
self.custom_process_input_func = metadata.custom_process_input_func
self.engine_input_source = metadata.engine_input_source
+
+ # Spawn StageDiffusionProc subprocess and wait for READY.
+ proc, handshake_address, request_address, response_address = spawn_diffusion_proc(model, od_config)
+ complete_diffusion_handshake(proc, handshake_address)
self._proc = proc
- self._owns_process = proc is not None
+ # ZMQ sockets (sync) for communicating with the subprocess.
self._zmq_ctx = zmq.Context()
self._request_socket = self._zmq_ctx.socket(zmq.PUSH)
self._request_socket.connect(request_address)
@@ -129,55 +74,14 @@ def _initialize_client(
self._encoder = OmniMsgpackEncoder()
self._decoder = OmniMsgpackDecoder()
+ # Buffers for demultiplexing response messages.
self._output_queue: asyncio.Queue[OmniRequestOutput] = asyncio.Queue()
self._rpc_results: dict[str, Any] = {}
self._pending_rpcs: set[str] = set()
self._tasks: dict[str, asyncio.Task] = {}
self._shutting_down = False
- self._engine_dead: bool = False
-
- # Background thread to detect silent process death (SIGKILL, segfault)
- # where the subprocess cannot send the ZMQ death sentinel.
- # Mirrors MPClient.start_engine_core_monitor() in vLLM.
- self._start_proc_monitor()
-
- logger.info(
- "[StageDiffusionClient] Stage-%s initialized (owns_process=%s, batch_size=%d)",
- self.stage_id,
- self._owns_process,
- batch_size,
- )
- # ------------------------------------------------------------------
- # Process monitor (mirrors vLLM's MPClient.start_engine_core_monitor)
- # ------------------------------------------------------------------
-
- def _start_proc_monitor(self) -> None:
- """Start a daemon thread that watches the subprocess sentinel.
-
- When the subprocess dies without sending the ZMQ death sentinel
- (e.g. SIGKILL, segfault), this thread sets ``_engine_dead`` so
- subsequent calls raise ``EngineDeadError``.
- """
- proc = self._proc
- self_ref = weakref.ref(self)
-
- def _monitor() -> None:
- try:
- multiprocessing.connection.wait([proc.sentinel])
- except Exception:
- return
- client = self_ref()
- if client is None or client._shutting_down or client._engine_dead:
- return
- client._engine_dead = True
- logger.error(
- "[StageDiffusionClient] Stage-%s StageDiffusionProc died unexpectedly (exit code %s).",
- client.stage_id,
- proc.exitcode,
- )
-
- Thread(target=_monitor, daemon=True, name="DiffusionProcMonitor").start()
+ logger.info("[StageDiffusionClient] Stage-%s initialized (batch_size=%d)", self.stage_id, batch_size)
# ------------------------------------------------------------------
# Internal helpers
@@ -191,15 +95,6 @@ def _drain_responses(self) -> None:
except zmq.Again:
break
- # Check for the death sentinel (raw bytes, not msgpack-encoded).
- if raw == StageDiffusionProc.DIFFUSION_PROC_DEAD:
- self._engine_dead = True
- logger.error(
- "[StageDiffusionClient] Stage-%s received DIFFUSION_PROC_DEAD sentinel from subprocess.",
- self.stage_id,
- )
- break
-
msg = self._decoder.decode(raw)
msg_type = msg.get("type")
@@ -223,10 +118,6 @@ def _drain_responses(self) -> None:
"error": True,
"reason": error_msg,
}
- # Route request errors as error outputs so the Orchestrator
- # sees the request complete (instead of hanging forever).
- if req_id is not None:
- self._output_queue.put_nowait(OmniRequestOutput.from_error(req_id, error_msg))
# Fields that are subprocess-local and cannot be serialized across
# process boundaries. They are recreated in the subprocess with
@@ -288,10 +179,7 @@ async def add_request_async(
request_id: str,
prompt: OmniPromptType,
sampling_params: OmniDiffusionSamplingParams,
- kv_sender_info: dict[int, dict[str, Any]] | None = None,
) -> None:
- if self._engine_dead:
- raise EngineDeadError()
self._request_socket.send(
self._encoder.encode(
{
@@ -299,7 +187,6 @@ async def add_request_async(
"request_id": request_id,
"prompt": prompt,
"sampling_params": self._sampling_params_to_dict(sampling_params),
- "kv_sender_info": kv_sender_info,
}
)
)
@@ -311,7 +198,6 @@ async def add_batch_request_async(
request_id: str,
prompts: list[OmniPromptType],
sampling_params: OmniDiffusionSamplingParams,
- kv_sender_info: dict[int, dict[str, Any]] | None = None,
) -> None:
"""Submit a list of prompts as a single batched engine call.
@@ -319,15 +205,8 @@ async def add_batch_request_async(
and the combined result is placed on the output queue with a single
*request_id*.
"""
- if self._engine_dead:
- raise EngineDeadError()
task = asyncio.create_task(
- self._run_batch(
- request_id,
- prompts,
- sampling_params,
- kv_sender_info,
- ),
+ self._run_batch(request_id, prompts, sampling_params),
name=f"diffusion-batch-{request_id}",
)
self._tasks[request_id] = task
@@ -337,7 +216,6 @@ async def _run_batch(
request_id: str,
prompts: list[OmniPromptType],
sampling_params: OmniDiffusionSamplingParams,
- kv_sender_info: dict[int, dict[str, Any]] | None = None,
) -> None:
try:
self._request_socket.send(
@@ -347,7 +225,6 @@ async def _run_batch(
"request_id": request_id,
"prompts": prompts,
"sampling_params": self._sampling_params_to_dict(sampling_params),
- "kv_sender_info": kv_sender_info,
}
)
)
@@ -358,7 +235,6 @@ async def _run_batch(
request_id,
e,
)
- await self._output_queue.put(OmniRequestOutput.from_error(request_id, str(e)))
finally:
self._tasks.pop(request_id, None)
@@ -367,10 +243,7 @@ def get_diffusion_output_nowait(self) -> OmniRequestOutput | None:
try:
return self._output_queue.get_nowait()
except asyncio.QueueEmpty:
- if self._engine_dead:
- raise EngineDeadError()
- if not self._shutting_down and self._owns_process and self._proc is not None and not self._proc.is_alive():
- self._engine_dead = True
+ if not self._shutting_down and self._proc is not None and not self._proc.is_alive():
exitcode = self._proc.exitcode
# One final drain – the last ZMQ frame may have arrived
# between the first drain and the is_alive() check.
@@ -384,7 +257,7 @@ def get_diffusion_output_nowait(self) -> OmniRequestOutput | None:
logger.warning("StageDiffusionProc was killed by signal %d; treating as external shutdown.", sig)
self._shutting_down = True
return None
- raise EngineDeadError(f"StageDiffusionProc died unexpectedly (exit code {exitcode})")
+ raise RuntimeError(f"StageDiffusionProc died unexpectedly (exit code {exitcode})")
return None
async def abort_requests_async(self, request_ids: list[str]) -> None:
@@ -405,9 +278,6 @@ async def collective_rpc_async(
kwargs: dict[str, Any] | None = None,
) -> Any:
"""Forward control RPCs to the diffusion subprocess."""
- if self._engine_dead:
- raise EngineDeadError()
-
# Inject a default profile_prefix that includes stage_id when profiling.
if method == "profile":
args_list = list(args)
@@ -445,9 +315,8 @@ async def collective_rpc_async(
self._drain_responses()
if rpc_id in self._rpc_results:
return self._rpc_results.pop(rpc_id)
- if self._engine_dead or (self._owns_process and self._proc is not None and not self._proc.is_alive()):
- self._engine_dead = True
- raise EngineDeadError(
+ if self._proc is not None and not self._proc.is_alive():
+ raise RuntimeError(
f"StageDiffusionProc died while waiting for "
f"collective_rpc '{method}' (exit code {self._proc.exitcode})"
)
@@ -457,19 +326,6 @@ async def collective_rpc_async(
finally:
self._pending_rpcs.discard(rpc_id)
- def check_health(self) -> None:
- """Raise ``EngineDeadError`` if the diffusion engine is dead.
-
- Mirrors the ``check_health`` protocol on vLLM's ``EngineClient``.
- """
- if self._engine_dead:
- raise EngineDeadError(f"Stage-{self.stage_id} diffusion subprocess is dead")
- if self._proc is not None and not self._proc.is_alive():
- self._engine_dead = True
- raise EngineDeadError(
- f"Stage-{self.stage_id} diffusion subprocess is not alive (exit code: {self._proc.exitcode})."
- )
-
def shutdown(self) -> None:
self._shutting_down = True
try:
@@ -477,7 +333,7 @@ def shutdown(self) -> None:
except Exception:
pass
- if self._owns_process and self._proc is not None and self._proc.is_alive():
+ if self._proc is not None and self._proc.is_alive():
self._proc.join(timeout=10)
terminate_alive_proc(self._proc)
diff --git a/vllm_omni/diffusion/stage_diffusion_proc.py b/vllm_omni/diffusion/stage_diffusion_proc.py
index d36c5c644c7..0a5fd359018 100644
--- a/vllm_omni/diffusion/stage_diffusion_proc.py
+++ b/vllm_omni/diffusion/stage_diffusion_proc.py
@@ -14,16 +14,15 @@
from typing import TYPE_CHECKING, Any
import msgspec
-import torch
import zmq
import zmq.asyncio
-from PIL import Image
from vllm.logger import init_logger
+from vllm.transformers_utils.config import get_hf_file_to_dict
from vllm.utils.network_utils import get_open_zmq_ipc_path, zmq_socket_ctx
from vllm.utils.system_utils import get_mp_context
from vllm.v1.utils import shutdown
-from vllm_omni.diffusion.data import DiffusionRequestAbortedError
+from vllm_omni.diffusion.data import DiffusionRequestAbortedError, TransformerConfig
from vllm_omni.diffusion.diffusion_engine import DiffusionEngine
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.distributed.omni_connectors.utils.serialization import (
@@ -38,6 +37,8 @@
logger = init_logger(__name__)
+_HANDSHAKE_POLL_TIMEOUT_S = 600
+
class StageDiffusionProc:
"""Subprocess entry point for diffusion inference.
@@ -46,8 +47,6 @@ class StageDiffusionProc:
and ZMQ-based communication with StageDiffusionClient.
"""
- DIFFUSION_PROC_DEAD = b"DIFFUSION_PROC_DEAD"
-
def __init__(self, model: str, od_config: OmniDiffusionConfig) -> None:
self._model = model
self._od_config = od_config
@@ -67,8 +66,47 @@ def initialize(self) -> None:
logger.info("StageDiffusionProc initialized with model: %s", self._model)
def _enrich_config(self) -> None:
- """Load model metadata from HuggingFace and populate od_config fields."""
- self._od_config.enrich_config()
+ """Load model metadata from HuggingFace and populate od_config fields.
+
+ Diffusers-style models expose ``model_index.json`` with ``_class_name``.
+ Non-diffusers models (e.g. Bagel, NextStep) only have ``config.json``,
+ so we fall back to reading that and mapping model_type manually.
+ """
+ od_config = self._od_config
+
+ try:
+ config_dict = get_hf_file_to_dict("model_index.json", od_config.model)
+ if config_dict is not None:
+ if od_config.model_class_name is None:
+ od_config.model_class_name = config_dict.get("_class_name", None)
+ od_config.update_multimodal_support()
+
+ tf_config_dict = get_hf_file_to_dict("transformer/config.json", od_config.model)
+ od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict)
+ else:
+ raise FileNotFoundError("model_index.json not found")
+ except (AttributeError, OSError, ValueError, FileNotFoundError):
+ cfg = get_hf_file_to_dict("config.json", od_config.model)
+ if cfg is None:
+ raise ValueError(f"Could not find config.json or model_index.json for model {od_config.model}")
+
+ od_config.tf_model_config = TransformerConfig.from_dict(cfg)
+ model_type = cfg.get("model_type")
+ architectures = cfg.get("architectures") or []
+
+ if model_type == "bagel" or "BagelForConditionalGeneration" in architectures:
+ od_config.model_class_name = "BagelPipeline"
+ od_config.tf_model_config = TransformerConfig()
+ od_config.update_multimodal_support()
+ elif model_type == "nextstep":
+ if od_config.model_class_name is None:
+ od_config.model_class_name = "NextStep11Pipeline"
+ od_config.tf_model_config = TransformerConfig()
+ od_config.update_multimodal_support()
+ elif architectures and len(architectures) == 1:
+ od_config.model_class_name = architectures[0]
+ else:
+ raise
# ------------------------------------------------------------------
# Request processing
@@ -90,7 +128,6 @@ async def _process_request(
request_id: str,
prompt: Any,
sampling_params_dict: dict,
- kv_sender_info: dict[str, Any] | None = None,
) -> OmniRequestOutput:
"""Build a diffusion request and run DiffusionEngine.step()."""
sampling_params = self._reconstruct_sampling_params(sampling_params_dict)
@@ -99,8 +136,6 @@ async def _process_request(
prompts=[prompt],
sampling_params=sampling_params,
request_ids=[request_id],
- request_id=request_id,
- kv_sender_info=kv_sender_info,
)
loop = asyncio.get_running_loop()
@@ -115,7 +150,6 @@ async def _process_batch_request(
request_id: str,
prompts: list[Any],
sampling_params_dict: dict,
- kv_sender_info: dict[str, Any] | None = None,
) -> OmniRequestOutput:
"""Build a batched diffusion request and run DiffusionEngine.step().
@@ -129,9 +163,7 @@ async def _process_batch_request(
request = OmniDiffusionRequest(
prompts=prompts,
sampling_params=sampling_params,
- request_ids=[f"{request_id}-{i}" for i in range(len(prompts))],
- request_id=request_id,
- kv_sender_info=kv_sender_info,
+ request_ids=[request_id] * len(prompts),
)
loop = asyncio.get_running_loop()
@@ -142,13 +174,8 @@ async def _process_batch_request(
merged_mm: dict[str, Any] = {}
merged_metrics: dict[str, Any] = {}
merged_durations: dict[str, float] = {}
- merged_custom: dict[str, Any] = {}
peak_mem = 0.0
latents = None
- trajectory_latents: list[torch.Tensor] | None = None
- trajectory_timesteps: list[torch.Tensor] | None = None
- trajectory_log_probs: torch.Tensor | None = None
- trajectory_decoded: list[Image.Image] | None = None
final_output_type = "image"
for r in results:
@@ -156,18 +183,9 @@ async def _process_batch_request(
merged_mm.update(r._multimodal_output)
merged_metrics.update(r.metrics)
merged_durations.update(r.stage_durations)
- merged_custom.update(r._custom_output)
peak_mem = max(peak_mem, r.peak_memory_mb)
if latents is None and r.latents is not None:
latents = r.latents
- if trajectory_latents is None:
- trajectory_latents = r.trajectory_latents
- if trajectory_timesteps is None:
- trajectory_timesteps = r.trajectory_timesteps
- if trajectory_log_probs is None:
- trajectory_log_probs = r.trajectory_log_probs
- if trajectory_decoded is None:
- trajectory_decoded = r.trajectory_decoded
if r.final_output_type != "image":
final_output_type = r.final_output_type
@@ -177,11 +195,6 @@ async def _process_batch_request(
prompt=prompts[0] if len(prompts) == 1 else None,
metrics=merged_metrics,
latents=latents,
- trajectory_latents=trajectory_latents,
- trajectory_timesteps=trajectory_timesteps,
- trajectory_log_probs=trajectory_log_probs,
- trajectory_decoded=trajectory_decoded,
- custom_output=merged_custom or None,
multimodal_output=merged_mm or None,
final_output_type=final_output_type,
stage_durations=merged_durations,
@@ -312,20 +325,10 @@ async def run_loop(
tasks: dict[str, asyncio.Task] = {}
- async def _dispatch_request(
- request_id: str,
- prompt: Any,
- sampling_params_dict: dict,
- kv_sender_info: dict[str, Any] | None = None,
- ) -> None:
+ async def _dispatch_request(request_id: str, prompt: Any, sampling_params_dict: dict) -> None:
"""Process a single diffusion request and send the response."""
try:
- result = await self._process_request(
- request_id,
- prompt,
- sampling_params_dict,
- kv_sender_info=kv_sender_info,
- )
+ result = await self._process_request(request_id, prompt, sampling_params_dict)
await response_socket.send(encoder.encode({"type": "result", "output": result}))
except DiffusionRequestAbortedError as e:
logger.info(
@@ -360,7 +363,6 @@ async def _dispatch_request(
request_id,
msg["prompt"],
msg["sampling_params"],
- msg.get("kv_sender_info"),
)
)
tasks[request_id] = task
@@ -368,19 +370,9 @@ async def _dispatch_request(
elif msg_type == "add_batch_request":
request_id = msg["request_id"]
- async def _dispatch_batch(
- rid: str,
- prompts: list,
- sp_dict: dict,
- kv_sender_info: dict[str, Any] | None = None,
- ) -> None:
+ async def _dispatch_batch(rid: str, prompts: list, sp_dict: dict) -> None:
try:
- result = await self._process_batch_request(
- rid,
- prompts,
- sp_dict,
- kv_sender_info=kv_sender_info,
- )
+ result = await self._process_batch_request(rid, prompts, sp_dict)
await response_socket.send(encoder.encode({"type": "result", "output": result}))
except DiffusionRequestAbortedError as e:
logger.info(
@@ -407,7 +399,6 @@ async def _dispatch_batch(
request_id,
msg["prompts"],
msg["sampling_params"],
- msg.get("kv_sender_info"),
)
)
tasks[request_id] = task
@@ -452,16 +443,6 @@ async def _dispatch_batch(
elif msg_type == "shutdown":
break
- except Exception:
- # Send the death sentinel so the client can detect the
- # fatal failure promptly (mirrors EngineCoreProc._send_engine_dead).
- try:
- response_socket.setsockopt(zmq.LINGER, 4000)
- await response_socket.send(StageDiffusionProc.DIFFUSION_PROC_DEAD)
- except Exception:
- logger.warning("Failed to send DIFFUSION_PROC_DEAD sentinel to client.")
- raise
-
finally:
for task in tasks.values():
task.cancel()
@@ -550,17 +531,14 @@ def signal_handler(signum: int, frame: Any) -> None:
def spawn_diffusion_proc(
model: str,
od_config: OmniDiffusionConfig,
- handshake_address: str | None = None,
- request_address: str | None = None,
- response_address: str | None = None,
) -> tuple[BaseProcess, str, str, str]:
"""Spawn a StageDiffusionProc subprocess.
Returns ``(proc, handshake_address, request_address, response_address)``.
"""
- handshake_address = handshake_address or get_open_zmq_ipc_path()
- request_address = request_address or get_open_zmq_ipc_path()
- response_address = response_address or get_open_zmq_ipc_path()
+ handshake_address = get_open_zmq_ipc_path()
+ request_address = get_open_zmq_ipc_path()
+ response_address = get_open_zmq_ipc_path()
ctx = get_mp_context()
proc = ctx.Process(
@@ -589,14 +567,13 @@ def spawn_diffusion_proc(
def complete_diffusion_handshake(
proc: BaseProcess,
handshake_address: str,
- handshake_timeout: int,
) -> None:
"""Wait for the diffusion subprocess to signal READY.
On failure the process is terminated before re-raising.
"""
try:
- _perform_diffusion_handshake(proc, handshake_address, handshake_timeout)
+ _perform_diffusion_handshake(proc, handshake_address)
except Exception:
shutdown([proc])
raise
@@ -605,7 +582,6 @@ def complete_diffusion_handshake(
def _perform_diffusion_handshake(
proc: BaseProcess,
handshake_address: str,
- handshake_timeout: int,
) -> None:
"""Run the handshake with the diffusion subprocess."""
with zmq_socket_ctx(handshake_address, zmq.ROUTER, bind=True) as handshake_socket:
@@ -613,15 +589,11 @@ def _perform_diffusion_handshake(
poller.register(handshake_socket, zmq.POLLIN)
poller.register(proc.sentinel, zmq.POLLIN)
- timeout_ms = handshake_timeout * 1000
+ timeout_ms = _HANDSHAKE_POLL_TIMEOUT_S * 1000
while True:
events = dict(poller.poll(timeout=timeout_ms))
if not events:
- raise TimeoutError(
- f"Timed out waiting for READY from StageDiffusionProc after {handshake_timeout}s. "
- f"This typically indicates model loading or warmup is taking too long. "
- f"Consider increasing `stage_init_timeout` for large models."
- )
+ raise TimeoutError("Timed out waiting for READY from StageDiffusionProc")
if handshake_socket in events:
identity, raw = handshake_socket.recv_multipart()
msg = msgspec.msgpack.decode(raw)
diff --git a/vllm_omni/diffusion/utils/media_utils.py b/vllm_omni/diffusion/utils/media_utils.py
deleted file mode 100644
index a09cd459539..00000000000
--- a/vllm_omni/diffusion/utils/media_utils.py
+++ /dev/null
@@ -1,80 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Video/audio muxing utilities using PyAV (no ffmpeg binary dependency)."""
-
-from __future__ import annotations
-
-import io
-from fractions import Fraction
-
-import av
-import numpy as np
-
-
-def mux_video_audio_bytes(
- video_frames: np.ndarray,
- audio_waveform: np.ndarray | None = None,
- *,
- fps: float = 25.0,
- audio_sample_rate: int = 44100,
- video_codec: str = "h264",
- audio_codec: str = "aac",
- crf: str = "18",
- video_codec_options: dict[str, str] | None = None,
-) -> bytes:
- """Mux video frames and optional audio waveform into MP4 bytes.
-
- Args:
- video_frames: uint8 array of shape ``(T, H, W, 3)`` (RGB).
- audio_waveform: float32 array – mono ``(N,)`` or ``(N, C)`` / ``(C, N)``.
- fps: Video frame rate.
- audio_sample_rate: Audio sample rate in Hz.
- video_codec: Video codec name.
- audio_codec: Audio codec name.
- crf: Constant rate factor for the video encoder.
-
- Returns:
- Raw MP4 bytes ready to be written to disk or streamed.
- """
- buf = io.BytesIO()
- container = av.open(buf, mode="w", format="mp4")
-
- v_stream = container.add_stream(video_codec, rate=Fraction(fps).limit_denominator(10000))
- v_stream.width = video_frames.shape[2]
- v_stream.height = video_frames.shape[1]
- v_stream.pix_fmt = "yuv420p"
-
- options = {"crf": str(crf)}
- if video_codec_options:
- options.update(video_codec_options)
- v_stream.options = options
-
- a_stream = None
- if audio_waveform is not None:
- samples = audio_waveform.astype(np.float32)
- if samples.ndim == 1:
- samples = samples.reshape(1, -1)
- elif samples.ndim == 2 and samples.shape[0] > samples.shape[1]:
- samples = np.ascontiguousarray(samples.T)
- num_channels = samples.shape[0]
- layout = "stereo" if num_channels >= 2 else "mono"
- a_stream = container.add_stream(audio_codec, rate=audio_sample_rate)
- a_stream.layout = layout
-
- for frame_data in video_frames:
- frame = av.VideoFrame.from_ndarray(frame_data, format="rgb24")
- for packet in v_stream.encode(frame):
- container.mux(packet)
- for packet in v_stream.encode():
- container.mux(packet)
-
- if a_stream is not None and audio_waveform is not None:
- audio_frame = av.AudioFrame.from_ndarray(samples, format="fltp", layout=layout)
- audio_frame.sample_rate = audio_sample_rate
- for packet in a_stream.encode(audio_frame):
- container.mux(packet)
- for packet in a_stream.encode():
- container.mux(packet)
-
- container.close()
- return buf.getvalue()
diff --git a/vllm_omni/diffusion/utils/prompt_utils.py b/vllm_omni/diffusion/utils/prompt_utils.py
deleted file mode 100644
index fc1769f4d54..00000000000
--- a/vllm_omni/diffusion/utils/prompt_utils.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import torch
-
-
-def validate_prompt_sequence_lengths(
- attention_mask: torch.Tensor,
- *,
- max_sequence_length: int,
- supported_max_sequence_length: int,
- prompt_name: str = "prompt",
- length_offset: int = 0,
- baseline_attention_mask: torch.Tensor | None = None,
- error_context: str,
-) -> None:
- sequence_lengths = attention_mask.sum(dim=1)
- if baseline_attention_mask is not None:
- # Some callers need to validate only the user-controlled portion of a
- # templated prompt. In those cases we subtract the fully-tokenized
- # template baseline instead of only removing a fixed prefix length,
- # because the template may also contribute a suffix or image markers.
- baseline_lengths = baseline_attention_mask.sum(dim=1)
- if baseline_lengths.shape[0] == 1 and sequence_lengths.shape[0] > 1:
- baseline_lengths = baseline_lengths.expand(sequence_lengths.shape[0])
- sequence_lengths = sequence_lengths - baseline_lengths
- if length_offset:
- sequence_lengths = sequence_lengths - length_offset
- sequence_lengths = torch.clamp(sequence_lengths, min=0)
- too_long = torch.nonzero(sequence_lengths > max_sequence_length, as_tuple=False)
- if too_long.numel() == 0:
- return
-
- batch_idx = int(too_long[0].item())
- actual_length = int(sequence_lengths[batch_idx].item())
- prompt_ref = f"`{prompt_name}` at batch index {batch_idx}" if attention_mask.shape[0] > 1 else f"`{prompt_name}`"
- raise ValueError(
- f"{prompt_ref} is too long {error_context}: got {actual_length} tokens, but "
- f"`max_sequence_length` is {max_sequence_length}. Shorten the prompt or increase "
- f"`max_sequence_length` up to {supported_max_sequence_length}."
- )
diff --git a/vllm_omni/diffusion/utils/size_utils.py b/vllm_omni/diffusion/utils/size_utils.py
deleted file mode 100644
index 030e542f17f..00000000000
--- a/vllm_omni/diffusion/utils/size_utils.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""Shared size normalization helpers for diffusion pipelines."""
-
-
-def normalize_min_aligned_size(height: int, width: int, alignment: int) -> tuple[int, int]:
- """Clamp dimensions to the minimum valid aligned size.
-
- This preserves floor-to-alignment behavior for normal requests while
- preventing very small dimensions from collapsing to zero after alignment.
- """
-
- alignment = int(alignment)
- if alignment <= 0:
- raise ValueError(f"Expected positive alignment, got {alignment}")
-
- normalized_height = max(alignment, (int(height) // alignment) * alignment)
- normalized_width = max(alignment, (int(width) // alignment) * alignment)
- return normalized_height, normalized_width
diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py
index 535f053c388..32ea5bf64dc 100644
--- a/vllm_omni/diffusion/worker/diffusion_model_runner.py
+++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py
@@ -35,12 +35,11 @@
from vllm_omni.diffusion.worker.utils import DiffusionRequestState, RunnerOutput
from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager
from vllm_omni.platforms import current_omni_platform
-from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin
logger = init_logger(__name__)
-class DiffusionModelRunner(OmniConnectorModelRunnerMixin):
+class DiffusionModelRunner:
"""
Model runner that handles model loading and execution for diffusion models.
diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py
index 927bbeb1a2a..ea4b9d96f71 100644
--- a/vllm_omni/diffusion/worker/diffusion_worker.py
+++ b/vllm_omni/diffusion/worker/diffusion_worker.py
@@ -13,7 +13,6 @@
import os
from collections.abc import Iterable
from contextlib import AbstractContextManager, nullcontext
-from types import SimpleNamespace
from typing import Any
import torch
@@ -21,18 +20,13 @@
from vllm.config import CompilationConfig, DeviceConfig, VllmConfig, set_current_vllm_config
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.logger import init_logger
-from vllm.profiler.wrapper import CudaProfilerWrapper, WorkerProfiler
-from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.mem_utils import GiB_bytes
from vllm.v1.worker.workspace import init_workspace_manager
from vllm_omni.diffusion.data import (
DiffusionOutput,
- OmniACK,
OmniDiffusionConfig,
- OmniSleepTask,
- OmniWakeTask,
)
from vllm_omni.diffusion.distributed.parallel_state import (
destroy_distributed_env,
@@ -82,7 +76,6 @@ def __init__(
self.model_runner: DiffusionModelRunner | None = None
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
self.lora_manager: DiffusionLoRAManager | None = None
- self.stage_id = getattr(od_config, "stage_id", 0)
self.init_device()
# Create model runner
self.model_runner = DiffusionModelRunner(
@@ -90,7 +83,15 @@ def __init__(
od_config=self.od_config,
device=self.device,
)
- self.profiler: WorkerProfiler | None = self._create_profiler()
+ # Initialize profiler if configured
+ self.profiler: OmniTorchProfilerWrapper | None = None
+ profiler_config = self.od_config.profiler_config
+ if profiler_config and profiler_config.profiler == "torch":
+ self.profiler = create_omni_profiler(
+ profiler_config=profiler_config,
+ worker_name=f"diffusion_worker_{self.rank}",
+ local_rank=self.local_rank,
+ )
if not skip_load_model:
self.load_model(load_format=self.od_config.diffusion_load_format)
self.init_lora_manager()
@@ -121,21 +122,6 @@ def init_device(self) -> None:
vllm_config.parallel_config.tensor_parallel_size = self.od_config.parallel_config.tensor_parallel_size
vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size
vllm_config.parallel_config.enable_expert_parallel = self.od_config.parallel_config.enable_expert_parallel
- vllm_config.profiler_config = self.od_config.profiler_config
- try:
- hf_config = get_config(self.od_config.model, trust_remote_code=self.od_config.trust_remote_code)
- except ValueError:
- hf_config = None
- logger.info("Skipping hf_config loading for diffusion model %r", self.od_config.model_class_name)
- hf_text_config = get_hf_text_config(hf_config) if hf_config is not None else None
- vllm_config.model_config = SimpleNamespace(
- hf_config=hf_config,
- hf_text_config=hf_text_config,
- enforce_eager=self.od_config.enforce_eager,
- dtype=self.od_config.dtype,
- enable_return_routed_experts=False,
- )
- vllm_config.quant_config = self.od_config.quantization_config
self.vllm_config = vllm_config
# Initialize distributed environment
@@ -161,28 +147,8 @@ def init_device(self) -> None:
)
init_workspace_manager(self.device)
- def _create_profiler(self) -> WorkerProfiler | None:
- profiler_config = self.od_config.profiler_config
- profiler_type = getattr(profiler_config, "profiler", None)
- if profiler_type == "torch":
- return create_omni_profiler(
- profiler_config=profiler_config,
- worker_name=f"diffusion_rank{self.rank}",
- local_rank=self.local_rank,
- )
- if profiler_type == "cuda":
- return CudaProfilerWrapper(profiler_config)
- if profiler_type is not None:
- logger.warning("Unknown profiler backend %r on diffusion worker %s", profiler_type, self.rank)
- return None
-
- def _get_profiler(self) -> WorkerProfiler | None:
- return getattr(self, "profiler", None)
-
- def load_model(self, load_format: str = "default", custom_pipeline_name: str | None = None, **kwargs) -> None:
+ def load_model(self, load_format: str = "default", custom_pipeline_name: str | None = None) -> None:
"""Load the diffusion model using DiffusionModelRunner."""
- load_format = kwargs.get("load_format", load_format)
- custom_pipeline_name = kwargs.get("custom_pipeline_name", custom_pipeline_name)
with (
set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config),
set_current_vllm_config(self.vllm_config),
@@ -192,8 +158,6 @@ def load_model(self, load_format: str = "default", custom_pipeline_name: str | N
load_format=load_format,
custom_pipeline_name=custom_pipeline_name,
)
- current_omni_platform.synchronize()
- gc.collect()
process_memory = get_process_gpu_memory(self.local_rank)
if process_memory is not None:
logger.info(
@@ -228,21 +192,27 @@ def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> N
Args:
is_start: True to start profiling, False to stop.
- profile_prefix: Optional prefix for trace filename.
+ profile_prefix: Optional prefix for trace filename (vLLM compat).
+
+ Note:
+ Matches vLLM's worker.profile() signature for consistency.
+ Traces are saved automatically via on_trace_ready callback.
"""
- profiler = self._get_profiler()
- if profiler is None:
+ if self.profiler is None:
+ logger.warning("Profiler not initialized, skipping profile(%s)", is_start)
return
if is_start:
- if isinstance(profiler, OmniTorchProfilerWrapper):
+ from vllm_omni.profiler import OmniTorchProfilerWrapper
+
+ if isinstance(self.profiler, OmniTorchProfilerWrapper):
import time
- filename = profile_prefix or f"diffusion_rank{self.rank}_{int(time.time())}"
- profiler.set_trace_filename(filename)
- profiler.start()
+ filename = profile_prefix or f"diffusion_{int(time.time())}"
+ self.profiler.set_trace_filename(filename)
+ self.profiler.start()
else:
- profiler.stop()
+ self.profiler.stop()
def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfig) -> DiffusionOutput:
"""Execute a forward pass by delegating to the model runner."""
@@ -254,13 +224,7 @@ def execute_model(self, req: OmniDiffusionRequest, od_config: OmniDiffusionConfi
if req.sampling_params.lora_request is not None:
raise
logger.warning("LoRA activation skipped: %s", exc)
- profiler = self._get_profiler()
- ctx = profiler.annotate_context_manager("diffusion_forward") if profiler else nullcontext()
- with ctx:
- output = self.model_runner.execute_model(req)
- if profiler:
- profiler.step()
- return output
+ return self.model_runner.execute_model(req)
def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput:
"""Execute one diffusion step by delegating to the model runner."""
@@ -272,13 +236,8 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> Runner
if any(new_req.req.sampling_params.lora_request is not None for new_req in scheduler_output.scheduled_new_reqs):
raise ValueError("Step mode does not support LoRA yet.")
- profiler = self._get_profiler()
- ctx = profiler.annotate_context_manager("diffusion_step") if profiler else nullcontext()
- with ctx:
- output = self.model_runner.execute_stepwise(scheduler_output)
- if profiler:
- profiler.step()
- return output
+
+ return self.model_runner.execute_stepwise(scheduler_output)
def load_weights(self, weights) -> set[str]:
"""Load weights by delegating to the model runner."""
@@ -308,176 +267,77 @@ def sleep(self, level: int = 1) -> bool:
"""
from vllm.device_allocator.cumem import CuMemAllocator
- allocator = CuMemAllocator.get_instance()
-
- usage_before = allocator.get_current_usage()
+ process_memory_before_sleep = get_process_gpu_memory(self.local_rank)
+ free_bytes_before_sleep = None
+ if process_memory_before_sleep is None:
+ free_bytes_before_sleep = current_omni_platform.get_free_memory()
+ # Save the buffers before level 2 sleep
if level == 2 and self.model_runner is not None:
- if hasattr(self.model_runner, "graph_runners"):
- self.model_runner.graph_runners.clear()
- logger.info(f"[Worker {self.rank}] CUDA Graphs cleared.")
model = self.model_runner.pipeline
self._sleep_saved_buffers = {name: buffer.cpu().clone() for name, buffer in model.named_buffers()}
- free_mem_before = current_omni_platform.get_free_memory()
-
- # Level 1: Offload weights; Level 2: Total Discard
- offload_tags = ("weights",) if level == 1 else tuple()
- allocator.sleep(offload_tags=offload_tags)
-
- current_omni_platform.empty_cache()
- current_omni_platform.synchronize()
-
- free_mem_after = current_omni_platform.get_free_memory()
- try:
- total_mem = current_omni_platform.get_device_total_memory()
- except (NotImplementedError, AttributeError):
- total_mem = torch.cuda.get_device_properties(self.device).total_memory
-
- phys_freed_bytes = max(0, free_mem_after - free_mem_before)
- phys_used_bytes = total_mem - free_mem_after
-
- if usage_before > 0:
- logger.info(
- f"[Diffusion Worker {self.rank}] Sleep Level {level}: "
- f"physically freed {phys_freed_bytes / GiB_bytes:.2f} GiB, "
- f"{phys_used_bytes / GiB_bytes:.2f} GiB is still in use."
- )
+ allocator = CuMemAllocator.get_instance()
+ allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
+ process_memory_after_sleep = get_process_gpu_memory(self.local_rank)
+ if process_memory_before_sleep is not None and process_memory_after_sleep is not None:
+ freed_bytes = process_memory_before_sleep - process_memory_after_sleep
+ used_bytes = process_memory_after_sleep
+ accounting_scope = "process-scoped"
else:
- logger.info(f"[Worker {self.rank}] Sleep Level {level} completed (GPU was already empty).")
- logger.info(f"[Worker {self.rank}] Memory usage before sleep: {usage_before / GiB_bytes:.2f} GiB.")
- return usage_before
+ free_bytes_after_sleep = current_omni_platform.get_free_memory()
+ assert free_bytes_before_sleep is not None
+ device_id = self.device.index if self.device.index is not None else 0
+ total = current_omni_platform.get_device_total_memory(device_id)
+ freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
+ used_bytes = total - free_bytes_after_sleep
+ accounting_scope = "device-scoped fallback"
+ assert freed_bytes >= 0, "Memory usage increased after sleeping."
+ logger.info(
+ "Sleep mode (%s) freed %.2f GiB memory, %.2f GiB memory is still in use.",
+ accounting_scope,
+ freed_bytes / GiB_bytes,
+ used_bytes / GiB_bytes,
+ )
+ return True
def wake_up(self, tags: list[str] | None = None) -> bool:
"""
- Wake up the worker from sleep mode.
-
- Re-activates the memory allocator for the specified tags and restores
- model buffers from CPU back to GPU if they were saved during Level 2 sleep.
+ Wake up the worker from sleep mode. See the sleep function
+ method for more details.
Args:
- tags: List of memory pool tags to re-activate (e.g., ["weights"]
- to match Level 1 sleep). If None, all pools are re-activated.
+ tags: An optional list of tags to reallocate the worker memory
+ for specific memory allocations. Values must be in
+ `("weights")`. If None, all memory is reallocated.
+ wake_up should be called with all tags (or None) before the
+ worker is used again.
"""
from vllm.device_allocator.cumem import CuMemAllocator
allocator = CuMemAllocator.get_instance()
allocator.wake_up(tags)
- current_omni_platform.synchronize()
+
+ # Restore the buffers after level 2 sleep
if len(self._sleep_saved_buffers) and self.model_runner is not None:
model = self.model_runner.pipeline
for name, buffer in model.named_buffers():
if name in self._sleep_saved_buffers:
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}
- logger.info(f"[Worker {self.rank}] Buffers restored from CPU.")
- logger.info(f"[Worker {self.rank}] Wake-up complete.")
return True
- def handle_sleep_task(self, task: OmniSleepTask) -> OmniACK:
- from vllm_omni.platforms import current_omni_platform
-
- try:
- if isinstance(task, dict):
- task = OmniSleepTask(**task)
- logger.info(f"[Worker {self.rank}] Handshake Received: Task {task.task_id}")
-
- current_omni_platform.synchronize()
- usage_before = current_omni_platform.get_current_memory_usage(self.device)
- self.sleep(level=task.level)
- current_omni_platform.synchronize()
- usage_after = current_omni_platform.get_current_memory_usage(self.device)
- real_freed = max(0, usage_before - usage_after)
- logger.info(f"[Worker {self.rank}] Preparing ACK: freed_bytes={real_freed / GiB_bytes:.2f} GiB.")
-
- # Ensure all ranks have completed sleep before measuring memory and sending ACK
- if torch.distributed.is_initialized():
- t_freed = torch.tensor([float(real_freed)], device=self.device)
- torch.distributed.all_reduce(t_freed)
- real_freed = int(t_freed.item())
-
- if self.rank != 0:
- return None
-
- ack = OmniACK(
- task_id=task.task_id,
- status="SUCCESS",
- stage_id=self.stage_id,
- rank=self.rank,
- freed_bytes=real_freed,
- # return RL need metadata
- metadata={
- "source": f"Platform_{current_omni_platform.get_device_name()}",
- "total_freed_gib": f"{real_freed / GiB_bytes:.2f}",
- "rank_residual_gib": f"{usage_after / GiB_bytes:.2f}",
- },
- )
- logger.info(f"[Worker {self.rank}] ACK emitted. Freed {real_freed / GiB_bytes:.2f} GiB.")
- return ack
- except Exception as e:
- logger.error(f"Sleep failed: {e}", exc_info=True)
- if torch.distributed.is_initialized():
- try:
- torch.distributed.barrier()
- except Exception:
- pass
- return OmniACK(task_id=task.task_id, status="ERROR", error_msg=str(e))
-
- def handle_wake_task(self, task: OmniWakeTask) -> OmniACK:
- from vllm_omni.platforms import current_omni_platform
-
- try:
- if isinstance(task, dict):
- task = OmniWakeTask(**task)
- logger.info(f"[Worker {self.rank}] Responding to Wake-up Task: {task.task_id}")
- self.wake_up(tags=task.tags)
-
- logger.info(f"[Worker {self.rank}] wake_up logic finished, entering barrier...")
- if torch.distributed.is_initialized():
- torch.distributed.barrier()
-
- current_omni_platform.synchronize()
- usage_now = current_omni_platform.get_current_memory_usage(self.device)
- current_used_gib = usage_now / (1024**3)
-
- if self.rank != 0:
- return None
- logger.info(f"[Worker {self.rank}] PASSED barrier, about to return to loop.")
-
- return OmniACK(
- task_id=task.task_id,
- status="SUCCESS",
- stage_id=self.stage_id,
- rank=self.rank,
- metadata={
- "state": "WARM",
- "source": f"Platform_{current_omni_platform.get_device_name()}",
- "current_vram_gib": f"{current_used_gib:.2f}",
- },
- )
- except Exception as e:
- logger.error(f"Wake-up failed on Rank {self.rank}: {e}", exc_info=True)
- if torch.distributed.is_initialized():
- try:
- torch.distributed.barrier()
- except Exception:
- pass
- return OmniACK(task_id=task.task_id, status="ERROR", error_msg=str(e))
-
def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
"""Get memory pool context for sleep mode support."""
- is_sleep_enabled = getattr(self.od_config, "enable_sleep_mode", False)
- if is_sleep_enabled:
- current_omni_platform.synchronize()
- gc.collect()
+ if self.od_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator
allocator = CuMemAllocator.get_instance()
if tag == "weights":
assert allocator.get_current_usage() == 0, "Sleep mode can only be used for one instance per process."
- logger.info(f"[Worker {self.rank}] Activating Diffusion CuMem pool for tag: {tag}")
return allocator.use_memory_pool(tag=tag)
- return nullcontext()
+ else:
+ return nullcontext()
def shutdown(self) -> None:
"""Shutdown the worker and cleanup distributed environment."""
@@ -518,13 +378,11 @@ def __init__(
od_config: OmniDiffusionConfig,
gpu_id: int,
broadcast_handle,
- wake_event: mp.Event,
worker_extension_cls: str | None = None,
custom_pipeline_args: dict[str, Any] | None = None,
):
self.od_config = od_config
self.gpu_id = gpu_id
- self.wake_event = wake_event
# Inter-process Communication
self.context = zmq.Context(io_threads=2)
@@ -539,13 +397,7 @@ def __init__(
if gpu_id == 0:
self.result_mq = MessageQueue(n_reader=1, n_local_reader=1, local_reader_ranks=[0])
self.result_mq_handle = self.result_mq.export_handle()
- WorkerProc._shared_result_handle = self.result_mq_handle
logger.info(f"Worker {gpu_id} created result MessageQueue")
- else:
- handle = getattr(WorkerProc, "_shared_result_handle", None)
- if handle:
- self.result_mq = MessageQueue.create_from_handle(handle, gpu_id)
- logger.info(f"Worker {gpu_id} attached to shared result MessageQueue")
assert od_config.master_port is not None
@@ -569,17 +421,13 @@ def _create_worker(
)
return wrapper
- def return_result(self, output: Any):
+ def return_result(self, output: object):
"""Reply to client, only on rank 0."""
if self.result_mq is not None:
- if isinstance(output, OmniACK):
- self.result_mq.enqueue(output)
- return
try:
pack_diffusion_output_shm(output)
except Exception as e:
- if hasattr(output, "output"):
- logger.warning("SHM pack failed for model output: %s", e)
+ logger.warning("SHM pack failed, falling back to raw enqueue: %s", e)
self.result_mq.enqueue(output)
def recv_message(self):
@@ -615,31 +463,20 @@ def worker_busy_loop(self) -> None:
while self._running:
msg = None
try:
- msg = self.mq.dequeue(timeout=1.0)
- except Exception:
- if self.wake_event and self.wake_event.is_set():
- self.wake_event.clear()
- logger.info(f"Worker {self.gpu_id} caught OOB POKE, forcing wake-up sequence.")
- msg = {"type": "wake_up", "task_id": "recovery-task", "tags": None}
- else:
- continue
- if msg is None:
+ msg = self.recv_message()
+ except Exception as e:
+ logger.error(
+ f"Error receiving message in worker loop: {e}",
+ exc_info=True,
+ )
continue
if msg is None or len(msg) == 0:
logger.warning("Worker %s: Received empty payload, ignoring", self.gpu_id)
continue
- if isinstance(msg, dict) and msg.get("type") == "sleep":
- task = OmniSleepTask(level=msg.get("level", 2), task_id=msg.get("task_id", "local"))
- ack = self.worker.handle_sleep_task(task)
- self.return_result(ack)
- elif isinstance(msg, dict) and msg.get("type") == "wake_up":
- task = OmniWakeTask(tags=msg.get("tags"), task_id=msg.get("task_id", "local"))
- ack = self.worker.handle_wake_task(task)
- self.return_result(ack)
# Route message based on type
- elif isinstance(msg, dict) and msg.get("type") == "rpc":
+ if isinstance(msg, dict) and msg.get("type") == "rpc":
try:
result, should_reply = self.execute_rpc(msg)
if should_reply:
@@ -684,7 +521,6 @@ def worker_main(
od_config: OmniDiffusionConfig,
pipe_writer: mp.connection.Connection,
broadcast_handle,
- wake_event: mp.Event,
worker_extension_cls: str | None = None,
custom_pipeline_args: dict[str, Any] | None = None,
) -> None:
@@ -696,7 +532,6 @@ def worker_main(
od_config,
gpu_id=rank,
broadcast_handle=broadcast_handle,
- wake_event=wake_event,
worker_extension_cls=worker_extension_cls,
custom_pipeline_args=custom_pipeline_args,
)
@@ -722,7 +557,6 @@ def __init__(
gpu_id: int,
od_config: OmniDiffusionConfig,
base_worker_class: type = DiffusionWorker,
- wake_event: mp.Event = None,
worker_extension_cls: str | None = None,
custom_pipeline_args: dict[str, Any] | None = None,
):
@@ -883,12 +717,6 @@ def wake_up(self, tags: list[str] | None = None) -> bool:
"""
return self.worker.wake_up(tags)
- def handle_sleep_task(self, task):
- return self.worker.handle_sleep_task(task)
-
- def handle_wake_task(self, task):
- return self.worker.handle_wake_task(task)
-
def shutdown(self) -> None:
"""Shutdown the worker and cleanup resources."""
return self.worker.shutdown()
diff --git a/vllm_omni/distributed/omni_connectors/connectors/base.py b/vllm_omni/distributed/omni_connectors/connectors/base.py
index 0df428f2ff5..83edb2ab0ae 100644
--- a/vllm_omni/distributed/omni_connectors/connectors/base.py
+++ b/vllm_omni/distributed/omni_connectors/connectors/base.py
@@ -34,21 +34,13 @@ def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[
pass
@abstractmethod
- def get(
- self, from_stage: str, to_stage: str, get_key: str, metadata: dict[str, Any] | None = None
- ) -> tuple[Any, int] | None:
+ def get(self, from_stage: str, to_stage: str, get_key: str, metadata=None) -> tuple[Any, int] | None:
"""Retrieve Python object and payload size (bytes).
Args:
from_stage: Source stage identifier
to_stage: Destination stage identifier
get_key: Unique request identifier
- metadata: Optional transport-specific metadata. When provided,
- the connector uses it directly (e.g. source_host, source_port,
- data_size) instead of querying the sender. For heterogeneous
- TP the manager may supply partial metadata (host/port only);
- the connector will query the sender at that address to fill
- in data_size.
Returns:
Tuple of (Python object, serialized byte size) if found, None otherwise
diff --git a/vllm_omni/distributed/omni_connectors/connectors/mooncake_store_connector.py b/vllm_omni/distributed/omni_connectors/connectors/mooncake_store_connector.py
index fa1fc3286db..c672e35f793 100644
--- a/vllm_omni/distributed/omni_connectors/connectors/mooncake_store_connector.py
+++ b/vllm_omni/distributed/omni_connectors/connectors/mooncake_store_connector.py
@@ -78,24 +78,7 @@ def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[
try:
serialized_data = self.serialize_obj(data)
key = self._make_key(put_key, from_stage, to_stage)
- put_rc = self.store.put(key, serialized_data, self.pin)
-
- if isinstance(put_rc, bool):
- put_ok = put_rc
- else:
- put_ok = put_rc is None or put_rc == 0
-
- if not put_ok:
- self._metrics["errors"] += 1
- logger.error(
- "MooncakeStoreConnector put failed for %s (%s -> %s), rc=%r, %d bytes",
- key,
- from_stage,
- to_stage,
- put_rc,
- len(serialized_data),
- )
- return False, 0, None
+ self.store.put(key, serialized_data, self.pin)
self._metrics["puts"] += 1
self._metrics["bytes_transferred"] += len(serialized_data)
diff --git a/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py b/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py
index bd4160f3e63..b1dc8b8987c 100644
--- a/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py
+++ b/vllm_omni/distributed/omni_connectors/connectors/mooncake_transfer_engine_connector.py
@@ -230,19 +230,16 @@ class MooncakeTransferEngineConnector(OmniConnectorBase):
sender immediately cleans up the buffer (``cleanup()``), so only the
first receiver to pull a given key will succeed. Broadcast / multicast
(1 sender → N receivers sharing the same data) is not yet supported.
- - **1 receiver → N senders**: Supported via partial metadata. The
- manager constructs metadata with the target sender's
- ``source_host`` / ``source_port`` (computed from ``from_rank``)
- and passes it to ``get(metadata=...)``. The connector detects
- that ``data_size`` is missing, queries the specified sender at
- the given address to fill it in, then performs the RDMA pull.
- This enables heterogeneous TP (sender TP > receiver TP) where a
- single receiver must pull KV shards from multiple sender ranks.
+ - **1 receiver → 1 sender**: ``update_sender_info()`` stores a single
+ ``(sender_host, sender_zmq_port)`` pair, so a receiver can only query
+ metadata from one sender at a time.
Future work:
- Support 1 sender → N receivers (e.g. reference-counted buffers, or
explicit ``retain()`` / ``release()`` semantics so the buffer survives
multiple pulls).
+ - Support 1 receiver → N senders (e.g. a sender registry mapping
+ ``get_key`` prefixes to different sender endpoints).
"""
# RDMA connector copies raw bytes/tensor directly to the memory pool
@@ -270,7 +267,6 @@ def __init__(self, config: dict[str, Any]):
self._req_local = threading.local()
self._worker_local = threading.local()
self._last_ttl_check: float = _time_mod.monotonic()
- self._sender_endpoints: dict[int, tuple[str, int]] = {}
self._metrics = {
"puts": 0,
@@ -281,15 +277,13 @@ def __init__(self, config: dict[str, Any]):
}
self.config = config
- host_config = config.get("host")
- host_value = "auto" if host_config is None else str(host_config)
- # Default sender/receiver bootstrap to a routable local IP so the
- # advertised endpoint matches the interface Mooncake binds.
- if host_value.lower() == "auto" or host_value in {"", "*", "0.0.0.0", "::"}:
+ host_config = config.get("host", "127.0.0.1")
+ # Support "auto" to auto-detect local IP address
+ if host_config.lower() == "auto":
self.host = self._get_local_ip()
logger.info(f"Auto-detected local IP for RDMA: {self.host}")
else:
- self.host = host_value
+ self.host = host_config
self.zmq_port = config.get("zmq_port", 50051)
self.protocol = config.get("protocol", "rdma")
@@ -412,38 +406,16 @@ def get_connection_info(self) -> dict[str, Any]:
"can_put": self.can_put,
}
- def update_sender_info(
- self,
- sender_host: str,
- sender_zmq_port: int,
- sender_rank: int | None = None,
- ) -> None:
- """Inject a sender's ZMQ endpoint into the receiver connector.
-
- When ``sender_rank`` is ``None`` (default), sets the single default
- sender used by ``get()`` when no rank is specified — this preserves
- backward-compatible 1:1 semantics.
-
- When ``sender_rank`` is an integer, the endpoint is stored in a
- per-rank registry for internal use (e.g. by
- ``_query_metadata_from_sender(sender_rank=R)``).
+ def update_sender_info(self, sender_host: str, sender_zmq_port: int) -> None:
"""
- if sender_rank is not None:
- self._sender_endpoints[sender_rank] = (sender_host, sender_zmq_port)
- logger.info(
- "Sender info updated for rank %s: host=%r, zmq_port=%s",
- sender_rank,
- sender_host,
- sender_zmq_port,
- )
- else:
- self.sender_host = sender_host
- self.sender_zmq_port = sender_zmq_port
- logger.info(
- "Sender info updated (default): host=%r, zmq_port=%s",
- sender_host,
- sender_zmq_port,
- )
+ Inject the sender's ZMQ endpoint into the receiver connector.
+ Used for NO METADATA GET calls.(E.g: KV-cache transfer path)
+ Must be called before using get() without metadata!
+ Otherwise, get() will raise an error.
+ """
+ self.sender_host = sender_host
+ self.sender_zmq_port = sender_zmq_port
+ logger.info(f"Sender info updated: host={sender_host!r}, zmq_port={sender_zmq_port}")
def _get_local_ip(self) -> str:
"""
@@ -683,75 +655,56 @@ def put(self, from_stage: str, to_stage: str, put_key: str, data: Any) -> tuple[
logger.error(f"RDMA Put failed for {put_key}: {e}", exc_info=True)
return False, 0, None
- def _resolve_sender_endpoint(self, sender_rank: int | None = None) -> tuple[str, int] | None:
- """Return ``(host, zmq_port)`` for *sender_rank*.
+ def _query_metadata_from_sender(self, get_key: str) -> dict[str, Any] | None:
+ """Query metadata from sender via ZMQ (fallback when ``metadata=None``).
- Resolution order:
- 1. Per-rank registry (``_sender_endpoints[sender_rank]``)
- 2. Default sender (``sender_host`` / ``sender_zmq_port``)
- 3. ``None`` if nothing is configured.
- """
- if sender_rank is not None and sender_rank in self._sender_endpoints:
- return self._sender_endpoints[sender_rank]
- host = getattr(self, "sender_host", None)
- port = getattr(self, "sender_zmq_port", None)
- if host and port and str(host).lower() != "auto":
- return (host, int(port))
- return None
-
- def _query_metadata_at(self, get_key: str, host: str, port: int) -> dict[str, Any] | None:
- """Query metadata from a sender endpoint via ZMQ.
-
- Returns ``{source_host, source_port, data_size, is_fast_path}``
- or ``None`` when the key is not found / the query fails.
+ ``get()`` supports two metadata resolution paths::
+
+ get(metadata=?)
+ ├── metadata provided (adapter path)
+ │ → use metadata directly (source_host/port/data_size)
+ │ → RDMA pull
+ └── metadata=None (KV-transfer polling path)
+ → _query_metadata_from_sender(get_key) ← this method
+ │
+ ├── sender_host resolved (via update_sender_info)
+ │ → ZMQ query → get data_size/is_fast_path
+ │ → construct metadata → RDMA pull
+ └── sender_host unresolved ("auto" / None)
+ → return None → caller retries or times out
+
+ For the second path, the caller must call
+ :meth:`update_sender_info` before ``get()`` to resolve the sender's ZMQ endpoint.
+ Support the two paths in case that the orchestrator pushes the request info
+ to different stages at the same time knowing metadata or not.
"""
- zmq_addr = f"tcp://{host}:{port}"
+ zmq_addr = f"tcp://{self.sender_host}:{self.sender_zmq_port}"
req_socket = self._get_req_socket(zmq_addr, timeout_ms=5000)
+
try:
- req_socket.send(QUERY_INFO + msgspec.msgpack.encode(QueryRequest(request_id=get_key)))
+ # Send query request
+ query = QueryRequest(request_id=get_key)
+ req_socket.send(QUERY_INFO + msgspec.msgpack.encode(query))
resp = req_socket.recv()
+
if resp == INFO_NOT_FOUND:
return None
+
+ # Parse response
query_resp = msgspec.msgpack.decode(resp, type=QueryResponse)
return {
- "source_host": host,
- "source_port": port,
+ # source_host/source_port are used for verification
+ "source_host": self.sender_host,
+ "source_port": self.sender_zmq_port,
"data_size": query_resp.data_size,
"is_fast_path": query_resp.is_fast_path,
}
except Exception as e:
+ # Socket may be stuck in bad state after timeout; discard it
self._invalidate_req_socket(zmq_addr)
- logger.debug("Failed to query metadata at %s for %s: %s", zmq_addr, get_key, e)
+ logger.debug(f"Failed to query metadata for {get_key}: {e}")
return None
- def _query_metadata_from_sender(self, get_key: str, sender_rank: int | None = None) -> dict[str, Any] | None:
- """Query metadata from sender via ZMQ (fallback when ``metadata=None``).
-
- ``get()`` supports three metadata resolution paths::
-
- get(metadata=?)
- ├── Path 1: metadata has data_size (adapter path)
- │ → use metadata directly → RDMA pull
- ├── Path 2: metadata has source_host/port but no data_size
- │ → _query_metadata_at(host, port) → get data_size → RDMA pull
- └── Path 3: metadata=None (KV-transfer polling path)
- → _query_metadata_from_sender(get_key) ← this method
- │
- ├── sender endpoint resolved (via update_sender_info)
- │ → ZMQ query → get data_size/is_fast_path
- │ → construct metadata → RDMA pull
- └── sender endpoint unresolved
- → return None → caller retries or times out
-
- When *sender_rank* is provided, the query is routed to that
- rank's endpoint (registered via ``update_sender_info(rank=...)``).
- Otherwise the default sender is used.
- """
- endpoint = self._resolve_sender_endpoint(sender_rank)
- if endpoint is None:
- return None
- return self._query_metadata_at(get_key, *endpoint)
-
def get(
self,
from_stage: str,
@@ -759,18 +712,12 @@ def get(
get_key: str,
metadata: dict[str, Any] | None = None,
) -> tuple[Any, int] | None:
- """Consumer Side. Allocates from local pool and pulls data via RDMA.
-
- Metadata resolution:
+ """
+ Consumer Side.
+ Allocates from local pool and pulls data via RDMA.
- 1. ``metadata`` provided **with** ``data_size`` → use directly (RDMA pull).
- 2. ``metadata`` provided with ``source_host``/``source_port`` but
- **without** ``data_size`` → query that specific sender for
- ``data_size`` / ``is_fast_path``, then RDMA pull. This is the
- heterogeneous-TP path where the manager knows the target sender
- endpoint but not the payload size.
- 3. ``metadata=None`` → query the default sender (set via
- ``update_sender_info()``) for the full metadata.
+ If metadata is not provided, will attempt to query it from sender
+ using configured sender_host/sender_zmq_port.
Returns:
``(data, size)`` on success, ``None`` on failure.
@@ -778,6 +725,9 @@ def get(
- **is_fast_path=True** (tensor *or* bytes payload):
Returns ``(ManagedBuffer, size)``.
**CALLER MUST call ``ManagedBuffer.release()`` after consuming.**
+ Note: even if the producer ``put()`` raw ``bytes``, the consumer
+ receives a ``ManagedBuffer`` — use ``buf.to_bytes()`` to obtain
+ a ``bytes`` copy, or ``buf.tensor`` for zero-copy access.
- **is_fast_path=False** (serialized Python object):
Returns ``(DeserializedObject, size)``.
Buffer is auto-released internally after deserialization.
@@ -789,8 +739,9 @@ def get(
_t0 = _time_mod.perf_counter()
+ # If no metadata provided, try to query from sender
if not metadata:
- # Path 3: no metadata at all — query default sender
+ # Must insert sender info before using get() without metadata.
if not self.sender_host or not self.sender_zmq_port or str(self.sender_host).lower() == "auto":
raise RuntimeError(
f"get(metadata=None) requires sender info to be resolved, "
@@ -800,21 +751,6 @@ def get(
metadata = self._query_metadata_from_sender(get_key)
if not metadata:
return None
- elif "data_size" not in metadata:
- # Path 2: partial metadata (host/port only) — query that sender
- partial_host = metadata.get("source_host")
- partial_port = metadata.get("source_port")
- if not partial_host or not partial_port:
- logger.warning(
- "get(%s): partial metadata missing source_host/source_port, cannot resolve data_size. metadata=%s",
- get_key,
- metadata,
- )
- return None
- queried = self._query_metadata_at(get_key, str(partial_host), int(partial_port))
- if not queried:
- return None
- metadata = queried
_t1 = _time_mod.perf_counter()
_query_ms = (_t1 - _t0) * 1000
diff --git a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py
index 6cf5c2f15b5..5c7384c1f8b 100644
--- a/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py
+++ b/vllm_omni/distributed/omni_connectors/connectors/shm_connector.py
@@ -15,13 +15,9 @@
class SharedMemoryConnector(OmniConnectorBase):
- """Key-addressed local shared-memory connector.
-
- SHM is a local-only transport: it reads/writes POSIX shared memory
- segments identified purely by *key*. It does **not** understand
- remote-transport metadata such as ``source_host`` / ``source_port``
- (that is the RDMA connector's job). When such metadata is passed in,
- the connector silently falls back to key-based lookup.
+ """
+ Connector that uses SharedMemory for large objects and inline data for small objects.
+ Acts as a unified replacement for the legacy IPC fallback logic.
"""
def __init__(self, config: dict[str, Any]):
@@ -29,7 +25,6 @@ def __init__(self, config: dict[str, Any]):
self.stage_id = config.get("stage_id", -1)
self.device = config.get("device", "cuda:0")
self.threshold = int(config.get("shm_threshold_bytes", 65536))
- self._pending_keys: set[str] = set()
self._metrics = {
"puts": 0,
"gets": 0,
@@ -64,7 +59,6 @@ def put(
# meta contains {'name': ..., 'size': ...}
metadata = {"shm": meta, "size": size}
- self._pending_keys.add(put_key)
self._metrics["shm_writes"] += 1
else:
# Inline - pass bytes directly to avoid double serialization of the object
@@ -99,28 +93,6 @@ def _get_data_with_lock(self, lock_file: str, shm_handle: dict):
if obj and os.path.exists(lock_file):
os.remove(lock_file)
- def _get_by_key(self, get_key: str) -> tuple[Any, int] | None:
- """Read a SHM segment addressed purely by *get_key*."""
- shm = None
- try:
- shm = shm_pkg.SharedMemory(name=get_key)
- if shm is None or shm.size == 0:
- return None
- lock_file = f"/dev/shm/shm_{get_key}_lockfile.lock"
- shm_handle = {"name": get_key, "size": shm.size}
- result = self._get_data_with_lock(lock_file, shm_handle)
- if result is not None:
- self._pending_keys.discard(get_key)
- return result
- except FileNotFoundError:
- return None
- except Exception:
- logger.debug("_get_by_key: unexpected error reading SHM segment %s", get_key, exc_info=True)
- return None
- finally:
- if shm:
- shm.close()
-
def get(
self,
from_stage: str,
@@ -129,16 +101,16 @@ def get(
metadata=None,
) -> tuple[Any, int] | None:
if metadata is not None:
+ # Some callers may wrap metadata by request id.
if isinstance(metadata, dict) and get_key in metadata:
metadata = metadata.get(get_key)
if not isinstance(metadata, dict):
- return self._get_by_key(get_key)
+ return None
if "inline_bytes" in metadata:
try:
obj = self.deserialize_obj(metadata["inline_bytes"])
- self._pending_keys.discard(get_key)
return obj, int(metadata.get("size", 0))
except Exception as e:
logger.error(f"SharedMemoryConnector inline get failed for req {get_key}: {e}")
@@ -147,64 +119,33 @@ def get(
if "shm" in metadata:
shm_handle = metadata["shm"]
lock_file = f"/dev/shm/shm_{shm_handle['name']}_lockfile.lock"
- result = self._get_data_with_lock(lock_file, shm_handle)
- if result is not None:
- self._pending_keys.discard(get_key)
- return result
+ return self._get_data_with_lock(lock_file, shm_handle)
- # Metadata is a dict but has no SHM-specific handle (e.g. RDMA-
- # style source_host/source_port). Fall back to key-based read.
- return self._get_by_key(get_key)
-
- return self._get_by_key(get_key)
+ return None
+ shm = None
+ try:
+ shm = shm_pkg.SharedMemory(name=get_key)
+ if shm is None or shm.size == 0:
+ return None
+ lock_file = f"/dev/shm/shm_{get_key}_lockfile.lock"
+ shm_handle = {"name": get_key, "size": shm.size}
+ return self._get_data_with_lock(lock_file, shm_handle)
+ except Exception:
+ return None
+ finally:
+ if shm:
+ shm.close()
def cleanup(self, request_id: str) -> None:
- """Best-effort cleanup of unconsumed SHM segments for *request_id*.
-
- Matches pending keys where *request_id* appears as the full key,
- as a ``_``-delimited prefix, or as a ``_``-delimited suffix.
- If ``get()`` was never called, we unlink it here so /dev/shm
- doesn't leak.
- """
- stale = [
- k
- for k in self._pending_keys
- if k == request_id or k.startswith(request_id + "_") or k.endswith("_" + request_id)
- ]
- for key in stale:
- self._pending_keys.discard(key)
- try:
- seg = shm_pkg.SharedMemory(name=key)
- seg.close()
- seg.unlink()
- logger.debug("cleanup: unlinked unconsumed SHM segment %s", key)
- except FileNotFoundError:
- pass
- except Exception as e:
- logger.debug("cleanup: failed to unlink SHM segment %s: %s", key, e)
- lock_file = f"/dev/shm/shm_{key}_lockfile.lock"
- if os.path.exists(lock_file):
- try:
- os.remove(lock_file)
- except OSError:
- pass
+ # SHM segments are automatically unlinked during 'get' (shm_read_bytes).
+ # If 'get' is never called (e.g. error flow), the SHM segment might leak.
+ # A robust implementation might track created segments and unlink them here
+ # if they haven't been consumed.
+ # For now, we rely on the consumer to read and unlink.
+ pass
def close(self) -> None:
- """Unlink all remaining tracked SHM segments."""
- for key in list(self._pending_keys):
- try:
- seg = shm_pkg.SharedMemory(name=key)
- seg.close()
- seg.unlink()
- except Exception:
- pass
- lock_file = f"/dev/shm/shm_{key}_lockfile.lock"
- if os.path.exists(lock_file):
- try:
- os.remove(lock_file)
- except OSError:
- pass
- self._pending_keys.clear()
+ pass
def health(self) -> dict[str, Any]:
return {"status": "healthy", "threshold": self.threshold, **self._metrics}
diff --git a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py
index ad008c3971f..1f493843837 100644
--- a/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py
+++ b/vllm_omni/distributed/omni_connectors/kv_transfer_manager.py
@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unified OmniConnector and KV cache transfer management."""
-import json
-import struct
import time
from collections.abc import Callable
from dataclasses import asdict, dataclass
@@ -14,48 +12,12 @@
from .factory import OmniConnectorFactory
from .utils.config import ConnectorSpec
-from .utils.initialization import KV_RANK_PORT_STRIDE
-from .utils.kv_utils import (
- KVTPTopology,
- build_rank_aware_recv_keys,
- build_rank_aware_send_keys,
- get_kv_target_ranks,
- get_local_tp_rank,
- get_tp_world_size,
- kv_zmq_port,
- merge_received_rank_shards,
- normalize_layer_kv,
- slice_layer_blocks,
- slice_received_rank_shard,
-)
+from .utils.kv_utils import normalize_layer_kv
logger = init_logger(__name__)
LayerKV = torch.Tensor | tuple[torch.Tensor, torch.Tensor]
-_SAFE_TORCH_DTYPES = {
- name: dtype
- for name in (
- "bool",
- "uint8",
- "int8",
- "int16",
- "int32",
- "int64",
- "float16",
- "float32",
- "float64",
- "bfloat16",
- "complex64",
- "complex128",
- "float8_e4m3fn",
- "float8_e4m3fnuz",
- "float8_e5m2",
- "float8_e5m2fnuz",
- )
- if isinstance((dtype := getattr(torch, name, None)), torch.dtype)
-}
-
@dataclass
class OmniKVCacheConfig:
@@ -69,8 +31,6 @@ class OmniKVCacheConfig:
need_recv_cache: bool = False
need_send_cache: bool = False
recv_timeout: float = 30.0
- from_tp: int = 1
- to_tp: int = 1
@dataclass
@@ -86,190 +46,6 @@ def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for serialization."""
return asdict(self)
- def _build_tensors_desc(self, *, cpu: bool) -> tuple[list[dict[str, Any]], list, int, torch.device | None]:
- """Iterate layer blocks and build tensor descriptors + data chunks.
-
- Returns ``(tensors_desc, chunks, total_bytes, device)``.
- *chunks* contains ``bytes`` when *cpu* is True, flat uint8 GPU tensors otherwise.
- """
- tensors_desc: list[dict[str, Any]] = []
- chunks: list = []
- data_offset = 0
- device = None
-
- for cache_name in ("key_cache", "value_cache"):
- for layer_idx, tensor in enumerate(self.layer_blocks.get(cache_name, [])):
- if tensor is None:
- tensors_desc.append({"n": f"{cache_name}_{layer_idx}", "x": True})
- continue
- t = tensor.detach().contiguous()
- if cpu:
- t = t.cpu()
- elif device is None and t.is_cuda:
- device = t.device
- nbytes = t.numel() * t.element_size()
- tensors_desc.append(
- {
- "n": f"{cache_name}_{layer_idx}",
- "i": layer_idx,
- "d": str(t.dtype).removeprefix("torch."),
- "s": list(t.shape),
- "o": data_offset,
- "b": nbytes,
- }
- )
- chunks.append(t.view(torch.uint8).numpy().tobytes() if cpu else t.view(torch.uint8).flatten())
- data_offset += nbytes
-
- return tensors_desc, chunks, data_offset, device
-
- def _build_header_bytes(self, tensors_desc: list[dict[str, Any]]) -> bytes:
- header = json.dumps(
- {
- "rid": self.request_id,
- "bids": self.block_ids,
- "meta": self.metadata,
- "td": tensors_desc,
- "nl": len(self.layer_blocks.get("key_cache", [])),
- },
- separators=(",", ":"),
- ).encode("utf-8")
- return struct.pack(">I", len(header)) + header
-
- def to_bytes(self) -> bytes:
- """Convert to compact binary format for fast transfer."""
- tensors_desc, chunks, _, _ = self._build_tensors_desc(cpu=True)
- return b"".join([self._build_header_bytes(tensors_desc)] + chunks)
-
- def to_gpu_tensor(self) -> torch.Tensor:
- """Convert to a packed GPU tensor for raw-data connectors."""
- tensors_desc, chunks, data_offset, device = self._build_tensors_desc(cpu=False)
- if device is None:
- raise RuntimeError("No CUDA tensors found, use to_bytes() instead")
- header_prefix = self._build_header_bytes(tensors_desc)
- output = torch.empty(len(header_prefix) + data_offset, dtype=torch.uint8, device=device)
- output[: len(header_prefix)].copy_(torch.frombuffer(bytearray(header_prefix), dtype=torch.uint8))
- pos = len(header_prefix)
- for t_flat in chunks:
- n = t_flat.numel()
- output[pos : pos + n].copy_(t_flat)
- pos += n
- return output
-
- @staticmethod
- def _load_header_from_memoryview(raw_mv: memoryview) -> tuple[dict[str, Any], memoryview]:
- if len(raw_mv) < 4:
- raise ValueError("Corrupted KV payload: missing 4-byte header length")
-
- header_len = struct.unpack(">I", raw_mv[:4])[0]
- if header_len > len(raw_mv) - 4:
- raise ValueError(f"Corrupted KV payload: header_len={header_len} exceeds buffer size={len(raw_mv)}")
-
- return json.loads(bytes(raw_mv[4 : 4 + header_len])), raw_mv[4 + header_len :]
-
- @staticmethod
- def _load_header_from_tensor(gpu_tensor: torch.Tensor) -> tuple[dict[str, Any], int]:
- if gpu_tensor.dtype != torch.uint8 or gpu_tensor.dim() != 1:
- raise ValueError("Packed GPU KV payload must be a 1-D uint8 tensor")
-
- total_bytes = int(gpu_tensor.numel())
- if total_bytes < 4:
- raise ValueError("Corrupted KV payload: missing 4-byte header length")
-
- header_len = struct.unpack(">I", gpu_tensor[:4].cpu().numpy().tobytes())[0]
- if header_len > total_bytes - 4:
- raise ValueError(f"Corrupted KV payload: header_len={header_len} exceeds buffer size={total_bytes}")
-
- header_bytes = gpu_tensor[4 : 4 + header_len].cpu().numpy().tobytes()
- return json.loads(header_bytes), 4 + header_len
-
- @staticmethod
- def _validate_tensor_span(name: str, info: dict[str, Any], tensor_data_bytes: int) -> tuple[int, int]:
- offset = info["o"]
- nbytes = info["b"]
- if offset < 0 or nbytes < 0 or offset + nbytes > tensor_data_bytes:
- raise ValueError(
- f"Corrupted KV payload tensor span for {name}: "
- f"offset={offset}, bytes={nbytes}, tensor_data_bytes={tensor_data_bytes}"
- )
- return offset, nbytes
-
- @staticmethod
- def _resolve_torch_dtype(dtype_name: Any) -> torch.dtype:
- torch_dtype = _SAFE_TORCH_DTYPES.get(str(dtype_name))
- if torch_dtype is None:
- raise ValueError(f"Unsupported dtype in KV payload: {dtype_name}")
- return torch_dtype
-
- @staticmethod
- def _resolve_layer_idx(info: dict[str, Any], num_layers: int) -> int:
- layer_idx = info.get("i")
- if layer_idx is None:
- name = info.get("n")
- if isinstance(name, str) and name.startswith("key_cache_"):
- layer_idx = int(name.removeprefix("key_cache_"))
- elif isinstance(name, str) and name.startswith("value_cache_"):
- layer_idx = int(name.removeprefix("value_cache_"))
- else:
- raise ValueError(f"Invalid KV tensor name in payload: {name}")
-
- if not isinstance(layer_idx, int):
- raise ValueError(f"Invalid layer index in KV payload: {layer_idx}")
- if layer_idx < 0 or layer_idx >= num_layers:
- raise ValueError(f"Invalid layer index in KV payload: {layer_idx} (num_layers={num_layers})")
- return layer_idx
-
- @staticmethod
- def _populate_caches(header: dict[str, Any], get_tensor: callable) -> dict[str, Any]:
- """Shared deserialization loop for both CPU and GPU paths."""
- num_layers = header["nl"]
- key_cache: list[torch.Tensor | None] = [None] * num_layers
- value_cache: list[torch.Tensor | None] = [None] * num_layers
-
- for info in header["td"]:
- if info.get("x"):
- continue
- name: str = info["n"]
- torch_dtype = KVCacheTransferData._resolve_torch_dtype(info["d"])
- t = get_tensor(info).view(torch_dtype).reshape(info["s"])
- layer_idx = KVCacheTransferData._resolve_layer_idx(info, num_layers)
- if name.startswith("key_cache_"):
- key_cache[layer_idx] = t
- elif name.startswith("value_cache_"):
- value_cache[layer_idx] = t
-
- return {
- "request_id": header["rid"],
- "layer_blocks": {"key_cache": key_cache, "value_cache": value_cache},
- "block_ids": header["bids"],
- "metadata": header["meta"],
- }
-
- @staticmethod
- def from_bytes(raw: "bytes | bytearray | memoryview") -> dict[str, Any]:
- """Reconstruct KV cache data from the packed bytes format."""
- raw_mv = memoryview(raw) if not isinstance(raw, memoryview) else raw
- header, tensor_data_mv = KVCacheTransferData._load_header_from_memoryview(raw_mv)
- data_len = len(tensor_data_mv)
-
- def _get(info: dict) -> torch.Tensor:
- offset, nbytes = KVCacheTransferData._validate_tensor_span(info["n"], info, data_len)
- return torch.frombuffer(tensor_data_mv, dtype=torch.uint8, offset=offset, count=nbytes)
-
- return KVCacheTransferData._populate_caches(header, _get)
-
- @staticmethod
- def from_bytes_gpu(gpu_tensor: torch.Tensor) -> dict[str, Any]:
- """Reconstruct KV cache data from a packed GPU tensor."""
- header, data_start = KVCacheTransferData._load_header_from_tensor(gpu_tensor)
- data_len = int(gpu_tensor.numel()) - data_start
-
- def _get(info: dict) -> torch.Tensor:
- offset, nbytes = KVCacheTransferData._validate_tensor_span(info["n"], info, data_len)
- return gpu_tensor[data_start + offset : data_start + offset + nbytes].clone()
-
- return KVCacheTransferData._populate_caches(header, _get)
-
class OmniKVTransferManager:
"""Unified management for OmniConnector and KV cache transfer.
@@ -303,51 +79,11 @@ def __init__(self, config: OmniKVCacheConfig):
else (None, None)
)
- local_rank = get_local_tp_rank()
-
- if config.from_tp <= 1 and config.to_tp <= 1:
- detected_tp = get_tp_world_size()
- from_tp = detected_tp
- to_tp = detected_tp
- else:
- from_tp = config.from_tp
- to_tp = config.to_tp
-
- self._tp_topo = KVTPTopology(source_tp_size=from_tp, target_tp_size=to_tp, local_rank=local_rank)
-
- # Injectable hooks (compatible with PR #2677 OmniConnectorModelRunnerMixin).
- self.kv_send_key_builder: Callable | None = None
- self.kv_recv_key_builder: Callable | None = None
- self.kv_payload_merger: Callable | None = None
- self.kv_payload_slicer: Callable | None = None
-
- # Base sender endpoint (rank-0 host/port) stored during
- # update_sender_info(). Used by the receive path to construct
- # per-rank metadata for heterogeneous TP without querying a registry.
- self._sender_base_host: str | None = None
- self._sender_base_zmq_port: int | None = None
-
- if config.need_send_cache and config.connector_config:
- try:
- _ = self.connector
- logger.info("Sender connector eagerly initialized")
- except Exception as e:
- logger.warning("Failed to eagerly initialize sender connector: %s", e)
-
- # ------------------------------------------------------------------ #
- # Factory helpers
- # ------------------------------------------------------------------ #
-
@classmethod
def _create(cls, cfg: dict | None) -> "OmniKVTransferManager":
"""Create manager from raw config dict."""
if not cfg or not isinstance(cfg, dict):
return cls(OmniKVCacheConfig())
-
- rank_mapping = cfg.get("rank_mapping", {})
- if not isinstance(rank_mapping, dict):
- rank_mapping = {}
-
return cls(
OmniKVCacheConfig(
connector_config=cfg.get("connector_config"),
@@ -358,17 +94,18 @@ def _create(cls, cfg: dict | None) -> "OmniKVTransferManager":
need_recv_cache=cfg.get("need_recv_cache", False),
need_send_cache=cfg.get("need_send_cache", False),
recv_timeout=cfg.get("recv_timeout", 30.0),
- from_tp=int(rank_mapping.get("from_tp", 1)),
- to_tp=int(rank_mapping.get("to_tp", 1)),
)
)
@classmethod
- def from_od_config(cls, config: Any) -> "OmniKVTransferManager":
- """Create from model or OmniDiffusion config."""
+ def from_model_config(cls, config: Any) -> "OmniKVTransferManager":
+ """Create from model config (for AR model runner)."""
return cls._create(getattr(config, "omni_kv_config", None))
- from_model_config = from_od_config
+ @classmethod
+ def from_od_config(cls, config: Any) -> "OmniKVTransferManager":
+ """Create from OmniDiffusion config (for diffusion runner)."""
+ return cls._create(getattr(config, "omni_kv_config", None))
@classmethod
def from_vllm_config(cls, vllm_config: Any, model_config: Any) -> "OmniKVTransferManager":
@@ -403,320 +140,22 @@ def connector(self):
cfg = self.config.connector_config
if cfg and (c_type := cfg.get("type")):
try:
+ logger.info(f"Initializing OmniConnector with config: {cfg}")
c_extra = {k: v for k, v in cfg.items() if k != "type"}
- if c_type == "MooncakeTransferEngineConnector":
- base_port = c_extra.get("zmq_port", 50051)
- c_extra["from_stage"] = (
- str(self.config.from_stage) if self.config.from_stage is not None else "0"
- )
- c_extra["to_stage"] = str(self.config.to_stage) if self.config.to_stage is not None else "1"
-
- try:
- stage_int = int(self.config.from_stage) if self.config.from_stage is not None else 0
- except (TypeError, ValueError):
- stage_int = 0
- zmq_port = kv_zmq_port(base_port, stage_int, self._tp_topo.local_rank)
-
- if self.config.need_send_cache:
- c_extra["role"] = "sender"
- c_extra["zmq_port"] = zmq_port
- elif self.config.need_recv_cache:
- c_extra["role"] = "receiver"
- c_extra.setdefault("sender_host", c_extra.get("host", "127.0.0.1"))
- c_extra.setdefault("sender_zmq_port", zmq_port)
-
- logger.info(
- "Initializing OmniConnector type=%s role=%s",
- c_type,
- c_extra.get("role", "N/A"),
- )
self._connector = OmniConnectorFactory.create_connector(ConnectorSpec(name=c_type, extra=c_extra))
- except Exception:
- logger.exception("Failed to initialize OmniConnector")
+ except Exception as e:
+ logger.error(f"Failed to initialize OmniConnector: {e}")
+ import traceback
+
+ traceback.print_exc()
+ # Cache failure sentinel to avoid repeated initialization attempts in hot paths.
self._connector = False
return self._connector if self._connector else None
- get_connector = property(lambda self: self.connector)
-
- def _resolve_sender_info(
- self, sender_info: dict[str, Any], sender_stage_id: str | int | None = None
- ) -> dict[str, Any] | None:
- if not sender_info:
- return None
-
- if "host" in sender_info:
- return sender_info
-
- if not isinstance(sender_info, dict):
- return None
-
- preferred_keys: list[str | int] = []
- if sender_stage_id is None:
- recv_from, _ = self.recv_stages
- sender_stage_id = recv_from
-
- if sender_stage_id is not None:
- preferred_keys.append(sender_stage_id)
- preferred_keys.append(str(sender_stage_id))
- try:
- preferred_keys.append(int(sender_stage_id))
- except (TypeError, ValueError):
- pass
-
- for key in dict.fromkeys(preferred_keys):
- info = sender_info.get(key)
- if isinstance(info, dict) and "host" in info:
- return info
-
- candidates = [info for info in sender_info.values() if isinstance(info, dict) and "host" in info]
- if len(candidates) == 1:
- return candidates[0]
-
- if candidates:
- logger.warning(
- "Ambiguous sender_info for sender_stage_id=%s: "
- "expected caller to resolve a single sender entry, got %s",
- sender_stage_id,
- sender_info,
- )
- return None
-
- @staticmethod
- def _clone_received_payload_tensors(data: dict[str, Any]) -> dict[str, Any]:
- if not isinstance(data, dict) or "layer_blocks" not in data:
- return data
-
- layer_blocks = data["layer_blocks"]
- for cache_name in ("key_cache", "value_cache"):
- cache_list = layer_blocks.get(cache_name, [])
- for idx, tensor in enumerate(cache_list):
- if isinstance(tensor, torch.Tensor):
- cache_list[idx] = tensor.clone()
- return data
-
- def _slice_transfer_data_for_target(self, kv_data: KVCacheTransferData, target_rank: int) -> KVCacheTransferData:
- """Pre-slice sender payload for one target rank when sender TP < receiver TP."""
- topo = self._tp_topo
- ratio = topo.target_tp_size // topo.source_tp_size
- offset_in_sender = target_rank % ratio
- metadata = dict(kv_data.metadata) if isinstance(kv_data.metadata, dict) else {}
- metadata["tp_head_slice"] = {
- "applied": True,
- "side": "sender",
- "target_rank": target_rank,
- "source_rank": topo.local_rank,
- "from_tp": topo.source_tp_size,
- "to_tp": topo.target_tp_size,
- "offset_in_shard": offset_in_sender,
- "num_slices": ratio,
- }
- return KVCacheTransferData(
- request_id=kv_data.request_id,
- layer_blocks=slice_layer_blocks(kv_data.layer_blocks, offset_in_sender, ratio),
- block_ids=list(kv_data.block_ids),
- metadata=metadata,
- )
-
- def _serialize_transfer_payload(self, kv_data: KVCacheTransferData) -> torch.Tensor | bytes | dict[str, Any]:
- """Serialize KV transfer data using the connector's fastest supported path."""
- if getattr(self.connector, "supports_raw_data", False):
- try:
- return kv_data.to_gpu_tensor()
- except Exception:
- pass
- try:
- return kv_data.to_bytes()
- except Exception:
- return kv_data.to_dict()
-
- @staticmethod
- def _collect_request_kv_payload(req: Any) -> dict[str, object]:
- """Collect request-side KV objects for object broadcast."""
- kv_payload: dict[str, object] = {}
- for attr in ("past_key_values", "kv_metadata"):
- val = getattr(req, attr, None)
- if val is not None:
- kv_payload[attr] = val
-
- if hasattr(req, "sampling_params") and req.sampling_params is not None:
- for key in list(vars(req.sampling_params).keys()):
- if key in ("past_key_values", "kv_metadata") or (
- key.startswith("cfg_")
- and (
- key.endswith("_past_key_values")
- or key.endswith("_kv_metadata")
- or key
- in (
- "cfg_kv_request_ids",
- "cfg_active_branch",
- "cfg_branch_roles",
- "cfg_branch_past_key_values",
- "cfg_branch_kv_metadata",
- )
- )
- ):
- val = getattr(req.sampling_params, key, None)
- if val is not None:
- kv_payload[f"sp.{key}"] = val
-
- return kv_payload
-
- @staticmethod
- def _apply_request_kv_payload(
- req: Any,
- kv_payload: dict[str, object],
- target_device: torch.device | None = None,
- ) -> None:
- """Apply a broadcast KV payload back onto a request object."""
- for attr in ("past_key_values", "kv_metadata"):
- val = kv_payload.get(attr)
- if val is not None:
- if target_device is not None:
- val = _move_to_device(val, target_device)
- setattr(req, attr, val)
-
- if hasattr(req, "sampling_params") and req.sampling_params is not None:
- for key, val in kv_payload.items():
- if key.startswith("sp."):
- if target_device is not None:
- val = _move_to_device(val, target_device)
- setattr(req.sampling_params, key[3:], val)
-
- @staticmethod
- def _discover_cfg_branch_roles(req: Any) -> list[str]:
- """Discover CFG branch roles in a stable order."""
- sampling_params = getattr(req, "sampling_params", None)
- if sampling_params is None:
- return []
-
- roles: list[str] = []
- branch_map = getattr(sampling_params, "cfg_branch_past_key_values", None) or {}
- for preferred_role in ("cfg_text", "cfg_img"):
- if (
- preferred_role in branch_map
- or getattr(sampling_params, f"{preferred_role}_past_key_values", None) is not None
- ):
- roles.append(preferred_role)
-
- for role in branch_map.keys():
- if role not in roles and branch_map.get(role) is not None:
- roles.append(role)
-
- for key in vars(sampling_params).keys():
- if not (key.startswith("cfg_") and key.endswith("_past_key_values")):
- continue
- role = key.removesuffix("_past_key_values")
- if role in ("cfg_branch",) or role in roles:
- continue
- if getattr(sampling_params, key, None) is not None:
- roles.append(role)
-
- return roles
-
- @classmethod
- def _build_cfg_rank_local_payloads(cls, req: Any, cfg_size: int) -> list[dict[str, object] | None]:
- """Build per-cfg-rank payloads so each rank receives only its branch KV."""
- full_payload = cls._collect_request_kv_payload(req)
- payloads: list[dict[str, object] | None] = []
-
- main_payload = {
- key: value
- for key, value in full_payload.items()
- if key in ("past_key_values", "kv_metadata", "sp.past_key_values", "sp.kv_metadata")
- }
- branch_roles = cls._discover_cfg_branch_roles(req)
- if branch_roles:
- main_payload["sp.cfg_branch_roles"] = list(branch_roles)
- main_payload["sp.cfg_active_branch"] = None
- payloads.append(main_payload or None)
-
- sampling_params = getattr(req, "sampling_params", None)
- branch_map = getattr(sampling_params, "cfg_branch_past_key_values", None) or {}
- branch_metadata_map = getattr(sampling_params, "cfg_branch_kv_metadata", None) or {}
-
- for role in branch_roles:
- if sampling_params is None:
- payloads.append(None)
- continue
-
- branch_kv = branch_map.get(role)
- if branch_kv is None:
- branch_kv = getattr(sampling_params, f"{role}_past_key_values", None)
- branch_metadata = branch_metadata_map.get(role)
- if branch_metadata is None:
- branch_metadata = getattr(sampling_params, f"{role}_kv_metadata", None)
- if branch_kv is None:
- payloads.append(None)
- continue
-
- local_payload = dict(main_payload)
- local_payload["sp.cfg_active_branch"] = role
- local_payload["sp.cfg_branch_roles"] = list(branch_roles)
- local_payload["sp.cfg_branch_past_key_values"] = {role: branch_kv}
- local_payload[f"sp.{role}_past_key_values"] = branch_kv
- if branch_metadata is not None:
- local_payload["sp.cfg_branch_kv_metadata"] = {role: branch_metadata}
- local_payload[f"sp.{role}_kv_metadata"] = branch_metadata
-
- payloads.append(local_payload)
-
- while len(payloads) < cfg_size:
- payloads.append(None)
-
- return payloads[:cfg_size]
-
- def update_sender_info(self, sender_info: dict[str, Any], sender_stage_id: str | int | None = None) -> None:
- """Update receiver-side sender info before loading remote KV cache.
-
- The orchestrator always reports rank-0's ZMQ port. When TP > 1 the
- receiver must offset the port so that each TP rank connects to the
- corresponding sender rank's port.
-
- The base host/port are also stored so that the receive path can
- construct per-rank metadata for heterogeneous TP scenarios.
- """
- if not self.config.need_recv_cache:
- return
-
- actual_info = self._resolve_sender_info(sender_info, sender_stage_id=sender_stage_id)
- if not actual_info or "host" not in actual_info:
- logger.warning("Invalid sender_info format: %s", sender_info)
- return
-
- sender_host = actual_info.get("host")
- base_zmq_port = actual_info.get("zmq_port")
-
- # Store base sender info for per-rank metadata construction.
- self._sender_base_host = sender_host
- if base_zmq_port is not None:
- self._sender_base_zmq_port = int(base_zmq_port)
-
- # --- Default sender: offset to match this receiver's corresponding sender rank ---
- zmq_port = base_zmq_port
- if zmq_port is not None and self._tp_topo.local_rank > 0:
- zmq_port = int(zmq_port) + self._tp_topo.local_rank * KV_RANK_PORT_STRIDE
-
- if self.config.connector_config:
- self.config.connector_config["sender_host"] = sender_host
- self.config.connector_config["sender_zmq_port"] = zmq_port
-
- if self._connector and hasattr(self._connector, "update_sender_info"):
- try:
- self._connector.update_sender_info(sender_host, zmq_port)
- except Exception:
- if hasattr(self._connector, "sender_host"):
- self._connector.sender_host = sender_host
- if hasattr(self._connector, "sender_zmq_port"):
- self._connector.sender_zmq_port = zmq_port
-
- logger.info(
- "Sender info updated: host=%s, base_port=%s, adjusted_port=%s (local_rank=%s)",
- sender_host,
- base_zmq_port,
- zmq_port,
- self._tp_topo.local_rank,
- )
+ def get_connector(self):
+ """Get connector (compatibility wrapper for existing code)."""
+ return self.connector
def handle_finished_requests_kv_transfer(
self,
@@ -764,8 +203,7 @@ def handle_finished_requests_kv_transfer(
custom_metadata = data.get("custom_metadata")
- # Extract KV cache from GPU blocks and keep it on-device when
- # possible so raw-data connectors can use the fast path.
+ # Extract KV cache from GPU blocks -> CPU tensors
kv_data = self._extract_kv_cache(
req_id, block_ids, seq_len, kv_caches, block_size, cache_dtype, custom_metadata
)
@@ -842,8 +280,9 @@ def _extract_kv_cache(
flat_k = flat_k[:seq_len]
flat_v = flat_v[:seq_len]
- key_cache[layer_idx] = flat_k.detach().contiguous()
- value_cache[layer_idx] = flat_v.detach().contiguous()
+ # Move to CPU
+ key_cache[layer_idx] = flat_k.detach().cpu().contiguous()
+ value_cache[layer_idx] = flat_v.detach().cpu().contiguous()
if not any(k is not None for k in key_cache):
return None
@@ -872,59 +311,14 @@ def _transfer_kv_cache(self, kv_data: KVCacheTransferData, transfer_req_id: str)
if not from_stage or not to_stage:
raise ValueError("Transfer stages (omni_from_stage, omni_to_stage) not configured")
- kv_data.request_id = transfer_req_id
- serialization_start = time.perf_counter()
- topo = self._tp_topo
- send_keys = build_rank_aware_send_keys(
- transfer_req_id, from_stage, to_stage, topo, hook=self.kv_send_key_builder
- )
- sender_slice_active = (
- topo.source_tp_size < topo.target_tp_size and len(send_keys) > 1 and not callable(self.kv_send_key_builder)
- )
- per_key_payloads: list[tuple[str, torch.Tensor | bytes | dict[str, Any]]] = []
+ # Prepare data and transfer with retry
+ data_dict = kv_data.to_dict()
+ data_dict["request_id"] = transfer_req_id
- if sender_slice_active:
- target_ranks = get_kv_target_ranks(topo)
- if len(target_ranks) != len(send_keys):
- logger.warning(
- "Skip sender-side KV slicing because target rank count does not match send key count: "
- "target_ranks=%s send_keys=%s",
- len(target_ranks),
- len(send_keys),
- )
- sender_slice_active = False
- else:
- for put_key, target_rank in zip(send_keys, target_ranks, strict=False):
- sliced_kv_data = self._slice_transfer_data_for_target(kv_data, target_rank)
- per_key_payloads.append((put_key, self._serialize_transfer_payload(sliced_kv_data)))
-
- if not per_key_payloads:
- transfer_data = self._serialize_transfer_payload(kv_data)
- per_key_payloads = [(put_key, transfer_data) for put_key in send_keys]
-
- serialization_ms = (time.perf_counter() - serialization_start) * 1000
- logger.info("KV cache serialized for %s in %.1f ms", transfer_req_id, serialization_ms)
-
- transfer_start = time.perf_counter()
- total_size = 0
- all_succeeded = True
- for put_key, transfer_data in per_key_payloads:
- success, size, _ = self._transfer_with_retry(from_stage, to_stage, put_key, transfer_data)
- total_size += size
- all_succeeded = all_succeeded and success
-
- elapsed = time.perf_counter() - transfer_start
-
- if all_succeeded:
- mbps = (total_size / 1024 / 1024) / elapsed if elapsed > 0 else 0
- logger.info(
- "KV transfer OK: %s, %s bytes across %s key(s), %.3fs, %.1f MB/s",
- transfer_req_id,
- total_size,
- len(send_keys),
- elapsed,
- mbps,
- )
+ success, size, _ = self._transfer_with_retry(from_stage, to_stage, f"kv_cache_{transfer_req_id}", data_dict)
+
+ if success:
+ logger.info(f"KV transfer OK: {transfer_req_id}, {size} bytes")
else:
logger.error(f"KV transfer FAILED: {transfer_req_id}")
@@ -932,8 +326,8 @@ def _transfer_with_retry(
self,
from_stage: str,
to_stage: str,
- put_key: str,
- data: "dict[str, Any] | bytes | torch.Tensor",
+ request_id: str,
+ data: dict[str, Any],
max_retries: int = 3,
) -> tuple[bool, int, dict[str, Any] | None]:
"""Transfer data with retry and exponential backoff.
@@ -941,7 +335,7 @@ def _transfer_with_retry(
Args:
from_stage: Source stage identifier
to_stage: Target stage identifier
- put_key: Pre-built connector key (rank-aware when TP > 1)
+ request_id: Request identifier for the key
data: Data to transfer
max_retries: Maximum number of retry attempts
@@ -950,12 +344,14 @@ def _transfer_with_retry(
"""
for attempt in range(max_retries):
try:
+ # Build the full key for connector
+ full_request_id = f"omni_{from_stage}_to_{to_stage}_{request_id}"
success, size, metadata = self.connector.put(
- from_stage=from_stage, to_stage=to_stage, put_key=put_key, data=data
+ from_stage=from_stage, to_stage=to_stage, put_key=full_request_id, data=data
)
if success:
return success, size, metadata
- logger.warning(f"Transfer attempt {attempt + 1} failed for {put_key}")
+ logger.warning(f"Transfer attempt {attempt + 1} failed for {request_id}")
except Exception as e:
logger.warning(f"Transfer attempt {attempt + 1} exception: {e}")
@@ -997,125 +393,46 @@ def receive_kv_cache_for_request(
timeout = self.config.recv_timeout
start_time = time.time()
- poll_interval = 0.01
- max_poll_interval = 0.5
- topo = self._tp_topo
- recv_key_pairs = build_rank_aware_recv_keys(
- request_id, from_stage, to_stage, topo, hook=self.kv_recv_key_builder
- )
- pending_pairs = list(recv_key_pairs)
- received_payloads: dict[str, tuple[dict[str, Any], int]] = {}
-
- logger.info(
- "Wait for KV cache for request %s from stage %s to %s via %s key(s)...",
- request_id,
- from_stage,
- to_stage,
- len(recv_key_pairs),
- )
+ logger.info(f"Wait for KV cache for request {request_id} from stage {from_stage} to {to_stage}...")
try:
while True:
- link_start = time.perf_counter()
- for get_key, from_rank in list(pending_pairs):
- # Construct per-rank metadata so the connector queries
- # the correct sender endpoint (heterogeneous TP path).
- # When from_rank is None (TP<=1), metadata stays None
- # and the connector falls back to its default sender.
- rank_metadata: dict[str, Any] | None = None
- if from_rank is not None and self._sender_base_host and self._sender_base_zmq_port is not None:
- rank_metadata = {
- "source_host": self._sender_base_host,
- "source_port": self._sender_base_zmq_port + from_rank * KV_RANK_PORT_STRIDE,
- }
-
- result = self.connector.get(
- from_stage=from_stage,
- to_stage=to_stage,
- get_key=get_key,
- metadata=rank_metadata,
- )
- if not result:
- continue
-
- raw_data, size = result
- managed_buffer = None
-
- if hasattr(raw_data, "tensor") and hasattr(raw_data, "release"):
- managed_buffer = raw_data
- try:
- buf_tensor = raw_data.tensor
- if buf_tensor.is_cuda:
- data = KVCacheTransferData.from_bytes_gpu(buf_tensor)
- raw_data.release()
- managed_buffer = None
- else:
- data = KVCacheTransferData.from_bytes(memoryview(buf_tensor.numpy()))
- data = self._clone_received_payload_tensors(data)
- raw_data.release()
- managed_buffer = None
- except Exception as e:
- logger.error("Failed to deserialize KV cache from ManagedBuffer: %s", e)
- if managed_buffer is not None:
- raw_data.release()
- return None, 0
- elif isinstance(raw_data, (bytes, bytearray)):
- data = KVCacheTransferData.from_bytes(raw_data)
- elif isinstance(raw_data, torch.Tensor) and raw_data.dtype == torch.uint8 and raw_data.dim() == 1:
- data = KVCacheTransferData.from_bytes(raw_data.cpu().numpy().tobytes())
- else:
- data = raw_data
-
- received_payloads[get_key] = (data, size)
- pending_pairs.remove((get_key, from_rank))
-
- if not pending_pairs and received_payloads:
- elapsed = time.time() - start_time
- link_ms = (time.perf_counter() - link_start) * 1000
- ordered_payloads = [received_payloads[key][0] for key, _ in recv_key_pairs]
- total_size = sum(received_payloads[key][1] for key, _ in recv_key_pairs)
-
- if len(ordered_payloads) == 1:
- data = ordered_payloads[0]
- else:
- data = merge_received_rank_shards(ordered_payloads, merger=self.kv_payload_merger)
- data = slice_received_rank_shard(data, topo, slicer=self.kv_payload_slicer)
-
- try:
- if isinstance(data, dict) and "layer_blocks" in data:
- layer_blocks = data["layer_blocks"]
- for cache_list in [
- layer_blocks.get("key_cache", []),
- layer_blocks.get("value_cache", []),
- ]:
- for i, tensor in enumerate(cache_list):
- if not isinstance(tensor, torch.Tensor):
- continue
- if target_device is not None and tensor.device != target_device:
- cache_list[i] = tensor.to(target_device).contiguous()
- except Exception:
- logger.exception("Failed to move KV cache tensors to target device")
-
- logger.info(
- "Successfully received KV cache for %s, %s bytes across %s key(s), wait=%.3fs, link=%.1fms",
- request_id,
- total_size,
- len(recv_key_pairs),
- elapsed,
- link_ms,
- )
- return data, total_size
+ # Build the full key for connector
+ full_request_id = f"omni_{from_stage}_to_{to_stage}_kv_cache_{request_id}"
+ result = self.connector.get(
+ from_stage=from_stage,
+ to_stage=to_stage,
+ get_key=full_request_id,
+ )
+ if result:
+ data, size = result
+ logger.info(f"Successfully received KV cache for {request_id}, {size} bytes")
+
+ # Move tensors to target device if specified
+ if target_device is not None and isinstance(data, dict) and "layer_blocks" in data:
+ layer_blocks = data["layer_blocks"]
+ for cache_list in [
+ layer_blocks.get("key_cache", []),
+ layer_blocks.get("value_cache", []),
+ ]:
+ for i, tensor in enumerate(cache_list):
+ if isinstance(tensor, torch.Tensor) and tensor.device != target_device:
+ cache_list[i] = tensor.to(target_device).contiguous()
+
+ return data, size
if time.time() - start_time > timeout:
logger.error(f"Timeout waiting for KV cache for request {request_id} after {timeout}s")
return None, 0
- time.sleep(poll_interval)
- poll_interval = min(poll_interval * 2, max_poll_interval)
+ time.sleep(0.5)
+
+ except Exception as e:
+ logger.error(f"Error receiving KV cache for {request_id}: {e}")
+ import traceback
- except Exception:
- logger.exception("Error receiving KV cache for %s", request_id)
+ traceback.print_exc()
return None, 0
def apply_kv_cache_to_request(self, req: Any, data: dict[str, Any]) -> None:
@@ -1142,16 +459,6 @@ def apply_kv_cache_to_request(self, req: Any, data: dict[str, Any]) -> None:
if hasattr(req, "sampling_params") and req.sampling_params is not None:
req.sampling_params.kv_metadata = data["metadata"]
- @staticmethod
- def _resolve_request_id(req: Any) -> str | None:
- """Resolve the logical request ID used for KV transfer lookups."""
- request_id = getattr(req, "request_id", None)
- if request_id:
- return request_id
- if hasattr(req, "request_ids") and req.request_ids:
- return req.request_ids[0]
- return None
-
# Legacy compatibility method
def receive_kv_cache(self, req: Any, target_device: torch.device | None = None) -> bool:
"""Receive KV cache and populate request object (legacy interface).
@@ -1163,11 +470,11 @@ def receive_kv_cache(self, req: Any, target_device: torch.device | None = None)
Returns:
True if successful, False otherwise
"""
- kv_sender_info = getattr(req, "kv_sender_info", None)
- if kv_sender_info:
- self.update_sender_info(kv_sender_info, sender_stage_id=self.recv_stages[0])
+ request_id = getattr(req, "request_id", None)
+ if not request_id and hasattr(req, "request_ids") and req.request_ids:
+ # Adaptation for new OmniDiffusionRequest which has list of prompts/ids
+ request_id = req.request_ids[0]
- request_id = self._resolve_request_id(req)
if not request_id:
logger.warning("Request has no ID, cannot receive KV cache")
return False
@@ -1206,7 +513,9 @@ def receive_multi_kv_cache(
cfg_ids = getattr(getattr(req, "sampling_params", None), "cfg_kv_request_ids", None)
if cfg_ids and cfg_kv_collect_func:
- request_id = self._resolve_request_id(req)
+ request_id = getattr(req, "request_id", None) or (
+ req.request_ids[0] if hasattr(req, "request_ids") and req.request_ids else None
+ )
try:
cfg_kvs = cfg_kv_collect_func(
request_id,
@@ -1229,79 +538,73 @@ def receive_multi_kv_cache_distributed(
cfg_kv_collect_func: Callable | None = None,
target_device: torch.device | None = None,
) -> bool:
- """Distributed wrapper around :meth:`receive_multi_kv_cache`.
-
- TP-aware path selection:
- - world size 1: direct receive
- - TP active, cfg size 1: each rank independently receives
- - TP active, cfg size > 1: cfg-rank 0 receives, then broadcasts to
- peers that share the same TP rank
- - TP inactive: legacy rank-0 receive then world broadcast
+ """Broadcast-aware wrapper around :meth:`receive_multi_kv_cache`.
+
+ SharedMemory connector is single-reader: once rank 0 consumes the
+ segment it is deleted. For multi-GPU stages (e.g. sequence-parallel)
+ only rank 0 receives; the result is then broadcast to every other
+ rank via the world process-group.
+
+ For single-worker stages this is equivalent to calling
+ :meth:`receive_multi_kv_cache` directly.
"""
- from vllm_omni.diffusion.distributed.parallel_state import (
- get_cfg_group,
- get_classifier_free_guidance_rank,
- get_classifier_free_guidance_world_size,
- get_world_group,
- )
+ from vllm_omni.diffusion.distributed.parallel_state import get_world_group
world = get_world_group()
if world.world_size <= 1:
return self.receive_multi_kv_cache(req, cfg_kv_collect_func, target_device)
- topo = self._tp_topo
- tp_active = topo.source_tp_size > 1 or topo.target_tp_size > 1
- cfg_size = 1
- cfg_rank = 0
- cfg_group = None
- try:
- cfg_size = get_classifier_free_guidance_world_size()
- cfg_rank = get_classifier_free_guidance_rank()
- cfg_group = get_cfg_group()
- except Exception:
- cfg_size = 1
- cfg_rank = 0
- cfg_group = None
-
- if tp_active and cfg_size <= 1:
- logger.info(
- "Rank-aware KV receive: rank %s independently receiving (from_tp=%s, to_tp=%s)",
- topo.local_rank,
- topo.source_tp_size,
- topo.target_tp_size,
- )
- return self.receive_multi_kv_cache(req, cfg_kv_collect_func, target_device)
-
- if tp_active and cfg_size > 1 and cfg_group is not None:
- kv_payload: dict[str, object] | None = None
- if cfg_rank == 0:
- received = self.receive_multi_kv_cache(req, cfg_kv_collect_func, torch.device("cpu"))
- rank_payloads = self._build_cfg_rank_local_payloads(req, cfg_size) if received else [None] * cfg_size
- kv_payload = rank_payloads[0]
- for dst_rank in range(1, cfg_size):
- cfg_group.send_object(rank_payloads[dst_rank], dst_rank)
- else:
- kv_payload = cfg_group.recv_object(0)
-
- if not kv_payload:
- return False
-
- self._apply_request_kv_payload(req, kv_payload, target_device)
- return True
-
- kv_payload: dict[str, object] | None = None
+ # --- rank 0: receive to CPU (needed for pickle-based broadcast) ---
if world.rank_in_group == 0:
- received = self.receive_multi_kv_cache(req, cfg_kv_collect_func, torch.device("cpu"))
- if received:
- kv_payload = self._collect_request_kv_payload(req)
+ self.receive_multi_kv_cache(req, cfg_kv_collect_func, torch.device("cpu"))
- kv_payload = world.broadcast_object(kv_payload, src=0)
+ kv_payload: dict[str, object] = {}
+ for attr in ("past_key_values", "kv_metadata"):
+ val = getattr(req, attr, None)
+ if val is not None:
+ kv_payload[attr] = val
+ if hasattr(req, "sampling_params") and req.sampling_params is not None:
+ for key in list(vars(req.sampling_params).keys()):
+ if (key.startswith("cfg_") and key.endswith("_past_key_values")) or key in (
+ "past_key_values",
+ "kv_metadata",
+ ):
+ val = getattr(req.sampling_params, key, None)
+ if val is not None:
+ kv_payload[f"sp.{key}"] = val
+
+ payload_list = [kv_payload]
+ # Use broadcast_object_list (pickle-based) instead of broadcast_tensor_dict
+ # because the KV cache is a heterogeneous nested structure (NaiveCache objects
+ # with metadata + tensors), not a flat tensor dict. This runs once before
+ # the denoising loop so the serialization cost is negligible.
+ torch.distributed.broadcast_object_list(payload_list, src=world.ranks[0], group=world.cpu_group)
+ kv_payload = payload_list[0]
+ else:
+ payload_list: list[dict[str, object] | None] = [None]
+ torch.distributed.broadcast_object_list(payload_list, src=world.ranks[0], group=world.cpu_group)
+ kv_payload = payload_list[0]
+
+ # --- apply on ALL ranks (rank 0 also needs CPU→GPU move) ---
if not kv_payload:
return False
- self._apply_request_kv_payload(req, kv_payload, target_device)
+ for attr in ("past_key_values", "kv_metadata"):
+ val = kv_payload.get(attr)
+ if val is not None:
+ if target_device is not None:
+ val = _move_to_device(val, target_device)
+ setattr(req, attr, val)
+
+ if hasattr(req, "sampling_params") and req.sampling_params is not None:
+ for key, val in kv_payload.items():
+ if key.startswith("sp."):
+ if target_device is not None:
+ val = _move_to_device(val, target_device)
+ setattr(req.sampling_params, key[3:], val)
+
return True
diff --git a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py
index 4535c2596dd..e8e00eeca2d 100644
--- a/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py
+++ b/vllm_omni/distributed/omni_connectors/transfer_adapter/chunk_transfer_adapter.py
@@ -8,8 +8,6 @@
import torch
from vllm.v1.request import Request, RequestStatus
-from vllm_omni.data_entry_keys import unflatten_payload
-
from ..factory import OmniConnectorFactory
from ..utils.config import ConnectorSpec
from ..utils.logging import get_connector_logger
@@ -60,7 +58,6 @@ def __init__(self, vllm_config: Any):
self.waiting_for_chunk_waiting_requests: deque[Any] = deque()
self.waiting_for_chunk_running_requests: deque[Any] = deque()
self.requests_with_ready_chunks = set()
- self.requests_origin_status = {}
@classmethod
def create_connector(cls, model_config: Any):
@@ -152,38 +149,26 @@ def _poll_single_request(self, request: Request):
# Update connector state
self.get_req_chunk[req_id] += 1
- meta = payload_data.get("meta", {})
if self.model_mode == "ar":
- merged_payload = self._update_request_payload(external_req_id, payload_data)
- request.additional_information = merged_payload
- if meta.get("finished"):
+ self._update_request_payload(external_req_id, payload_data)
+ request.additional_information = payload_data
+ if payload_data.get("finished"):
self.finished_requests.add(req_id)
else:
- if meta.get("finished"):
+ if payload_data.get("finished"):
self.finished_requests.add(req_id)
- new_ids = payload_data.get("codes", {}).get("audio", [])
+ new_ids = payload_data.get("code_predictor_codes", [])
request.prompt_token_ids = new_ids
- prev_info = getattr(request, "additional_information", None)
- info = dict(prev_info) if isinstance(prev_info, dict) else {}
- for key, value in payload_data.items():
- if key == "codes":
- continue
- if isinstance(value, dict):
- existing_sub = info.get(key)
- merged_sub = dict(existing_sub) if isinstance(existing_sub, dict) else {}
- for sk, sv in value.items():
- if key == "meta" and sk == "finished":
- continue
- merged_sub[sk] = sv
- info[key] = merged_sub
- continue
- info[key] = value
- request.additional_information = info
+ # Pass additional fields (like left_context_size) to the request
+ # Only pass chunk context metadata in additional_information
+ request.additional_information = {}
+ if "left_context_size" in payload_data:
+ request.additional_information["left_context_size"] = payload_data["left_context_size"]
request.num_computed_tokens = 0
# Empty chunk with more data expected: keep polling.
- if not new_ids and not meta.get("finished"):
+ if not new_ids and not payload_data.get("finished"):
return True
# Mark as finished for consumption
@@ -194,36 +179,33 @@ def _poll_single_request(self, request: Request):
return False
def _update_request_payload(self, req_id: str, payload_data: dict[str, Any]) -> dict[str, Any]:
- """Update the stored payload for *req_id* with the latest chunk."""
+ """Update the payload data for a request in the connector.
+
+ Args:
+ connector: OmniConnectorBase instance
+ req_id: Request ID to update
+ payload_data: New payload data to store
+ """
if req_id not in self.request_payload:
self.request_payload[req_id] = payload_data
return payload_data
- origin = self.request_payload[req_id]
- raw_ok = payload_data.get("meta", {}).pop("override_keys", [])
- override_keys = {tuple(k) if isinstance(k, list) else k for k in raw_ok}
-
- for type_key, new_val in payload_data.items():
- if not isinstance(new_val, dict):
+ origin_payload = self.request_payload[req_id]
+ override_keys = payload_data.pop("override_keys", [])
+ for key, value in payload_data.items():
+ if key == "finished":
continue
- origin_sub = origin.get(type_key)
- if not isinstance(origin_sub, dict):
- continue
- for qual, value in new_val.items():
- if type_key == "meta" and qual == "finished":
- continue
- if (type_key, qual) in override_keys:
- continue
- if isinstance(value, torch.Tensor) and qual in origin_sub:
- new_val[qual] = torch.cat([origin_sub[qual], value], dim=0)
- elif isinstance(value, list) and qual in origin_sub:
- new_val[qual] = origin_sub[qual] + value
+ elif key in override_keys:
+ payload_data[key] = value
+ elif isinstance(value, torch.Tensor) and key in origin_payload:
+ payload_data[key] = torch.cat([origin_payload[key], value], dim=0)
+ elif isinstance(value, list) and key in origin_payload:
+ payload_data[key] = origin_payload[key] + value
self.request_payload[req_id] = payload_data
return payload_data
def _send_single_request(self, task: dict):
- raw_po = task["pooling_output"]
- pooling_output = unflatten_payload(raw_po) if isinstance(raw_po, dict) else raw_po
+ pooling_output = task["pooling_output"]
request = task["request"]
is_finished = task["is_finished"]
stage_id = self.connector.stage_id
@@ -258,23 +240,9 @@ def _send_single_request(self, task: dict):
if success:
self.put_req_chunk[external_req_id] += 1
logger.debug(f"[Stage-{stage_id}] Sent {connector_put_key}")
- finished_flag = payload_data.get("meta", {}).get("finished", payload_data.get("finished"))
- is_payload_finished = False
- if isinstance(finished_flag, torch.Tensor):
- is_payload_finished = finished_flag.numel() == 1 and bool(finished_flag.item())
- elif finished_flag is not None:
- is_payload_finished = bool(finished_flag)
-
- # Reclaim per-request async state only after the terminal payload
- # has been sent successfully. This avoids cleanup->save races.
- if is_payload_finished:
- self.cleanup(request.request_id, external_req_id)
if is_finished:
- self.code_prompt_token_ids.pop(external_req_id, None)
- cached_ic = getattr(self, "_cached_ic", None)
- if cached_ic is not None:
- cached_ic.pop(external_req_id, None)
+ self.cleanup_sender(external_req_id)
########################################################################
# Cleanup
@@ -293,7 +261,6 @@ def cleanup_receiver(self, request_id: str) -> None:
self.get_req_chunk.pop(request_id, None)
self.requests_with_ready_chunks.discard(request_id)
self.request_ids_mapping.pop(request_id, None)
- self.requests_origin_status.pop(request_id, None)
self._cancelled_load_reqs.add(request_id)
self._finished_load_reqs.discard(request_id)
@@ -423,7 +390,6 @@ def _process_chunk_queue(
self.requests_with_ready_chunks.add(request.request_id)
continue
queue.remove(request)
- self.requests_origin_status[request.request_id] = target_status
waiting_for_chunk_list.append(request)
def _clear_chunk_ready(self, scheduler_output: Any) -> None:
@@ -436,23 +402,3 @@ def _clear_chunk_ready(self, scheduler_output: Any) -> None:
for req_id in scheduler_output.scheduled_cached_reqs.req_ids:
if req_id in self.requests_with_ready_chunks:
self.requests_with_ready_chunks.remove(req_id)
-
- def finish_requests(
- self, request_ids: Any, finished_status: RequestStatus, requests: dict[str, Request] | None = None
- ) -> list[tuple[str, int]]:
- assert RequestStatus.is_finished(finished_status)
- if isinstance(request_ids, str):
- request_ids = (request_ids,)
- elif request_ids is not None:
- request_ids = set(request_ids)
- else:
- request_ids = requests.keys()
-
- # First pass: collect requests to remove from queues
- for req_id in request_ids:
- request = requests.get(req_id) if requests else None
- if request is None or request.is_finished():
- # Invalid request ID.
- continue
- if req_id in self.requests_origin_status:
- request.status = self.requests_origin_status.pop(req_id)
diff --git a/vllm_omni/distributed/omni_connectors/utils/initialization.py b/vllm_omni/distributed/omni_connectors/utils/initialization.py
index f012af3c9c3..aaa222b4c50 100644
--- a/vllm_omni/distributed/omni_connectors/utils/initialization.py
+++ b/vllm_omni/distributed/omni_connectors/utils/initialization.py
@@ -19,22 +19,9 @@
logger = get_connector_logger(__name__)
-# Reserve a separate port range for KV-transfer sockets so they do not
-# collide with request-forwarding endpoints that share the same base port.
-KV_TRANSFER_PORT_OFFSET = 100
-
-# Port stride between TP ranks so each worker binds a unique ZMQ port
-# when TP > 1. Must be larger than the maximum number of pipeline stages.
-# Formula: zmq_port = base + KV_TRANSFER_PORT_OFFSET + rank * STRIDE + stage
-KV_RANK_PORT_STRIDE = 16
-
def initialize_connectors_from_config(
- config_path: str | Path | None = None,
- default_shm_threshold: int = 65536,
- purpose: str = "request_forwarding",
- caller_stage_id: int | str | None = None,
- is_sender: bool | None = None,
+ config_path: str | Path | None = None, default_shm_threshold: int = 65536
) -> tuple[OmniTransferConfig | None, dict[tuple[str, str], OmniConnectorBase]]:
"""
Initialize connectors from configuration file.
@@ -49,20 +36,12 @@ def initialize_connectors_from_config(
return None, {}
# create connectors from config
- connectors = create_connectors_from_config(
- transfer_config.connectors,
- purpose=purpose,
- caller_stage_id=caller_stage_id,
- is_sender=is_sender,
- )
+ connectors = create_connectors_from_config(transfer_config.connectors)
return transfer_config, connectors
def create_connectors_from_config(
connectors_config: dict[tuple[str, str], ConnectorSpec],
- purpose: str = "request_forwarding",
- caller_stage_id: int | str | None = None,
- is_sender: bool | None = None,
) -> dict[tuple[str, str], OmniConnectorBase]:
"""
Create connectors from config.
@@ -73,59 +52,12 @@ def create_connectors_from_config(
Returns:
A dictionary of connectors.
"""
- purpose_port_offsets = {
- "request_forwarding": 0,
- "kv_transfer": KV_TRANSFER_PORT_OFFSET,
- }
- port_offset = purpose_port_offsets.get(purpose, 0)
- orchestrator_port_offset = 200
-
connectors = {}
for edge_key, connector_spec in connectors_config.items():
- from_stage, to_stage = edge_key
try:
- if connector_spec.name == "MooncakeTransferEngineConnector":
- extra = dict(connector_spec.extra) if connector_spec.extra else {}
- base_port = extra.get("zmq_port", 50051)
- try:
- stage_offset = int(from_stage)
- except (TypeError, ValueError):
- stage_offset = 0
-
- if str(caller_stage_id) == "orchestrator":
- adjusted_port = base_port + orchestrator_port_offset + stage_offset
- else:
- adjusted_port = base_port + port_offset + stage_offset
- extra["zmq_port"] = adjusted_port
-
- if is_sender is not None:
- extra["role"] = "sender" if is_sender else "receiver"
- if not is_sender:
- extra.setdefault("sender_host", extra.get("host", "127.0.0.1"))
- extra.setdefault("sender_zmq_port", adjusted_port)
- elif caller_stage_id is not None:
- caller_str = str(caller_stage_id)
- if caller_str == from_stage:
- extra["role"] = "sender"
- elif caller_str == to_stage:
- extra["role"] = "receiver"
- extra.setdefault("sender_host", extra.get("host", "127.0.0.1"))
- extra.setdefault("sender_zmq_port", adjusted_port)
- else:
- extra["role"] = "sender"
- else:
- extra["role"] = extra.get("role", "auto")
-
- connector = OmniConnectorFactory.create_connector(ConnectorSpec(name=connector_spec.name, extra=extra))
- else:
- connector = OmniConnectorFactory.create_connector(connector_spec)
+ connector = OmniConnectorFactory.create_connector(connector_spec)
connectors[edge_key] = connector
- logger.info(
- "Created connector for %s -> %s: %s",
- from_stage,
- to_stage,
- type(connector).__name__,
- )
+ logger.info(f"Created connector for {edge_key[0]} -> {edge_key[1]}: {type(connector).__name__}")
except Exception as e:
raise RuntimeError(f"Failed to initialize connector for edge {edge_key}: {e}") from e
@@ -206,19 +138,6 @@ def load_omni_transfer_config(
if config_dict is None:
return None
- # Normalize new-schema (top-level ``connectors`` + ``stages``) into the
- # legacy ``runtime.connectors`` + ``stage_args`` shape the parser reads.
- if "stages" in config_dict and "stage_args" not in config_dict:
- normalized: dict[str, Any] = dict(config_dict)
- runtime = dict(normalized.get("runtime") or {})
- if "connectors" in normalized and "connectors" not in runtime:
- runtime["connectors"] = normalized["connectors"]
- if "edges" in normalized and "edges" not in runtime:
- runtime["edges"] = normalized["edges"]
- normalized["runtime"] = runtime
- normalized["stage_args"] = normalized["stages"]
- config_dict = normalized
-
# Parse connectors
connectors = {}
runtime_config = config_dict.get("runtime", {})
@@ -370,11 +289,7 @@ def initialize_orchestrator_connectors(
else:
default_shm_threshold = max(0, shm_threshold_bytes)
transfer_config, connectors = initialize_connectors_from_config(
- config_path,
- default_shm_threshold=default_shm_threshold,
- purpose="request_forwarding",
- caller_stage_id="orchestrator",
- is_sender=True,
+ config_path, default_shm_threshold=default_shm_threshold
)
return transfer_config, connectors
@@ -401,7 +316,6 @@ def get_stage_connector_config(
def build_stage_connectors(
stage_id: int,
connectors_config: dict[str, Any],
- purpose: str = "request_forwarding",
) -> dict[tuple[str, str], Any] | None:
"""Instantiate OmniConnectors for a stage based on config."""
if not connectors_config:
@@ -438,12 +352,7 @@ def build_stage_connectors(
try:
# Use unified connector creation logic
- connectors = create_connectors_from_config(
- stage_connector_specs,
- purpose=purpose,
- caller_stage_id=stage_id,
- is_sender=False,
- )
+ connectors = create_connectors_from_config(stage_connector_specs)
except Exception as exc: # pragma: no cover - defensive logging
# Fail fast so the stage does not start with missing connectors.
logger.exception("[Stage-%s] Failed to initialize connectors: %s", stage_id, exc)
diff --git a/vllm_omni/distributed/omni_connectors/utils/kv_utils.py b/vllm_omni/distributed/omni_connectors/utils/kv_utils.py
index 12b9b3d4f77..2cb48a8b344 100644
--- a/vllm_omni/distributed/omni_connectors/utils/kv_utils.py
+++ b/vllm_omni/distributed/omni_connectors/utils/kv_utils.py
@@ -1,380 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Utility helpers for KV cache manipulation, TP routing, and merge/slice."""
-
-from __future__ import annotations
-
-import os
-from collections.abc import Callable
-from dataclasses import dataclass
-from typing import Any
+"""Utility helpers for KV cache manipulation."""
import torch
-from vllm.distributed.parallel_state import (
- get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size,
-)
from vllm.logger import init_logger
-from .initialization import KV_RANK_PORT_STRIDE, KV_TRANSFER_PORT_OFFSET
-
logger = init_logger(__name__)
LayerKV = torch.Tensor | tuple[torch.Tensor, torch.Tensor]
-# ------------------------------------------------------------------ #
-# TP Topology
-# ------------------------------------------------------------------ #
-
-
-@dataclass(frozen=True)
-class KVTPTopology:
- """Immutable descriptor for a KV-transfer parallel mapping.
-
- Captures sender/receiver parallel sizes and the local rank within
- that parallel dimension. Works for any divisible parallel dimension
- (TP, SP, Ring Attention).
- """
-
- source_tp_size: int
- target_tp_size: int
- local_rank: int
-
- def __post_init__(self) -> None:
- if self.source_tp_size <= 0 or self.target_tp_size <= 0:
- raise ValueError(
- f"Parallel sizes must be positive: "
- f"source_tp_size={self.source_tp_size}, target_tp_size={self.target_tp_size}"
- )
- if self.local_rank < 0:
- raise ValueError(f"local_rank must be non-negative, got {self.local_rank}")
-
- @property
- def is_heterogeneous(self) -> bool:
- return self.source_tp_size != self.target_tp_size
-
- @property
- def ratio(self) -> int:
- """Larger parallel size divided by smaller. Always >= 1."""
- return max(self.source_tp_size, self.target_tp_size) // min(self.source_tp_size, self.target_tp_size)
-
-
-# ------------------------------------------------------------------ #
-# Runtime TP detection
-# ------------------------------------------------------------------ #
-
-
-def get_local_tp_rank() -> int:
- """Return the TP-local rank of this worker process.
-
- Uses ``get_tensor_model_parallel_rank()`` which returns the rank
- within the TP group only, not the stage-global rank.
- """
- try:
- return get_tensor_model_parallel_rank()
- except Exception:
- logger.debug("TP parallel state not initialized, falling back to LOCAL_RANK env", exc_info=True)
- try:
- return int(os.environ.get("LOCAL_RANK", "0"))
- except (ValueError, TypeError):
- return 0
-
-
-def get_tp_world_size() -> int:
- """Return the TP world size (tensor-parallel dimension only).
-
- Uses ``get_tensor_model_parallel_world_size()`` so that
- cfg_parallel, SP, PP etc. are not included in the count.
- """
- try:
- return get_tensor_model_parallel_world_size()
- except Exception:
- logger.debug("TP parallel state not initialized, defaulting world_size=1", exc_info=True)
- return 1
-
-
-# ------------------------------------------------------------------ #
-# ZMQ port computation
-# ------------------------------------------------------------------ #
-
-
-def kv_zmq_port(base_port: int, from_stage: int, local_rank: int = 0) -> int:
- """Compute the ZMQ port for a KV-transfer connector.
-
- Each TP rank gets its own port so that TP > 1 deployments do not
- cause ``EADDRINUSE`` when multiple sender workers bind on the same
- host. The formula is backward-compatible: rank 0 produces the same
- port as the previous ``base + OFFSET + stage`` formula.
- """
- return base_port + KV_TRANSFER_PORT_OFFSET + local_rank * KV_RANK_PORT_STRIDE + from_stage
-
-
-# ------------------------------------------------------------------ #
-# TP topology validation and rank routing
-# ------------------------------------------------------------------ #
-
-
-def validate_kv_tp_topology(topo: KVTPTopology) -> None:
- """Reject heterogeneous TP mappings that cannot be routed losslessly."""
- larger = max(topo.source_tp_size, topo.target_tp_size)
- smaller = min(topo.source_tp_size, topo.target_tp_size)
- if larger % smaller != 0:
- raise ValueError(
- f"KV TP mapping must be divisible: "
- f"source_tp_size={topo.source_tp_size}, "
- f"target_tp_size={topo.target_tp_size}"
- )
-
-
-def get_kv_target_ranks(topo: KVTPTopology) -> list[int]:
- """Which remote ranks this local rank sends KV shards to (send side)."""
- validate_kv_tp_topology(topo)
- if topo.source_tp_size == topo.target_tp_size:
- return [topo.local_rank]
- if topo.source_tp_size > topo.target_tp_size:
- return [topo.local_rank // (topo.source_tp_size // topo.target_tp_size)]
- ratio = topo.target_tp_size // topo.source_tp_size
- return [topo.local_rank * ratio + i for i in range(ratio)]
-
-
-def get_kv_source_ranks(topo: KVTPTopology) -> list[int]:
- """Which remote ranks this local rank receives KV shards from (recv side)."""
- validate_kv_tp_topology(topo)
- if topo.source_tp_size == topo.target_tp_size:
- return [topo.local_rank]
- if topo.source_tp_size > topo.target_tp_size:
- ratio = topo.source_tp_size // topo.target_tp_size
- return [topo.local_rank * ratio + i for i in range(ratio)]
- return [topo.local_rank // (topo.target_tp_size // topo.source_tp_size)]
-
-
-# ------------------------------------------------------------------ #
-# Rank-aware connector key building
-# ------------------------------------------------------------------ #
-
-
-def get_kv_connector_key(
- req_id: str,
- from_stage: int | str,
- chunk_id: int,
- from_rank: int,
- to_rank: int,
-) -> str:
- """Build connector key that includes rank info for KV transfers.
-
- Format matches PR #2677: ``{req_id}_{from_stage}_{chunk_id}_{from_rank}_{to_rank}``
- """
- return f"{req_id}_{from_stage}_{chunk_id}_{from_rank}_{to_rank}"
-
-
-def build_rank_aware_send_keys(
- request_id: str,
- from_stage: str,
- to_stage: str,
- topo: KVTPTopology,
- hook: Callable | None = None,
-) -> list[str]:
- """Build send-side connector keys, checking injectable hook first."""
- if callable(hook):
- keys = list(hook(request_id, from_stage, to_stage))
- if keys:
- return keys
- if topo.source_tp_size <= 1 and topo.target_tp_size <= 1:
- return [f"omni_{from_stage}_to_{to_stage}_kv_cache_{request_id}"]
- target_ranks = get_kv_target_ranks(topo)
- return [get_kv_connector_key(request_id, from_stage, 0, topo.local_rank, r) for r in target_ranks]
-
-
-def build_rank_aware_recv_keys(
- request_id: str,
- from_stage: str,
- to_stage: str,
- topo: KVTPTopology,
- hook: Callable | None = None,
-) -> list[tuple[str, int | None]]:
- """Build recv-side connector keys with sender rank info.
-
- Returns a list of ``(key, from_rank)`` tuples. ``from_rank`` is
- ``None`` when TP <= 1 (single sender, no per-rank routing needed).
- For TP > 1, ``from_rank`` identifies which sender rank owns the
- key so that the connector can route metadata queries to the
- correct endpoint.
- """
- if callable(hook):
- raw = list(hook(request_id, from_stage, to_stage))
- if raw:
- if isinstance(raw[0], tuple):
- return raw
- # Hook returned plain strings (e.g. OmniConnectorModelRunnerMixin.
- # get_rank_aware_kv_keys). Reconstruct from_rank from topology so
- # Mooncake connector can route metadata queries to the correct
- # sender endpoint in heterogeneous TP.
- # TODO: have the mixin return (key, from_rank) tuples directly
- # to avoid this indirect reconstruction.
- source_ranks = get_kv_source_ranks(topo)
- if len(raw) == len(source_ranks):
- return list(zip(raw, source_ranks))
- return [(k, None) for k in raw]
- if topo.source_tp_size <= 1 and topo.target_tp_size <= 1:
- return [(f"omni_{from_stage}_to_{to_stage}_kv_cache_{request_id}", None)]
- source_ranks = get_kv_source_ranks(topo)
- return [(get_kv_connector_key(request_id, from_stage, 0, r, topo.local_rank), r) for r in source_ranks]
-
-
-# ------------------------------------------------------------------ #
-# KV tensor head slicing (heterogeneous TP)
-# ------------------------------------------------------------------ #
-
-
-def slice_kv_tensor_heads(
- tensor: torch.Tensor | None,
- offset_in_shard: int,
- num_slices: int,
-) -> torch.Tensor | None:
- """Slice one KV tensor along its head dimension (dim 1)."""
- if tensor is None:
- return None
- if not isinstance(tensor, torch.Tensor):
- return tensor
- if tensor.dim() < 2:
- raise ValueError(f"Expected KV tensor with a head dimension, got shape={tuple(tensor.shape)}")
- if num_slices <= 0:
- raise ValueError(f"num_slices must be > 0, got {num_slices}")
- if not (0 <= offset_in_shard < num_slices):
- raise ValueError(f"offset_in_shard must be in [0, {num_slices}), got {offset_in_shard}")
-
- heads_in_shard = tensor.shape[1]
- if heads_in_shard % num_slices != 0:
- raise ValueError(
- "KV head count must be divisible for heterogeneous TP slicing: "
- f"heads_in_shard={heads_in_shard}, num_slices={num_slices}"
- )
-
- heads_per_slice = heads_in_shard // num_slices
- start = offset_in_shard * heads_per_slice
- end = start + heads_per_slice
- return tensor[:, start:end, ...].contiguous()
-
-
-def slice_layer_blocks(
- layer_blocks: dict[str, Any],
- offset_in_shard: int,
- num_slices: int,
-) -> dict[str, list[torch.Tensor | None]]:
- """Slice all KV layers for one logical receiver rank."""
- sliced_blocks: dict[str, list[torch.Tensor | None]] = {}
- for cache_name in ("key_cache", "value_cache"):
- cache_list = layer_blocks.get(cache_name, [])
- sliced_blocks[cache_name] = [
- slice_kv_tensor_heads(tensor, offset_in_shard, num_slices) for tensor in cache_list
- ]
- return sliced_blocks
-
-
-# ------------------------------------------------------------------ #
-# Multi-rank merge and receiver-side slice
-# ------------------------------------------------------------------ #
-
-
-def merge_received_rank_shards(
- payloads: list[dict[str, Any]],
- merger: Callable | None = None,
-) -> dict[str, Any] | None:
- """Merge multiple source-rank KV shards for one target rank.
-
- When *merger* is provided (injectable hook), it is called directly.
- Otherwise the default merges along the head dimension (dim 1).
- """
- if callable(merger):
- return merger(payloads)
- if not payloads:
- return None
- if len(payloads) == 1:
- return payloads[0]
-
- base_payload = payloads[0]
- if not isinstance(base_payload, dict) or "layer_blocks" not in base_payload:
- return base_payload
-
- merged: dict[str, Any] = {
- "request_id": base_payload.get("request_id"),
- "block_ids": list(base_payload.get("block_ids", [])),
- "metadata": dict(base_payload.get("metadata", {})),
- }
- merged_layer_blocks: dict[str, list[torch.Tensor | None]] = {}
-
- for cache_name in ("key_cache", "value_cache"):
- cache_lists = [payload.get("layer_blocks", {}).get(cache_name, []) for payload in payloads]
- num_layers = max((len(cache_list) for cache_list in cache_lists), default=0)
- merged_cache: list[torch.Tensor | None] = []
-
- for layer_idx in range(num_layers):
- layer_tensors = [
- cache_list[layer_idx]
- for cache_list in cache_lists
- if layer_idx < len(cache_list) and cache_list[layer_idx] is not None
- ]
- if not layer_tensors:
- merged_cache.append(None)
- elif len(layer_tensors) == 1 or not isinstance(layer_tensors[0], torch.Tensor):
- merged_cache.append(layer_tensors[0])
- else:
- merged_cache.append(torch.cat(layer_tensors, dim=1).contiguous())
-
- merged_layer_blocks[cache_name] = merged_cache
-
- merged["layer_blocks"] = merged_layer_blocks
- return merged
-
-
-def slice_received_rank_shard(
- payload: dict[str, Any] | None,
- topo: KVTPTopology,
- slicer: Callable | None = None,
-) -> dict[str, Any] | None:
- """Optionally slice a received payload to extract this rank's portion.
-
- Used when ``to_tp > from_tp``: the sender sent full heads and each
- receiver rank slices out its own subset.
- """
- if callable(slicer):
- return slicer(payload)
- if not payload or topo.target_tp_size <= topo.source_tp_size or "layer_blocks" not in payload:
- return payload
-
- metadata = payload.get("metadata", {})
- slice_metadata = metadata.get("tp_head_slice") if isinstance(metadata, dict) else None
- if isinstance(slice_metadata, dict) and slice_metadata.get("applied"):
- tagged_rank = slice_metadata.get("target_rank")
- if tagged_rank is not None and tagged_rank != topo.local_rank:
- logger.warning(
- "Received pre-sliced KV payload for unexpected target rank: expected=%s got=%s",
- topo.local_rank,
- tagged_rank,
- )
- return payload
-
- ratio = topo.target_tp_size // topo.source_tp_size
- offset_in_sender = topo.local_rank % ratio
- updated_metadata = dict(metadata) if isinstance(metadata, dict) else {}
- updated_metadata["tp_head_slice"] = {
- "applied": True,
- "side": "receiver",
- "target_rank": topo.local_rank,
- "from_tp": topo.source_tp_size,
- "to_tp": topo.target_tp_size,
- "offset_in_shard": offset_in_sender,
- "num_slices": ratio,
- }
- return {
- "request_id": payload.get("request_id"),
- "layer_blocks": slice_layer_blocks(payload["layer_blocks"], offset_in_sender, ratio),
- "block_ids": list(payload.get("block_ids", [])),
- "metadata": updated_metadata,
- }
-
-
def normalize_layer_kv(
layer_kv: LayerKV,
*,
diff --git a/vllm_omni/distributed/omni_coordinator/__init__.py b/vllm_omni/distributed/omni_coordinator/__init__.py
index 6894e311378..cbef920d4be 100644
--- a/vllm_omni/distributed/omni_coordinator/__init__.py
+++ b/vllm_omni/distributed/omni_coordinator/__init__.py
@@ -2,11 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .load_balancer import (
- LeastQueueLengthBalancer,
LoadBalancer,
LoadBalancingPolicy,
RandomBalancer,
- RoundRobinBalancer,
Task,
)
from .messages import InstanceEvent, InstanceInfo, InstanceList, StageStatus
@@ -26,6 +24,4 @@
"LoadBalancer",
"LoadBalancingPolicy",
"RandomBalancer",
- "RoundRobinBalancer",
- "LeastQueueLengthBalancer",
]
diff --git a/vllm_omni/distributed/omni_coordinator/load_balancer.py b/vllm_omni/distributed/omni_coordinator/load_balancer.py
index 41b03be1630..15a079b0a87 100644
--- a/vllm_omni/distributed/omni_coordinator/load_balancer.py
+++ b/vllm_omni/distributed/omni_coordinator/load_balancer.py
@@ -4,7 +4,6 @@
from __future__ import annotations
import random
-import threading
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, TypedDict
@@ -27,16 +26,11 @@ class Task(TypedDict, total=False):
class LoadBalancingPolicy(str, Enum):
"""Enumeration for load balancing policies.
- These policies are used by :class:`LoadBalancer` implementations to route
- tasks to a subset of available instances.
-
- TODO(NumberWan): Map enum values to balancer classes when OmniCoordinator
- integration lands. Tracked in https://github.com/vllm-project/vllm-omni/pull/2448
+ Only ``RANDOM`` is implemented. Additional policies (e.g. round-robin,
+ least-connections) can be added in the future.
"""
RANDOM = "random"
- ROUND_ROBIN = "round-robin"
- LEAST_QUEUE_LENGTH = "least-queue-length"
class LoadBalancer(ABC):
@@ -67,10 +61,10 @@ def select(self, task: Task, instances: list[InstanceInfo]) -> int:
class RandomBalancer(LoadBalancer):
"""Load balancer that selects an instance uniformly at random.
- It intentionally ignores the task payload and chooses a random index from
- the provided instance list. More sophisticated policies (e.g. round-robin,
- least-queue-length) can be implemented as additional subclasses of
- :class:`LoadBalancer`.
+ This is the initial and only policy supported. It intentionally ignores
+ the task payload and chooses a random index from the provided instance
+ list. More sophisticated policies (e.g. round-robin, least-connections)
+ can be implemented as additional subclasses of :class:`LoadBalancer`.
"""
def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG002
@@ -80,69 +74,9 @@ def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG
return random.randrange(len(instances))
-class RoundRobinBalancer(LoadBalancer):
- """Load balancer that selects instances in a round-robin fashion.
-
- This implementation keeps a running index modulo ``len(instances)``. It
- therefore depends on the **order and stable meaning** of the ``instances``
- list between calls. If the list length or ordering changes, the sequence
- of picks may skip or repeat entries relative to a fixed set of backends.
-
- When instance membership changes dynamically, callers should reset routing
- state—for example by constructing a new ``RoundRobinBalancer`` or resetting
- ``_next_index``—similar to rebuilding ``itertools.cycle`` after mutating
- the instance list (as in vLLM's disaggregated proxy examples).
-
- Concurrency: ``select`` is synchronous and is expected to run on the
- coordinator asyncio event loop thread without ``await`` inside this
- method, so a single invocation is not interleaved with another on that
- thread. A :class:`threading.Lock` still serializes updates to
- ``_next_index`` for callers that might invoke ``select`` from multiple
- threads or alongside threaded infrastructure (e.g. ZMQ receive threads).
- """
-
- def __init__(self, start_index: int = 0) -> None:
- self._next_index = start_index
- self._lock = threading.Lock()
-
- def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG002
- if not instances:
- raise ValueError("instances must not be empty")
-
- n = len(instances)
- with self._lock:
- idx = self._next_index % n
- self._next_index = (self._next_index + 1) % n
- return idx
-
-
-class LeastQueueLengthBalancer(LoadBalancer):
- """Select the instance with the smallest ``queue_length``.
-
- If multiple instances share the same minimum queue length, one of them is
- chosen uniformly at random.
-
- Raises:
- ValueError: If any instance has a negative ``queue_length``.
- """
-
- def select(self, task: Task, instances: list[InstanceInfo]) -> int: # noqa: ARG002
- if not instances:
- raise ValueError("instances must not be empty")
-
- queue_lengths = [inst.queue_length for inst in instances]
- if any(q < 0 for q in queue_lengths):
- raise ValueError("queue_length must be non-negative for all instances")
- min_q = min(queue_lengths)
- candidates = [i for i, q in enumerate(queue_lengths) if q == min_q]
- return random.choice(candidates)
-
-
__all__ = [
"Task",
"LoadBalancingPolicy",
"LoadBalancer",
"RandomBalancer",
- "RoundRobinBalancer",
- "LeastQueueLengthBalancer",
]
diff --git a/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py b/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py
index cd3c99ab812..cd5c357bb4e 100644
--- a/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py
+++ b/vllm_omni/distributed/omni_coordinator/omni_coord_client_for_stage.py
@@ -45,10 +45,9 @@ def __init__(
self._status = StageStatus.UP
self._queue_length = 0
self._closed = False
- self._closing = False
self._heartbeat_interval = 5.0
self._stop_event = threading.Event()
- self._send_lock = threading.RLock()
+ self._send_lock = threading.Lock()
self._send_event("update")
@@ -58,45 +57,38 @@ def __init__(
)
self._heartbeat_thread.start()
- def _reconnect(self, max_retries: int = 3, retry_interval: float = 5.0) -> bool:
+ def _reconnect(self) -> bool:
"""Best-effort reconnect with up to ``max_retries`` attempts.
- Each attempt closes the current socket/context, sleeps ``retry_interval``
- seconds, then creates a new DEALER socket and reconnects to the coordinator.
+ Each attempt closes the current socket/context, sleeps 5 seconds,
+ then creates a new DEALER socket and reconnects to the coordinator.
+ Caller must hold ``_send_lock``.
Returns True on success, False if all attempts fail.
"""
- if max_retries <= 0:
- return False
+ while not self._stop_event.is_set() and not self._closed:
+ try:
+ self._socket.close(0)
+ except zmq.ZMQError:
+ pass
+ try:
+ self._ctx.term()
+ except zmq.ZMQError:
+ pass
- for attempt in range(1, max_retries + 1):
- with self._send_lock:
- if self._stop_event.is_set() or self._closed:
- return False
- try:
- self._socket.close(0)
- except zmq.ZMQError:
- pass
- try:
- self._ctx.term()
- except zmq.ZMQError:
- pass
+ time.sleep(5.0)
- try:
- self._ctx = zmq.Context()
- self._socket = self._ctx.socket(zmq.DEALER)
- self._socket.connect(self._coord_zmq_addr)
- return True
- except zmq.ZMQError as e:
- logger.error(
- "Stage client reconnect failed (attempt=%d/%d, coord=%s)",
- attempt,
- max_retries,
- self._coord_zmq_addr,
- exc_info=e,
- )
-
- if retry_interval > 0:
- time.sleep(retry_interval)
+ try:
+ self._ctx = zmq.Context()
+ self._socket = self._ctx.socket(zmq.DEALER)
+ self._socket.connect(self._coord_zmq_addr)
+ return True
+ except zmq.ZMQError as e:
+ logger.error(
+ "Stage client reconnect failed, will retry in 5s (coord=%s)",
+ self._coord_zmq_addr,
+ exc_info=e,
+ )
+ continue
return False
def _send_event(self, event_type: str) -> None:
@@ -110,20 +102,20 @@ def _send_event(self, event_type: str) -> None:
to 3 times (5s sleep each) and retries the send once after a
successful reconnect. Raises if reconnect or the retry send fails.
"""
- with self._send_lock:
- if self._closed:
- raise RuntimeError("Client already closed")
-
- event = InstanceEvent(
- input_addr=self._input_addr,
- output_addr=self._output_addr,
- stage_id=self._stage_id,
- event_type=event_type,
- status=self._status,
- queue_length=self._queue_length,
- )
- data = json.dumps(asdict(event)).encode("utf-8")
+ if self._closed:
+ raise RuntimeError("Client already closed")
+
+ event = InstanceEvent(
+ input_addr=self._input_addr,
+ output_addr=self._output_addr,
+ stage_id=self._stage_id,
+ event_type=event_type,
+ status=self._status,
+ queue_length=self._queue_length,
+ )
+ data = json.dumps(asdict(event)).encode("utf-8")
+ with self._send_lock:
try:
self._socket.send(data, flags=zmq.NOBLOCK)
return
@@ -132,7 +124,7 @@ def _send_event(self, event_type: str) -> None:
return
except (RuntimeError, zmq.ZMQError) as e:
# First send failed; try reconnecting a few times.
- if not self._reconnect(max_retries=3):
+ if not self._reconnect:
logger.error("Failed to send event and reconnect to coordinator", exc_info=e)
raise
@@ -157,16 +149,12 @@ def update_info(
if status is None and queue_length is None:
raise ValueError("At least one of status or queue_length must be provided")
- with self._send_lock:
- if self._closed or self._closing:
- raise RuntimeError("Client is closing or already closed")
-
- if status is not None:
- self._status = status
- if queue_length is not None:
- self._queue_length = queue_length
+ if status is not None:
+ self._status = status
+ if queue_length is not None:
+ self._queue_length = queue_length
- self._send_event("update")
+ self._send_event("update")
def _heartbeat_loop(self) -> None:
"""Periodically send heartbeat events while the client is alive."""
@@ -176,11 +164,8 @@ def _heartbeat_loop(self) -> None:
try:
self._send_event("heartbeat")
- except (RuntimeError, zmq.ZMQError) as e:
- if self._closed or self._stop_event.is_set():
- break
- logger.warning("Heartbeat send failed; will retry on next interval", exc_info=e)
- continue
+ except (RuntimeError, zmq.ZMQError):
+ break
def close(self) -> None:
"""Send a final down event and close the underlying socket."""
@@ -192,23 +177,17 @@ def close(self) -> None:
if hasattr(self, "_heartbeat_thread"):
self._heartbeat_thread.join(timeout=1.0)
- with self._send_lock:
- if self._closed:
- raise RuntimeError("Client already closed")
-
- self._closing = True
-
- # Mark status as DOWN and send one last update.
- self._status = StageStatus.DOWN
- try:
- self._send_event("update")
- except (RuntimeError, zmq.ZMQError):
- pass # Socket may already be broken, proceed with close
+ # Mark status as DOWN and send one last update.
+ self._status = StageStatus.DOWN
+ try:
+ self._send_event("update")
+ except zmq.ZMQError:
+ pass # Socket may already be broken, proceed with close
- # Close DEALER socket and terminate this client's context.
- self._socket.close(0)
- try:
- self._ctx.term()
- except zmq.ZMQError:
- pass
- self._closed = True
+ # Close DEALER socket and terminate this client's context.
+ self._socket.close(0)
+ try:
+ self._ctx.term()
+ except zmq.ZMQError:
+ pass
+ self._closed = True
diff --git a/vllm_omni/distributed/omni_coordinator/omni_coordinator.py b/vllm_omni/distributed/omni_coordinator/omni_coordinator.py
index 2c7c8fbb995..7ff608f2fa3 100644
--- a/vllm_omni/distributed/omni_coordinator/omni_coordinator.py
+++ b/vllm_omni/distributed/omni_coordinator/omni_coordinator.py
@@ -61,7 +61,6 @@ def __init__(
self._publish_min_interval: float = 0.1 # seconds
self._pending_broadcast: bool = False
- self._pending_lock = threading.Lock()
self._running = True
self._closed = False
@@ -85,13 +84,13 @@ def add_new_instance(self, event: InstanceEvent) -> None:
"""Add a new instance based on an incoming event."""
with self._lock:
self._add_new_instance_locked(event)
- self._schedule_broadcast()
+ self.publish_instance_list_update()
def update_instance_info(self, event: InstanceEvent) -> None:
"""Update an existing instance based on an incoming event."""
with self._lock:
self._update_instance_info_locked(event)
- self._schedule_broadcast()
+ self.publish_instance_list_update()
def remove_instance(self, event: InstanceEvent) -> None:
"""Mark an instance as removed / down based on an incoming event.
@@ -102,15 +101,10 @@ def remove_instance(self, event: InstanceEvent) -> None:
"""
with self._lock:
self._remove_instance_locked(event)
- self._schedule_broadcast()
+ self.publish_instance_list_update()
- def publish_instance_list_update(self) -> bool:
- """Publish the current active instance list to all subscribers.
-
- Returns:
- True if the PUB send succeeded, False if it was dropped (e.g.
- socket not ready when using ``zmq.NOBLOCK``).
- """
+ def publish_instance_list_update(self) -> None:
+ """Publish the current active instance list to all subscribers."""
active_list = self.get_active_instances()
payload = asdict(active_list)
data = json.dumps(payload).encode("utf-8")
@@ -119,18 +113,20 @@ def publish_instance_list_update(self) -> bool:
try:
# PUB socket is best-effort; drop update if not ready.
self._pub.send(data, flags=zmq.NOBLOCK)
- return True
except (zmq.Again, zmq.ZMQError):
# Silently ignore send failures; next update will catch up.
- return False
+ return
- def _schedule_broadcast(self) -> None:
- """Request a broadcast to be flushed by the periodic loop.
+ def _schedule_broadcast(self, force: bool) -> None:
+ """Schedule a broadcast, optionally bypassing throttling.
- All broadcast requests are coalesced via ``_pending_broadcast`` and
- flushed at most once per ``_publish_min_interval``.
+ When ``force`` is True, publish immediately. Otherwise, mark a pending
+ broadcast that will be flushed by the periodic loop at most once per
+ ``_publish_min_interval``.
"""
- with self._pending_lock:
+ if force:
+ self.publish_instance_list_update()
+ else:
self._pending_broadcast = True
def _mark_instance_error_locked(self, info: InstanceInfo) -> None:
@@ -156,8 +152,8 @@ def _check_heartbeat_timeouts(self) -> None:
for input_addr in to_delete:
del self._instances[input_addr]
if timed_out:
- # Instance liveness changed; request broadcast.
- self._schedule_broadcast()
+ # Instance liveness changed; force immediate broadcast.
+ self._schedule_broadcast(force=True)
def close(self) -> None:
"""Shut down background threads and close all ZMQ sockets."""
@@ -232,9 +228,9 @@ def _recv_loop(self) -> None:
def _periodic_loop(self) -> None:
"""Periodic loop to check heartbeat timeouts and flush broadcasts.
- Heartbeat timeouts are checked on their original cadence, while all
- broadcast requests are coalesced and flushed at most once per
- ``_publish_min_interval``.
+ Heartbeat timeouts are checked on their original cadence, while
+ queue_length / non-liveness updates are coalesced and flushed at
+ most once per ``_publish_min_interval``.
"""
heartbeat_interval = max(1.0, min(self._heartbeat_timeout / 2.0, 5.0))
loop_interval = self._publish_min_interval
@@ -247,18 +243,9 @@ def _periodic_loop(self) -> None:
self._check_heartbeat_timeouts()
last_heartbeat_check = now
- with self._pending_lock:
- has_pending_broadcast = self._pending_broadcast
-
- if not has_pending_broadcast:
- if self._stop_event.wait(timeout=loop_interval):
- break
- continue
-
- # Publish outside lock. Clear pending only on success.
- if self.publish_instance_list_update():
- with self._pending_lock:
- self._pending_broadcast = False
+ if self._pending_broadcast:
+ self.publish_instance_list_update()
+ self._pending_broadcast = False
if self._stop_event.wait(timeout=loop_interval):
break
@@ -280,23 +267,27 @@ def _handle_event(self, event: InstanceEvent) -> None:
info.status = StageStatus.UP
promote = True
if promote:
- self._schedule_broadcast()
+ self._schedule_broadcast(force=True)
return
# Check-and-act under single lock to avoid TOCTOU race (duplicate
# registration when concurrent events arrive for the same instance).
with self._lock:
+ force_broadcast = False
if input_addr not in self._instances:
self._add_new_instance_locked(event)
+ force_broadcast = True
else:
if event.status == StageStatus.DOWN:
self._remove_instance_locked(event)
+ force_broadcast = True
else:
self._update_instance_info_locked(event)
- # Any non-heartbeat state change that affects the active list
- # is coalesced and flushed via the periodic loop.
- self._schedule_broadcast()
+ # New instances / DOWN events are broadcast immediately; other
+ # updates (e.g. queue_length changes) are throttled via the
+ # periodic loop.
+ self._schedule_broadcast(force=force_broadcast)
except (KeyError, ValueError, TypeError) as e:
logger.warning("Dropping malformed event: %s", e)
diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py
index 6c92d7952de..c8a96e6d25d 100644
--- a/vllm_omni/engine/__init__.py
+++ b/vllm_omni/engine/__init__.py
@@ -79,10 +79,6 @@ class OmniEngineCoreRequest(EngineCoreRequest):
class OmniEngineCoreOutput(EngineCoreOutput):
pooling_output: dict[str, torch.Tensor] | None = None
- # Finished flag for streaming input segment
- is_segment_finished: bool | None = False
- # Streaming update prompt length
- new_prompt_len_snapshot: int | None = None
class OmniEngineCoreOutputs(EngineCoreOutputs):
diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py
index 3f16c329e27..b6637892624 100644
--- a/vllm_omni/engine/arg_utils.py
+++ b/vllm_omni/engine/arg_utils.py
@@ -3,15 +3,14 @@
import json
import os
import tempfile
-from dataclasses import dataclass, field, fields
+from dataclasses import dataclass, field
from typing import Any
-from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
+from vllm.engine.arg_utils import EngineArgs
from vllm.logger import init_logger
from vllm_omni.config import OmniModelConfig
from vllm_omni.engine.output_modality import OutputModality
-from vllm_omni.platforms import current_omni_platform
from vllm_omni.plugins import load_omni_general_plugins
logger = init_logger(__name__)
@@ -21,8 +20,6 @@
_ARCH_TO_MODEL_TYPE: dict[str, str] = {
"CosyVoice3Model": "cosyvoice3",
"OmniVoiceModel": "omnivoice",
- "VoxCPM2TalkerForConditionalGeneration": "voxcpm2",
- "VoxCPMForConditionalGeneration": "voxcpm",
}
# Maps model architecture names to tokenizer subfolder paths within HF repos.
@@ -40,8 +37,9 @@ def _register_omni_hf_configs() -> None:
from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import (
Qwen3TTSConfig,
)
- from vllm_omni.transformers_utils.configs.voxcpm import VoxCPMConfig
- from vllm_omni.transformers_utils.configs.voxcpm2 import VoxCPM2Config
+ from vllm_omni.model_executor.models.voxtral_tts.configuration_voxtral_tts import (
+ VoxtralTTSConfig,
+ )
except Exception as exc: # pragma: no cover - best-effort optional registration
logger.warning("Skipping omni HF config registration due to import error: %s", exc)
return
@@ -58,8 +56,7 @@ def _register_omni_hf_configs() -> None:
("qwen3_tts", Qwen3TTSConfig),
("cosyvoice3", CosyVoice3Config),
("omnivoice", OmniVoiceConfig),
- ("voxcpm", VoxCPMConfig),
- ("voxcpm2", VoxCPM2Config),
+ ("voxtral_tts", VoxtralTTSConfig),
]:
try:
AutoConfig.register(model_type, config_cls)
@@ -89,11 +86,7 @@ class OmniEngineArgs(EngineArgs):
Adds omni-specific configuration fields for multi-stage pipeline
processing and output type specification.
Args:
- stage_id: Identifier for the stage in a multi-stage pipeline.
- Defaults to 0 for per-stage engine construction. The CLI-level
- single-stage selector remains optional on the parsed argparse
- namespace and should not be forwarded as a nullable per-stage
- engine argument.
+ stage_id: Identifier for the stage in a multi-stage pipeline (default: 0)
model_stage: Stage type identifier, e.g., "thinker" or "talker"
(default: "thinker")
model_arch: Model architecture name
@@ -112,21 +105,6 @@ class OmniEngineArgs(EngineArgs):
worker_type: Model Type, e.g., "ar" or "generation"
task_type: Default task type for TTS models (CustomVoice, VoiceDesign, or Base).
If not specified, will be inferred from model path.
- omni_master_address: TCP address that the OmniMasterServer (running
- inside AsyncOmniEngine) listens on for engine core registrations.
- Required when single-stage mode is active.
- omni_master_port: TCP port for the OmniMasterServer registration
- socket. Required when single-stage mode is active.
- stage_configs_path: Optional path to a JSON/YAML file containing
- stage configurations for the multi-stage pipeline. If None,
- stage configs are resolved from the model's default configuration.
- output_modalities: Optional list of output modality names to enable
- (e.g. ["text", "audio"]). If None, all modalities supported by
- the model are used.
- log_stats: Whether to log engine statistics. Defaults to False.
- custom_pipeline_args: Dictionary of arguments for custom pipeline
- initialization (e.g., ``{"pipeline_class": "my.Module"}``).
- Passed through to the diffusion stage engine.
"""
stage_id: int = 0
@@ -136,44 +114,13 @@ class OmniEngineArgs(EngineArgs):
hf_config_name: str | None = None
custom_process_next_stage_input_func: str | None = None
stage_connector_spec: dict[str, Any] = field(default_factory=dict)
- subtalker_sampling_params: dict[str, Any] | None = None
async_chunk: bool = False
omni_kv_config: dict | None = None
quantization_config: Any | None = None
worker_type: str | None = None
task_type: str | None = None
- worker_cls: str = None
- enable_sleep_mode: bool = False
- omni: bool = False
-
- @classmethod
- def _add_omni_specific_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
- try:
- parser.add_argument("--omni", action="store_true", default=False, help="Enable Omni engine features.")
- except argparse.ArgumentError:
- pass
- try:
- parser.add_argument(
- "--enable-sleep-mode", action="store_true", default=False, help="Enable GPU memory pool for sleep mode."
- )
- except argparse.ArgumentError:
- pass
- return parser
-
- omni_master_address: str | None = None
- omni_master_port: int | None = None
- stage_configs_path: str | None = None
- output_modalities: list[str] | None = None
- log_stats: bool = False
- custom_pipeline_args: dict[str, Any] | None = None
- has_sampling_extra_args: bool = False
def __post_init__(self) -> None:
- if self.worker_cls is None:
- if self.worker_type == "ar":
- self.worker_cls = current_omni_platform.get_omni_ar_worker_cls()
- elif self.worker_type == "generation":
- self.worker_cls = current_omni_platform.get_omni_generation_worker_cls()
load_omni_general_plugins()
super().__post_init__()
@@ -181,26 +128,8 @@ def __post_init__(self) -> None:
def from_cli_args(cls, args: argparse.Namespace) -> "OmniEngineArgs":
attrs = [attr.name for attr in dataclasses.fields(cls)]
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)})
- engine_args._explicit_fields = frozenset(
- attr for attr in attrs if hasattr(args, attr) and getattr(args, attr) is not None
- )
return engine_args
- @classmethod
- def create(cls, **explicit: Any) -> "OmniEngineArgs":
- """Tracks caller-set fields for ``Omni(..., engine_args=ea)``."""
- ea = cls(**explicit)
- ea._explicit_fields = frozenset(explicit.keys())
- return ea
-
- def explicit_kwargs(self) -> dict[str, Any]:
- explicit = getattr(self, "_explicit_fields", None)
- if explicit is None:
- return {
- f.name: getattr(self, f.name) for f in dataclasses.fields(self) if getattr(self, f.name) is not None
- }
- return {k: getattr(self, k) for k in explicit}
-
def _ensure_omni_models_registered(self):
if hasattr(self, "_omni_models_registered"):
return True
@@ -261,13 +190,6 @@ def create_model_config(self) -> OmniModelConfig:
if model_type is not None:
self.hf_overrides.setdefault("model_type", model_type)
- # Stage wrappers (e.g. Code2Wav) may need max_model_len larger
- # than the base checkpoint's text max_position_embeddings.
- if self.model_arch == "Qwen3TTSCode2Wav" and self.max_model_len is not None:
- self.hf_overrides.setdefault("talker_config", {}).setdefault(
- "max_position_embeddings", int(self.max_model_len)
- )
-
# For models whose HF config.json is empty or lacks model_type
# (e.g. CosyVoice3), AutoConfig.from_pretrained fails because it
# cannot determine which config class to use from the empty dict.
@@ -338,332 +260,12 @@ def create_model_config(self) -> OmniModelConfig:
hf_config_name=self.hf_config_name,
custom_process_next_stage_input_func=self.custom_process_next_stage_input_func,
stage_connector_config=stage_connector_config,
- subtalker_sampling_params=self.subtalker_sampling_params,
omni_kv_config=self.omni_kv_config,
task_type=self.task_type,
- has_sampling_extra_args=self.has_sampling_extra_args,
)
return omni_config
-
-@dataclass
-class OmniAsyncEngineArgs(AsyncEngineArgs, OmniEngineArgs):
- @classmethod
- def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
- parser = AsyncEngineArgs.add_cli_args(parser)
- parser = OmniEngineArgs._add_omni_specific_args(parser)
- return parser
-
@property
def output_modality(self) -> OutputModality:
"""Parse engine_output_type into a type-safe OutputModality flag."""
return OutputModality.from_string(self.engine_output_type)
-
-
-# ============================================================================
-# CLI argument routing
-# ============================================================================
-#
-# vLLM-Omni's CLI flags live in three buckets:
-#
-# ┌──────────────────┐ ┌──────────────────┐ ┌──────────────────┐
-# │ OrchestratorArgs │ │ OmniEngineArgs │ │ (upstream vllm) │
-# │ │ │ │ │ server/api │
-# │ stage_timeout │ │ max_num_seqs │ │ host, port │
-# │ worker_backend │ │ gpu_mem_util │ │ ssl_keyfile │
-# │ deploy_config │ │ dtype, quant │ │ api_key │
-# │ ... │ │ ... │ │ ... │
-# └──────────────────┘ └──────────────────┘ └──────────────────┘
-# │ │ │
-# ▼ ▼ ▼
-# orchestrator each stage uvicorn /
-# consumes engine FastAPI
-#
-# Fields in ``SHARED_FIELDS`` (e.g. ``model``, ``log_stats``) flow to BOTH
-# orchestrator and engine by design.
-#
-# Invariants enforced by ``tests/test_arg_utils.py``:
-#
-# 1. ``OrchestratorArgs`` ∩ ``OmniEngineArgs`` ⊆ ``SHARED_FIELDS``
-# 2. Every CLI flag is classifiable into one of the three buckets
-# 3. User-typed flags that match none of the above are logged as dropped
-#
-# Adding a new orchestrator-only flag → add a field to ``OrchestratorArgs``.
-# Everything else is automatic.
-
-
-@dataclass(frozen=True)
-class OrchestratorArgs:
- """CLI flags consumed by the orchestrator.
-
- Contract: every field here is either
- (a) orchestrator-only (never needed by a stage engine), OR
- (b) orchestrator-read-then-redistributed (e.g. ``async_chunk`` is read
- from CLI, written to ``DeployConfig``, then propagated to every
- stage via ``merge_pipeline_deploy`` — not via direct kwargs
- forwarding).
-
- Fields that BOTH orchestrator and engine genuinely need (e.g. ``model``,
- ``log_stats``) should be listed in ``SHARED_FIELDS`` below; ``split_kwargs``
- will copy them to both buckets.
- """
-
- # === Lifecycle ===
- stage_init_timeout: int = 300
- init_timeout: int = 600
-
- # === Cross-stage Communication ===
- shm_threshold_bytes: int = 65536
- batch_timeout: int = 10
-
- # === Cluster / Backend ===
- worker_backend: str = "multi_process"
- ray_address: str | None = None
-
- # === Config Files ===
- stage_configs_path: str | None = None
- deploy_config: str | None = None
- stage_overrides: str | None = None # raw JSON string; parsed downstream
-
- # === Mode Switches (orchestrator reads, DeployConfig redistributes) ===
- async_chunk: bool | None = None
-
- # === Observability ===
- log_stats: bool = False
-
- # === Headless Mode (also forwarded to engine — see SHARED_FIELDS) ===
- stage_id: int | None = None
-
- # === Pre-built Objects ===
- parallel_config: Any = None
-
- # === Multi-stage guards ===
- # --tokenizer is captured here so it does not propagate to every stage
- # uniformly (different stages often need different tokenizers, e.g.
- # qwen3_omni thinker vs talker). Users wanting a per-stage tokenizer
- # should set it in the deploy YAML.
- tokenizer: str | None = None
-
-
-# Fields that live in BOTH OrchestratorArgs and OmniEngineArgs by design.
-# Changes to this set are a review red flag — revisit the contract.
-SHARED_FIELDS: frozenset[str] = frozenset(
- {
- "model", # orch: detect model_type; engine: load weights
- "stage_id", # orch: route (headless); engine: identity
- "log_stats", # both want the flag
- "stage_configs_path", # orch: load legacy YAML; engine: may reference for validation
- }
-)
-
-_DEPLOY_ENGINE_ARG_OVERRIDE_FIELDS: frozenset[str] = frozenset(
- {
- # Capacity / scheduling.
- "async_scheduling",
- "max_model_len",
- "max_num_batched_tokens",
- "max_num_seqs",
- # Memory / parallelism.
- "data_parallel_size",
- "gpu_memory_utilization",
- "pipeline_parallel_size",
- "tensor_parallel_size",
- # Execution / loading.
- "enforce_eager",
- "distributed_executor_backend",
- "dtype",
- "quantization",
- "trust_remote_code",
- # Caching / chunking.
- "async_chunk",
- "enable_prefix_caching",
- "enable_chunked_prefill",
- # Model-specific engine extras.
- "subtalker_sampling_params",
- }
-)
-
-_DEPLOY_RUNTIME_OVERRIDE_FIELDS: frozenset[str] = frozenset(
- {
- "devices",
- }
-)
-
-
-def orchestrator_field_names() -> frozenset[str]:
- """Return the names of every field on OrchestratorArgs."""
- return frozenset(f.name for f in fields(OrchestratorArgs))
-
-
-def deploy_override_field_names() -> frozenset[str]:
- """Return kwargs whose parser defaults must not override deploy YAML."""
- return _DEPLOY_ENGINE_ARG_OVERRIDE_FIELDS | _DEPLOY_RUNTIME_OVERRIDE_FIELDS
-
-
-def internal_blacklist_keys() -> frozenset[str]:
- """Return the set of CLI keys that must never be forwarded as per-stage
- engine overrides.
-
- Derived from ``OrchestratorArgs`` fields minus ``SHARED_FIELDS``, so
- adding a new orchestrator-owned flag is a one-line change to the
- dataclass — this function updates automatically.
- """
- return orchestrator_field_names() - SHARED_FIELDS
-
-
-def split_kwargs(
- kwargs: dict[str, Any],
- *,
- engine_cls: type | None = None,
- user_typed: set[str] | None = None,
- strict: bool = False,
-) -> tuple[OrchestratorArgs, dict[str, Any]]:
- """Partition CLI kwargs into (orchestrator, engine) buckets.
-
- Args:
- kwargs: Raw dict, typically ``vars(args)``.
- engine_cls: Engine dataclass used to whitelist-filter the engine
- bucket. Defaults to ``OmniEngineArgs``. Pass a custom class
- for testing.
- user_typed: Keys the user actually typed on the command line. Used
- to warn when a user-typed flag is unclassifiable.
- strict: If True, raise ``ValueError`` on ambiguous (double-classified
- but not in ``SHARED_FIELDS``) fields. Default False to keep the
- rollout non-breaking; flip to True in tests and CI.
-
- Returns:
- ``(orchestrator_args, engine_kwargs)``. ``engine_kwargs`` has already
- been whitelist-filtered against ``engine_cls`` — safe to pass directly
- to ``engine_cls(**engine_kwargs)``.
- """
- if engine_cls is None:
- engine_cls = OmniEngineArgs
-
- orch_fields = orchestrator_field_names()
- engine_fields = {f.name for f in fields(engine_cls)}
-
- orch_kwargs: dict[str, Any] = {}
- engine_candidate: dict[str, Any] = {}
- shared_values: dict[str, Any] = {}
- unclassified: dict[str, Any] = {}
-
- for key, value in kwargs.items():
- in_orch = key in orch_fields
- in_engine = key in engine_fields
- is_shared = key in SHARED_FIELDS
-
- if is_shared:
- shared_values[key] = value
- elif in_orch and in_engine:
- # Declared in both but not marked shared → ambiguous.
- msg = (
- f"Field {key!r} is defined on both OrchestratorArgs and "
- f"{engine_cls.__name__} but is not in SHARED_FIELDS. "
- f"This causes double-routing. Either remove the duplicate or "
- f"add {key!r} to SHARED_FIELDS if the sharing is intentional."
- )
- if strict:
- raise ValueError(msg)
- logger.error(msg)
- # Default: treat as orchestrator-only to preserve existing behavior.
- orch_kwargs[key] = value
- elif in_orch:
- orch_kwargs[key] = value
- elif in_engine:
- engine_candidate[key] = value
- else:
- unclassified[key] = value
-
- # Warn on user-typed but unclassifiable flags so we don't silently drop
- # something the user cared about (fixes the class of bug that spawned #873).
- if unclassified and user_typed:
- user_typed_unknown = sorted(k for k in unclassified if k in user_typed)
- if user_typed_unknown:
- logger.warning(
- "CLI flags not consumed by vllm-omni and dropped before "
- "per-stage engine construction: %s. If these are vllm "
- "frontend/uvicorn flags (host, port, ssl_*, api_key, …) this "
- "is expected; otherwise check your spelling.",
- user_typed_unknown,
- )
-
- # Engine bucket: shared + engine-only. We do NOT pass through unclassified
- # fields — that's exactly the server/uvicorn noise we want to shed.
- engine_kwargs = {**shared_values, **engine_candidate}
-
- # Construct the orchestrator dataclass. Shared fields that OrchestratorArgs
- # also declares get copied into its constructor.
- orch_init: dict[str, Any] = dict(orch_kwargs)
- for key, value in shared_values.items():
- if key in orch_fields:
- orch_init[key] = value
- orch_args = OrchestratorArgs(**orch_init)
-
- return orch_args, engine_kwargs
-
-
-def derive_server_dests_from_vllm_parser() -> frozenset[str]:
- """Derive the set of argparse dests that belong to vllm's frontend/server.
-
- Returns every dest registered by ``make_arg_parser`` that is NOT a field
- of ``OmniEngineArgs`` and NOT a field of ``OrchestratorArgs``. Useful for
- CI tests to assert all CLI flags are classifiable without maintaining
- a hardcoded server list.
-
- Returns empty frozenset if vllm's parser cannot be built (e.g. in a
- minimal test environment).
- """
- try:
- from vllm.entrypoints.openai.cli_args import make_arg_parser
- from vllm.utils.argparse_utils import FlexibleArgumentParser
- except ImportError:
- logger.debug("Cannot import vllm parser — server-dest derivation skipped")
- return frozenset()
-
- try:
- parser = make_arg_parser(FlexibleArgumentParser())
- all_dests = {a.dest for a in parser._actions if a.dest and a.dest != "help"}
- except Exception as exc:
- logger.debug("Failed to build vllm parser: %s", exc)
- return frozenset()
-
- engine_fields = {f.name for f in fields(OmniEngineArgs)}
- orch_fields = orchestrator_field_names()
-
- return frozenset(all_dests - engine_fields - orch_fields - SHARED_FIELDS)
-
-
-def orchestrator_args_from_argparse(args: Any) -> OrchestratorArgs:
- """Build an ``OrchestratorArgs`` from an ``argparse.Namespace``.
-
- Only copies attributes that exist on the namespace — missing fields fall
- back to the dataclass default. Useful when the full parser is already
- built and ``vars(args)`` would include noise.
- """
- kwargs: dict[str, Any] = {}
- for f in fields(OrchestratorArgs):
- if hasattr(args, f.name):
- value = getattr(args, f.name)
- if value is not None or f.default is None:
- kwargs[f.name] = value
- return OrchestratorArgs(**kwargs)
-
-
-def nullify_stage_engine_defaults(parser: argparse.ArgumentParser) -> None:
- """Reset stage-level engine flag defaults to ``None``; preserve real
- default in help text. Only deploy-YAML override fields are touched.
- Idempotent."""
- override_dests = deploy_override_field_names()
-
- for action in parser._actions:
- if action.dest in ("help", "version") or not action.option_strings:
- continue
- if action.dest not in override_dests:
- continue
- if action.default is None or action.default is argparse.SUPPRESS:
- continue
- if action.help and "(default:" not in action.help and "%(default)" not in action.help:
- action.help = f"{action.help} (default: {action.default})"
- action.default = None
-
- parser._omni_nullified = True # type: ignore[attr-defined]
diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py
index a37afd24b4f..092b341e42a 100644
--- a/vllm_omni/engine/async_omni_engine.py
+++ b/vllm_omni/engine/async_omni_engine.py
@@ -18,54 +18,39 @@
import uuid
import weakref
from collections.abc import Mapping, Sequence
-from contextlib import ExitStack
from dataclasses import asdict
from typing import TYPE_CHECKING, Any
+if TYPE_CHECKING:
+ from vllm_omni.engine.arg_utils import OmniEngineArgs
+
import janus
import torch
from omegaconf import OmegaConf
-from vllm import envs as vllm_envs
-from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor
-from vllm_omni.config.stage_config import strip_parent_engine_args
from vllm_omni.diffusion.data import DiffusionParallelConfig
-from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient
-from vllm_omni.diffusion.stage_diffusion_proc import (
- complete_diffusion_handshake,
- spawn_diffusion_proc,
-)
from vllm_omni.distributed.omni_connectors.utils.initialization import (
resolve_omni_kv_config_for_stage,
)
-from vllm_omni.engine import OmniEngineCoreRequest
+from vllm_omni.engine import (
+ OmniEngineCoreRequest,
+)
from vllm_omni.engine.orchestrator import Orchestrator
from vllm_omni.engine.output_processor import MultimodalOutputProcessor
-from vllm_omni.engine.serialization import (
- deserialize_additional_information,
- serialize_additional_information,
-)
-from vllm_omni.engine.stage_engine_core_client import StageEngineCoreClientBase
+from vllm_omni.engine.serialization import serialize_additional_information
+from vllm_omni.engine.stage_engine_core_client import StageEngineCoreClient
from vllm_omni.engine.stage_engine_core_proc import (
complete_stage_handshake,
spawn_stage_core,
)
-from vllm_omni.engine.stage_engine_startup import (
- OmniMasterServer,
- connect_remote_engine_cores,
- launch_omni_core_engines,
- register_stage_with_omni_master,
-)
from vllm_omni.engine.stage_init_utils import (
StartedLlmStage,
- _inject_inferred_kv_tp_topology,
acquire_device_locks,
- build_diffusion_config,
build_engine_args_dict,
build_vllm_config,
cleanup_failed_stage_initialization,
@@ -74,55 +59,19 @@
finalize_initialized_stages,
get_stage_connector_spec,
initialize_diffusion_stage,
- inject_kv_stage_info,
load_omni_transfer_config_for_model,
prepare_engine_environment,
release_device_locks,
setup_stage_devices,
- terminate_alive_proc,
)
-from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin
from vllm_omni.entrypoints.utils import (
- inject_omni_kv_config,
load_and_resolve_stage_configs,
)
from vllm_omni.inputs.preprocess import OmniInputPreprocessor
from vllm_omni.platforms import current_omni_platform
-if TYPE_CHECKING:
- from vllm_omni.engine.arg_utils import OmniEngineArgs
-
logger = init_logger(__name__)
-_STARTUP_POLL_INTERVAL_S = 1.0
-
-
-# ============================================================================
-# Parent-EngineArgs field-routing contracts (consumed by
-# AsyncOmniEngine._strip_parent_engine_args when ``stage_configs_path`` is set).
-# ============================================================================
-
-# Fields that must survive the "equal to default → strip" filter because
-# diffusion stages need them even when equal to vllm's default value
-# (e.g. colocate worker setup relies on worker_extension_cls being forwarded).
-_PARENT_ARGS_KEEP: frozenset[str] = frozenset(
- {
- "worker_extension_cls",
- "allowed_local_media_path",
- "allowed_media_domains",
- }
-)
-
-# Omni orchestrator-level fields consumed by ``_resolve_stage_configs`` that
-# must never leak into per-stage EngineArgs (``stage_configs_path`` would
-# trigger the ``create_model_config`` guard).
-_PARENT_ARGS_STRIP: frozenset[str] = frozenset({"stage_configs_path"})
-
-# Fields always populated by callers (via ``from_cli_args`` / ``asdict``) so
-# their presence as an override is never a surprise — suppress the
-# "override ignored" warning for these.
-_PARENT_ARGS_NO_WARN: frozenset[str] = frozenset({"model"})
-
def _patch_generation_config_if_needed(model_config: Any) -> None:
"""Ensure try_get_generation_config won't crash for models whose HF
@@ -134,6 +83,39 @@ def _patch_generation_config_if_needed(model_config: Any) -> None:
model_config.try_get_generation_config = lambda: {}
+def _inject_kv_stage_info(stage_cfg: Any, stage_id: int) -> None:
+ """Inject stage_id and engine_input_source into omni_kv_config.
+
+ OmniKVTransferManager needs stage_id to compute recv_stages for the
+ receiving side. In the old Omni architecture, OmniDiffusion.__init__
+ performed this injection; replicate it here for AsyncOmniEngine.
+ """
+ try:
+ engine_args = stage_cfg.engine_args
+ if hasattr(engine_args, "get"):
+ omni_kv = engine_args.get("omni_kv_config", None)
+ else:
+ omni_kv = getattr(engine_args, "omni_kv_config", None)
+
+ if omni_kv is None:
+ return
+
+ if hasattr(omni_kv, "setdefault"):
+ omni_kv.setdefault("stage_id", stage_id)
+ elif hasattr(omni_kv, "__setitem__"):
+ if "stage_id" not in omni_kv:
+ omni_kv["stage_id"] = stage_id
+
+ engine_input_source = getattr(stage_cfg, "engine_input_source", None)
+ if engine_input_source is not None:
+ if hasattr(omni_kv, "setdefault"):
+ omni_kv.setdefault("engine_input_source", list(engine_input_source))
+ elif hasattr(omni_kv, "__setitem__") and "engine_input_source" not in omni_kv:
+ omni_kv["engine_input_source"] = list(engine_input_source)
+ except Exception as e:
+ logger.debug("Failed to inject stage info into omni_kv_config: %s", e)
+
+
def _inject_global_id(target: Any, request_id: str) -> None:
"""Inject global_request_id into a prompt dict's additional_information."""
if isinstance(target, dict):
@@ -188,38 +170,6 @@ def _upgrade_to_omni_request(
)
-def _apply_omni_final_stage_metadata(
- request: EngineCoreRequest,
- final_stage_id: int,
-) -> EngineCoreRequest:
- """Tag EngineCoreRequest so OmniARScheduler can skip DiT KV when final_stage_id is 0."""
- merged: dict[str, Any] = {}
- if isinstance(request, OmniEngineCoreRequest) and request.additional_information is not None:
- merged = deserialize_additional_information(request.additional_information)
- merged["omni_final_stage_id"] = final_stage_id
- payload = serialize_additional_information(merged)
- return OmniEngineCoreRequest(
- request_id=request.request_id,
- prompt_token_ids=request.prompt_token_ids,
- mm_features=request.mm_features,
- sampling_params=request.sampling_params,
- pooling_params=request.pooling_params,
- arrival_time=request.arrival_time,
- lora_request=request.lora_request,
- cache_salt=request.cache_salt,
- data_parallel_rank=request.data_parallel_rank,
- prompt_embeds=request.prompt_embeds,
- client_index=request.client_index,
- current_wave=request.current_wave,
- priority=request.priority,
- trace_headers=request.trace_headers,
- resumable=request.resumable,
- external_req_id=request.external_req_id,
- reasoning_ended=request.reasoning_ended,
- additional_information=payload,
- )
-
-
def _weak_shutdown_async_omni_engine(
orchestrator_thread: threading.Thread | None,
request_queue: janus.Queue[dict[str, Any]] | None,
@@ -270,7 +220,6 @@ def __init__(
stage_init_timeout: int = 300,
init_timeout: int = 600,
diffusion_batch_size: int = 1,
- single_stage_mode: bool = False,
**kwargs: Any,
) -> None:
self.model = model
@@ -279,45 +228,17 @@ def __init__(
logger.info(f"[AsyncOmniEngine] Initializing with model {model}")
- # Merge tracked engine_args fields into kwargs; explicit kwargs take priority.
+ # Merge typed engine_args fields into kwargs; explicit kwargs take priority.
if engine_args is not None:
- if not hasattr(engine_args, "_explicit_fields"):
- raise TypeError(
- "engine_args=OmniEngineArgs(...) is ambiguous under "
- "sentinel-default precedence. Use "
- "OmniEngineArgs.create(**explicit) or pass explicit kwargs "
- "directly."
- )
- ea_dict = engine_args.explicit_kwargs()
+ ea_dict = {
+ f.name: getattr(engine_args, f.name)
+ for f in dataclasses.fields(engine_args)
+ if not f.name.startswith("_")
+ }
# Remove model since it is passed as a positional arg already.
ea_dict.pop("model", None)
kwargs = {**ea_dict, **kwargs}
- # ------------------------------------------------------------------ #
- # Single-stage mode detection #
- # ------------------------------------------------------------------ #
- # Single-stage mode is enabled when the caller explicitly passes #
- # single_stage_mode=True, or when a stage_id is provided in the args. #
- _stage_id_kwarg = kwargs.get("stage_id")
- if isinstance(_stage_id_kwarg, int) and not single_stage_mode:
- single_stage_mode = True
-
- self.single_stage_mode: bool = single_stage_mode
- self._single_stage_id_filter: int | None = (
- int(_stage_id_kwarg) if single_stage_mode and isinstance(_stage_id_kwarg, int) else None
- )
- self._omni_master_address: str | None = kwargs.get("omni_master_address")
- self._omni_master_port: int | None = kwargs.get("omni_master_port")
- self._omni_master_server: OmniMasterServer | None = None
-
- if single_stage_mode:
- logger.info(
- "[AsyncOmniEngine] Single-stage mode enabled (stage_id_filter=%s, master=%s:%s)",
- self._single_stage_id_filter,
- self._omni_master_address,
- self._omni_master_port,
- )
-
self.config_path, self.stage_configs = self._resolve_stage_configs(model, kwargs)
self.num_stages = len(self.stage_configs)
@@ -352,7 +273,22 @@ def __init__(
name="orchestrator",
)
self.orchestrator_thread.start()
- self._wait_for_orchestrator_init(startup_future, startup_timeout)
+
+ # Wait for stage/runtime initialization result from orchestrator thread.
+ try:
+ startup_future.result(timeout=startup_timeout)
+ except concurrent.futures.TimeoutError as e:
+ try:
+ self.shutdown()
+ except Exception:
+ logger.exception("[AsyncOmniEngine] Failed to cleanup after orchestrator startup timeout")
+ raise TimeoutError(f"Orchestrator did not become ready within {startup_timeout}s") from e
+ except Exception:
+ try:
+ self.shutdown()
+ except Exception:
+ logger.exception("[AsyncOmniEngine] Failed to cleanup after orchestrator startup failure")
+ raise
# Stage runtime fields are assigned directly on self by the bootstrap thread.
self._weak_finalizer = weakref.finalize(
@@ -379,99 +315,60 @@ def _launch_llm_stage(
started_stage: StartedLlmStage | None = None
lock_fds: list[int] = []
device_control_env = current_omni_platform.device_control_env_var
+
try:
- proc = None
- handshake_address = None
- with ExitStack() as launch_stack:
- with llm_stage_launch_lock:
- previous_visible_devices = os.environ.get(device_control_env)
- try:
- setup_stage_devices(metadata.stage_id, metadata.runtime_cfg)
- engine_args_dict = build_engine_args_dict(
- stage_cfg,
- self.model,
- stage_connector_spec=stage_connector_spec,
- )
- omni_conn_cfg, omni_from, omni_to = omni_kv_connector
- if omni_conn_cfg:
- omni_kv = engine_args_dict.get("omni_kv_config") or {}
- if not isinstance(omni_kv, dict):
- omni_kv = dict(omni_kv)
- omni_kv["connector_config"] = omni_conn_cfg
- omni_kv["omni_from_stage"] = omni_from
- omni_kv["omni_to_stage"] = omni_to
- omni_kv.setdefault("stage_id", metadata.stage_id)
- engine_args_dict["omni_kv_config"] = omni_kv
- if self.stage_configs:
- _inject_inferred_kv_tp_topology(
- engine_args_dict.get("omni_kv_config"),
- metadata.stage_id,
- self.stage_configs,
- )
- vllm_config, executor_class = build_vllm_config(
- stage_cfg,
- self.model,
- stage_connector_spec=stage_connector_spec,
- engine_args_dict=engine_args_dict,
- )
- lock_fds = acquire_device_locks(
- metadata.stage_id,
- engine_args_dict,
- stage_init_timeout,
- )
- if self.single_stage_mode and self._omni_master_server is not None:
- engine_manager, coordinator, addresses = launch_stack.enter_context(
- launch_omni_core_engines(
- vllm_config=vllm_config,
- executor_class=executor_class,
- log_stats=False,
- omni_master_server=self._omni_master_server,
- stage_id=metadata.stage_id,
- stage_config=stage_cfg,
- )
- )
- started_stage = StartedLlmStage(
- stage_id=metadata.stage_id,
- metadata=metadata,
- vllm_config=vllm_config,
- executor_class=executor_class,
- addresses=addresses,
- engine_manager=engine_manager,
- coordinator=coordinator,
- )
- else:
- addresses, proc, handshake_address = spawn_stage_core(
- vllm_config=vllm_config,
- executor_class=executor_class,
- log_stats=False,
- )
- started_stage = StartedLlmStage(
- stage_id=metadata.stage_id,
- metadata=metadata,
- vllm_config=vllm_config,
- executor_class=executor_class,
- addresses=addresses,
- proc=proc,
- )
- logger.info("[AsyncOmniEngine] Stage %s engine launch started", metadata.stage_id)
- finally:
- if previous_visible_devices is None:
- current_omni_platform.unset_device_control_env_var()
- else:
- current_omni_platform.set_device_control_env_var(previous_visible_devices)
-
- # After StageEngineCoreProc has been spawned it carries its
- # stage-specific device visibility into descendants, so the
- # slow HELLO/READY handshake can run without holding the
- # process-wide launch lock.
- if self.single_stage_mode and self._omni_master_server is not None:
- launch_stack.close()
- else:
- assert proc is not None
- assert handshake_address is not None
- complete_stage_handshake(proc, handshake_address, addresses, vllm_config, stage_init_timeout)
- logger.info("[AsyncOmniEngine] Stage %s engine startup completed", metadata.stage_id)
+ with llm_stage_launch_lock:
+ previous_visible_devices = os.environ.get(device_control_env)
+ try:
+ setup_stage_devices(metadata.stage_id, metadata.runtime_cfg)
+ engine_args_dict = build_engine_args_dict(
+ stage_cfg,
+ self.model,
+ stage_connector_spec=stage_connector_spec,
+ )
+ omni_conn_cfg, omni_from, omni_to = omni_kv_connector
+ if omni_conn_cfg:
+ omni_kv = engine_args_dict.get("omni_kv_config") or {}
+ if not isinstance(omni_kv, dict):
+ omni_kv = dict(omni_kv)
+ omni_kv["connector_config"] = omni_conn_cfg
+ omni_kv["omni_from_stage"] = omni_from
+ omni_kv["omni_to_stage"] = omni_to
+ omni_kv.setdefault("stage_id", metadata.stage_id)
+ engine_args_dict["omni_kv_config"] = omni_kv
+ vllm_config, executor_class = build_vllm_config(
+ stage_cfg,
+ self.model,
+ stage_connector_spec=stage_connector_spec,
+ engine_args_dict=engine_args_dict,
+ )
+ lock_fds = acquire_device_locks(
+ metadata.stage_id,
+ engine_args_dict,
+ stage_init_timeout,
+ )
+ addresses, proc, handshake_address = spawn_stage_core(
+ vllm_config=vllm_config,
+ executor_class=executor_class,
+ log_stats=False,
+ )
+ started_stage = StartedLlmStage(
+ stage_id=metadata.stage_id,
+ metadata=metadata,
+ vllm_config=vllm_config,
+ executor_class=executor_class,
+ proc=proc,
+ addresses=addresses,
+ )
+ finally:
+ if previous_visible_devices is None:
+ current_omni_platform.unset_device_control_env_var()
+ else:
+ current_omni_platform.set_device_control_env_var(previous_visible_devices)
+ logger.info("[AsyncOmniEngine] Stage %s engine launch started", metadata.stage_id)
+ complete_stage_handshake(proc, handshake_address, addresses, vllm_config)
+ logger.info("[AsyncOmniEngine] Stage %s engine startup completed", metadata.stage_id)
assert started_stage is not None
return started_stage
except Exception:
@@ -482,138 +379,13 @@ def _launch_llm_stage(
if lock_fds:
release_device_locks(lock_fds)
- def _create_remote_llm_stage(
- self,
- stage_cfg: Any,
- metadata: Any,
- stage_connector_spec: dict[str, Any],
- stage_init_timeout: int,
- omni_master_server: OmniMasterServer,
- ) -> StartedLlmStage:
- """Attach to a remote engine core and wait for its startup handshake."""
- started_stage: StartedLlmStage | None = None
- try:
- raw_stage_cfg = omni_master_server.get_stage_config(
- metadata.stage_id,
- timeout_s=stage_init_timeout,
- )
- if raw_stage_cfg is None:
- raise ValueError(f"Remote stage {metadata.stage_id} registered without stage config")
- stage_cfg = OmegaConf.create(raw_stage_cfg)
- engine_args_dict = build_engine_args_dict(
- stage_cfg,
- self.model,
- stage_connector_spec=stage_connector_spec,
- )
- vllm_config, executor_class = build_vllm_config(
- stage_cfg,
- self.model,
- stage_connector_spec=stage_connector_spec,
- engine_args_dict=engine_args_dict,
- )
- vllm_config.parallel_config.data_parallel_size_local = 0
- launch_cm = connect_remote_engine_cores(
- vllm_config=vllm_config,
- omni_master_server=omni_master_server,
- stage_id=metadata.stage_id,
- )
- logger.info("[AsyncOmniEngine] Stage %s remote engine handshake started", metadata.stage_id)
- with launch_cm as (engine_manager, coordinator, addresses):
- started_stage = StartedLlmStage(
- stage_id=metadata.stage_id,
- metadata=metadata,
- vllm_config=vllm_config,
- executor_class=executor_class,
- engine_manager=engine_manager,
- coordinator=coordinator,
- addresses=addresses,
- )
- logger.info("[AsyncOmniEngine] Stage %s remote engine startup completed", metadata.stage_id)
- assert started_stage is not None
- return started_stage
- except Exception:
- if started_stage is not None:
- close_started_llm_stage(started_stage)
- raise
-
- def _launch_diffusion_stage(
- self,
- stage_cfg: Any,
- metadata: Any,
- omni_master_server: OmniMasterServer,
- ) -> StageDiffusionClient:
- """Launch a local diffusion stage on OmniMasterServer-allocated sockets."""
- proc = None
- try:
- od_config = build_diffusion_config(self.model, stage_cfg, metadata)
- handshake_address, request_address, response_address = register_stage_with_omni_master(
- omni_master_address=omni_master_server.address,
- omni_master_port=omni_master_server.port,
- omni_stage_id=metadata.stage_id,
- omni_stage_config=stage_cfg,
- return_addresses=True,
- )
- logger.info(
- "[AsyncOmniEngine] Stage %s diffusion registration completed",
- metadata.stage_id,
- )
- proc, _, _, _ = spawn_diffusion_proc(
- self.model,
- od_config,
- handshake_address=handshake_address,
- request_address=request_address,
- response_address=response_address,
- )
- complete_diffusion_handshake(proc, handshake_address)
- logger.info(
- "[AsyncOmniEngine] Stage %s diffusion startup completed",
- metadata.stage_id,
- )
- return StageDiffusionClient.from_addresses(
- metadata,
- request_address=request_address,
- response_address=response_address,
- proc=proc,
- batch_size=self.diffusion_batch_size,
- )
- except Exception:
- if proc is not None:
- terminate_alive_proc(proc)
- raise
-
- def _create_remote_diffusion_stage(
- self,
- metadata: Any,
- stage_init_timeout: int,
- omni_master_server: OmniMasterServer,
- ) -> StageDiffusionClient:
- """Attach to a remote diffusion stage registered with OmniMasterServer."""
- remote_stage_cfg = OmegaConf.create(
- omni_master_server.get_stage_config(
- metadata.stage_id,
- timeout_s=stage_init_timeout,
- )
- )
- remote_metadata = extract_stage_metadata(remote_stage_cfg)
- addresses = omni_master_server.get_zmq_addresses(metadata.stage_id)
- logger.info(
- "[AsyncOmniEngine] Stage %s remote diffusion startup completed",
- metadata.stage_id,
- )
- return StageDiffusionClient.from_addresses(
- remote_metadata,
- request_address=addresses.inputs[0],
- response_address=addresses.outputs[0],
- batch_size=self.diffusion_batch_size,
- )
-
def _attach_llm_stage(
self,
started: StartedLlmStage,
) -> tuple[Any, Any, Any, InputProcessor | None]:
"""Attach a READY LLM stage to the orchestrator event loop."""
- client_addresses: dict[str, str] = {
+ client_addresses = {
"input_address": started.addresses.inputs[0],
"output_address": started.addresses.outputs[0],
}
@@ -621,18 +393,14 @@ def _attach_llm_stage(
client_addresses["stats_update_address"] = started.addresses.frontend_stats_publish_address
try:
- stage_client = StageEngineCoreClientBase.make_async_mp_client(
+ stage_client = StageEngineCoreClient(
vllm_config=started.vllm_config,
executor_class=started.executor_class,
metadata=started.metadata,
client_addresses=client_addresses,
proc=started.proc,
- engine_manager=started.engine_manager,
- coordinator=started.coordinator,
)
started.proc = None
- started.engine_manager = None
- started.coordinator = None
except Exception:
close_started_llm_stage(started)
raise
@@ -688,7 +456,7 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
output_processors: list[Any | None] = [None] * num_stages
stage_vllm_configs: list[Any | None] = [None] * num_stages
input_processor: InputProcessor | None = None
- llm_stage_positions: list[int] = []
+ llm_stage_ids: list[int] = []
llm_launch_futures: dict[int, concurrent.futures.Future[StartedLlmStage]] = {}
started_llm_stages: dict[int, StartedLlmStage] = {}
llm_stage_launch_lock = threading.Lock()
@@ -702,104 +470,45 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
prepare_engine_environment()
omni_transfer_config = load_omni_transfer_config_for_model(self.model, self.config_path)
- # ------------------------------------------------------------------ #
- # Single-stage mode: start OmniMasterServer before launching stages. #
- # ------------------------------------------------------------------ #
- if self.single_stage_mode:
- if not self._omni_master_address or not self._omni_master_port:
- raise ValueError(
- "AsyncOmniEngine single_stage_mode requires both "
- "omni_master_address and omni_master_port to be set."
- )
- # Collect all configured stage IDs for pre-allocation.
- all_stage_ids: list[int] = []
- seen_stage_ids: set[int] = set()
- for i, sc in enumerate(self.stage_configs):
- stage_id = int(getattr(sc, "stage_id", i))
- if stage_id in seen_stage_ids:
- raise ValueError(
- f"Duplicate stage_id {stage_id!r} detected among configured stages; stage_ids must be unique."
- )
- seen_stage_ids.add(stage_id)
- all_stage_ids.append(stage_id)
- self._omni_master_server = OmniMasterServer(
- master_address=self._omni_master_address,
- master_port=self._omni_master_port,
- stage_ids=all_stage_ids,
- )
- self._omni_master_server.start()
- logger.info(
- "[AsyncOmniEngine] OmniMasterServer started for stages %s",
- all_stage_ids,
- )
-
try:
with concurrent.futures.ThreadPoolExecutor(
max_workers=max(1, llm_stage_count),
thread_name_prefix="llm-stage-launch",
) as launch_executor:
- for stage_idx, stage_cfg in enumerate(self.stage_configs):
+ for stage_id, stage_cfg in enumerate(self.stage_configs):
+ logger.info("[AsyncOmniEngine] Initializing stage %s", stage_id)
metadata = extract_stage_metadata(stage_cfg)
- configured_stage_id = metadata.stage_id
- logger.info("[AsyncOmniEngine] Initializing stage %s", configured_stage_id)
if metadata.prompt_expand_func is not None:
prompt_expand_func = metadata.prompt_expand_func
- if self.single_stage_mode:
- metadata.runtime_cfg = None
-
stage_connector_spec = get_stage_connector_spec(
omni_transfer_config=omni_transfer_config,
- stage_id=configured_stage_id,
+ stage_id=stage_id,
async_chunk=async_chunk,
)
- omni_kv_connector = resolve_omni_kv_config_for_stage(omni_transfer_config, configured_stage_id)
+ omni_kv_connector = resolve_omni_kv_config_for_stage(omni_transfer_config, stage_id)
if metadata.stage_type == "diffusion":
- is_remote_diffusion_stage = (
- self.single_stage_mode
- and self._single_stage_id_filter is not None
- and configured_stage_id != self._single_stage_id_filter
- )
- if is_remote_diffusion_stage:
- assert self._omni_master_server is not None
- stage_clients[stage_idx] = self._create_remote_diffusion_stage(
- metadata,
- stage_init_timeout,
- self._omni_master_server,
- )
- continue
-
with llm_stage_launch_lock:
previous_visible_devices = os.environ.get(device_control_env)
try:
- setup_stage_devices(configured_stage_id, metadata.runtime_cfg)
+ setup_stage_devices(stage_id, metadata.runtime_cfg)
omni_conn_cfg, omni_from, omni_to = omni_kv_connector
if omni_conn_cfg:
+ from vllm_omni.entrypoints.utils import inject_omni_kv_config
+
inject_omni_kv_config(stage_cfg, omni_conn_cfg, omni_from, omni_to)
- inject_kv_stage_info(stage_cfg, configured_stage_id, self.stage_configs)
- if self.single_stage_mode:
- assert self._omni_master_server is not None
- stage_clients[stage_idx] = self._launch_diffusion_stage(
- stage_cfg,
- metadata,
- self._omni_master_server,
- )
- else:
- use_inline = True if self.num_stages == 1 else False
- stage_clients[stage_idx] = initialize_diffusion_stage(
- configured_stage_id,
- self.model,
- stage_cfg,
- metadata,
- stage_init_timeout=stage_init_timeout,
- batch_size=self.diffusion_batch_size,
- use_inline=use_inline,
- )
+ _inject_kv_stage_info(stage_cfg, stage_id)
+ stage_clients[stage_id] = initialize_diffusion_stage(
+ self.model,
+ stage_cfg,
+ metadata,
+ batch_size=self.diffusion_batch_size,
+ )
logger.info(
"[AsyncOmniEngine] Stage %s initialized (diffusion, batch_size=%d)",
- configured_stage_id,
+ stage_id,
self.diffusion_batch_size,
)
finally:
@@ -809,58 +518,30 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
current_omni_platform.set_device_control_env_var(previous_visible_devices)
continue
- llm_stage_positions.append(stage_idx)
-
- # In single-stage mode, stages that don't match the local
- # stage_id filter are skipped.
- if (
- self.single_stage_mode
- and self._single_stage_id_filter is not None
- and configured_stage_id != self._single_stage_id_filter
- ):
- assert self._omni_master_server is not None
- llm_launch_futures[stage_idx] = launch_executor.submit(
- self._create_remote_llm_stage,
- stage_cfg,
- metadata,
- stage_connector_spec,
- stage_init_timeout,
- self._omni_master_server,
- )
- else:
- llm_launch_futures[stage_idx] = launch_executor.submit(
- self._launch_llm_stage,
- stage_cfg,
- metadata,
- stage_connector_spec,
- stage_init_timeout,
- llm_stage_launch_lock,
- omni_kv_connector,
- )
+ llm_stage_ids.append(stage_id)
+ llm_launch_futures[stage_id] = launch_executor.submit(
+ self._launch_llm_stage,
+ stage_cfg,
+ metadata,
+ stage_connector_spec,
+ stage_init_timeout,
+ llm_stage_launch_lock,
+ omni_kv_connector,
+ )
concurrent.futures.wait(list(llm_launch_futures.values()))
- for stage_idx in llm_stage_positions:
- started_llm_stages[stage_idx] = llm_launch_futures[stage_idx].result()
-
- attach_futures: dict[concurrent.futures.Future[tuple[Any, Any, Any, InputProcessor | None]], int] = {}
- with concurrent.futures.ThreadPoolExecutor(
- max_workers=max(1, len(llm_stage_positions)),
- thread_name_prefix="llm-stage-attach",
- ) as attach_executor:
- for stage_idx in llm_stage_positions:
- attach_futures[attach_executor.submit(self._attach_llm_stage, started_llm_stages[stage_idx])] = (
- stage_idx
- )
+ for stage_id in llm_stage_ids:
+ started_llm_stages[stage_id] = llm_launch_futures[stage_id].result()
- for future in concurrent.futures.as_completed(attach_futures):
- stage_idx = attach_futures[future]
- stage_client, output_processor, vllm_config, stage0_input_processor = future.result()
- stage_clients[stage_idx] = stage_client
- output_processors[stage_idx] = output_processor
- stage_vllm_configs[stage_idx] = vllm_config
- if stage0_input_processor is not None:
- input_processor = stage0_input_processor
+ for stage_id in llm_stage_ids:
+ started = started_llm_stages[stage_id]
+ stage_client, output_processor, vllm_config, stage0_input_processor = self._attach_llm_stage(started)
+ stage_clients[stage_id] = stage_client
+ output_processors[stage_id] = output_processor
+ stage_vllm_configs[stage_id] = vllm_config
+ if stage0_input_processor is not None:
+ input_processor = stage0_input_processor
initialized_stage_clients, default_sampling_params_list, stage_metadata = finalize_initialized_stages(
stage_clients,
@@ -877,13 +558,8 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
)
cleanup_failed_stage_initialization(
stage_clients,
- [started_llm_stages[stage_idx] for stage_idx in llm_stage_positions if stage_idx in started_llm_stages],
+ [started_llm_stages[stage_id] for stage_id in llm_stage_ids if stage_id in started_llm_stages],
)
- if self._omni_master_server is not None:
- try:
- self._omni_master_server.stop()
- except Exception:
- logger.exception("[AsyncOmniEngine] Failed to stop OmniMasterServer during stage-init cleanup")
raise
self.stage_clients = initialized_stage_clients
@@ -923,7 +599,6 @@ async def _run_orchestrator() -> None:
self._initialize_janus_queues()
self._initialize_stages(stage_init_timeout)
- pd_config = self._detect_pd_config()
orchestrator = Orchestrator(
request_async_queue=self.request_queue.async_q,
output_async_queue=self.output_queue.async_q,
@@ -932,7 +607,6 @@ async def _run_orchestrator() -> None:
stage_clients=self.stage_clients,
output_processors=self.output_processors,
stage_vllm_configs=self.stage_vllm_configs,
- pd_config=pd_config,
)
if not startup_future.done():
startup_future.set_result(asyncio.get_running_loop())
@@ -942,17 +616,13 @@ async def _run_orchestrator() -> None:
loop.run_until_complete(_run_orchestrator())
except Exception as e:
if not startup_future.done():
- wrapped = RuntimeError(f"Orchestrator initialization failed: {e}")
- wrapped.__cause__ = e
- startup_future.set_exception(wrapped)
+ startup_future.set_exception(RuntimeError(f"Orchestrator initialization failed: {e}"))
logger.exception("[AsyncOmniEngine] Orchestrator thread crashed")
- error_text = str(e) or "Orchestrator thread crashed"
try:
- error_msg = {"type": "error", "error": error_text, "fatal": True}
if self.output_queue is not None:
- self.output_queue.sync_q.put_nowait(error_msg)
+ self.output_queue.sync_q.put_nowait({"type": "error", "error": "Orchestrator thread crashed"})
if self.rpc_output_queue is not None:
- self.rpc_output_queue.sync_q.put_nowait(error_msg)
+ self.rpc_output_queue.sync_q.put_nowait({"type": "error", "error": "Orchestrator thread crashed"})
except Exception:
pass
raise
@@ -972,31 +642,6 @@ async def _run_orchestrator() -> None:
asyncio.set_event_loop(None)
loop.close()
- def _wait_for_orchestrator_init(self, startup_future: concurrent.futures.Future, startup_timeout: int) -> None:
- """
- Wait for orchestrator startup future to return ready. Raises exception on any failures to the init process.
- """
- deadline = time.monotonic() + startup_timeout
- while True:
- remaining = deadline - time.monotonic()
- if remaining <= 0:
- self._try_shutdown("[AsyncOmniEngine] Failed to cleanup after orchestrator startup timeout")
- raise TimeoutError(f"Orchestrator did not become ready within {startup_timeout}s")
- try:
- startup_future.result(
- timeout=min(remaining, _STARTUP_POLL_INTERVAL_S),
- )
- break
- except concurrent.futures.TimeoutError:
- if not self.orchestrator_thread.is_alive():
- self._try_shutdown("[AsyncOmniEngine] Failed to cleanup after orchestrator startup failure")
- if startup_future.done():
- startup_future.result() # re-raises the real exception
- raise RuntimeError("Orchestrator thread died during startup")
- except Exception:
- self._try_shutdown("[AsyncOmniEngine] Failed to cleanup after orchestrator startup failure")
- raise
-
# ---- request helpers ----
def _build_add_request_message(
@@ -1032,7 +677,6 @@ def _build_add_request_message(
original_prompt = prompt
stage_type = self.stage_metadata[0].get("stage_type")
- _preprocess_ms = 0.0
if stage_type != "diffusion" and not isinstance(prompt, EngineCoreRequest):
# Inject global_request_id into the raw prompt.
if isinstance(prompt, dict):
@@ -1042,7 +686,6 @@ def _build_add_request_message(
_inject_global_id(item, request_id)
# Full input processing (tokenization, multimodal, etc.)
- _t_preprocess = time.perf_counter()
request = self.input_processor.process_inputs(
request_id=request_id,
prompt=prompt,
@@ -1056,7 +699,6 @@ def _build_add_request_message(
data_parallel_rank=data_parallel_rank,
resumable=resumable,
)
- _preprocess_ms = (time.perf_counter() - _t_preprocess) * 1000.0
# TODO (Peiqi): add this for Qwen3-TTS only. Other models don't have
# additional_information field in the prompt.
request = _upgrade_to_omni_request(request, prompt)
@@ -1071,12 +713,9 @@ def _build_add_request_message(
# to match the key used in Orchestrator.request_states so that
# output routing (output.request_id lookup) can find the req_state.
request.external_req_id = request_id
- request = _apply_omni_final_stage_metadata(request, final_stage_id)
# Register with stage 0's output processor.
output_prompt_text = prompt_text
- if output_prompt_text is None and isinstance(original_prompt, dict):
- output_prompt_text = original_prompt.get("prompt")
self.output_processors[0].add_request(
request=request,
prompt=output_prompt_text,
@@ -1093,8 +732,6 @@ def _build_add_request_message(
"original_prompt": original_prompt,
"sampling_params_list": effective_sampling_params_list,
"final_stage_id": final_stage_id,
- "preprocess_ms": _preprocess_ms,
- "enqueue_ts": time.perf_counter(),
}
def _enqueue_cfg_companions(
@@ -1129,6 +766,7 @@ def _enqueue_cfg_companions(
params=companion_params,
supported_tasks=self.supported_tasks,
)
+ request = _upgrade_to_omni_request(request, companion_prompt)
request.external_req_id = cid
self.output_processors[0].add_request(
@@ -1188,66 +826,12 @@ def _normalize_cache_config(cache_backend: str | None, cache_config: Any | None)
cache_config = AsyncOmniEngine._get_default_cache_config(cache_backend)
return cache_config
- def _detect_pd_config(self) -> dict[str, Any] | None:
- """Detect PD (Prefill-Decode) disaggregation config from stage_configs.
- Returns a dict with 'pd_pair' and 'bootstrap_addr', or None.
- """
- pd_pair = PDDisaggregationMixin.detect_pd_separation_from_stage_configs(self.stage_configs)
- if pd_pair is None:
- return None
- prefill_idx, decode_idx = pd_pair
-
- # Extract bootstrap address from prefill stage engine_args
- bootstrap_addr: str | None = None
- try:
- prefill_cfg = self.stage_configs[prefill_idx]
- ea = getattr(prefill_cfg, "engine_args", None)
- kv_cfg = getattr(ea, "kv_transfer_config", None) if ea is not None else None
- if kv_cfg is not None:
- port = vllm_envs.VLLM_MOONCAKE_BOOTSTRAP_PORT
- kv_ip = getattr(kv_cfg, "kv_ip", None) or "127.0.0.1"
- bootstrap_addr = f"http://{kv_ip}:{port}"
- except Exception as exc:
- logger.warning("[AsyncOmniEngine] Could not extract PD bootstrap address: %s", exc)
-
- logger.info(
- "[AsyncOmniEngine] PD disaggregation detected: prefill=stage-%d, decode=stage-%d, bootstrap=%s",
- prefill_idx,
- decode_idx,
- bootstrap_addr,
- )
- prefill_engine_id: str | None = None
- try:
- prefill_client = self.stage_clients[prefill_idx]
- kv_cfg = getattr(getattr(prefill_client, "vllm_config", None), "kv_transfer_config", None)
- prefill_engine_id = getattr(kv_cfg, "engine_id", None)
- except Exception as exc:
- logger.warning("[AsyncOmniEngine] Could not extract prefill engine_id: %s", exc)
-
- return {
- "pd_pair": (prefill_idx, decode_idx),
- "bootstrap_addr": bootstrap_addr,
- "prefill_engine_id": prefill_engine_id,
- }
-
@staticmethod
def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list:
"""Create a default single-stage diffusion config from kwargs."""
# We temporally create a default config for diffusion stage.
# In the future, we should merge the default config with the user-provided config.
normalized_kwargs = dict(kwargs)
- default_sampling_params = normalized_kwargs.get("default_sampling_params")
- if isinstance(default_sampling_params, str):
- try:
- default_sampling_params = json.loads(default_sampling_params)
- except json.JSONDecodeError:
- logger.warning("Invalid default_sampling_params JSON, ignoring stage defaults.")
- default_sampling_params = None
- if not isinstance(default_sampling_params, dict):
- default_sampling_params = None
- stage_default_sampling_params = default_sampling_params.get("0", {}) if default_sampling_params else {}
- if normalized_kwargs.get("dtype") is None:
- normalized_kwargs["dtype"] = "auto"
# TODO: hack, convert dtype to string to avoid non-premitive omegaconf create error.
if "dtype" in normalized_kwargs and not isinstance(normalized_kwargs["dtype"], str):
@@ -1298,50 +882,6 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list:
num_devices = max(1, int(parallel_config.world_size))
devices = ",".join(str(i) for i in range(num_devices))
- stage_engine_args = {
- "max_num_seqs": 1,
- "parallel_config": parallel_config,
- "model_class_name": kwargs.get("model_class_name", None),
- "step_execution": kwargs.get("step_execution", False),
- "vae_use_slicing": kwargs.get("vae_use_slicing", False),
- "vae_use_tiling": kwargs.get("vae_use_tiling", False),
- "cache_backend": cache_backend,
- "cache_config": cache_config,
- "enable_cache_dit_summary": kwargs.get("enable_cache_dit_summary", False),
- "enable_cpu_offload": kwargs.get("enable_cpu_offload", False),
- "enable_layerwise_offload": kwargs.get("enable_layerwise_offload", False),
- "enforce_eager": kwargs.get("enforce_eager", False),
- "boundary_ratio": kwargs.get("boundary_ratio", None),
- "flow_shift": kwargs.get("flow_shift", None),
- "diffusion_load_format": kwargs.get("diffusion_load_format", "default"),
- "custom_pipeline_args": kwargs.get("custom_pipeline_args", None),
- "worker_extension_cls": kwargs.get("worker_extension_cls", None),
- "enable_sleep_mode": kwargs.get("enable_sleep_mode", False),
- "enable_multithread_weight_load": kwargs.get("enable_multithread_weight_load", True),
- "num_weight_load_threads": kwargs.get("num_weight_load_threads", 4),
- "quantization": kwargs.get("quantization", None),
- "enable_diffusion_pipeline_profiler": kwargs.get("enable_diffusion_pipeline_profiler", False),
- "enable_ar_profiler": kwargs.get("enable_ar_profiler", False),
- **(
- {
- "profiler_config": asdict(kwargs["profiler_config"])
- if hasattr(kwargs["profiler_config"], "__dataclass_fields__")
- else kwargs["profiler_config"]
- }
- if kwargs.get("profiler_config") is not None
- else {}
- ),
- }
- # Only set dtype if it was already explicitly passed and normalized
- if "dtype" in normalized_kwargs:
- stage_engine_args["dtype"] = normalized_kwargs["dtype"]
-
- # New split fields for diffusers adapter kwargs.
- if kwargs.get("diffusers_load_kwargs") is not None:
- stage_engine_args["diffusers_load_kwargs"] = kwargs["diffusers_load_kwargs"]
- if kwargs.get("diffusers_call_kwargs") is not None:
- stage_engine_args["diffusers_call_kwargs"] = kwargs["diffusers_call_kwargs"]
-
default_stage_cfg = [
{
"stage_id": 0,
@@ -1350,8 +890,37 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list:
"process": True,
"devices": devices,
},
- "engine_args": stage_engine_args,
- "default_sampling_params": stage_default_sampling_params,
+ "engine_args": {
+ "max_num_seqs": 1,
+ "parallel_config": parallel_config,
+ "model_class_name": kwargs.get("model_class_name", None),
+ "step_execution": kwargs.get("step_execution", False),
+ "vae_use_slicing": kwargs.get("vae_use_slicing", False),
+ "vae_use_tiling": kwargs.get("vae_use_tiling", False),
+ "cache_backend": cache_backend,
+ "cache_config": cache_config,
+ "enable_cache_dit_summary": kwargs.get("enable_cache_dit_summary", False),
+ "enable_cpu_offload": kwargs.get("enable_cpu_offload", False),
+ "enable_layerwise_offload": kwargs.get("enable_layerwise_offload", False),
+ "enforce_eager": kwargs.get("enforce_eager", False),
+ "diffusion_load_format": kwargs.get("diffusion_load_format", "default"),
+ "custom_pipeline_args": kwargs.get("custom_pipeline_args", None),
+ "worker_extension_cls": kwargs.get("worker_extension_cls", None),
+ "enable_sleep_mode": kwargs.get("enable_sleep_mode", False),
+ "enable_multithread_weight_load": kwargs.get("enable_multithread_weight_load", True),
+ "num_weight_load_threads": kwargs.get("num_weight_load_threads", 4),
+ "quantization": kwargs.get("quantization", None),
+ "enable_diffusion_pipeline_profiler": kwargs.get("enable_diffusion_pipeline_profiler", False),
+ **(
+ {
+ "profiler_config": asdict(kwargs["profiler_config"])
+ if hasattr(kwargs["profiler_config"], "__dataclass_fields__")
+ else kwargs["profiler_config"]
+ }
+ if kwargs.get("profiler_config") is not None
+ else {}
+ ),
+ },
"final_output": True,
"final_output_type": "image",
}
@@ -1359,47 +928,10 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list:
default_stage_cfg[0]["engine_args"]["model_stage"] = "diffusion"
return default_stage_cfg
- @staticmethod
- def _strip_single_engine_args(kwargs: dict[str, Any]) -> dict[str, Any]:
- """Remove parent ``EngineArgs`` fields from *kwargs*.
-
- When ``stage_configs_path`` is set, per-stage engine args are defined
- in the YAML. Top-level single-engine fields (``compilation_config``,
- ``tensor_parallel_size``, …) must not leak into per-stage configs via
- the ``base_engine_args`` merge in ``load_stage_configs_from_yaml`` —
- they can cause type errors (e.g. ``compilation_config`` as a JSON
- string rejected by ``VllmConfig``) or silently override YAML values.
-
- Logs a warning for any parent field whose value differs from the
- dataclass default, so users know their explicit overrides are ignored.
- See the module-level ``_PARENT_ARGS_*`` constants for the routing
- contracts this method enforces.
- """
- parent_fields: dict[str, dataclasses.Field] = {f.name: f for f in dataclasses.fields(EngineArgs)}
- result, overridden = strip_parent_engine_args(
- kwargs,
- parent_fields=parent_fields,
- keep_keys=_PARENT_ARGS_KEEP,
- strip_keys=_PARENT_ARGS_STRIP,
- no_warn_keys=_PARENT_ARGS_NO_WARN,
- )
-
- if overridden:
- logger.warning(
- "stage_configs_path is set — the following top-level engine "
- "args are ignored (per-stage YAML takes precedence): %s",
- ", ".join(sorted(overridden)),
- )
-
- return result
-
def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[str, list[Any]]:
"""Resolve stage configs and inject defaults shared by orchestrator/headless."""
stage_configs_path = kwargs.get("stage_configs_path", None)
- deploy_config_path = kwargs.pop("deploy_config", None)
- stage_overrides_json = kwargs.pop("stage_overrides", None)
- kwargs.pop("_cli_explicit_keys", None)
explicit_stage_configs = kwargs.pop("stage_configs", None)
if explicit_stage_configs is not None:
logger.warning(
@@ -1407,42 +939,18 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st
"Ignoring it and resolving stages from stage_configs_path/model factory."
)
- if stage_configs_path is not None:
- base_kwargs = self._strip_single_engine_args(kwargs)
- else:
- base_kwargs = kwargs
-
- # Parse --stage-overrides JSON string if provided
- stage_overrides = None
- if stage_overrides_json:
- if isinstance(stage_overrides_json, str):
- try:
- stage_overrides = json.loads(stage_overrides_json)
- except json.JSONDecodeError as exc:
- raise ValueError(
- f"--stage-overrides is not valid JSON: {exc}. Got: {stage_overrides_json!r}"
- ) from exc
- else:
- stage_overrides = stage_overrides_json
-
+ # Use the legacy config loading path (load_and_resolve_stage_configs).
+ # StageConfigFactory wiring will be done in config refactor [2/N].
config_path, stage_configs = load_and_resolve_stage_configs(
model,
stage_configs_path,
- base_kwargs,
+ kwargs,
default_stage_cfg_factory=lambda: self._create_default_diffusion_stage_cfg(kwargs),
- deploy_config_path=deploy_config_path,
- stage_overrides=stage_overrides,
)
# Inject diffusion LoRA-related knobs from kwargs if not present in the stage config.
for cfg in stage_configs:
try:
- if not hasattr(cfg, "engine_args") or cfg.engine_args is None:
- cfg.engine_args = OmegaConf.create({})
- global_sleep_mode = kwargs.get("enable_sleep_mode")
- if global_sleep_mode is not None:
- if not hasattr(cfg.engine_args, "enable_sleep_mode") or cfg.engine_args.enable_sleep_mode is None:
- cfg.engine_args.enable_sleep_mode = global_sleep_mode
if getattr(cfg, "stage_type", None) != "diffusion":
continue
if not hasattr(cfg, "engine_args") or cfg.engine_args is None:
@@ -1464,21 +972,6 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st
or cfg.engine_args.quantization_config is None
):
cfg.engine_args.quantization_config = quantization_config
- # Inject profiler flags for diffusion stages
- for profiler_key in (
- "enable_diffusion_pipeline_profiler",
- "enable_ar_profiler",
- ):
- val = kwargs.get(profiler_key)
- if val:
- if not hasattr(cfg.engine_args, profiler_key) or not getattr(
- cfg.engine_args, profiler_key, False
- ):
- setattr(cfg.engine_args, profiler_key, val)
- quantization = kwargs.get("quantization")
- if quantization is not None:
- if not hasattr(cfg.engine_args, "quantization") or cfg.engine_args.quantization is None:
- cfg.engine_args.quantization = quantization
except Exception as e:
logger.warning("Failed to inject LoRA config for stage: %s", e)
@@ -1772,16 +1265,3 @@ def shutdown(self) -> None:
q.close()
except Exception:
pass
-
- if self._omni_master_server is not None:
- try:
- self._omni_master_server.stop()
- except Exception:
- logger.exception("[AsyncOmniEngine] Failed to stop OmniMasterServer during shutdown")
- self._omni_master_server = None
-
- def _try_shutdown(self, *args, **kwargs) -> None:
- try:
- self.shutdown()
- except Exception:
- logger.exception(*args, **kwargs)
diff --git a/vllm_omni/engine/cfg_companion_tracker.py b/vllm_omni/engine/cfg_companion_tracker.py
deleted file mode 100644
index b9dfae833e2..00000000000
--- a/vllm_omni/engine/cfg_companion_tracker.py
+++ /dev/null
@@ -1,125 +0,0 @@
-"""CFG companion request tracker for the Omni orchestrator.
-
-Encapsulates all bookkeeping for Classifier-Free Guidance companion
-requests (parent/companion ID mapping, completion tracking,
-deferred forwarding, and cleanup).
-"""
-
-from __future__ import annotations
-
-import logging
-from typing import Any
-
-from vllm_omni.inputs.data import OmniDiffusionSamplingParams
-
-logger = logging.getLogger(__name__)
-
-
-class CfgCompanionTracker:
- """Manages CFG companion request lifecycle in the orchestrator scheduling loop."""
-
- def __init__(self) -> None:
- self._companion_map: dict[str, dict[str, str]] = {} # parent -> {role: companion_id}
- self._companion_ids: set[str] = set()
- self._companion_to_parent: dict[str, str] = {} # companion -> parent
- self._done: dict[str, set[str]] = {} # parent -> completed companion ids
- self._pending_parents: dict[str, dict[str, Any]] = {} # parent -> deferred result
-
- def is_companion(self, req_id: str) -> bool:
- return req_id in self._companion_ids
-
- def has_companions(self, parent_id: str) -> bool:
- return parent_id in self._companion_map
-
- def all_companions_done(self, parent_id: str) -> bool:
- role_map = self._companion_map.get(parent_id, {})
- done_set = self._done.get(parent_id, set())
- return all(cid in done_set for cid in role_map.values())
-
- def get_companion_request_ids(self, parent_id: str) -> dict[str, str]:
- """Return ``{role: companion_request_id}`` for a parent."""
- return self._companion_map.get(parent_id, {})
-
- def register_parent(self, parent_id: str) -> None:
- self._companion_map.setdefault(parent_id, {})
- self._done.setdefault(parent_id, set())
-
- def register_companion(self, parent_id: str, role: str, companion_id: str) -> None:
- self.register_parent(parent_id)
- self._companion_map[parent_id][role] = companion_id
- self._companion_ids.add(companion_id)
- self._companion_to_parent[companion_id] = parent_id
-
- def attach_cfg_request_ids(self, parent_id: str, sampling_params: Any) -> Any:
- cfg_ids = self.get_companion_request_ids(parent_id)
- if not cfg_ids:
- return sampling_params
-
- if isinstance(sampling_params, OmniDiffusionSamplingParams):
- sampling_params = sampling_params.clone()
- sampling_params.cfg_kv_request_ids = cfg_ids
- logger.info(
- "Attaching cfg_kv_request_ids=%s to request %s",
- cfg_ids,
- parent_id,
- )
- return sampling_params
-
- def on_companion_completed(self, companion_id: str) -> str | None:
- """Mark done. Returns parent_id only if parent is pending and all companions finished."""
- parent_id = self._companion_to_parent.get(companion_id)
- if parent_id is None:
- return None
- done_set = self._done.get(parent_id)
- assert done_set is not None, f"Companion {companion_id} completed before parent {parent_id} was registered"
- if companion_id in done_set:
- return None
- done_set.add(companion_id)
- logger.debug("CFG companion %s completed (parent=%s)", companion_id, parent_id)
- if parent_id in self._pending_parents and self.all_companions_done(parent_id):
- return parent_id
- return None
-
- def defer_parent(self, parent_id: str, engine_outputs: Any, stage_id: int) -> None:
- """Hold parent result while waiting for companions to finish."""
- # TODO: Add timeout/error recovery when the orchestrator grows a
- # companion-failure path. Today deferred parents are released only when
- # companions finish or the external layer aborts the request.
- self._pending_parents[parent_id] = {
- "engine_outputs": engine_outputs,
- "stage_id": stage_id,
- }
- logger.debug("Parent %s deferred, waiting for CFG companions", parent_id)
-
- def pop_pending_parent(self, parent_id: str) -> dict[str, Any] | None:
- return self._pending_parents.pop(parent_id, None)
-
- def cleanup_parent(self, parent_id: str) -> list[str]:
- companion_ids = list(self._companion_map.pop(parent_id, {}).values())
- for companion_id in companion_ids:
- self._companion_ids.discard(companion_id)
- self._companion_to_parent.pop(companion_id, None)
- self._done.pop(parent_id, None)
- self._pending_parents.pop(parent_id, None)
- return companion_ids
-
- def abort_parents(self, request_ids: list[str]) -> list[str]:
- all_request_ids = list(request_ids)
- seen = set(all_request_ids)
- parents_to_cleanup: set[str] = set()
-
- for req_id in request_ids:
- # The orchestrator calls this with parent request IDs. If a raw
- # companion ID is passed here, keep it as a direct abort target and
- # avoid tearing down parent tracking state implicitly.
- if req_id not in self._companion_ids:
- parents_to_cleanup.add(req_id)
-
- for parent_id in parents_to_cleanup:
- companion_ids = self.cleanup_parent(parent_id)
- for companion_id in companion_ids:
- if companion_id not in seen:
- seen.add(companion_id)
- all_request_ids.append(companion_id)
-
- return all_request_ids
diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py
index e1bcf39a590..8ea9a5096ca 100644
--- a/vllm_omni/engine/orchestrator.py
+++ b/vllm_omni/engine/orchestrator.py
@@ -22,18 +22,15 @@
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import EngineCoreOutputs
-from vllm.v1.engine.exceptions import EngineDeadError
from vllm_omni.distributed.omni_connectors.adapter import compute_talker_prompt_ids_length
from vllm_omni.engine import (
OmniEngineCoreRequest,
)
-from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker
from vllm_omni.engine.serialization import serialize_additional_information
from vllm_omni.metrics.stats import StageRequestStats as StageRequestMetrics
from vllm_omni.metrics.stats import StageStats
from vllm_omni.metrics.utils import count_tokens_from_outputs
-from vllm_omni.outputs import OmniRequestOutput
logger = init_logger(__name__)
@@ -44,8 +41,6 @@ def build_engine_core_request_from_tokens(
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
model_config: ModelConfig | None = None,
- resumable: bool = False,
- mm_features: list | None = None,
) -> OmniEngineCoreRequest:
"""Build an OmniEngineCoreRequest directly from an OmniTokensPrompt.
@@ -80,15 +75,14 @@ def build_engine_core_request_from_tokens(
return OmniEngineCoreRequest(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
- mm_features=mm_features,
+ mm_features=None,
sampling_params=sampling_params,
pooling_params=pooling_params,
arrival_time=arrival_time,
- lora_request=getattr(params, "lora_request", None),
+ lora_request=None,
cache_salt=None,
data_parallel_rank=None,
prompt_embeds=prompt_embeds,
- resumable=resumable,
additional_information=additional_info_payload,
)
@@ -109,25 +103,6 @@ class OrchestratorRequestState:
# Metrics: timestamp when request was submitted to each stage
stage_submit_ts: dict[int, float] = field(default_factory=dict)
- mm_processor_kwargs: dict | None = None
- mm_features: list | None = None
-
- streaming: StreamingInputState = field(default_factory=lambda: StreamingInputState())
-
- # Per-request pipeline timing accumulator (milliseconds)
- pipeline_timings: dict[str, float] = field(default_factory=dict)
-
-
-@dataclass
-class StreamingInputState:
- # Flag of streaming input request
- enabled: bool = False
- # Flag of segment of streaming input finished
- segment_finished: bool = False
- # Streaming update prompt length
- new_prompt_len_snapshot: int | None = None
- # Model/bridge-specific runtime states (e.g., thinker->talker)
- bridge_states: dict[str, Any] = field(default_factory=dict)
class Orchestrator:
@@ -147,7 +122,6 @@ def __init__(
stage_vllm_configs: list[Any],
*,
async_chunk: bool = False,
- pd_config: dict[str, Any] | None = None,
) -> None:
self.request_async_queue = request_async_queue
self.output_async_queue = output_async_queue
@@ -160,21 +134,15 @@ def __init__(
self.output_processors: list[Any] = output_processors
self.stage_vllm_configs: list[Any] = stage_vllm_configs
- # PD disaggregation state
- self._pd_pair: tuple[int, int] | None = None
- self._pd_bootstrap_addr: str | None = None
- self._pd_prefill_engine_id: str | None = None
- self._pd_kv_params: dict[str, Any] = {}
- if pd_config is not None:
- self._pd_pair = pd_config.get("pd_pair")
- self._pd_bootstrap_addr = pd_config.get("bootstrap_addr")
- self._pd_prefill_engine_id = pd_config.get("prefill_engine_id")
-
# Per-request state
self.request_states: dict[str, OrchestratorRequestState] = {}
# CFG companion tracking
- self._cfg_tracker = CfgCompanionTracker()
+ self._companion_map: dict[str, dict[str, str]] = {}
+ self._companion_to_parent: dict[str, str] = {}
+ self._companion_ids: set[str] = set()
+ self._companion_done: dict[str, set[str]] = {}
+ self._deferred_parents: dict[str, dict[str, Any]] = {}
# Per-stage metrics accumulators.
self._batch_seq: list[int] = [0] * self.num_stages
@@ -184,8 +152,6 @@ def __init__(
# Shutdown coordination
self._shutdown_event = asyncio.Event()
self._stages_shutdown = False
- self._fatal_error: str | None = None
- self._fatal_error_stage_id: int | None = None
async def run(self) -> None:
"""Main entry point for the Orchestrator event loop."""
@@ -215,12 +181,6 @@ async def run(self) -> None:
except Exception:
pass
- # If a fatal error caused the shutdown, drain any pending
- # add_request messages that were never processed and broadcast
- # fatal error responses so callers are not left hanging.
- if self._fatal_error is not None:
- await self._drain_pending_requests_on_fatal()
-
self._shutdown_stages()
# Cancel any remaining tasks spawned by wait_for / gather so
@@ -284,40 +244,8 @@ async def _orchestration_loop(self) -> None:
output = stage_client.get_diffusion_output_nowait()
if output is not None:
idle = False
-
- if getattr(output, "error", None) is not None:
- await self.output_async_queue.put(
- {
- "type": "output",
- "request_id": output.request_id,
- "stage_id": stage_id,
- "engine_outputs": output,
- "metrics": None,
- "finished": True,
- }
- )
- self.request_states.pop(output.request_id, None)
- continue
-
req_state = self.request_states.get(output.request_id)
if req_state is not None:
- if getattr(output, "error", None) is not None:
- parent_id = self._companion_to_parent.get(output.request_id, output.request_id)
- await self.output_async_queue.put(
- {
- "type": "error",
- "request_id": parent_id,
- "stage_id": stage_id,
- "error": output.error,
- }
- )
- role_map = self._companion_map.get(parent_id, {})
- for cid in role_map.values():
- self.request_states.pop(cid, None)
- self._cleanup_companion_state(parent_id)
- self.request_states.pop(parent_id, None)
- continue
-
stage_metrics = self._build_stage_metrics(stage_id, output.request_id, [output], req_state)
await self._route_output(stage_id, output, req_state, stage_metrics)
continue
@@ -329,28 +257,6 @@ async def _orchestration_loop(self) -> None:
continue
except asyncio.CancelledError:
raise
- except EngineDeadError as e:
- logger.error(
- "[Orchestrator] Stage-%s is dead: %s",
- stage_id,
- e,
- )
- self._fatal_error = str(e)
- self._fatal_error_stage_id = stage_id
- for req_id, req_state in list(self.request_states.items()):
- if stage_id in req_state.stage_submit_ts:
- await self.output_async_queue.put(
- {
- "type": "error",
- "error": str(e),
- "fatal": True,
- "request_id": req_id,
- "stage_id": stage_id,
- }
- )
- self.request_states.pop(req_id, None)
- self._shutdown_event.set()
- raise
except Exception:
if self._shutdown_event.is_set():
return
@@ -411,7 +317,7 @@ async def _route_output(
# CFG companion handling: companions don't produce user-visible output
# and don't forward to the next stage directly.
- if finished and self._cfg_tracker.is_companion(req_id):
+ if finished and req_id in self._companion_ids:
await self._handle_cfg_companion_ready(req_id)
self.request_states.pop(req_id, None)
return
@@ -439,80 +345,63 @@ async def _route_output(
}
)
- # PD disaggregation: extract KV transfer params from prefill stage output
- if self._pd_pair is not None and finished and stage_id == self._pd_pair[0]:
- kv_params = getattr(output, "kv_transfer_params", None)
- if kv_params is not None:
- self._pd_kv_params[req_id] = kv_params if isinstance(kv_params, dict) else dict(kv_params)
- logger.debug(
- "[Orchestrator][PD] stored kv_transfer_params for req=%s (keys=%s)",
- req_id,
- list(self._pd_kv_params[req_id].keys()),
- )
- else:
- logger.warning(
- "[Orchestrator][PD] prefill stage output for req=%s has no kv_transfer_params; "
- "KV transfer may fail. Ensure apply_mooncake_connector_patch() was called.",
- req_id,
- )
-
if (
- (finished or (req_state.streaming.enabled and req_state.streaming.segment_finished))
+ finished
and stage_id < req_state.final_stage_id
and not self.async_chunk
- and (not self._next_stage_already_submitted(stage_id, req_state) or req_state.streaming.enabled)
+ and not self._next_stage_already_submitted(stage_id, req_state)
):
- if (
- finished
- and self._cfg_tracker.has_companions(req_id)
- and not self._cfg_tracker.all_companions_done(req_id)
- ):
- self._cfg_tracker.defer_parent(req_id, output, stage_id)
+ if req_id in self._companion_map and not self._all_companions_done(req_id):
+ self._deferred_parents[req_id] = {
+ "stage_id": stage_id,
+ "output": output,
+ }
else:
- await self._forward_to_next_stage(
- req_id,
- stage_id,
- output,
- req_state,
- is_streaming_session=req_state.streaming.enabled,
- is_final_update=False,
- )
- if req_state.streaming.enabled and finished:
- # For streaming sessions, send the terminal (resumable=False) update only on a finish
- await self._forward_to_next_stage(
- req_id,
- stage_id,
- output,
- req_state,
- is_streaming_session=True,
- is_final_update=True,
- )
+ await self._forward_to_next_stage(req_id, stage_id, output, req_state)
if finished and stage_id == req_state.final_stage_id:
- # PD: clean up any lingering KV params for this request
- self._pd_kv_params.pop(req_id, None)
- self._cfg_tracker.cleanup_parent(req_id)
+ self._cleanup_companion_state(req_id)
self.request_states.pop(req_id, None)
+ def _cleanup_companion_state(self, parent_id: str) -> None:
+ """Remove all companion tracking state for a completed parent."""
+ role_map = self._companion_map.pop(parent_id, {})
+ for cid in role_map.values():
+ self._companion_ids.discard(cid)
+ self._companion_to_parent.pop(cid, None)
+ self._companion_done.pop(parent_id, None)
+ self._deferred_parents.pop(parent_id, None)
+
+ def _all_companions_done(self, parent_id: str) -> bool:
+ """Check whether all CFG companions for a parent request have finished."""
+ role_map = self._companion_map.get(parent_id, {})
+ if not role_map:
+ return True
+ done_set = self._companion_done.get(parent_id, set())
+ return all(cid in done_set for cid in role_map.values())
+
def _next_stage_already_submitted(self, stage_id: int, req_state: OrchestratorRequestState) -> bool:
return (stage_id + 1) in req_state.stage_submit_ts
async def _handle_cfg_companion_ready(self, req_id: str) -> None:
"""Mark a CFG companion as done; if all companions are done, flush deferred parent."""
- parent_id = self._cfg_tracker.on_companion_completed(req_id)
+ parent_id = self._companion_to_parent.get(req_id)
if parent_id is None:
return
- deferred = self._cfg_tracker.pop_pending_parent(parent_id)
- if deferred is None:
+ done_set = self._companion_done.setdefault(parent_id, set())
+ if req_id in done_set:
return
- parent_state = self.request_states.get(parent_id)
- if parent_state is not None and not self._next_stage_already_submitted(deferred["stage_id"], parent_state):
- await self._forward_to_next_stage(
- parent_id,
- deferred["stage_id"],
- deferred["engine_outputs"],
- parent_state,
- )
+ done_set.add(req_id)
+ if parent_id in self._deferred_parents and self._all_companions_done(parent_id):
+ deferred = self._deferred_parents.pop(parent_id)
+ parent_state = self.request_states.get(parent_id)
+ if parent_state is not None and not self._next_stage_already_submitted(deferred["stage_id"], parent_state):
+ await self._forward_to_next_stage(
+ parent_id,
+ deferred["stage_id"],
+ deferred["output"],
+ parent_state,
+ )
async def _handle_kv_ready_raw_outputs(self, stage_id: int, raw_outputs: EngineCoreOutputs) -> None:
"""Forward split requests once stage-0 KV is ready, not only when decode fully finishes."""
@@ -526,63 +415,21 @@ async def _handle_kv_ready_raw_outputs(self, stage_id: int, raw_outputs: EngineC
req_state = self.request_states.get(req_id)
if req_state is None:
continue
- if self._cfg_tracker.is_companion(req_id):
+ if req_id in self._companion_ids:
await self._handle_cfg_companion_ready(req_id)
continue
if stage_id >= req_state.final_stage_id:
continue
if self._next_stage_already_submitted(stage_id, req_state):
continue
- if self._cfg_tracker.has_companions(req_id) and not self._cfg_tracker.all_companions_done(req_id):
- self._cfg_tracker.defer_parent(req_id, raw_output, stage_id)
+ if req_id in self._companion_map and not self._all_companions_done(req_id):
+ self._deferred_parents[req_id] = {
+ "stage_id": stage_id,
+ "output": raw_output,
+ }
else:
await self._forward_to_next_stage(req_id, stage_id, raw_output, req_state)
- def _build_pd_decode_params(self, req_id: str, sp: Any) -> Any:
- """Build decode-side sampling params with KV transfer params for PD routing.
-
- Clones the sampling params and injects kv_transfer_params that tell the
- decode engine where to pull the KV cache from (prefill engine's bootstrap addr).
- """
- sp = sp.clone()
- if sp.extra_args is None:
- sp.extra_args = {}
-
- # Get KV params captured from the prefill output (must include remote_request_id).
- kv_prefill_params = self._pd_kv_params.pop(req_id, None)
- if not kv_prefill_params or "remote_request_id" not in kv_prefill_params:
- raise RuntimeError(
- f"[Orchestrator][PD] Missing prefill kv_transfer_params.remote_request_id for req={req_id}"
- )
-
- decode_kv_params: dict[str, Any] = {
- "transfer_id": f"xfer-{req_id}",
- }
-
- if self._pd_bootstrap_addr:
- decode_kv_params["remote_bootstrap_addr"] = self._pd_bootstrap_addr
-
- if self._pd_prefill_engine_id:
- decode_kv_params["remote_engine_id"] = self._pd_prefill_engine_id
-
- # Overlay params from prefill side (includes remote_request_id set by monkey patch).
- decode_kv_params.update(kv_prefill_params)
-
- # Ensure these flags are set correctly after any overlay.
- decode_kv_params["do_remote_prefill"] = True
- decode_kv_params["do_remote_decode"] = False
- if not decode_kv_params.get("transfer_id"):
- decode_kv_params["transfer_id"] = f"xfer-{req_id}"
-
- sp.extra_args["kv_transfer_params"] = decode_kv_params
-
- logger.debug(
- "[Orchestrator][PD] decode kv_transfer_params for req=%s: %s",
- req_id,
- decode_kv_params,
- )
- return sp
-
def _build_stage_metrics(
self,
stage_id: int,
@@ -628,42 +475,14 @@ def _build_stage_metrics(
total_token=self._agg_total_tokens[stage_id],
total_gen_time_ms=self._agg_total_gen_time_ms[stage_id],
),
- pipeline_timings=dict(req_state.pipeline_timings),
)
- def _build_kv_sender_info(self, sender_stage_ids: list[int]) -> dict[int, dict[str, Any]] | None:
- """Build per-request sender info for diffusion KV-transfer receivers."""
- sender_infos: dict[int, dict[str, Any]] = {}
- for sender_stage_id in dict.fromkeys(sender_stage_ids):
- if sender_stage_id < 0 or sender_stage_id >= self.num_stages:
- continue
-
- sender_stage = self.stage_clients[sender_stage_id]
- get_sender_info = getattr(sender_stage, "get_kv_sender_info", None)
- if not callable(get_sender_info):
- continue
-
- sender_info = get_sender_info()
- if not sender_info:
- logger.warning(
- "[Orchestrator] Stage-%s has no KV sender info available",
- sender_stage_id,
- )
- continue
-
- sender_infos[sender_stage_id] = sender_info
-
- return sender_infos or None
-
async def _forward_to_next_stage(
self,
req_id: str,
stage_id: int,
output: Any,
req_state: OrchestratorRequestState,
- *,
- is_streaming_session: bool = False,
- is_final_update: bool = False,
) -> None:
"""Forward output from current stage to the next stage.
@@ -673,122 +492,44 @@ async def _forward_to_next_stage(
next_stage_id = stage_id + 1
next_client = self.stage_clients[next_stage_id]
params = req_state.sampling_params_list[next_stage_id]
- next_stage_resumable = is_streaming_session and not is_final_update
if next_client.stage_type == "diffusion":
self.stage_clients[stage_id].set_engine_outputs([output])
if next_client.custom_process_input_func is not None:
- _t_ar2d = _time.perf_counter()
diffusion_prompt = next_client.custom_process_input_func(
self.stage_clients,
next_client.engine_input_source,
req_state.prompt,
False,
)
- _dt_ar2d = (_time.perf_counter() - _t_ar2d) * 1000
- req_state.pipeline_timings["ar2diffusion_ms"] = _dt_ar2d
- logger.info(
- "[Orchestrator] ar2diffusion req=%s wall_time=%.3fms stage=%d->%d",
- req_id,
- _dt_ar2d,
- stage_id,
- next_stage_id,
- )
if isinstance(diffusion_prompt, list):
- if not diffusion_prompt:
- error_output = OmniRequestOutput.from_error(
- req_id,
- f"Stage-{stage_id} produced no valid inputs for diffusion stage-{next_stage_id}",
- )
- logger.warning(
- "[Orchestrator] req=%s stage=%d produced empty diffusion inputs for stage=%d; "
- "routing terminal error output",
- req_id,
- stage_id,
- next_stage_id,
- )
- await self.output_async_queue.put(
- {
- "type": "output",
- "request_id": req_id,
- "stage_id": next_stage_id,
- "engine_outputs": error_output,
- "metrics": None,
- "finished": True,
- }
- )
- self._pd_kv_params.pop(req_id, None)
- self._cfg_tracker.cleanup_parent(req_id)
- self.request_states.pop(req_id, None)
- return
diffusion_prompt = diffusion_prompt[0]
else:
diffusion_prompt = req_state.prompt
- params = self._cfg_tracker.attach_cfg_request_ids(req_id, params)
+ # Attach CFG companion KV request IDs so the diffusion model
+ # runner can fetch companion KV caches alongside the primary one.
+ cfg_ids = self._companion_map.get(req_id)
+ if cfg_ids:
+ from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+
+ if isinstance(params, OmniDiffusionSamplingParams):
+ params = copy.deepcopy(params)
+ params.cfg_kv_request_ids = cfg_ids
+ logger.info(
+ "[Orchestrator] Attaching cfg_kv_request_ids=%s to req %s",
+ cfg_ids,
+ req_id,
+ )
- source_stage_ids = list(getattr(next_client, "engine_input_source", None) or [stage_id])
- kv_sender_info = self._build_kv_sender_info(sender_stage_ids=source_stage_ids)
if isinstance(diffusion_prompt, list):
await next_client.add_batch_request_async(
req_id,
diffusion_prompt,
params,
- kv_sender_info=kv_sender_info,
)
else:
- await next_client.add_request_async(
- req_id,
- diffusion_prompt,
- params,
- kv_sender_info=kv_sender_info,
- )
- req_state.stage_submit_ts[next_stage_id] = _time.time()
- return
-
- # PD disaggregation: prefill → decode routing uses original prompt + KV transfer params
- if self._pd_pair is not None and (stage_id, next_stage_id) == self._pd_pair:
- # Save prefill stage outputs so thinker2talker can merge embeddings later
- self.stage_clients[stage_id].set_engine_outputs([output])
-
- params = self._build_pd_decode_params(req_id, params)
-
- # Use the original user prompt for the decode stage (not processed embeddings)
- original_prompt = req_state.prompt
- raw_decode_inputs = [original_prompt] if not isinstance(original_prompt, list) else original_prompt
-
- decode_inputs: list[dict[str, Any]] = []
- for decode_input in raw_decode_inputs:
- if isinstance(decode_input, dict):
- decode_inputs.append(decode_input)
- continue
- prompt_token_ids = getattr(decode_input, "prompt_token_ids", None)
- if prompt_token_ids is None:
- raise TypeError(
- "[Orchestrator][PD] decode input must be dict or have prompt_token_ids, "
- f"got {type(decode_input).__name__} for req={req_id}"
- )
- decode_inputs.append({"prompt_token_ids": list(prompt_token_ids)})
-
- for decode_input in decode_inputs:
- request = build_engine_core_request_from_tokens(
- request_id=req_id,
- prompt=decode_input,
- params=params,
- model_config=self.stage_vllm_configs[next_stage_id].model_config,
- mm_features=req_state.mm_features, # Pass mm_features for M-RoPE
- )
- request.external_req_id = request.request_id
-
- self.output_processors[next_stage_id].add_request(
- request=request,
- prompt=None,
- parent_req=None,
- request_index=0,
- queue=None,
- )
- await next_client.add_request_async(request)
-
+ await next_client.add_request_async(req_id, diffusion_prompt, params)
req_state.stage_submit_ts[next_stage_id] = _time.time()
return
@@ -799,7 +540,6 @@ async def _forward_to_next_stage(
next_inputs = next_client.process_engine_inputs(
stage_list=self.stage_clients,
prompt=req_state.prompt,
- streaming_context=(req_state.streaming if req_state.streaming.enabled else None),
)
except Exception:
logger.exception(
@@ -811,17 +551,11 @@ async def _forward_to_next_stage(
# Build and submit requests for each input
for next_input in next_inputs:
- # Only AR thinker stages consume encoder mm_features; downstream
- # (talker/code2wav/…) must not see them (avoids encoder-cache misses).
- _ms = getattr(next_client, "model_stage", None)
- _mm_features = req_state.mm_features if _ms == "thinker" else None
request = build_engine_core_request_from_tokens(
request_id=req_id,
prompt=next_input,
params=params,
model_config=self.stage_vllm_configs[next_stage_id].model_config,
- mm_features=_mm_features,
- resumable=next_stage_resumable,
)
# TODO: Here we directly use the req id to assign.
@@ -863,13 +597,6 @@ async def _process_stage_outputs(self, stage_id: int, raw_outputs: EngineCoreOut
raw_outputs.timestamp,
None,
)
- for eco in raw_outputs.outputs:
- if not hasattr(eco, "request_id"):
- continue
- req_state = self.request_states.get(eco.request_id)
- if req_state:
- req_state.streaming.segment_finished = eco.is_segment_finished
- req_state.streaming.new_prompt_len_snapshot = eco.new_prompt_len_snapshot
if processed.reqs_to_abort:
await self.stage_clients[stage_id].abort_requests_async(processed.reqs_to_abort)
@@ -905,31 +632,19 @@ async def _handle_add_request(self, msg: dict[str, Any]) -> None:
# Track request state - use original_prompt so downstream stages
# (e.g. thinker2talker) can access the raw dict with multi_modal_data.
- request = prompt
- is_streaming = bool(getattr(request, "resumable", False))
req_state = OrchestratorRequestState(
request_id=request_id,
prompt=original_prompt,
sampling_params_list=sampling_params_list,
final_stage_id=final_stage_id,
- mm_features=getattr(prompt, "mm_features", None), # Save mm_features for PD
)
- req_state.streaming.enabled = is_streaming
req_state.stage_submit_ts[stage_id] = _time.time()
-
- # Per-request pipeline timings from caller thread
- _enqueue_ts = msg.get("enqueue_ts", 0.0)
- if _enqueue_ts > 0:
- req_state.pipeline_timings["queue_wait_ms"] = (_time.perf_counter() - _enqueue_ts) * 1000.0
- _preprocess_ms = msg.get("preprocess_ms", 0.0)
- if _preprocess_ms > 0:
- req_state.pipeline_timings["preprocess_ms"] = _preprocess_ms
-
self.request_states[request_id] = req_state
# Stage-0 prompt is already a fully-formed OmniEngineCoreRequest
# (pre-processed by AsyncOmniEngine.add_request, output processor
# already registered there) - submit directly.
+ request = prompt
stage_client = self.stage_clients[stage_id]
if stage_client.stage_type == "diffusion":
if isinstance(prompt, list):
@@ -965,7 +680,6 @@ async def _handle_streaming_update(self, msg: dict[str, Any]) -> None:
if "sampling_params_list" in msg and msg["sampling_params_list"]:
req_state.sampling_params_list = msg["sampling_params_list"]
- req_state.streaming.enabled = True
req_state.stage_submit_ts[stage_id] = _time.time()
stage_client = self.stage_clients[stage_id]
@@ -1017,14 +731,7 @@ async def _prewarm_async_chunk_stages(
params = req_state.sampling_params_list[next_stage_id]
if next_client.stage_type == "diffusion":
- source_stage_ids = list(getattr(next_client, "engine_input_source", None) or [next_stage_id - 1])
- kv_sender_info = self._build_kv_sender_info(sender_stage_ids=source_stage_ids)
- await next_client.add_request_async(
- request_id,
- req_state.prompt,
- params,
- kv_sender_info=kv_sender_info,
- )
+ await next_client.add_request_async(request_id, req_state.prompt, params)
req_state.stage_submit_ts[next_stage_id] = _time.time()
continue
@@ -1054,7 +761,13 @@ async def _handle_add_companion(self, msg: dict[str, Any]) -> None:
companion_prompt = msg["prompt"]
sampling_params_list = msg["sampling_params_list"]
- self._cfg_tracker.register_companion(parent_id, role, companion_id)
+ # Register companion mapping
+ if parent_id not in self._companion_map:
+ self._companion_map[parent_id] = {}
+ self._companion_map[parent_id][role] = companion_id
+ self._companion_ids.add(companion_id)
+ self._companion_to_parent[companion_id] = parent_id
+ self._companion_done.setdefault(parent_id, set())
companion_state = OrchestratorRequestState(
request_id=companion_id,
@@ -1079,10 +792,22 @@ async def _handle_add_companion(self, msg: dict[str, Any]) -> None:
async def _handle_abort(self, msg: dict[str, Any]) -> None:
"""Handle an abort message from the main thread."""
request_ids = msg["request_ids"]
- all_ids_to_abort = self._cfg_tracker.abort_parents(request_ids)
+ # Also abort any CFG companions for aborted parents
+ companion_ids_to_abort: list[str] = []
+ for req_id in request_ids:
+ role_map = self._companion_map.pop(req_id, {})
+ for cid in role_map.values():
+ companion_ids_to_abort.append(cid)
+ self._companion_ids.discard(cid)
+ self._companion_to_parent.pop(cid, None)
+ self.request_states.pop(cid, None)
+ self._companion_done.pop(req_id, None)
+ self._deferred_parents.pop(req_id, None)
+
+ all_ids_to_abort = list(request_ids) + companion_ids_to_abort
for stage_id in range(self.num_stages):
await self.stage_clients[stage_id].abort_requests_async(all_ids_to_abort)
- for req_id in all_ids_to_abort:
+ for req_id in request_ids:
self.request_states.pop(req_id, None)
logger.info("[Orchestrator] Aborted request(s) %s", request_ids)
@@ -1151,55 +876,6 @@ async def _handle_collective_rpc(self, msg: dict[str, Any]) -> None:
}
)
- async def _drain_pending_requests_on_fatal(self) -> None:
- """Drain the request queue and broadcast fatal errors for any
- pending add_request messages that were never processed.
-
- Called from the ``run()`` finally block when a fatal error
- (e.g. ``EngineDeadError``) caused the orchestrator to shut down
- before the request handler could process all queued messages.
- Also broadcasts for any already-tracked requests still in
- ``request_states`` that were not yet notified.
- """
- assert self._fatal_error is not None
-
- notified: set[str] = set()
-
- # 1) Drain pending messages from the request queue.
- while True:
- try:
- msg = self.request_async_queue.get_nowait()
- except Exception:
- break
- if msg.get("type") == "add_request":
- req_id = msg["request_id"]
- await self.output_async_queue.put(
- {
- "type": "error",
- "error": self._fatal_error,
- "fatal": True,
- "request_id": req_id,
- "stage_id": self._fatal_error_stage_id,
- }
- )
- notified.add(req_id)
-
- # 2) Broadcast for any tracked requests not already notified
- # (e.g. request was registered but the EngineDeadError handler
- # missed it because it wasn't submitted to the dead stage yet).
- for req_id in list(self.request_states):
- if req_id not in notified:
- await self.output_async_queue.put(
- {
- "type": "error",
- "error": self._fatal_error,
- "fatal": True,
- "request_id": req_id,
- "stage_id": self._fatal_error_stage_id,
- }
- )
- self.request_states.pop(req_id, None)
-
def _shutdown_stages(self) -> None:
"""Shutdown all stage clients."""
if self._stages_shutdown:
diff --git a/vllm_omni/engine/output_modality.py b/vllm_omni/engine/output_modality.py
index c2615ca8c61..c4a1288a8ab 100644
--- a/vllm_omni/engine/output_modality.py
+++ b/vllm_omni/engine/output_modality.py
@@ -8,7 +8,7 @@
from __future__ import annotations
import re
-from enum import Enum, Flag, StrEnum, auto
+from enum import Enum, Flag, auto
_MODALITY_ALIASES: dict[str, str] = {
"speech": "audio",
@@ -21,27 +21,6 @@
}
-class OutputModalityNames(StrEnum):
- """Keys for output modalities.
-
- TODO: (Alex) Integrate this with the big-flag enum below + throughout the code
- for better type safety (currently only used for output processor).
- """
-
- TEXT = "text"
- IMAGE = "image"
- AUDIO = "audio"
- LATENT = "latent"
-
-
-# Specify which output modalities may be drained when handling delta messages.
-# For some types, e.g., latents, we need to be careful to ensure the full context
-# is passed as the stream yields due to assumptions in the I/O processing and model
-# when async chunk isn't enabled.
-NON_DRAINABLE_MODALITIES = {OutputModalityNames.TEXT, OutputModalityNames.LATENT}
-DRAINABLE_MODALITIES = {mod for mod in OutputModalityNames if mod not in NON_DRAINABLE_MODALITIES}
-
-
class OutputModality(Flag):
"""Bit-flag enum for output modalities.
diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py
index 84016fdb4a1..43d02e85b84 100644
--- a/vllm_omni/engine/output_processor.py
+++ b/vllm_omni/engine/output_processor.py
@@ -16,8 +16,6 @@
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.metrics.stats import IterationStats
-from vllm_omni.data_entry_keys import unflatten_payload
-from vllm_omni.engine.output_modality import DRAINABLE_MODALITIES
from vllm_omni.outputs import OmniRequestOutput
logger = init_logger(__name__)
@@ -38,63 +36,64 @@ def __init__(
):
super().__init__(*args, **kwargs)
# Omni-specific: multimodal output accumulation
- # NOTE: Keys in mm_accumulated matter, because they dictate which
- # outputs are drained (i.e., only drain modality keys, don't drain
- # hidden states).
- self.mm_accumulated: dict[str, Any] = {}
-
- @staticmethod
- def _to_cpu(x):
- """Try to convert to CPU tensor if needed."""
- # TODO: Make this more robust and unify with other payload
- # building utils, we do this in multiple places.
- if isinstance(x, torch.Tensor):
- try:
- return x.detach().to("cpu", non_blocking=True).contiguous()
- except Exception:
- return x
- return x
+ self.mm_type: str | None = None
+ self.mm_accumulated: dict[str, Any] | None = None
def add_multimodal_tensor(self, payload: Any | None, mm_type: str | None) -> None:
if payload is None:
return
-
- mm_type = (mm_type or "").lower()
try:
- if isinstance(payload, dict):
- # Keep payload flat (dotted keys like "hidden_states.layer_0")
- # during accumulation so that all values are tensors/scalars and
- # the merge logic below works correctly. Unflatten happens
- # later in _consolidate_multimodal_tensors after concatenation.
+ if mm_type:
+ self.mm_type = (mm_type or "").lower()
+ # Normalize incoming payload to dict on CPU
+ def _to_cpu(x):
+ if isinstance(x, torch.Tensor):
+ try:
+ return x.detach().to("cpu", non_blocking=True).contiguous()
+ except Exception:
+ return x
+ return x
+
+ if isinstance(payload, dict):
incoming: dict[str, Any] = {}
- # TODO (Alex): Clean up and simplify key management
- target_key = mm_type or "hidden"
+ target_key = self.mm_type or "hidden"
for k, v in payload.items():
+ # Normalize producer keys to the modality name.
+ # AR runners produce {"hidden": ...} and generation
+ # runners produce {"model_outputs": ...}; remap both
+ # to the semantic modality key (e.g. "audio", "latent").
if k == "model_outputs":
k = target_key
elif k == "hidden" and target_key != "hidden":
k = target_key
- incoming[k] = self._to_cpu(v)
+ if isinstance(v, dict):
+ incoming[k] = {str(sk): _to_cpu(sv) for sk, sv in v.items()}
+ else:
+ incoming[k] = _to_cpu(v)
else:
- key = mm_type or "hidden"
- incoming = {key: self._to_cpu(payload)}
+ key = self.mm_type or "hidden"
+ incoming = {key: _to_cpu(payload)}
- if not self.mm_accumulated:
+ if self.mm_accumulated is None:
self.mm_accumulated = incoming
else:
+ # Merge keys; accumulate tensors in lists for deferred concatenation
for k, v in incoming.items():
if k not in self.mm_accumulated:
self.mm_accumulated[k] = v
else:
existing = self.mm_accumulated[k]
if isinstance(v, torch.Tensor) and isinstance(existing, torch.Tensor):
+ # Use list accumulation to avoid O(n²) repeated concatenation
self.mm_accumulated[k] = [existing, v]
elif isinstance(v, torch.Tensor) and isinstance(existing, list):
+ # Append to existing list
existing.append(v)
elif isinstance(v, dict) and isinstance(existing, dict):
+ # Merge nested dicts with list accumulation for tensors
for sk, sv in v.items():
if sk not in existing:
existing[sk] = sv
@@ -111,26 +110,17 @@ def add_multimodal_tensor(self, payload: Any | None, mm_type: str | None) -> Non
logger.exception("Error accumulating multimodal tensor")
def _consolidate_multimodal_tensors(self) -> None:
- """Consolidate accumulated tensor lists into single tensors via concatenation.
-
- Only DELTA drains modality keys per-step, so they will never be lists here and
- can be skipped. For CUMULATIVE and FINAL_ONLY, modality keys accumulate across
- steps and need consolidation.
- """
- if not self.mm_accumulated:
+ """Consolidate accumulated tensor lists into single tensors via concatenation."""
+ if self.mm_accumulated is None:
return
-
- skip_modality = self.output_kind == RequestOutputKind.DELTA
try:
for k, v in self.mm_accumulated.items():
- if skip_modality and k in DRAINABLE_MODALITIES:
- continue
if isinstance(v, list) and v and isinstance(v[0], torch.Tensor):
try:
if k == "audio":
- # Audio chunks are usually 2D (i.e., 1, N); concatenate
- # on the last axis to preserve the channel dimension.
- self.mm_accumulated[k] = torch.cat(v, dim=-1)
+ # When the audio tensor shape is inconsistent, torch.cat will fail.
+ # We need to use torch.cat in -1 dimension.
+ continue
elif k == "sr":
# Sample rate is a constant scalar, keep last value.
self.mm_accumulated[k] = v[-1]
@@ -150,13 +140,6 @@ def _consolidate_multimodal_tensors(self) -> None:
except Exception:
logger.exception("Error consolidating multimodal tensors")
- # Restore nested structure from flat dotted keys now that all tensor
- # lists have been concatenated into single tensors.
- try:
- self.mm_accumulated = unflatten_payload(self.mm_accumulated)
- except Exception:
- logger.exception("Error unflattening consolidated multimodal tensors")
-
# Override: do not route to pooling-only path; always create completion
# outputs, and attach pooling_result into the CompletionOutput.
def make_request_output(
@@ -197,13 +180,13 @@ def make_request_output(
)
finished = finish_reason is not None
- is_final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
- is_delta = self.output_kind == RequestOutputKind.DELTA
+ final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
- if not finished and is_final_only:
+ if not finished and final_only:
return None
- if finished or not is_delta:
+ # Consolidate accumulated tensors when finishing.
+ if finished:
self._consolidate_multimodal_tensors()
if self.stream_interval > 1:
@@ -220,7 +203,7 @@ def make_request_output(
):
return None
- if is_delta:
+ if self.output_kind == RequestOutputKind.DELTA:
# Send tokens from the offset in DELTA mode, otherwise all
# tokens are sent.
new_token_ids = self.detokenizer.output_token_ids[self.sent_tokens_offset :]
@@ -248,26 +231,19 @@ def _new_completion_output(
) -> Any:
# Reuse base text/logprobs logic, then annotate with pooling_result.
base_output = super()._new_completion_output(token_ids, finish_reason, stop_reason, routed_experts)
-
- # Inter-stage processors need the full cumulative token sequence.
- # In DELTA mode, base_output.token_ids only has the latest step's
- # tokens, so we always store a cumulative copy here.
- base_output.cumulative_token_ids = list(self.detokenizer.output_token_ids)
-
- if not hasattr(base_output, "multimodal_output"):
- setattr(base_output, "multimodal_output", {})
- if self.mm_accumulated:
- mm_out = getattr(base_output, "multimodal_output")
- if isinstance(mm_out, dict):
- for k, v in self.mm_accumulated.items():
- mm_out[k] = v
- else:
- setattr(base_output, "multimodal_output", self.mm_accumulated)
-
- if self.output_kind == RequestOutputKind.DELTA:
- for modality_key in DRAINABLE_MODALITIES:
- self.mm_accumulated.pop(modality_key, None)
-
+ try:
+ if self.mm_accumulated is not None:
+ # Attach accumulated multimodal dict on the completion output
+ if not hasattr(base_output, "multimodal_output"):
+ setattr(base_output, "multimodal_output", {})
+ mm_out = getattr(base_output, "multimodal_output")
+ if isinstance(mm_out, dict):
+ for k, v in self.mm_accumulated.items():
+ mm_out[k] = v
+ else:
+ setattr(base_output, "multimodal_output", self.mm_accumulated)
+ except Exception:
+ logger.exception("Error in _new_completion_output")
return base_output
diff --git a/vllm_omni/engine/serialization.py b/vllm_omni/engine/serialization.py
index 5b87074106c..386c97814f3 100644
--- a/vllm_omni/engine/serialization.py
+++ b/vllm_omni/engine/serialization.py
@@ -4,39 +4,110 @@
from typing import Any
+import numpy as np
+import torch
from vllm.logger import init_logger
-from vllm_omni.data_entry_keys import OmniPayload, deserialize_payload, serialize_payload
-from vllm_omni.engine import AdditionalInformationPayload
+from vllm_omni.engine import (
+ AdditionalInformationEntry,
+ AdditionalInformationPayload,
+)
logger = init_logger(__name__)
+def dtype_to_name(dtype: torch.dtype) -> str:
+ """Convert torch dtype to a stable string name for serialization."""
+ mapping = {
+ torch.float32: "float32",
+ torch.float: "float32",
+ torch.float16: "float16",
+ torch.half: "float16",
+ torch.bfloat16: "bfloat16",
+ torch.float64: "float64",
+ torch.double: "float64",
+ torch.int64: "int64",
+ torch.long: "int64",
+ torch.int32: "int32",
+ torch.int: "int32",
+ torch.int16: "int16",
+ torch.short: "int16",
+ torch.int8: "int8",
+ torch.uint8: "uint8",
+ torch.bool: "bool",
+ }
+ return mapping.get(dtype, str(dtype).replace("torch.", ""))
+
+
def serialize_additional_information(
raw_info: dict[str, Any] | AdditionalInformationPayload | None,
*,
log_prefix: str | None = None,
) -> AdditionalInformationPayload | None:
- """Serialize omni request metadata for EngineCore transport.
-
- Delegates to ``serialize_payload`` which understands the nested
- ``OmniPayload`` TypedDict structure.
- """
+ """Serialize omni request metadata for EngineCore transport."""
if raw_info is None:
return None
if isinstance(raw_info, AdditionalInformationPayload):
return raw_info
- payload: OmniPayload = raw_info # type: ignore[assignment]
- return serialize_payload(payload)
+ entries: dict[str, AdditionalInformationEntry] = {}
+ for key, value in raw_info.items():
+ if isinstance(value, torch.Tensor):
+ value_cpu = value.detach().to("cpu").contiguous()
+ entries[key] = AdditionalInformationEntry(
+ tensor_data=value_cpu.numpy().tobytes(),
+ tensor_shape=list(value_cpu.shape),
+ tensor_dtype=dtype_to_name(value_cpu.dtype),
+ )
+ continue
+
+ if isinstance(value, list):
+ entries[key] = AdditionalInformationEntry(list_data=value)
+ continue
+
+ entries[key] = AdditionalInformationEntry(scalar_data=value)
+
+ return AdditionalInformationPayload(entries=entries) if entries else None
def deserialize_additional_information(
- payload: dict | AdditionalInformationPayload | None,
+ payload: dict | AdditionalInformationPayload | object | None,
) -> dict:
- """Deserialize an *additional_information* payload into a plain dict."""
+ """Deserialize an *additional_information* payload into a plain dict.
+
+ Accepts:
+ - ``dict`` – returned as-is.
+ - ``AdditionalInformationPayload`` (or duck-typed with
+ ``.entries``) – decoded entry-by-entry.
+ - ``None`` – returns ``{}``.
+ """
+
if payload is None:
return {}
+
if isinstance(payload, dict):
return payload
- return deserialize_payload(payload) # type: ignore[return-value]
+
+ try:
+ entries = getattr(payload, "entries", None)
+ if not isinstance(entries, dict):
+ logger.exception("Failed to decode additional_information payload, entries field not a dict")
+ return {}
+ info: dict[str, object] = {}
+ for k, entry in entries.items():
+ if getattr(entry, "tensor_data", None) is not None:
+ dt = np.dtype(getattr(entry, "tensor_dtype", "float32"))
+ arr = np.frombuffer(entry.tensor_data, dtype=dt)
+ arr = arr.reshape(getattr(entry, "tensor_shape", ()))
+ info[k] = torch.from_numpy(arr.copy())
+ elif getattr(entry, "list_data", None) is not None:
+ info[k] = entry.list_data
+ elif getattr(entry, "scalar_data", None) is not None:
+ info[k] = entry.scalar_data
+ else:
+ info[k] = None
+ return info
+ except Exception:
+ logger.exception("Failed to decode additional_information payload")
+
+ return {}
diff --git a/vllm_omni/engine/stage_engine_core_client.py b/vllm_omni/engine/stage_engine_core_client.py
index 37b0e538a13..e08ce780112 100644
--- a/vllm_omni/engine/stage_engine_core_client.py
+++ b/vllm_omni/engine/stage_engine_core_client.py
@@ -6,22 +6,12 @@
from __future__ import annotations
-import multiprocessing.connection
-import socket
-import threading
-import weakref
from typing import TYPE_CHECKING, Any
-from urllib.parse import urlparse
-import psutil
from vllm.logger import init_logger
from vllm.v1.engine import EngineCoreRequest
-from vllm.v1.engine.core_client import AsyncMPClient, DPLBAsyncMPClient
-from vllm.v1.engine.exceptions import EngineDeadError
+from vllm.v1.engine.core_client import AsyncMPClient
-from vllm_omni.distributed.omni_connectors.utils.initialization import (
- KV_TRANSFER_PORT_OFFSET,
-)
from vllm_omni.engine.stage_init_utils import StageMetadata
if TYPE_CHECKING:
@@ -31,57 +21,19 @@
logger = init_logger(__name__)
-SHUTDOWN_TIMEOUT_S = 5
+class StageEngineCoreClient(AsyncMPClient):
+ """Stage async client that inherits from vLLM's AsyncMPClient.
-class StageEngineCoreClientBase:
- """Shared stage-aware behavior for async EngineCore clients.
-
- The concrete transport/load-balancing behavior is supplied by the
- multiprocessing client subclass in the MRO.
-
- Fully reuses the underlying vLLM async MP client ``__init__`` for:
+ Fully reuses AsyncMPClient for:
- ZMQ setup, sockets
- outputs_queue, output_queue_task
- All utility methods (get_output_async, abort_requests_async, etc.)
The subprocess is spawned externally via ``spawn_stage_core`` /
``complete_stage_handshake`` from *stage_engine_core_proc.py*.
- In single-stage CLI mode, the client may instead attach to an
- ``engine_manager`` / ``coordinator`` pair created elsewhere.
"""
- @staticmethod
- def make_async_mp_client(
- vllm_config: Any,
- executor_class: type,
- metadata: StageMetadata,
- client_addresses: dict[str, str] | None = None,
- proc: Any = None,
- engine_manager: Any = None,
- coordinator: Any = None,
- client_count: int = 1,
- client_index: int = 0,
- ) -> StageEngineCoreClient | DPLBStageEngineCoreClient:
- """Create the appropriate stage async client for the DP mode."""
- parallel_config = vllm_config.parallel_config
- client_args = dict(
- vllm_config=vllm_config,
- executor_class=executor_class,
- metadata=metadata,
- client_addresses=client_addresses,
- proc=proc,
- engine_manager=engine_manager,
- coordinator=coordinator,
- client_count=client_count,
- client_index=client_index,
- )
-
- if parallel_config.data_parallel_size > 1 and not parallel_config.data_parallel_external_lb:
- return DPLBStageEngineCoreClient(**client_args)
-
- return StageEngineCoreClient(**client_args)
-
def __init__(
self,
vllm_config: Any,
@@ -124,16 +76,9 @@ def __init__(
self.engine_outputs: Any = None
self._proc = proc
- self.client_addresses = dict(client_addresses or {})
- self._omni_kv_config = getattr(getattr(vllm_config, "model_config", None), "omni_kv_config", None)
- self._kv_sender_host = self._resolve_contact_host()
- self._kv_sender_info: dict[str, Any] | None = None
- self._kv_sender_initialized = False
- client_name = self.__class__.__name__
logger.info(
- "[%s] Stage-%s initializing EngineCore",
- client_name,
+ "[StageEngineCoreClient] Stage-%s initializing EngineCore",
self.stage_id,
)
try:
@@ -145,213 +90,34 @@ def __init__(
client_count=client_count,
client_index=client_index,
)
- if engine_manager is not None:
- self.resources.engine_manager = engine_manager
- if coordinator is not None:
- self.resources.coordinator = coordinator
except Exception:
logger.exception(
- "[%s] Stage-%s EngineCore init failed",
- client_name,
+ "[StageEngineCoreClient] Stage-%s EngineCore init failed",
self.stage_id,
)
try:
self.shutdown()
except Exception as shutdown_error:
logger.warning(
- "[%s] Stage-%s cleanup after init failure failed: %s",
- client_name,
+ "[StageEngineCoreClient] Stage-%s cleanup after init failure failed: %s",
self.stage_id,
shutdown_error,
)
raise
-
- self._initialize_kv_sender_endpoint()
-
- if self._proc is not None:
- self._start_proc_monitor()
-
logger.info(
- "[%s] Stage-%s EngineCore running",
- client_name,
+ "[StageEngineCoreClient] Stage-%s EngineCore running",
self.stage_id,
)
- def _start_proc_monitor(self) -> None:
- """Start a daemon thread that watches the subprocess sentinel.
-
- When the subprocess dies without sending the ZMQ ``ENGINE_CORE_DEAD``
- sentinel (e.g. SIGKILL, segfault, OOM-killer), this thread sets
- ``resources.engine_dead`` so subsequent calls raise
- ``EngineDeadError``.
- """
- proc = self._proc
- resources_ref = weakref.ref(self.resources)
- stage_id = self.stage_id
-
- def _monitor() -> None:
- try:
- multiprocessing.connection.wait([proc.sentinel])
- except Exception:
- return
- resources = resources_ref()
- if resources is None or resources.engine_dead:
- return
- resources.engine_dead = True
- logger.error(
- "[StageEngineCoreClient] Stage-%s subprocess died unexpectedly (exit code %s).",
- stage_id,
- proc.exitcode,
- )
-
- t = threading.Thread(
- target=_monitor,
- daemon=True,
- name=f"StageCoreProcMonitor-{stage_id}",
- )
- t.start()
-
- def check_health(self) -> None:
- """Raise ``EngineDeadError`` if the stage subprocess is dead.
-
- Called by ``OmniBase.check_health()`` and transitively by the
- ``/health`` HTTP endpoint.
- """
- if self.resources.engine_dead:
- raise EngineDeadError(f"Stage-{self.stage_id} engine core is dead")
- if self._proc is not None and not self._proc.is_alive():
- self.resources.engine_dead = True
- raise EngineDeadError(f"Stage-{self.stage_id} subprocess is not alive (exit code {self._proc.exitcode})")
-
# ==================== Overrides ====================
async def add_request_async(self, request: EngineCoreRequest) -> None:
"""Add request to the stage engine core."""
- logger.info(
- "[%s] Stage-%s adding request: %s",
- self.__class__.__name__,
- self.stage_id,
- request.request_id,
- )
+ logger.info(f"[StageEngineCoreClient] Stage-{self.stage_id} adding request: {request.request_id}")
await super().add_request_async(request)
# ==================== Stage Methods ====================
- @staticmethod
- def _detect_local_ip() -> str | None:
- """Best-effort local IP detection for cross-node connector bootstrap."""
- try:
- with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
- sock.connect(("8.8.8.8", 80))
- return sock.getsockname()[0]
- except Exception:
- try:
- return socket.gethostbyname(socket.gethostname())
- except Exception:
- return None
-
- def _resolve_contact_host(self) -> str | None:
- """Resolve a routable host for this stage from its client addresses."""
- for key in ("input_address", "output_address", "stats_update_address"):
- address = self.client_addresses.get(key)
- if not address:
- continue
- host = urlparse(address).hostname
- if host in {None, "", "*", "0.0.0.0", "::"}:
- continue
- if host in {"localhost", "127.0.0.1"}:
- detected = self._detect_local_ip()
- if detected:
- return detected
- continue
- return host
- return self._detect_local_ip()
-
- def _get_kv_connector_config(self) -> dict[str, Any] | None:
- omni_kv_config = getattr(self, "_omni_kv_config", None)
- if not isinstance(omni_kv_config, dict):
- return None
- connector_config = omni_kv_config.get("connector_config")
- if not isinstance(connector_config, dict):
- return None
- return connector_config
-
- def _resolve_sender_host_from_config(self, connector_config: dict[str, Any]) -> str | None:
- host = connector_config.get("sender_host") or connector_config.get("host")
- if host in {None, "", "auto", "*", "0.0.0.0", "::"}:
- return self._resolve_contact_host()
- return str(host)
-
- def _initialize_kv_sender_endpoint(self) -> None:
- if self._kv_sender_initialized:
- return
- self._kv_sender_initialized = True
- connector_config = self._get_kv_connector_config()
- if connector_config is None or connector_config.get("role") != "sender":
- return
-
- sender_host = self._resolve_sender_host_from_config(connector_config)
- if sender_host is not None:
- self._kv_sender_host = sender_host
-
- sender_port = connector_config.get("sender_zmq_port")
- if sender_port is None:
- base_port = connector_config.get("zmq_port")
- if base_port is None:
- return
-
- omni_kv_config = getattr(self, "_omni_kv_config", None)
- from_stage = self.stage_id
- if isinstance(omni_kv_config, dict):
- from_stage = omni_kv_config.get("omni_from_stage", from_stage)
-
- try:
- # Orchestrator always reports rank-0's port; receiver
- # workers add their own local_rank * KV_RANK_PORT_STRIDE.
- sender_port = int(base_port) + KV_TRANSFER_PORT_OFFSET + int(from_stage)
- except (TypeError, ValueError):
- logger.warning(
- "[StageEngineCoreClient] Stage-%s could not resolve sender_zmq_port "
- "from base_port=%s and from_stage=%s",
- self.stage_id,
- base_port,
- from_stage,
- )
- return
-
- if self._kv_sender_host is None:
- return
-
- self._kv_sender_info = {
- "host": str(self._kv_sender_host),
- "zmq_port": int(sender_port),
- }
-
- def get_kv_sender_info(
- self,
- *,
- base_port: int = 50051,
- kv_transfer_port_offset: int = KV_TRANSFER_PORT_OFFSET,
- ) -> dict[str, Any] | None:
- """Build sender bootstrap info for diffusion KV transfer receivers.
-
- ``base_port`` and ``kv_transfer_port_offset`` are only used by the
- legacy fallback path when no connector-level sender endpoint is
- configured in ``omni_kv_config``.
- """
- if self._kv_sender_info is not None:
- return dict(self._kv_sender_info)
-
- if self._kv_sender_host is None:
- self._kv_sender_host = self._resolve_contact_host()
- if self._kv_sender_host is None:
- return None
- # rank-0 base port; receiver workers adjust per KV_RANK_PORT_STRIDE.
- return {
- "host": self._kv_sender_host,
- "zmq_port": base_port + kv_transfer_port_offset + int(self.stage_id),
- }
-
def set_engine_outputs(self, engine_outputs: EngineCoreOutput) -> None:
"""Set engine outputs (called by orchestrator)."""
self.engine_outputs = engine_outputs
@@ -360,21 +126,11 @@ def process_engine_inputs(
self,
stage_list: list[Any],
prompt: OmniTokensPrompt | list[OmniTokensPrompt] | None = None,
- streaming_context: Any | None = None,
) -> list[OmniTokensPrompt]:
"""Process inputs from upstream stages."""
from vllm_omni.inputs.data import OmniTokensPrompt
if self.custom_process_input_func is not None:
- # Keep legacy arg call for non-streaming processors.
- if bool(getattr(streaming_context, "enabled", False)):
- return self.custom_process_input_func(
- stage_list,
- self.engine_input_source,
- prompt,
- self.requires_multimodal_data,
- streaming_context,
- )
return self.custom_process_input_func(
stage_list,
self.engine_input_source,
@@ -395,7 +151,7 @@ def process_engine_inputs(
return [
OmniTokensPrompt(
- prompt_token_ids=so.outputs[0].cumulative_token_ids,
+ prompt_token_ids=so.outputs[0].token_ids,
multi_modal_data=(mm_data[so.request_id] if self.requires_multimodal_data else None),
)
for so in source_outputs
@@ -410,9 +166,9 @@ async def collective_rpc_async(
) -> Any:
"""Forward control RPCs to the underlying AsyncMPClient stage engine.
- Each stage client already represents one logical stage, so stage-scoped
- control operations should be executed here and then fanned in-core
- across the workers managed by this EngineCore client.
+ Each ``StageEngineCoreClient`` already represents one logical stage, so
+ stage-scoped control operations should be executed here and then fanned
+ in-core across the workers managed by this EngineCore client.
"""
return await super().collective_rpc_async(
method=method,
@@ -422,43 +178,10 @@ async def collective_rpc_async(
)
def shutdown(self) -> None:
- """Shutdown managed resources and any externally spawned subprocess."""
- child_procs: list[psutil.Process] = []
- if self._proc is not None and self._proc.pid is not None:
- try:
- child_procs = psutil.Process(self._proc.pid).children(recursive=True)
- except psutil.Error:
- child_procs = []
-
- try:
- super().shutdown()
- finally:
- if self._proc is not None and self._proc.is_alive():
- self._proc.terminate()
- self._proc.join(timeout=SHUTDOWN_TIMEOUT_S)
- if self._proc.is_alive():
- self._proc.kill()
- self._proc.join(timeout=SHUTDOWN_TIMEOUT_S)
-
- alive_children = [proc for proc in child_procs if proc.is_running()]
- for proc in alive_children:
- try:
- proc.terminate()
- except psutil.Error:
- pass
- _, still_alive = psutil.wait_procs(alive_children, timeout=SHUTDOWN_TIMEOUT_S)
- for proc in still_alive:
- try:
- proc.kill()
- except psutil.Error:
- pass
- # The process handle is no longer reliable after best-effort cleanup.
- self._proc = None
-
-
-class StageEngineCoreClient(StageEngineCoreClientBase, AsyncMPClient):
- """Stage async client backed by vLLM's ``AsyncMPClient``."""
-
-
-class DPLBStageEngineCoreClient(StageEngineCoreClientBase, DPLBAsyncMPClient):
- """Stage async client backed by vLLM's ``DPLBAsyncMPClient``."""
+ """Shutdown ZMQ connections and the subprocess."""
+ super().shutdown()
+ if self._proc is not None and self._proc.is_alive():
+ self._proc.terminate()
+ self._proc.join(timeout=5)
+ if self._proc.is_alive():
+ self._proc.kill()
diff --git a/vllm_omni/engine/stage_engine_core_proc.py b/vllm_omni/engine/stage_engine_core_proc.py
index 2ab8b37dd5f..05d8f107c23 100644
--- a/vllm_omni/engine/stage_engine_core_proc.py
+++ b/vllm_omni/engine/stage_engine_core_proc.py
@@ -23,12 +23,10 @@
get_mp_context,
set_process_title,
)
-from vllm.v1.engine import EngineCoreRequestType
-from vllm.v1.engine.core import EngineCoreProc, EngineShutdownState
+from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.utils import (
EngineHandshakeMetadata,
EngineZmqAddresses,
- SignalCallback,
get_engine_zmq_addresses,
)
from vllm.v1.utils import shutdown
@@ -39,6 +37,8 @@
logger = init_logger(__name__)
+_HANDSHAKE_POLL_TIMEOUT_S = 600
+
class StageEngineCoreProc(EngineCoreProc):
"""Stage-specific engine core process for vLLM-Omni.
@@ -56,9 +56,18 @@ def run_stage_core(
**kwargs: Any,
) -> None:
"""Launch StageEngineCoreProc busy loop in background process."""
- signal_callback: SignalCallback | None = None
+ shutdown_requested = False
maybe_register_config_serialize_by_value()
+ def signal_handler(signum: int, frame: Any) -> None:
+ nonlocal shutdown_requested
+ if not shutdown_requested:
+ shutdown_requested = True
+ raise SystemExit()
+
+ signal.signal(signal.SIGTERM, signal_handler)
+ signal.signal(signal.SIGINT, signal_handler)
+
engine_core: StageEngineCoreProc | None = None
try:
vllm_config: VllmConfig = kwargs["vllm_config"]
@@ -81,19 +90,6 @@ def run_stage_core(
engine_index=dp_rank,
**kwargs,
)
-
- def wakeup_engine() -> None:
- engine_core.input_queue.put_nowait((EngineCoreRequestType.WAKEUP, None))
-
- signal_callback = SignalCallback(wakeup_engine)
-
- def signal_handler(signum: int, frame: Any) -> None:
- engine_core.shutdown_state = EngineShutdownState.REQUESTED
- signal_callback.trigger()
-
- signal.signal(signal.SIGTERM, signal_handler)
- signal.signal(signal.SIGINT, signal_handler)
-
engine_core.run_busy_loop()
except SystemExit:
@@ -107,10 +103,6 @@ def signal_handler(signum: int, frame: Any) -> None:
engine_core._send_engine_dead()
raise
finally:
- signal.signal(signal.SIGTERM, signal.SIG_DFL)
- signal.signal(signal.SIGINT, signal.SIG_DFL)
- if signal_callback is not None:
- signal_callback.stop()
if engine_core is not None:
engine_core.shutdown()
@@ -153,14 +145,13 @@ def complete_stage_handshake(
handshake_address: str,
addresses: EngineZmqAddresses,
vllm_config: VllmConfig,
- handshake_timeout: int,
) -> None:
"""Perform the HELLO/INIT/READY handshake with an already-spawned proc.
On failure the process is terminated before re-raising.
"""
try:
- _perform_handshake(proc, handshake_address, addresses, vllm_config, handshake_timeout)
+ _perform_handshake(proc, handshake_address, addresses, vllm_config)
except Exception:
shutdown([proc])
raise
@@ -171,7 +162,6 @@ def _perform_handshake(
handshake_address: str,
addresses: EngineZmqAddresses,
vllm_config: VllmConfig,
- handshake_timeout: int,
) -> None:
"""Run the HELLO / INIT / READY handshake with the subprocess."""
with zmq_socket_ctx(handshake_address, zmq.ROUTER, bind=True) as handshake_socket:
@@ -179,7 +169,7 @@ def _perform_handshake(
poller.register(handshake_socket, zmq.POLLIN)
poller.register(proc.sentinel, zmq.POLLIN)
- identity, msg = _recv(poller, handshake_socket, proc, "HELLO", handshake_timeout)
+ identity, msg = _recv(poller, handshake_socket, proc, "HELLO")
if msg.get("status") != "HELLO":
raise RuntimeError(f"Expected HELLO, got: {msg}")
@@ -189,7 +179,7 @@ def _perform_handshake(
)
handshake_socket.send_multipart([identity, msgspec.msgpack.encode(init_payload)])
- identity, msg = _recv(poller, handshake_socket, proc, "READY", handshake_timeout)
+ identity, msg = _recv(poller, handshake_socket, proc, "READY")
if msg.get("status") != "READY":
raise RuntimeError(f"Expected READY, got: {msg}")
num_gpu_blocks = msg.get("num_gpu_blocks")
@@ -202,18 +192,13 @@ def _recv(
handshake_socket: zmq.Socket,
proc: BaseProcess,
expected: str,
- timeout_s: int = 600,
) -> tuple[bytes, dict]:
"""Wait for one handshake message; raise if the process dies first."""
- timeout_ms = timeout_s * 1000
+ timeout_ms = _HANDSHAKE_POLL_TIMEOUT_S * 1000
while True:
events = dict(poller.poll(timeout=timeout_ms))
if not events:
- raise TimeoutError(
- f"Timed out waiting for {expected} from StageEngineCoreProc after {timeout_s}s. "
- f"This typically indicates model loading or initialization is taking too long. "
- f"Consider increasing `stage_init_timeout` for large models."
- )
+ raise TimeoutError(f"Timed out waiting for {expected} from StageEngineCoreProc")
if handshake_socket in events:
identity, raw = handshake_socket.recv_multipart()
return identity, msgspec.msgpack.decode(raw)
diff --git a/vllm_omni/engine/stage_engine_startup.py b/vllm_omni/engine/stage_engine_startup.py
deleted file mode 100644
index 6af66c71f34..00000000000
--- a/vllm_omni/engine/stage_engine_startup.py
+++ /dev/null
@@ -1,599 +0,0 @@
-"""Helpers for launching and handshaking omni engine cores."""
-
-from __future__ import annotations
-
-import contextlib
-import dataclasses
-import threading
-from collections.abc import Iterator
-from dataclasses import dataclass
-from typing import Any
-
-import msgspec
-import zmq
-from omegaconf import OmegaConf
-from vllm.config import CacheConfig, VllmConfig
-from vllm.logger import init_logger
-from vllm.utils.network_utils import get_open_port, zmq_socket_ctx
-from vllm.v1.engine.coordinator import DPCoordinator
-from vllm.v1.engine.utils import (
- STARTUP_POLL_PERIOD_MS,
- CoreEngine,
- CoreEngineProcManager,
- CoreEngineState,
- EngineHandshakeMetadata,
- EngineZmqAddresses,
- wait_for_engine_startup,
-)
-from vllm.v1.executor import Executor
-
-logger = init_logger(__name__)
-
-# Poll period (ms) used by the registration/handshake loop.
-_POLL_PERIOD_MS = 5_000
-# Default timeout (s) for a stage to send READY.
-_DEFAULT_STARTUP_TIMEOUT_S = 300
-
-
-def _serialize_stage_config(stage_config: Any) -> Any:
- """Convert a stage config to msgpack-friendly builtins."""
- if stage_config is None or isinstance(stage_config, (str, bytes, int, float, bool)):
- return stage_config
-
- if OmegaConf.is_config(stage_config):
- return _serialize_stage_config(OmegaConf.to_container(stage_config, resolve=True))
-
- if dataclasses.is_dataclass(stage_config):
- return _serialize_stage_config(dataclasses.asdict(stage_config))
-
- if isinstance(stage_config, dict):
- return {key: _serialize_stage_config(value) for key, value in stage_config.items() if not callable(value)}
-
- if isinstance(stage_config, (list, tuple, set)):
- return [_serialize_stage_config(item) for item in stage_config if not callable(item)]
-
- if hasattr(stage_config, "items"):
- return {key: _serialize_stage_config(value) for key, value in stage_config.items() if not callable(value)}
-
- if hasattr(stage_config, "__dict__"):
- return {
- key: _serialize_stage_config(value)
- for key, value in vars(stage_config).items()
- if not key.startswith("_") and not callable(value)
- }
-
- return stage_config
-
-
-# ---------------------------------------------------------------------------
-# Per-stage address allocation
-# ---------------------------------------------------------------------------
-
-
-@dataclass
-class StageAllocation:
- """ZMQ addresses reserved for a single stage."""
-
- # Per-stage handshake socket (OmniMasterServer binds, engine connects)
- handshake_bind_address: str
- handshake_connect_address: str
- # Input channel: client binds ROUTER, engine connects DEALER
- input_bind_address: str
- input_connect_address: str
- # Output channel: client binds PULL, engine connects PUSH
- output_bind_address: str
- output_connect_address: str
-
-
-@dataclass(frozen=True)
-class StageCoordinatorAddresses:
- """Optional DP coordinator addresses registered for a stage."""
-
- coordinator_input: str | None = None
- coordinator_output: str | None = None
- frontend_stats_publish_address: str | None = None
-
-
-# ---------------------------------------------------------------------------
-# OmniMasterServer
-# ---------------------------------------------------------------------------
-
-
-class OmniMasterServer:
- """Registration server for single-stage engine startup."""
-
- def __init__(
- self,
- master_address: str,
- master_port: int,
- stage_ids: list[int],
- ) -> None:
- self._address = master_address
- self._port = master_port
- self._allocations: dict[int, StageAllocation] = {}
- self._stage_configs: dict[int, Any] = {}
- self._stage_coordinator_addresses: dict[int, StageCoordinatorAddresses] = {}
- self._stage_config_events: dict[int, threading.Event] = {}
- self._thread: threading.Thread | None = None
- self._stop_event = threading.Event()
-
- for sid in stage_ids:
- self._stage_config_events[sid] = threading.Event()
- self._stage_coordinator_addresses[sid] = StageCoordinatorAddresses()
- hs_port = get_open_port()
- inp_port = get_open_port()
- out_port = get_open_port()
- self._allocations[sid] = StageAllocation(
- handshake_bind_address=f"tcp://{master_address}:{hs_port}",
- handshake_connect_address=f"tcp://{master_address}:{hs_port}",
- input_bind_address=f"tcp://{master_address}:{inp_port}",
- input_connect_address=f"tcp://{master_address}:{inp_port}",
- output_bind_address=f"tcp://{master_address}:{out_port}",
- output_connect_address=f"tcp://{master_address}:{out_port}",
- )
-
- logger.info(
- "[OmniMasterServer] Pre-allocated addresses for stages %s (master=%s:%d)",
- list(stage_ids),
- master_address,
- master_port,
- )
-
- # ------------------------------------------------------------------
- # Public helpers
- # ------------------------------------------------------------------
- @property
- def address(self) -> str:
- """Return the registration address exposed to stage launchers."""
- return self._address
-
- @property
- def port(self) -> int:
- """Return the registration port exposed to stage launchers."""
- return self._port
-
- def get_allocation(self, stage_id: int) -> StageAllocation:
- """Return the full address allocation for *stage_id*."""
- return self._allocations[stage_id]
-
- def register_stage_config(
- self,
- stage_id: int,
- stage_config: Any,
- coordinator_addresses: StageCoordinatorAddresses | None = None,
- ) -> None:
- """Store the latest stage registration payload for *stage_id*."""
- if stage_id not in self._allocations:
- raise KeyError(stage_id)
- self._stage_configs[stage_id] = stage_config
- if coordinator_addresses is not None:
- self._stage_coordinator_addresses[stage_id] = coordinator_addresses
- self._stage_config_events[stage_id].set()
-
- def get_stage_config(self, stage_id: int, timeout_s: float | None = None) -> Any:
- """Return the stage config for *stage_id*, waiting if necessary."""
- if stage_id not in self._allocations:
- raise KeyError(stage_id)
-
- if stage_id in self._stage_configs:
- return self._stage_configs[stage_id]
-
- if not self._stage_config_events[stage_id].wait(timeout=timeout_s):
- raise TimeoutError(f"Timed out waiting for stage config for stage {stage_id}.")
-
- return self._stage_configs[stage_id]
-
- def get_stage_coordinator_addresses(
- self,
- stage_id: int,
- timeout_s: float | None = None,
- ) -> StageCoordinatorAddresses:
- """Return the registered coordinator addresses for *stage_id*."""
- if stage_id not in self._allocations:
- raise KeyError(stage_id)
-
- if not self._stage_config_events[stage_id].is_set():
- if not self._stage_config_events[stage_id].wait(timeout=timeout_s):
- raise TimeoutError(f"Timed out waiting for stage registration for stage {stage_id}.")
-
- return self._stage_coordinator_addresses[stage_id]
-
- def get_client_addresses(self, stage_id: int) -> dict[str, str]:
- """Return the addresses the client-side sockets should *bind* to."""
- alloc = self._allocations[stage_id]
- return {
- "input_address": alloc.input_bind_address,
- "output_address": alloc.output_bind_address,
- }
-
- def get_zmq_addresses(self, stage_id: int) -> EngineZmqAddresses:
- """Return EngineZmqAddresses using the *bind* (client) side addresses."""
- alloc = self._allocations[stage_id]
- return EngineZmqAddresses(
- inputs=[alloc.input_bind_address],
- outputs=[alloc.output_bind_address],
- )
-
- def get_engine_zmq_addresses(self, stage_id: int) -> EngineZmqAddresses:
- """Return EngineZmqAddresses using the *connect* (engine) addresses."""
- alloc = self._allocations[stage_id]
- return EngineZmqAddresses(
- inputs=[alloc.input_connect_address],
- outputs=[alloc.output_connect_address],
- )
-
- # ------------------------------------------------------------------
- # Lifecycle
- # ------------------------------------------------------------------
-
- def start(self) -> None:
- """Start the background server thread."""
- self._thread = threading.Thread(
- target=self._run,
- name="OmniMasterServer",
- daemon=True,
- )
- self._thread.start()
- logger.info(
- "[OmniMasterServer] Listening on tcp://%s:%d",
- self.address,
- self.port,
- )
-
- def stop(self) -> None:
- """Signal stop and join the background thread."""
- self._stop_event.set()
- if self._thread is not None:
- self._thread.join(timeout=10)
-
- # ------------------------------------------------------------------
- # Internal server logic
- # ------------------------------------------------------------------
-
- def _run(self) -> None:
- ctx = zmq.Context()
- try:
- self._serve(ctx)
- except Exception:
- logger.exception("[OmniMasterServer] Server thread crashed")
- finally:
- ctx.term()
-
- def _serve(self, ctx: zmq.Context) -> None: # type: ignore[type-arg]
- # Registration socket for the initial stage registration.
- # Per-stage handshake sockets are bound by the launch helpers.
- reg_socket: zmq.Socket = ctx.socket(zmq.ROUTER) # type: ignore[attr-defined]
- reg_socket.bind(f"tcp://{self.address}:{self.port}")
-
- poller = zmq.Poller()
- poller.register(reg_socket, zmq.POLLIN)
-
- pending: set[int] = set(self._allocations.keys())
-
- while pending and not self._stop_event.is_set():
- events: list[tuple[zmq.Socket, int]] = poller.poll(_POLL_PERIOD_MS) # type: ignore[assignment]
- if not events:
- logger.debug("[OmniMasterServer] Still waiting for registration from stages: %s", pending)
- continue
-
- for sock, _ in events:
- if sock is reg_socket:
- sid = self._handle_registration(reg_socket)
- if sid is not None:
- pending.discard(sid)
-
- # Cleanup
- reg_socket.close(linger=0)
- logger.info("[OmniMasterServer] All stages registered; server thread exiting.")
-
- def _handle_registration(self, reg_socket: zmq.Socket) -> int | None: # type: ignore[type-arg]
- """Receive a stage registration and reply with the handshake address.
-
- Returns the registered stage_id on success, or None on failure.
- """
- frames = reg_socket.recv_multipart()
- if len(frames) < 2:
- logger.warning(
- "[OmniMasterServer] Unexpected registration frame count: %d",
- len(frames),
- )
- return None
- identity = frames[0]
- msg_bytes = frames[-1]
- try:
- msg = msgspec.msgpack.decode(msg_bytes)
- except Exception as exc:
- logger.warning("[OmniMasterServer] Failed to decode registration message: %s", exc)
- return None
-
- stage_id: int | None = msg.get("stage_id")
- if stage_id not in self._allocations:
- logger.warning(
- "[OmniMasterServer] Received registration for unknown stage_id=%s",
- stage_id,
- )
- return None
-
- self.register_stage_config(
- stage_id,
- msg.get("stage_config"),
- coordinator_addresses=StageCoordinatorAddresses(
- coordinator_input=msg.get("coordinator_input"),
- coordinator_output=msg.get("coordinator_output"),
- frontend_stats_publish_address=msg.get("frontend_stats_publish_address"),
- ),
- )
-
- alloc = self._allocations[stage_id]
- response = msgspec.msgpack.encode(
- {
- "handshake_address": alloc.handshake_connect_address,
- "input_address": alloc.input_bind_address,
- "output_address": alloc.output_bind_address,
- }
- )
- # ROUTER-DEALER: reply is [identity, payload] (no empty delimiter).
- reg_socket.send_multipart([identity, response])
- logger.info(
- "[OmniMasterServer] Stage %d registered; assigned handshake=%s",
- stage_id,
- alloc.handshake_connect_address,
- )
- return stage_id
-
-
-def register_stage_with_omni_master(
- *,
- omni_master_address: str,
- omni_master_port: int,
- omni_stage_id: int,
- omni_stage_config: Any = None,
- coordinator: DPCoordinator | None = None,
- return_addresses: bool = False,
-) -> str | tuple[str, str, str]:
- """Register a stage with the omni master server.
-
- Returns the per-stage handshake address by default. When
- ``return_addresses`` is true, also returns the stage input/output
- addresses allocated by the master.
- """
-
- reg_ctx = zmq.Context()
- try:
- reg_sock: zmq.Socket = reg_ctx.socket(zmq.DEALER) # type: ignore[attr-defined]
- try:
- reg_sock.connect(f"tcp://{omni_master_address}:{omni_master_port}")
- payload = {
- "stage_id": omni_stage_id,
- "stage_config": _serialize_stage_config(omni_stage_config),
- }
- if coordinator is not None:
- coordinator_input, coordinator_output = coordinator.get_engine_socket_addresses()
- payload["coordinator_input"] = coordinator_input
- payload["coordinator_output"] = coordinator_output
- payload["frontend_stats_publish_address"] = coordinator.get_stats_publish_address()
-
- reg_sock.send(msgspec.msgpack.encode(payload))
- timeout_ms = _DEFAULT_STARTUP_TIMEOUT_S * 1_000
- if not reg_sock.poll(timeout=timeout_ms):
- raise RuntimeError(
- f"Timed out waiting for registration "
- f"response from OmniMasterServer "
- f"({omni_master_address}:{omni_master_port}) "
- f"for stage {omni_stage_id}."
- )
- response_bytes = reg_sock.recv()
- response = msgspec.msgpack.decode(response_bytes)
- handshake_address: str = response["handshake_address"]
- input_address: str = response["input_address"]
- output_address: str = response["output_address"]
- logger.info(
- "Stage %d registered; handshake_address=%s",
- omni_stage_id,
- handshake_address,
- )
- finally:
- reg_sock.close(linger=0)
- finally:
- reg_ctx.term()
-
- if return_addresses:
- return handshake_address, input_address, output_address
- return handshake_address
-
-
-def _wait_for_omni_engine_startup(
- handshake_socket: zmq.Socket,
- engine_addresses: EngineZmqAddresses,
- engines: list[CoreEngine],
- cache_config: CacheConfig,
-) -> None:
- """Wait for omni-managed engines to finish the HELLO/READY handshake."""
- conn_pending = len(engines)
- start_pending = 0
-
- poller = zmq.Poller()
- poller.register(handshake_socket, zmq.POLLIN)
-
- while conn_pending or start_pending:
- events = poller.poll(STARTUP_POLL_PERIOD_MS)
- if not events:
- logger.debug(
- "[omni] Waiting for %d engine(s) to connect, %d to start.",
- conn_pending,
- start_pending,
- )
- continue
-
- eng_identity, msg_bytes = handshake_socket.recv_multipart()
- eng_index = int.from_bytes(eng_identity, "little")
- engine = next((e for e in engines if e.identity == eng_identity), None)
- if engine is None:
- raise RuntimeError(f"[omni] Handshake message from unexpected engine rank: {eng_index}")
-
- msg = msgspec.msgpack.decode(msg_bytes)
- status: str = msg["status"]
-
- if status == "HELLO" and engine.state == CoreEngineState.NEW:
- init_message = msgspec.msgpack.encode(
- EngineHandshakeMetadata(addresses=engine_addresses, parallel_config={})
- )
- handshake_socket.send_multipart((eng_identity, init_message), copy=False)
- conn_pending -= 1
- start_pending += 1
- engine.state = CoreEngineState.CONNECTED
- logger.debug("[omni] HELLO from engine %d", eng_index)
-
- elif status == "READY" and engine.state == CoreEngineState.CONNECTED:
- num_gpu_blocks = (cache_config.num_gpu_blocks or 0) + msg["num_gpu_blocks"]
- cache_config.num_gpu_blocks = num_gpu_blocks
- if engine_addresses.frontend_stats_publish_address is None:
- engine_addresses.frontend_stats_publish_address = msg.get("dp_stats_address")
- start_pending -= 1
- engine.state = CoreEngineState.READY
- logger.debug("[omni] READY from engine %d (num_gpu_blocks=%d)", eng_index, msg["num_gpu_blocks"])
-
- else:
- raise RuntimeError(f"[omni] Unexpected status '{status}' from engine {eng_index} in state {engine.state}.")
-
-
-@contextlib.contextmanager
-def connect_remote_engine_cores(
- vllm_config: VllmConfig,
- omni_master_server: OmniMasterServer,
- stage_id: int,
-) -> Iterator[tuple[None, DPCoordinator | None, EngineZmqAddresses]]:
- """Wait for remote engine cores to connect through the omni handshake."""
- addresses = omni_master_server.get_zmq_addresses(stage_id)
- parallel_config = vllm_config.parallel_config
- # Mirror the engine-count logic from launch_omni_core_engines.
- remote_engine_count = (
- parallel_config.data_parallel_size_local
- if parallel_config.data_parallel_size_local is not None and parallel_config.data_parallel_size_local > 0
- else max(1, parallel_config.data_parallel_size)
- )
- start_index = parallel_config.data_parallel_rank if parallel_config.data_parallel_rank is not None else 0
- coordinator = None
-
- registered_coordinator_addresses = omni_master_server.get_stage_coordinator_addresses(stage_id)
- addresses.coordinator_input = registered_coordinator_addresses.coordinator_input
- addresses.coordinator_output = registered_coordinator_addresses.coordinator_output
- addresses.frontend_stats_publish_address = registered_coordinator_addresses.frontend_stats_publish_address
-
- engines_to_handshake = [CoreEngine(index=start_index + i, local=False) for i in range(remote_engine_count)]
-
- logger.info(
- "Waiting for %d remote engine(s) for stage %d",
- remote_engine_count,
- stage_id,
- )
-
- handshake_bind_address = omni_master_server.get_allocation(stage_id).handshake_bind_address
-
- with zmq_socket_ctx(handshake_bind_address, zmq.ROUTER, bind=True) as handshake_socket:
- yield None, coordinator, addresses
-
- _wait_for_omni_engine_startup(
- handshake_socket,
- addresses,
- engines_to_handshake,
- vllm_config.cache_config,
- )
-
-
-@contextlib.contextmanager
-def launch_omni_core_engines(
- vllm_config: VllmConfig,
- executor_class: type[Executor],
- log_stats: bool,
- omni_master_server: OmniMasterServer,
- stage_id: int,
- stage_config: Any = None,
-) -> Iterator[tuple[CoreEngineProcManager, DPCoordinator | None, EngineZmqAddresses]]:
- """Launch local engine cores using the omni registration flow."""
- addresses = omni_master_server.get_zmq_addresses(stage_id)
- parallel_config = vllm_config.parallel_config
- # Determine the number of local engines and their ranks.
- local_engine_count = (
- parallel_config.data_parallel_size_local
- if parallel_config.data_parallel_size_local is not None and parallel_config.data_parallel_size_local > 0
- else max(1, parallel_config.data_parallel_size)
- )
- dp_rank = parallel_config.data_parallel_rank if parallel_config.data_parallel_rank is not None else 0
- local_start_index = 0
- start_index = dp_rank
-
- # Run the DP Coordinator process with rank 0 when in online DP mode.
- # The coordinator is needed for:
- # 1. Internal/hybrid LB: collecting and publishing queue stats
- # 2. MoE models: wave coordination in addition to stats
- run_coordinator = vllm_config.needs_dp_coordinator and dp_rank == 0
-
- if run_coordinator:
- coordinator = DPCoordinator(
- parallel_config,
- enable_wave_coordination=vllm_config.model_config.is_moe,
- )
-
- addresses.coordinator_input, addresses.coordinator_output = coordinator.get_engine_socket_addresses()
- addresses.frontend_stats_publish_address = coordinator.get_stats_publish_address()
-
- logger.info(
- "[omni] Started DP Coordinator process for stage %d (PID: %d)",
- stage_id,
- coordinator.proc.pid,
- )
- else:
- coordinator = None
-
- logger.info(
- "Starting %d local engine(s) for stage %d (dp_rank=%d)",
- local_engine_count,
- stage_id,
- dp_rank,
- )
-
- # Register the stage once and reuse the returned per-stage handshake
- # address for all local engine-core processes.
- handshake_address = register_stage_with_omni_master(
- omni_master_address=omni_master_server.address,
- omni_master_port=omni_master_server.port,
- omni_stage_id=stage_id,
- omni_stage_config=stage_config,
- coordinator=coordinator,
- )
-
- # One CoreEngine entry per local engine so wait_for_engine_startup can
- # track the HELLO/READY handshake for each of them.
- engines_to_handshake = [CoreEngine(index=start_index + i, local=True) for i in range(local_engine_count)]
-
- # Bind the pre-allocated handshake socket for this stage.
- handshake_bind_address = omni_master_server.get_allocation(stage_id).handshake_bind_address
-
- with zmq_socket_ctx(handshake_bind_address, zmq.ROUTER, bind=True) as handshake_socket:
- local_engine_manager = CoreEngineProcManager(
- local_engine_count=local_engine_count,
- start_index=start_index,
- local_start_index=local_start_index,
- vllm_config=vllm_config,
- local_client=True,
- handshake_address=handshake_address,
- executor_class=executor_class,
- log_stats=log_stats,
- )
-
- yield local_engine_manager, coordinator, addresses
-
- # Wait for all local engine-core processes to complete the
- # standard HELLO/READY handshake — mirrors launch_core_engines.
- coordinated_dp = parallel_config.data_parallel_size > 1 and vllm_config.model_config.is_moe
- wait_for_engine_startup(
- handshake_socket,
- addresses,
- engines_to_handshake,
- parallel_config,
- coordinated_dp,
- vllm_config.cache_config,
- local_engine_manager,
- coordinator.proc if coordinator else None,
- )
diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py
index 9fdcd6216ee..f71afad83b4 100644
--- a/vllm_omni/engine/stage_init_utils.py
+++ b/vllm_omni/engine/stage_init_utils.py
@@ -13,8 +13,8 @@
import multiprocessing as mp
import os
import time
-from collections.abc import Callable, Sequence
-from dataclasses import dataclass, replace
+from collections.abc import Callable
+from dataclasses import dataclass
from typing import Any, Literal
from vllm.logger import init_logger
@@ -23,13 +23,11 @@
from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.executor import Executor
-from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.engine.arg_utils import OmniEngineArgs
from vllm_omni.entrypoints.stage_utils import _to_dict, set_stage_devices
from vllm_omni.entrypoints.utils import filter_dataclass_kwargs, resolve_model_config_path
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams
from vllm_omni.platforms import current_omni_platform
-from vllm_omni.quantization.inc_config import OmniINCConfig
logger = init_logger(__name__)
@@ -103,144 +101,6 @@ def resolve_worker_cls(engine_args: dict[str, Any]) -> None:
raise ValueError(f"Unknown worker_type: {worker_type}")
-def _get_attr_or_item(obj: Any, key: str, default: Any = None) -> Any:
- """Read *key* from *obj* regardless of whether it's a dict or object."""
- if hasattr(obj, "get"):
- return obj.get(key, default)
- return getattr(obj, key, default)
-
-
-def _tp_size_for_stage(stage_configs: Sequence[Any], stage_id: Any) -> int | None:
- """Resolve tensor_parallel_size for *stage_id* from the loaded stage configs."""
- id_strs = {str(stage_id)}
- try:
- id_strs.add(str(int(stage_id)))
- except (TypeError, ValueError):
- pass
-
- for stage_cfg in stage_configs:
- if str(getattr(stage_cfg, "stage_id", None)) not in id_strs:
- continue
- engine_args = getattr(stage_cfg, "engine_args", None)
- if engine_args is None:
- return 1
- parallel_config = _get_attr_or_item(engine_args, "parallel_config")
- if parallel_config is not None:
- tp = _get_attr_or_item(parallel_config, "tensor_parallel_size", 1)
- else:
- tp = _get_attr_or_item(engine_args, "tensor_parallel_size", 1)
- try:
- return max(1, int(tp))
- except (TypeError, ValueError):
- return 1
- return None
-
-
-def _inject_inferred_kv_tp_topology(
- omni_kv: Any,
- stage_id: int,
- stage_configs: Sequence[Any],
- engine_input_source: Sequence[int] | None = None,
-) -> None:
- """Infer adjacent-stage TP topology and inject it into omni_kv_config.
-
- This keeps heterogeneous TP working without requiring user-authored
- rank_mapping blocks in config files.
- """
- if omni_kv is None:
- return
-
- if hasattr(omni_kv, "get"):
- need_send = bool(omni_kv.get("need_send_cache", False))
- need_recv = bool(omni_kv.get("need_recv_cache", False))
- omni_from_stage = omni_kv.get("omni_from_stage")
- omni_to_stage = omni_kv.get("omni_to_stage")
- rank_mapping = omni_kv.get("rank_mapping")
- else:
- need_send = bool(getattr(omni_kv, "need_send_cache", False))
- need_recv = bool(getattr(omni_kv, "need_recv_cache", False))
- omni_from_stage = getattr(omni_kv, "omni_from_stage", None)
- omni_to_stage = getattr(omni_kv, "omni_to_stage", None)
- rank_mapping = getattr(omni_kv, "rank_mapping", None)
-
- if not need_send and not need_recv:
- return
-
- current_tp = _tp_size_for_stage(stage_configs, stage_id)
- if current_tp is None:
- return
-
- peer_stage_id = None
- from_tp = None
- to_tp = None
- if str(omni_from_stage) == str(stage_id):
- peer_stage_id = omni_to_stage
- from_tp = current_tp
- to_tp = _tp_size_for_stage(stage_configs, peer_stage_id)
- elif str(omni_to_stage) == str(stage_id):
- peer_stage_id = omni_from_stage
- from_tp = _tp_size_for_stage(stage_configs, peer_stage_id)
- to_tp = current_tp
- elif need_recv and engine_input_source:
- peer_stage_id = engine_input_source[0]
- from_tp = _tp_size_for_stage(stage_configs, peer_stage_id)
- to_tp = current_tp
-
- if from_tp is None or to_tp is None:
- return
-
- if not isinstance(rank_mapping, dict):
- rank_mapping = {}
- rank_mapping.setdefault("from_tp", int(from_tp))
- rank_mapping.setdefault("to_tp", int(to_tp))
-
- if hasattr(omni_kv, "__setitem__"):
- omni_kv["rank_mapping"] = rank_mapping
- else:
- setattr(omni_kv, "rank_mapping", rank_mapping)
-
-
-def inject_kv_stage_info(stage_cfg: Any, stage_id: int, stage_configs: Sequence[Any] | None = None) -> None:
- """Inject stage_id, engine_input_source, and inferred TP topology into omni_kv_config.
-
- When *stage_configs* is provided, also infers from_tp/to_tp for
- heterogeneous TP topologies so the KV transfer manager can compute
- rank mappings automatically.
- """
- try:
- engine_args = stage_cfg.engine_args
- if hasattr(engine_args, "get"):
- omni_kv = engine_args.get("omni_kv_config", None)
- else:
- omni_kv = getattr(engine_args, "omni_kv_config", None)
-
- if omni_kv is None:
- return
-
- if hasattr(omni_kv, "setdefault"):
- omni_kv.setdefault("stage_id", stage_id)
- elif hasattr(omni_kv, "__setitem__"):
- if "stage_id" not in omni_kv:
- omni_kv["stage_id"] = stage_id
-
- engine_input_source = getattr(stage_cfg, "engine_input_source", None)
- if engine_input_source is not None:
- if hasattr(omni_kv, "setdefault"):
- omni_kv.setdefault("engine_input_source", list(engine_input_source))
- elif hasattr(omni_kv, "__setitem__") and "engine_input_source" not in omni_kv:
- omni_kv["engine_input_source"] = list(engine_input_source)
-
- if stage_configs:
- _inject_inferred_kv_tp_topology(
- omni_kv,
- stage_id=stage_id,
- stage_configs=stage_configs,
- engine_input_source=engine_input_source,
- )
- except Exception as e:
- logger.debug("Failed to inject stage info into omni_kv_config: %s", e)
-
-
@dataclass
class StageMetadata:
"""Lightweight stage attributes extracted from stage_config."""
@@ -269,10 +129,8 @@ class StartedLlmStage:
metadata: Any
vllm_config: Any
executor_class: type
+ proc: Any
addresses: Any
- proc: Any = None
- engine_manager: Any = None
- coordinator: Any = None
def extract_stage_metadata(stage_config: Any) -> StageMetadata:
@@ -280,20 +138,6 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata:
stage_id: int = stage_config.stage_id
stage_type: Literal["llm", "diffusion"] = getattr(stage_config, "stage_type", "llm")
engine_args = stage_config.engine_args
-
- if current_omni_platform.is_rocm():
- if engine_args.get("attention_backend") is None:
- from vllm._aiter_ops import rocm_aiter_ops
-
- if rocm_aiter_ops.is_enabled():
- engine_args["attention_backend"] = "ROCM_AITER_FA"
- # Before vLLM v0.19.0, the default attention backend is TRITON_ATTN for ROCm.
- # Since vLLM v0.19.0, the default attention backend is ROCM_ATTN for ROCm.
- # However, the compatibility of ROCM_ATTN with Omni is not guaranteed.
- # Therefore, we still use TRITON_ATTN as the default attention backend,
- # when the selected_backend is not specified.
- engine_args["attention_backend"] = "TRITON_ATTN"
-
runtime_cfg = getattr(stage_config, "runtime", {})
engine_input_source: list[int] = getattr(stage_config, "engine_input_source", [])
final_output: bool = getattr(stage_config, "final_output", False)
@@ -304,9 +148,8 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata:
default_sampling_params: OmniSamplingParams = SPClass(**default_sp)
custom_process_input_func: Callable | None = None
- _cpif_path = getattr(stage_config, "custom_process_input_func", None)
- if _cpif_path:
- mod_path, fn_name = _cpif_path.rsplit(".", 1)
+ if hasattr(stage_config, "custom_process_input_func"):
+ mod_path, fn_name = stage_config.custom_process_input_func.rsplit(".", 1)
custom_process_input_func = getattr(importlib.import_module(mod_path), fn_name)
prompt_expand_func: Callable | None = None
@@ -412,10 +255,6 @@ def build_engine_args_dict(
if stage_type != "diffusion":
resolve_worker_cls(engine_args_dict)
- # Check whether the stage's default_sampling_params defines extra_args.
- default_sp = _to_dict(getattr(stage_config, "default_sampling_params", {}))
- engine_args_dict["has_sampling_extra_args"] = bool(default_sp.get("extra_args"))
-
return engine_args_dict
@@ -424,7 +263,6 @@ def build_vllm_config(
model: str,
stage_connector_spec: dict[str, Any] | None = None,
engine_args_dict: dict[str, Any] | None = None,
- headless: bool = False,
) -> tuple[Any, type]:
"""Build engine args, then create VllmConfig and executor_class.
@@ -440,38 +278,16 @@ def build_vllm_config(
filtered_engine_args_dict = filter_dataclass_kwargs(OmniEngineArgs, engine_args_dict)
omni_engine_args = OmniEngineArgs(**filtered_engine_args_dict)
-
- # Multi-stage pipelines (qwen3_tts code2wav, etc.) set max_model_len
- # larger than HF max_position_embeddings by design. vLLM's validator
- # rejects that without the env flag.
- if filtered_engine_args_dict.get("max_model_len") is not None and not os.environ.get(
- "VLLM_ALLOW_LONG_MAX_MODEL_LEN"
- ):
- os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
- logger.debug(
- "Auto-set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 for stage %s (max_model_len=%s).",
- stage_config.stage_id,
- filtered_engine_args_dict["max_model_len"],
- )
-
- vllm_config = omni_engine_args.create_engine_config(
- usage_context=UsageContext.LLM_CLASS,
- headless=headless,
- )
+ vllm_config = omni_engine_args.create_engine_config(usage_context=UsageContext.LLM_CLASS)
executor_class = Executor.get_class(vllm_config)
- # Upgrade vanilla INCConfig to OmniINCConfig for multi-stage models.
- upgraded = OmniINCConfig.maybe_upgrade(vllm_config.quant_config)
- if upgraded is not vllm_config.quant_config:
- vllm_config = replace(vllm_config, quant_config=upgraded)
-
return vllm_config, executor_class
def acquire_device_locks(
stage_id: int,
engine_args_dict: dict[str, Any],
- stage_init_timeout: int,
+ stage_init_timeout: int = 300,
) -> list[int]:
"""Acquire exclusive file locks on devices needed by this stage.
@@ -520,13 +336,7 @@ def acquire_device_locks(
num_devices = current_omni_platform.get_device_count()
physical_devices = list(range(num_devices))
- if len(physical_devices) < num_devices_per_stage:
- raise RuntimeError(
- f"Stage {stage_id} requires {num_devices_per_stage} device(s) based on parallel_config, "
- f"but only {len(physical_devices)} device(s) are available: {physical_devices}"
- )
-
- num_devices_to_lock = num_devices_per_stage
+ num_devices_to_lock = min(num_devices_per_stage, len(physical_devices))
devices_to_lock = sorted(physical_devices[:num_devices_to_lock])
logger.debug(
@@ -601,22 +411,12 @@ def release_device_locks(lock_fds: list[int]) -> None:
def load_omni_transfer_config_for_model(model: str, config_path: str | None) -> Any:
- """Load omni transfer config from an explicit path or resolved model config.
-
- Resolves ``base_config`` inheritance (CI overlay → base deploy YAML) so
- that connectors defined in the base config are visible to the transfer
- config parser.
- """
+ """Load omni transfer config from an explicit path or resolved model config."""
from vllm_omni.distributed.omni_connectors import load_omni_transfer_config
try:
resolved_config_path = config_path or resolve_model_config_path(model)
- if resolved_config_path is None:
- return None
- from vllm_omni.config.stage_config import resolve_deploy_yaml
-
- resolved_dict = resolve_deploy_yaml(resolved_config_path)
- return load_omni_transfer_config(config_dict=resolved_dict)
+ return load_omni_transfer_config(resolved_config_path)
except Exception as e:
logger.warning("[stage_init] Failed to load transfer config: %s", e)
return None
@@ -639,45 +439,11 @@ def get_stage_connector_spec(
return {}
-def build_diffusion_config(
- model: str,
- stage_cfg: Any,
- metadata: StageMetadata,
-) -> Any:
- """Build diffusion config for a stage."""
- from vllm_omni.diffusion.data import OmniDiffusionConfig
-
- engine_args_dict = build_engine_args_dict(stage_cfg, model)
- od_config = OmniDiffusionConfig.from_kwargs(**engine_args_dict)
-
- num_devices_per_stage = od_config.parallel_config.world_size
- device_control_env = current_omni_platform.device_control_env_var
- visible_devices_str = os.environ.get(device_control_env) if device_control_env else None
- if visible_devices_str:
- physical_devices = [device.strip() for device in visible_devices_str.split(",") if device.strip()]
- else:
- physical_devices = list(range(current_omni_platform.get_device_count()))
-
- if len(physical_devices) < num_devices_per_stage:
- raise ValueError(
- f"Stage {metadata.stage_id} requires {num_devices_per_stage} device(s) based on parallel_config, "
- f"but {len(physical_devices)} device(s) are available: {physical_devices}"
- )
-
- od_config.num_gpus = num_devices_per_stage
- if metadata.cfg_kv_collect_func is not None:
- od_config.cfg_kv_collect_func = metadata.cfg_kv_collect_func
- return od_config
-
-
def initialize_diffusion_stage(
- stage_id: int,
model: str,
stage_cfg: Any,
metadata: StageMetadata,
- stage_init_timeout: int,
batch_size: int = 1,
- use_inline: bool = False,
) -> Any:
"""Build a diffusion stage client.
@@ -685,25 +451,20 @@ def initialize_diffusion_stage(
model: Model name or path.
stage_cfg: Stage configuration.
metadata: Extracted stage metadata.
- stage_init_timeout: Timeout in seconds for stage initialization handshake
batch_size: Maximum number of requests to batch together in the
diffusion engine. Passed through to ``StageDiffusionClient``
and ultimately to ``AsyncOmni``.
- use_inline: If True, uses the inline diffusion client instead of subprocess.
"""
- from vllm_omni.diffusion.stage_diffusion_client import create_diffusion_client
+ from vllm_omni.diffusion.data import OmniDiffusionConfig
+ from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient
- engine_args = _to_dict(stage_cfg.engine_args)
- engine_args.pop("stage_id", None)
od_config = OmniDiffusionConfig.from_kwargs(
- stage_id=stage_id,
model=model,
- **engine_args,
+ **_to_dict(stage_cfg.engine_args),
)
if metadata.cfg_kv_collect_func is not None:
od_config.cfg_kv_collect_func = metadata.cfg_kv_collect_func
- od_config = build_diffusion_config(model, stage_cfg, metadata)
- return create_diffusion_client(model, od_config, metadata, stage_init_timeout, batch_size, use_inline)
+ return StageDiffusionClient(model, od_config, metadata, batch_size=batch_size)
def _shutdown_or_close_resource(resource: Any, resource_name: str, stage_id: int) -> None:
@@ -736,18 +497,17 @@ def _shutdown_or_close_resource(resource: Any, resource_name: str, stage_id: int
def close_started_llm_stage(started: StartedLlmStage) -> None:
- """Release resources owned by a launched stage that never attached."""
- if started.proc is not None:
- try:
- terminate_alive_proc(started.proc)
- except Exception as cleanup_error:
- logger.warning(
- "[stage_init] Failed to terminate process for stage %s: %s",
- started.stage_id,
- cleanup_error,
- )
- _shutdown_or_close_resource(started.engine_manager, "engine manager", started.stage_id)
- _shutdown_or_close_resource(started.coordinator, "coordinator", started.stage_id)
+ """Terminate the subprocess owned by a launched stage that never attached."""
+ if started.proc is None:
+ return
+ try:
+ terminate_alive_proc(started.proc)
+ except Exception as cleanup_error:
+ logger.warning(
+ "[stage_init] Failed to terminate process for stage %s: %s",
+ started.stage_id,
+ cleanup_error,
+ )
def finalize_initialized_stages(
diff --git a/vllm_omni/entrypoints/__init__.py b/vllm_omni/entrypoints/__init__.py
index b273929a8ec..7b09adf9398 100644
--- a/vllm_omni/entrypoints/__init__.py
+++ b/vllm_omni/entrypoints/__init__.py
@@ -5,21 +5,8 @@
vLLM-Omni entrypoints module.
"""
-
-def __getattr__(name: str):
- # Lazy imports to avoid eagerly loading heavy modules (engine,
- # model_loader, pynvml) when the package is imported in lightweight
- # contexts such as model-architecture inspection subprocesses.
- if name == "AsyncOmni":
- from vllm_omni.entrypoints.async_omni import AsyncOmni
-
- return AsyncOmni
- if name == "Omni":
- from vllm_omni.entrypoints.omni import Omni
-
- return Omni
- raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
-
+from vllm_omni.entrypoints.async_omni import AsyncOmni
+from vllm_omni.entrypoints.omni import Omni
__all__ = [
"AsyncOmni",
diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py
index 056f56c003b..129ef3c99d8 100644
--- a/vllm_omni/entrypoints/async_omni.py
+++ b/vllm_omni/entrypoints/async_omni.py
@@ -9,7 +9,6 @@
import asyncio
import time
-import uuid
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
@@ -25,85 +24,22 @@
from vllm.tasks import SupportedTask
from vllm.v1.engine.exceptions import EngineDeadError
-from vllm_omni.diffusion.data import OmniACK, OmniSleepTask, OmniWakeTask
from vllm_omni.entrypoints.client_request_state import ClientRequestState
-from vllm_omni.entrypoints.omni_base import (
- OmniBase,
- OmniEngineDeadError,
-)
-from vllm_omni.inputs.data import OmniSamplingParams
+from vllm_omni.entrypoints.omni_base import OmniBase
from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics
from vllm_omni.outputs import OmniRequestOutput
-from vllm_omni.platforms import current_omni_platform
if TYPE_CHECKING:
from vllm.inputs.preprocess import InputPreprocessor
from vllm.tokenizers import TokenizerLike
from vllm.v1.engine import PauseMode
- from vllm_omni.inputs.data import OmniPromptType
+ from vllm_omni.inputs.data import OmniPromptType, OmniSamplingParams
logger = init_logger(__name__)
_FINAL_OUTPUT_IDLE_SLEEP_S = 0.001
-class AsyncEventResolver:
- """
- A generic signal aggregator designed for synchronized handshakes in
- distributed or multi-stage environments. Supports waiting for a specified
- number (expected_count) of worker signals in both inline and multiprocess modes.
- """
-
- def __init__(self, orchestrator=None):
- self._pending_tasks: dict[str, dict] = {}
- self.orchestrator = orchestrator
- self._lock = asyncio.Lock()
-
- def watch_task(self, task_id: str, expected_count: int = 1) -> asyncio.Future:
- loop = asyncio.get_running_loop()
- fut = loop.create_future()
- self._pending_tasks[task_id] = {
- "future": fut,
- "expected_count": expected_count,
- "received": [],
- "start_time": time.time(),
- }
- return fut
-
- async def resolve(self, ack: OmniACK):
- tid = getattr(ack, "task_id", None)
-
- if tid is None and isinstance(ack, dict):
- tid = ack.get("task_id")
-
- async with self._lock:
- task_info = self._pending_tasks.get(tid)
- if task_info is None:
- logger.warning(f"Received stray ACK for task_id {tid}. Task might have timed out.")
- return
-
- task_info["received"].append(ack)
- current_count = len(task_info["received"])
- expected = task_info["expected_count"]
-
- orchestrator = self.orchestrator
- if orchestrator and hasattr(orchestrator, "metrics") and orchestrator.metrics:
- freed = getattr(ack, "freed_bytes", 0)
- if freed == 0 and isinstance(ack, dict):
- freed = ack.get("freed_bytes", 0)
- orchestrator.metrics.record_vram_reclaimed(freed)
-
- logger.info(f"[Resolver] Task {tid} progress: {current_count}/{expected} ACKs received.")
-
- if current_count >= expected:
- self._pending_tasks.pop(tid)
- fut = task_info["future"]
- if not fut.done():
- elapsed = time.time() - task_info["start_time"]
- logger.info(f"[Resolver] Task {tid} completed successfully in {elapsed:.2f}s.")
- fut.set_result(task_info["received"])
-
-
class AsyncOmni(EngineClient, OmniBase):
"""Asynchronous unified entry point for multi-stage pipelines using AsyncOmniEngine.
@@ -140,8 +76,9 @@ def __init__(self, *args: Any, model: str = "", **kwargs: Any) -> None:
self._paused: bool = False
self._is_sleeping: bool = False
self.final_output_task: asyncio.Task | None = None
- self.event_resolver = AsyncEventResolver(orchestrator=self)
+
self.config_path = self.engine.config_path
+ self.stage_configs = self.engine.stage_configs
self.tts_max_instructions_length = kwargs.get("tts_max_instructions_length", None)
self.input_processor = self.engine.input_processor
@@ -272,20 +209,7 @@ async def generate(
# Start final output dispatcher on the first call to generate()
self._final_output_handler()
- # Expand sampling params for PD disaggregation (user may provide N-1 params)
- if (
- sampling_params_list is not None
- and isinstance(sampling_params_list, Sequence)
- and not isinstance(sampling_params_list, (str, bytes))
- ):
- sampling_params_list = self._maybe_expand_sampling_params(list(sampling_params_list))
-
- # Set the output kind to delta output if sampling params were omitted,
- # since AsyncOmni is typically used for streaming.
- sampling_params_list = self.resolve_sampling_params_list(
- sampling_params_list,
- allow_delta_coercion=True,
- )
+ sampling_params_list = self.resolve_sampling_params_list(sampling_params_list)
# Track per-request metrics
wall_start_ts = time.time()
@@ -304,27 +228,20 @@ async def generate(
req_state.metrics = metrics
self.request_states[request_id] = req_state
- # PD disaggregation: modify prefill-stage sampling params per request
- req_sp_list = list(sampling_params_list)
- pd_pair = self._get_pd_separation_pair()
- if pd_pair is not None:
- p_id = pd_pair[0]
- req_sp_list[p_id] = self._prepare_prefill_sampling_params(request_id, req_sp_list[p_id])
-
# Add request(s) to stage 0. For streaming inputs, submit
# chunks incrementally through streaming_update.
if isinstance(prompt, AsyncGenerator):
input_stream_task = await self._add_streaming_input_request(
request_id=request_id,
input_stream=prompt,
- sampling_params_list=req_sp_list,
+ sampling_params_list=sampling_params_list,
final_stage_id=final_stage_id_for_e2e,
)
else:
await self.engine.add_request_async(
request_id=request_id,
prompt=prompt,
- sampling_params_list=req_sp_list,
+ sampling_params_list=sampling_params_list,
final_stage_id=final_stage_id_for_e2e,
)
submit_ts = time.time()
@@ -373,14 +290,15 @@ async def _add_streaming_input_request(
# only check thinker's sampling params now
stage0_params = sampling_params_list[0]
self._validate_streaming_input_sampling_params(stage0_params)
+
req_state = self.request_states[request_id]
- has_submitted_first_chunk = False
- # NOTE: InputProcessor in vLLM should generally do this too, but for
- # now we do it defensively. TODO (Alex) ensure clones/copying are optimized
if not stage0_params.skip_clone:
stage0_params = stage0_params.clone()
stage0_params.skip_clone = True
+ stage0_params.output_kind = RequestOutputKind.DELTA
+
+ has_submitted_first_chunk = False
async def handle_inputs() -> None:
nonlocal has_submitted_first_chunk
@@ -503,12 +421,6 @@ async def _process_orchestrator_results(
stage_id = result.get("stage_id", 0)
- if result.get("type") == "error" and result.get("fatal"):
- raise OmniEngineDeadError(
- result.get("error", ""),
- error_stage_id=result.get("stage_id"),
- )
-
# Check for errors
if "error" in result:
logger.error(
@@ -519,8 +431,6 @@ async def _process_orchestrator_results(
)
raise RuntimeError(result)
- self._check_engine_output_error(result, request_id, stage_id)
-
# Process the result (constructs OmniRequestOutput)
output_to_yield = self._process_single_result(
result,
@@ -566,22 +476,6 @@ async def _final_output_loop():
await asyncio.sleep(_FINAL_OUTPUT_IDLE_SLEEP_S)
continue
- if isinstance(msg, dict) and msg.get("type") == "ack":
- ack_data = msg.get("ack")
- tid = getattr(ack_data, "task_id", "unknown")
- logger.info(f"[{self._name}] Intercepted wrapped ACK for task {tid}")
- await self.event_resolver.resolve(ack_data)
- continue
- if isinstance(msg, OmniACK):
- logger.info(f"[{self._name}] Intercepted raw ACK object: {msg.task_id}")
- await self.event_resolver.resolve(msg)
- continue
- if hasattr(msg, "task_id"):
- tid = getattr(msg, "task_id")
- logger.info(f"[{self._name}] Intercepted task-ID object: {tid}")
- await self.event_resolver.resolve(msg)
- continue
-
should_continue, _, stage_id, req_state = self._handle_output_message(msg)
if should_continue:
continue
@@ -593,28 +487,6 @@ async def _final_output_loop():
except asyncio.CancelledError:
raise
- except OmniEngineDeadError as e:
- logger.error("[AsyncOmni] Engine dead: %s", e)
- for req_state in list(self.request_states.values()):
- error_msg = {
- "type": "error",
- "error": str(e),
- "fatal": True,
- "request_id": req_state.request_id,
- }
- if e.error_stage_id is not None:
- error_msg["stage_id"] = e.error_stage_id
- await req_state.queue.put(error_msg)
- except EngineDeadError as e:
- logger.error("[AsyncOmni] Engine dead: %s", e)
- for req_state in list(self.request_states.values()):
- error_msg = {
- "type": "error",
- "error": str(e),
- "fatal": True,
- "request_id": req_state.request_id,
- }
- await req_state.queue.put(error_msg)
except Exception as e:
logger.exception("[AsyncOmni] final_output_loop failed.")
for req_state in list(self.request_states.values()):
@@ -770,68 +642,21 @@ async def reset_prefix_cache(
logger.warning("[AsyncOmni] reset_prefix_cache not yet supported with Orchestrator process")
return True
- async def sleep(
- self, stage_ids: list[int] | None = None, level: int = 2, mode: PauseMode = "abort"
- ) -> list[OmniACK]:
- self._final_output_handler()
- if stage_ids is None:
- stage_ids = list(range(len(self.engine.stage_clients)))
- total_workers = 0
- for sid in stage_ids:
- client = self.engine.stage_clients[sid]
- # During the Diffusion phase, regardless of the TP amount,
- # currently only a summary ACK is reported at Rank 0.
- if getattr(client, "stage_type", "") == "diffusion":
- total_workers += 1
- else:
- config = self.engine.stage_vllm_configs[sid]
- actual_tp = config.parallel_config.tensor_parallel_size if config else 1
- total_workers += actual_tp
-
- task_id = str(uuid.uuid4())
- self.event_resolver.watch_task(task_id, expected_count=total_workers)
- logger.info(f"[{self._name}] Sleep initiated (Task: {task_id}). Awaiting {total_workers} ACKs...")
- task = OmniSleepTask(level=level, task_id=task_id)
- rpc_results = await self.collective_rpc(method="handle_sleep_task", args=(task,), stage_ids=stage_ids)
- final_acks = []
- for stage_res in rpc_results:
- worker_acks = stage_res if isinstance(stage_res, list) else [stage_res]
- for ack in worker_acks:
- if ack is not None:
- await self.event_resolver.resolve(ack)
- final_acks.append(ack)
+ async def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
+ """Sleep all stages.
+
+ Best-effort: unsupported stages will emit a TODO result.
+ """
self._is_sleeping = True
- return final_acks
-
- async def wake_up(self, stage_ids: list[int] | None = None, tags: list[str] | None = None) -> list[OmniACK]:
- self._final_output_handler()
- if stage_ids is None:
- stage_ids = list(range(len(self.engine.stage_clients)))
- total_workers = 0
- for sid in stage_ids:
- client = self.engine.stage_clients[sid]
- if getattr(client, "stage_type", "") == "diffusion":
- total_workers += 1
- else:
- config = self.engine.stage_vllm_configs[sid]
- total_workers += config.parallel_config.tensor_parallel_size if config else 1
- task_id = str(uuid.uuid4())
- self.event_resolver.watch_task(task_id, expected_count=total_workers)
- logger.info(f"[{self._name}] Wake-up initiated (Task: {task_id}). Awaiting {total_workers} ACKs...")
- task = OmniWakeTask(tags=tags, task_id=task_id)
- rpc_results = await self.collective_rpc(method="handle_wake_task", args=(task,), stage_ids=stage_ids)
- final_acks = []
- for stage_res in rpc_results:
- worker_acks = stage_res if isinstance(stage_res, list) else [stage_res]
- for ack in worker_acks:
- if ack is not None:
- await self.event_resolver.resolve(ack)
- final_acks.append(ack)
- current_omni_platform.synchronize()
- await asyncio.sleep(0.1)
+ await self.collective_rpc(method="sleep", args=(level,))
+
+ async def wake_up(self, tags: list[str] | None = None) -> None:
+ """Wake up all stages.
+
+ Best-effort: unsupported stages will emit a TODO result.
+ """
self._is_sleeping = False
- logger.info(f"[{self._name}] All {len(final_acks)}/{total_workers} workers reported WARM for task {task_id}.")
- return final_acks
+ await self.collective_rpc(method="wake_up", args=(tags,))
async def is_sleeping(self) -> bool:
"""Return whether all stages are sleeping.
@@ -881,25 +706,12 @@ async def pin_lora(self, adapter_id: int) -> bool:
@property
def is_running(self) -> bool:
"""Check if the engine is running."""
- orchestrator_alive = self.engine.is_alive()
- task_alive = self.final_output_task is not None and not self.final_output_task.done()
- return orchestrator_alive and task_alive
+ return self.final_output_task is not None and not self.final_output_task.done()
@property
def errored(self) -> bool:
- """Whether the engine is in a non-recoverable error state.
-
- Delegates to ``OmniBase.errored`` which checks the orchestrator
- thread and all stage clients. Redeclared here to satisfy the
- ``EngineClient`` abstract-property requirement (Python's ABC
- mechanism does not resolve abstract methods from sibling MRO
- entries).
- """
- return OmniBase.errored.fget(self) # type: ignore[union-attr]
-
- @property
- def _name(self) -> str:
- return "AsyncOrchestrator"
+ """Whether orchestrator thread has stopped unexpectedly."""
+ return not self.engine.is_alive()
@property
def is_stopped(self) -> bool:
@@ -909,7 +721,7 @@ def is_stopped(self) -> bool:
@property
def dead_error(self) -> BaseException:
"""EngineClient abstract property implementation."""
- return OmniEngineDeadError()
+ return EngineDeadError()
# ==================== EngineClient Interface ====================
diff --git a/vllm_omni/entrypoints/cfg_companion_tracker.py b/vllm_omni/entrypoints/cfg_companion_tracker.py
new file mode 100644
index 00000000000..9c2e835f074
--- /dev/null
+++ b/vllm_omni/entrypoints/cfg_companion_tracker.py
@@ -0,0 +1,233 @@
+"""CFG companion request tracker for the Omni orchestrator.
+
+Encapsulates all bookkeeping for Classifier-Free Guidance companion
+requests (prompt expansion, parent/companion ID mapping, completion
+tracking, deferred forwarding, failure propagation, and timeouts)
+so that ``Omni._run_generation`` stays clean.
+"""
+
+from __future__ import annotations
+
+import copy
+import logging
+import os
+import time
+from collections.abc import Callable, Sequence
+from typing import Any
+
+from vllm_omni.distributed.omni_connectors.adapter import try_send_via_connector
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams
+
+logger = logging.getLogger(__name__)
+
+
+class CfgCompanionTracker:
+ """Manages CFG companion request lifecycle in the orchestrator scheduling loop."""
+
+ def __init__(
+ self,
+ prompt_expand_func: Callable[..., Any] | None,
+ stage0_sampling_params: Any,
+ timeout_s: float | None = None,
+ ) -> None:
+ self._expand_func = prompt_expand_func
+ self._sp0 = stage0_sampling_params
+ self._timeout_s = (
+ timeout_s if timeout_s is not None else float(os.environ.get("VLLM_CFG_PENDING_TIMEOUT_S", "120"))
+ )
+
+ self._companion_map: dict[str, dict[str, str]] = {} # parent -> {role: companion_id}
+ self._companion_ids: set[str] = set()
+ self._companion_to_parent: dict[str, str] = {} # companion -> parent
+ self._done: dict[str, set[str]] = {} # parent -> completed companion ids
+ self._pending_parents: dict[str, dict[str, Any]] = {} # parent -> deferred result
+ self._failed_parents: set[str] = set()
+
+ @property
+ def is_active(self) -> bool:
+ return bool(self._companion_ids)
+
+ @property
+ def num_companions(self) -> int:
+ return len(self._companion_ids)
+
+ @property
+ def stage0_sampling_params(self) -> Any:
+ return self._sp0
+
+ def expand_prompts(
+ self,
+ request_id_to_prompt: dict[str, Any],
+ ) -> list[tuple[str, Any]]:
+ """Expand user prompts into ``(companion_id, prompt)`` pairs via model-specific func."""
+ if not self._expand_func:
+ return []
+
+ pairs: list[tuple[str, Any]] = []
+ for rid, prompt in request_id_to_prompt.items():
+ expanded = self._expand_func(prompt, self._sp0)
+ if not expanded:
+ continue
+ role_map: dict[str, str] = {}
+ for ep in expanded:
+ cid = f"{rid}{ep.request_id_suffix}"
+ role_map[ep.role] = cid
+ self._companion_ids.add(cid)
+ self._companion_to_parent[cid] = rid
+ pairs.append((cid, ep.prompt))
+ self._companion_map[rid] = role_map
+ self._done[rid] = set()
+
+ logger.debug(
+ "CFG expansion: %d parent(s) -> %d companion(s)",
+ len(self._companion_map),
+ len(self._companion_ids),
+ )
+ return pairs
+
+ def is_companion(self, req_id: str) -> bool:
+ return req_id in self._companion_ids
+
+ def has_companions(self, parent_id: str) -> bool:
+ return parent_id in self._companion_map
+
+ def all_companions_done(self, parent_id: str) -> bool:
+ role_map = self._companion_map.get(parent_id, {})
+ done_set = self._done.get(parent_id, set())
+ return all(cid in done_set for cid in role_map.values())
+
+ def get_companion_request_ids(self, parent_id: str) -> dict[str, str]:
+ """Return ``{role: companion_request_id}`` for a parent."""
+ return self._companion_map.get(parent_id, {})
+
+ def is_parent_failed(self, parent_id: str) -> bool:
+ return parent_id in self._failed_parents
+
+ # -- Lifecycle events --
+
+ def on_companion_error(self, companion_id: str) -> tuple[str | None, bool]:
+ """Record failure. Returns ``(parent_id, parent_was_aborted)``."""
+ parent_id = self._companion_to_parent.get(companion_id)
+ if parent_id is None:
+ return None, False
+ self._failed_parents.add(parent_id)
+ logger.error("CFG companion %s failed; marking parent %s as failed", companion_id, parent_id)
+ aborted = parent_id in self._pending_parents
+ if aborted:
+ self._pending_parents.pop(parent_id, None)
+ return parent_id, aborted
+
+ def on_companion_completed(self, companion_id: str) -> str | None:
+ """Mark done. Returns parent_id only if parent is pending and all companions finished."""
+ parent_id = self._companion_to_parent.get(companion_id)
+ if parent_id is None:
+ return None
+ self._done[parent_id].add(companion_id)
+ logger.debug("CFG companion %s completed (parent=%s)", companion_id, parent_id)
+ if parent_id in self._pending_parents and self.all_companions_done(parent_id):
+ return parent_id
+ return None
+
+ def consume_parent_failure(self, parent_id: str) -> None:
+ self._failed_parents.discard(parent_id)
+
+ # -- Deferred parent management --
+
+ def defer_parent(self, parent_id: str, engine_outputs: Any, stage_id: int) -> None:
+ """Hold parent result while waiting for companions to finish."""
+ self._pending_parents[parent_id] = {
+ "engine_outputs": engine_outputs,
+ "stage_id": stage_id,
+ "pending_since": time.time(),
+ }
+ logger.debug("Parent %s deferred, waiting for CFG companions", parent_id)
+
+ def pop_pending_parent(self, parent_id: str) -> dict[str, Any] | None:
+ return self._pending_parents.pop(parent_id, None)
+
+ def check_timeouts(self) -> list[str]:
+ """Return and remove parent IDs that exceeded the pending timeout."""
+ if not self._pending_parents:
+ return []
+ now = time.time()
+ timed_out: list[str] = []
+ for pid in list(self._pending_parents):
+ pending_since = self._pending_parents[pid].get("pending_since", now)
+ if now - pending_since > self._timeout_s:
+ self._pending_parents.pop(pid)
+ self._failed_parents.discard(pid)
+ timed_out.append(pid)
+ logger.error("Parent %s timed out waiting for CFG companions (>%.0fs)", pid, self._timeout_s)
+ return timed_out
+
+ # -- Forward parent with CFG KV --
+
+ def forward_parent_with_cfg(
+ self,
+ req_id: str,
+ parent_result: dict[str, Any],
+ stage_list: Sequence[Any],
+ connectors: dict[tuple[str, str], Any],
+ sampling_params_list: Sequence[OmniSamplingParams],
+ request_id_to_prompt: dict[str, Any],
+ final_stage_id_to_prompt: dict[str, int],
+ metrics: Any,
+ remaining_by_stage: list[int],
+ ) -> bool:
+ """Forward a parent request to the next stage with CFG KV request IDs attached."""
+ stage_id = parent_result["stage_id"]
+ next_stage_id = stage_id + 1
+ if next_stage_id > final_stage_id_to_prompt.get(req_id, 0):
+ return True
+
+ next_stage = stage_list[next_stage_id]
+ try:
+ with metrics.stage_postprocess_timer(stage_id, req_id):
+ next_inputs = next_stage.process_engine_inputs(
+ stage_list,
+ [request_id_to_prompt[req_id]],
+ source_outputs_override=parent_result["engine_outputs"],
+ )
+ except Exception as e:
+ logger.exception(
+ "Process engine inputs error for req %s at stage %d: %s",
+ req_id,
+ next_stage_id,
+ e,
+ )
+ return False
+
+ sp_next = copy.deepcopy(sampling_params_list[next_stage_id])
+ if isinstance(sp_next, OmniDiffusionSamplingParams):
+ sp_next.cfg_kv_request_ids = self.get_companion_request_ids(req_id)
+ logger.info(
+ "Attaching cfg_kv_request_ids=%s to request %s",
+ sp_next.cfg_kv_request_ids,
+ req_id,
+ )
+
+ connector_key = (str(stage_id), str(next_stage_id))
+ connector = connectors.get(connector_key)
+ sent_via_connector = False
+ if connector:
+ sent_via_connector = try_send_via_connector(
+ connector=connector,
+ stage_id=stage_id,
+ next_stage_id=next_stage_id,
+ req_id=req_id,
+ next_inputs=next_inputs,
+ sampling_params=sp_next,
+ original_prompt=request_id_to_prompt[req_id],
+ next_stage_queue_submit_fn=stage_list[next_stage_id].submit,
+ metrics=metrics,
+ )
+
+ if not sent_via_connector:
+ raise RuntimeError(
+ f"Failed to send CFG request {req_id} to stage-{next_stage_id} via connector. "
+ "Configure a connector for this edge or inspect connector logs for details."
+ )
+
+ logger.debug("Forwarded CFG-enabled request %s to stage-%d", req_id, next_stage_id)
+ remaining_by_stage[next_stage_id] += 1
+ return True
diff --git a/vllm_omni/entrypoints/chat_utils.py b/vllm_omni/entrypoints/chat_utils.py
index 4c3d311ec50..8970e589844 100644
--- a/vllm_omni/entrypoints/chat_utils.py
+++ b/vllm_omni/entrypoints/chat_utils.py
@@ -2,7 +2,7 @@
async def extract_audio_from_video_async(video_url: str) -> tuple[np.ndarray, int | float]:
- """Extract audio from a video URL using vllm's load_audio.
+ """Extract audio from a video URL using librosa.
Returns a (audio_array, sample_rate) tuple compatible with audio format.
All blocking I/O operations are run in a thread pool.
@@ -26,9 +26,9 @@ def _write_temp_file_sync(data: bytes, suffix: str) -> str:
return temp_file.name
def _load_audio_sync(file_path: str) -> tuple[np.ndarray, int | float]:
- from vllm.multimodal.media.audio import load_audio
+ import librosa
- return load_audio(file_path, sr=16000)
+ return librosa.load(file_path, sr=16000)
def _cleanup_file_sync(file_path: str) -> None:
try:
diff --git a/vllm_omni/entrypoints/cli/benchmark/serve.py b/vllm_omni/entrypoints/cli/benchmark/serve.py
index cee900f618e..906e8851a4a 100644
--- a/vllm_omni/entrypoints/cli/benchmark/serve.py
+++ b/vllm_omni/entrypoints/cli/benchmark/serve.py
@@ -1,5 +1,4 @@
import argparse
-import os
from vllm.benchmarks.serve import add_cli_args
@@ -7,153 +6,15 @@
from vllm_omni.entrypoints.cli.benchmark.base import OmniBenchmarkSubcommandBase
-def add_daily_omni_cli_args(parser: argparse.ArgumentParser) -> None:
- """Add CLI arguments specific to Daily-Omni dataset.
-
- This function should be called by the CLI entrypoint to add additional
- arguments for daily-omni benchmark support.
-
- Args:
- parser: The ArgumentParser instance to extend
- """
- # Daily-Omni specific arguments
- daily_omni_group = parser.add_argument_group("Daily-Omni Dataset Options")
-
- daily_omni_group.add_argument(
- "--daily-omni-qa-json",
- type=str,
- default=None,
- help="Path to local upstream qa.json. When set, QA rows are read from this file and "
- "the HuggingFace dataset is not loaded (no network). Use with --daily-omni-video-dir "
- "for fully offline runs. --dataset-path / Hub split flags are then ignored for QA loading.",
- )
- daily_omni_group.add_argument(
- "--daily-omni-video-dir",
- type=str,
- default=None,
- help="Root directory of extracted Daily-Omni videos (contents of Videos.tar: "
- "each video_id in its own subdir with {video_id}_video.mp4). "
- "If omitted, Videos.tar is downloaded from the Hugging Face dataset repo on first multimodal "
- "request. "
- "When using file URLs, you MUST start the vLLM server with "
- "--allowed-local-media-path set to this same directory (or a parent), "
- "otherwise requests fail with 'Cannot load local files without "
- "--allowed-local-media-path'.",
- )
- daily_omni_group.add_argument(
- "--daily-omni-inline-local-video",
- action="store_true",
- default=False,
- help="For local videos only: embed MP4 as base64 data URLs in benchmark "
- "requests so the server does not need --allowed-local-media-path. "
- "Increases request size and client memory; use for small --num-prompts. "
- "When using --daily-omni-input-mode audio or all, local WAV files are "
- "embedded the same way.",
- )
- daily_omni_group.add_argument(
- "--daily-omni-input-mode",
- type=str,
- choices=["all", "visual", "audio"],
- default="all",
- help="Daily-Omni input protocol (mirrors upstream Lliar-liar/Daily-Omni "
- "--input_mode). 'visual': video only (default). 'audio': WAV only, "
- "requires {video_id}/{video_id}_audio.wav under --daily-omni-video-dir. "
- "'all': video + WAV together. Sets mm_processor_kwargs.use_audio_in_video=false "
- "and matches official separate video/audio streams.",
- )
- daily_omni_group.add_argument(
- "--daily-omni-save-eval-items",
- action="store_true",
- default=False,
- help="Include per-request Daily-Omni accuracy rows (gold/predicted/correct) "
- "in the saved JSON under key daily_omni_eval_items. "
- "Alternatively set env DAILY_OMNI_SAVE_EVAL_ITEMS=1.",
- )
-
- # Note: --dataset-name daily-omni via get_samples patch; use either Hub (--dataset-path
- # liarliar/Daily-Omni) or local --daily-omni-qa-json (offline).
-
-
-def add_seed_tts_cli_args(parser: argparse.ArgumentParser) -> None:
- """CLI for Seed-TTS zero-shot TTS benchmark (``--dataset-name seed-tts``)."""
- g = parser.add_argument_group("Seed-TTS Dataset Options")
- g.add_argument(
- "--seed-tts-locale",
- type=str,
- choices=["en", "zh"],
- default="en",
- help="Which Seed-TTS split to load: en/meta.lst or zh/meta.lst under the dataset root.",
- )
- g.add_argument(
- "--seed-tts-root",
- type=str,
- default=None,
- help="Override root directory that contains en/ and zh/ (meta.lst + prompt-wavs). "
- "If set, --dataset-path can still name the HF repo for logging; this path is used for files.",
- )
- g.add_argument(
- "--seed-tts-file-ref-audio",
- action="store_true",
- default=False,
- help="Send ref_audio as file:// URIs (smaller HTTP bodies). Requires the API server "
- "to be started with --allowed-local-media-path covering the Seed-TTS dataset root. "
- "Default is inline data:audio/wav;base64 so Qwen3-TTS works without that flag.",
- )
- g.add_argument(
- "--seed-tts-inline-ref-audio",
- action="store_true",
- default=False,
- help=argparse.SUPPRESS,
- )
- g.add_argument(
- "--seed-tts-system-prompt",
- type=str,
- default=None,
- help="Override chat system message for --backend openai-chat-omni (Qwen3-Omni TTS). "
- "Default follows official Qwen3-Omni identity + zero-shot voice-clone instructions.",
- )
- g.add_argument(
- "--seed-tts-wer-eval",
- action="store_true",
- default=False,
- help="Keep synthesized audio as 24 kHz mono PCM for WER (works with "
- "--backend openai-audio-speech or openai-chat-omni). Scoring follows "
- "zhaochenyang20/seed-tts-eval (Whisper-large-v3 / Paraformer-zh + jiwer). "
- "Sets SEED_TTS_WER_EVAL=1. Install: pip install 'vllm-omni[seed-tts-eval]'. "
- "Optional: SEED_TTS_EVAL_DEVICE, SEED_TTS_HF_WHISPER_MODEL.",
- )
- g.add_argument(
- "--seed-tts-wer-save-items",
- action="store_true",
- default=False,
- help="Include per-utterance ASR rows in the saved JSON under key seed_tts_wer_eval_items. "
- "Or set SEED_TTS_WER_SAVE_ITEMS=1.",
- )
-
-
class OmniBenchmarkServingSubcommand(OmniBenchmarkSubcommandBase):
"""The `serve` subcommand for vllm bench."""
name = "serve"
- help = "Benchmark the online serving throughput. Supports Daily-Omni and Seed-TTS datasets."
+ help = "Benchmark the online serving throughput."
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
-
- # Add Daily-Omni specific arguments
- add_daily_omni_cli_args(parser)
- add_seed_tts_cli_args(parser)
-
- for action in parser._actions:
- if action.dest == "dataset_name" and action.choices is not None:
- extra = [
- c for c in ("daily-omni", "seed-tts", "seed-tts-text", "seed-tts-design") if c not in action.choices
- ]
- if extra:
- action.choices = list(action.choices) + extra
-
- # Update help messages for omni-specific features
for action in parser._actions:
if action.dest == "percentile_metrics":
action.help = (
@@ -187,10 +48,4 @@ def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
@staticmethod
def cmd(args: argparse.Namespace) -> None:
- if getattr(args, "daily_omni_save_eval_items", False):
- os.environ["DAILY_OMNI_SAVE_EVAL_ITEMS"] = "1"
- if getattr(args, "seed_tts_wer_eval", False):
- os.environ["SEED_TTS_WER_EVAL"] = "1"
- if getattr(args, "seed_tts_wer_save_items", False):
- os.environ["SEED_TTS_WER_SAVE_ITEMS"] = "1"
main(args)
diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py
index 9f2bef26776..b72df41cdd5 100644
--- a/vllm_omni/entrypoints/cli/serve.py
+++ b/vllm_omni/entrypoints/cli/serve.py
@@ -8,8 +8,6 @@
import argparse
import json
import os
-import signal
-from types import FrameType
from typing import Any
import uvloop
@@ -19,7 +17,6 @@
from vllm.logger import init_logger
from vllm.utils.argparse_utils import FlexibleArgumentParser
-from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
from vllm_omni.entrypoints.cli.logo import log_logo
from vllm_omni.entrypoints.openai.api_server import omni_run_server
@@ -80,9 +77,6 @@ class OmniServeCommand(CLISubcommand):
"""The `serve` subcommand for the vLLM CLI."""
name = "serve"
- # Parser stashed at subparser_init so ``cmd`` can resolve each user-typed
- # flag to its real ``dest`` via the parser's action table.
- _parser: FlexibleArgumentParser | None = None
@staticmethod
def cmd(args: argparse.Namespace) -> None:
@@ -134,17 +128,6 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu
action="store_true",
help="Enable vLLM-Omni mode for multi-modal and diffusion models",
)
-
- try:
- omni_config_group.add_argument(
- "--enable-sleep-mode",
- action="store_true",
- default=False,
- help="Enable GPU memory pool for sleep mode.",
- )
- except argparse.ArgumentError:
- pass
-
omni_config_group.add_argument(
"--task-type",
type=str,
@@ -153,33 +136,11 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu
help="Default task type for TTS models (CustomVoice, VoiceDesign, or Base). "
"If not specified, will be inferred from model path.",
)
- # TODO(@lishunyang12): deprecate once all models migrate to --deploy-config
omni_config_group.add_argument(
"--stage-configs-path",
type=str,
default=None,
- help="[Deprecated — will be removed in a future release] Path to a legacy "
- "stage configs YAML (stage_args format). Prefer --deploy-config for new-format deploy YAMLs.",
- )
- omni_config_group.add_argument(
- "--deploy-config",
- type=str,
- default=None,
- help="Path to a deploy config YAML (new format with stages/engine_args). "
- "Mutually exclusive with --stage-configs-path.",
- )
- omni_config_group.add_argument(
- "--stage-overrides",
- type=str,
- default=None,
- help="Per-stage JSON overrides. Example: "
- '\'{"0": {"gpu_memory_utilization": 0.8}, "2": {"enforce_eager": true}}\'',
- )
- omni_config_group.add_argument(
- "--async-chunk",
- action=argparse.BooleanOptionalAction,
- default=None,
- help="Override the deploy YAML's ``async_chunk:`` bool. Unset leaves the YAML value in force.",
+ help="Path to the stage configs file. If not specified, the stage configs will be loaded from the model.",
)
omni_config_group.add_argument(
"--stage-id",
@@ -262,40 +223,6 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu
default=None,
help="Override the diffusion pipeline class name (e.g. LTX2ImageToVideoPipeline).",
)
- omni_config_group.add_argument(
- "--diffusion-load-format",
- dest="diffusion_load_format",
- type=str,
- default=None,
- choices=["default", "custom_pipeline", "dummy", "diffusers"],
- help=(
- "How to load the diffusion pipeline: native/registry (default), "
- "custom_pipeline, dummy, or diffusers for the HF diffusers adapter."
- ),
- )
- omni_config_group.add_argument(
- "--diffusers-load-kwargs",
- dest="diffusers_load_kwargs",
- type=json.loads,
- default="{}",
- help=(
- "JSON object passed to DiffusionPipeline.from_pretrained()."
- "It overrides corresponding parameters in the standard vLLM-Omni interface."
- '(e.g. \'{"use_safetensors": true, "variant": "fp16"}\').'
- ),
- )
- omni_config_group.add_argument(
- "--diffusers-call-kwargs",
- dest="diffusers_call_kwargs",
- type=json.loads,
- default="{}",
- help=(
- "JSON object passed to pipeline.__call__(). "
- "Useful for model-specific sampling parameters not covered by the vLLM-Omni interface."
- "During request time, it is overridden by corresponding parameters in the vLLM-Omni interface."
- '(e.g. \'{"num_inference_steps": 30, "guidance_scale": 7.5}\').'
- ),
- )
omni_config_group.add_argument(
"--usp",
"--ulysses-degree",
@@ -477,16 +404,6 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)
- omni_config_group.add_argument(
- "--enable-ar-profiler",
- action="store_true",
- help="Enable AR stage profiler to include AR stage timing in stage_durations.",
- )
- # Stash via type(self) so the docs hook (which execs this function in a
- # sandboxed globals dict via ``DummySelf``) doesn't fail on a NameError.
- type(self)._parser = serve_parser
-
- nullify_stage_engine_defaults(serve_parser)
return serve_parser
@@ -502,230 +419,23 @@ def _create_default_diffusion_stage_cfg(args: argparse.Namespace) -> list[dict[s
def run_headless(args: argparse.Namespace) -> None:
- """Run a single stage in headless mode."""
- from vllm.v1.engine.coordinator import DPCoordinator
- from vllm.v1.engine.utils import CoreEngineProcManager
- from vllm.v1.executor.multiproc_executor import MultiprocExecutor
- from vllm.version import __version__ as VLLM_VERSION
-
- from vllm_omni.diffusion.stage_diffusion_proc import (
- complete_diffusion_handshake,
- spawn_diffusion_proc,
- )
- from vllm_omni.distributed.omni_connectors.utils.initialization import resolve_omni_kv_config_for_stage
- from vllm_omni.engine.stage_engine_startup import register_stage_with_omni_master
- from vllm_omni.engine.stage_init_utils import (
- build_diffusion_config,
- build_engine_args_dict,
- build_vllm_config,
- extract_stage_metadata,
- get_stage_connector_spec,
- inject_kv_stage_info,
- load_omni_transfer_config_for_model,
- prepare_engine_environment,
- terminate_alive_proc,
- )
- from vllm_omni.entrypoints.utils import inject_omni_kv_config, load_and_resolve_stage_configs
-
- model = args.model
- stage_id: int | None = args.stage_id
- omni_master_address: str | None = args.omni_master_address
- omni_master_port: int | None = args.omni_master_port
-
- if stage_id is None:
- raise ValueError("--stage-id is required in headless mode")
- if omni_master_address is None or omni_master_port is None:
- raise ValueError("--omni-master-address and --omni-master-port are required in headless mode")
- if getattr(args, "api_server_count", 0) and args.api_server_count > 1:
- raise ValueError("api_server_count can't be set in headless mode")
- if args.worker_backend != "multi_process":
- raise ValueError("headless mode requires worker_backend=multi_process")
-
- args_dict = vars(args).copy()
- args_dict.pop("_cli_explicit_keys", None)
- config_path, stage_configs = load_and_resolve_stage_configs(
- model,
- args_dict.get("stage_configs_path"),
- args_dict,
- )
-
- # Locate the stage config that matches stage_id.
- stage_cfg = None
- for cfg in stage_configs:
- if getattr(cfg, "stage_id", None) == stage_id:
- stage_cfg = cfg
- break
- if stage_cfg is None:
- raise ValueError(
- f"No stage config found for stage_id={stage_id}. "
- f"Available stage ids: {[getattr(c, 'stage_id', None) for c in stage_configs]}"
- )
-
- prepare_engine_environment()
- omni_transfer_config = load_omni_transfer_config_for_model(model, config_path)
- omni_conn_cfg, omni_from, omni_to = resolve_omni_kv_config_for_stage(omni_transfer_config, stage_id)
-
- if getattr(stage_cfg, "stage_type", "llm") == "diffusion":
- metadata = extract_stage_metadata(stage_cfg)
- if omni_conn_cfg:
- inject_omni_kv_config(stage_cfg, omni_conn_cfg, omni_from, omni_to)
- inject_kv_stage_info(stage_cfg, stage_id)
- od_config = build_diffusion_config(model, stage_cfg, metadata)
-
- logger.info(
- "[Headless] Launching diffusion stage %d via OmniMasterServer at %s:%d",
- stage_id,
- omni_master_address,
- omni_master_port,
- )
-
- proc = None
- try:
- handshake_address, request_address, response_address = register_stage_with_omni_master(
- omni_master_address=omni_master_address,
- omni_master_port=omni_master_port,
- omni_stage_id=stage_id,
- omni_stage_config=stage_cfg,
- return_addresses=True,
- )
- proc, _, _, _ = spawn_diffusion_proc(
- model,
- od_config,
- handshake_address=handshake_address,
- request_address=request_address,
- response_address=response_address,
- )
- complete_diffusion_handshake(proc, handshake_address)
- proc.join()
- if proc.exitcode not in (None, 0):
- raise RuntimeError(f"Diffusion stage {stage_id} exited with code {proc.exitcode}")
- return
- finally:
- logger.info("[Headless] Shutting down stage %d.", stage_id)
- if proc is not None and proc.is_alive():
- terminate_alive_proc(proc)
-
- stage_connector_spec = get_stage_connector_spec(
- omni_transfer_config=omni_transfer_config,
- stage_id=stage_id,
- async_chunk=False,
- )
+ """Run a single stage in headless mode.
- # Device assignment is managed externally (e.g. CUDA_VISIBLE_DEVICES);
- # runtime_cfg is intentionally ignored in headless mode.
- engine_args_dict = build_engine_args_dict(
- stage_cfg,
- model,
- stage_connector_spec=stage_connector_spec,
- )
-
- # Inject omni KV connector config so the engine runner can initialize the
- # correct connector (sender/receiver role, type, addresses, etc.).
- if omni_conn_cfg:
- omni_kv = engine_args_dict.get("omni_kv_config") or {}
- if not isinstance(omni_kv, dict):
- omni_kv = dict(omni_kv)
- omni_kv["connector_config"] = omni_conn_cfg
- omni_kv["omni_from_stage"] = omni_from
- omni_kv["omni_to_stage"] = omni_to
- omni_kv.setdefault("stage_id", stage_id)
- engine_args_dict["omni_kv_config"] = omni_kv
-
- vllm_config, executor_class = build_vllm_config(
- stage_cfg,
- model,
- stage_connector_spec=stage_connector_spec,
- engine_args_dict=engine_args_dict,
- headless=True,
- )
- parallel_config = vllm_config.parallel_config
- local_engine_count = parallel_config.data_parallel_size_local
-
- if local_engine_count <= 0:
- raise ValueError("data_parallel_size_local must be > 0 in headless mode")
-
- shutdown_requested = False
-
- def signal_handler(signum: int, frame: FrameType | None) -> None:
- nonlocal shutdown_requested
- logger.debug("Received %d signal.", signum)
- if not shutdown_requested:
- shutdown_requested = True
- raise SystemExit
-
- signal.signal(signal.SIGTERM, signal_handler)
- signal.signal(signal.SIGINT, signal_handler)
-
- if parallel_config.node_rank_within_dp > 0:
- head_node_address = f"{parallel_config.master_addr}:{parallel_config.master_port}"
- logger.info(
- "Launching vLLM-Omni (v%s) headless multiproc executor, "
- "with head node address %s for torch.distributed process group.",
- VLLM_VERSION,
- head_node_address,
- )
-
- executor = MultiprocExecutor(vllm_config, monitor_workers=False)
- executor.start_worker_monitor(inline=True)
- return
-
- dp_rank = parallel_config.data_parallel_rank if parallel_config.data_parallel_rank is not None else 0
- coordinator = None
- if vllm_config.needs_dp_coordinator and dp_rank == 0:
- coordinator = DPCoordinator(
- parallel_config,
- enable_wave_coordination=vllm_config.model_config.is_moe,
- )
- logger.info(
- "[Headless] Started DP Coordinator process for stage %d (PID: %d)",
- stage_id,
- coordinator.proc.pid,
- )
-
- logger.info(
- "[Headless] Launching %d engine core(s) for stage %d via OmniMasterServer at %s:%d",
- local_engine_count,
- stage_id,
- omni_master_address,
- omni_master_port,
- )
+ .. deprecated:: 0.x.x
+ Headless mode is deprecated and will be removed in a future version.
+ It is only compatible with the old OmniStage-based runtime.
+ The current AsyncOmniEngine-based runtime does not support headless mode.
- # Headless mode launches all local engine cores for a single stage.
- # The OmniMasterServer allocates one handshake endpoint per stage, so we
- # register the stage once here and let every local engine core reuse the
- # returned handshake address directly.
- handshake_address = register_stage_with_omni_master(
- omni_master_address=omni_master_address,
- omni_master_port=omni_master_port,
- omni_stage_id=stage_id,
- omni_stage_config=stage_cfg,
- coordinator=coordinator,
+ Raises:
+ RuntimeError: Always raises an error indicating headless mode is deprecated.
+ """
+ raise RuntimeError(
+ "Headless mode is deprecated and not supported in the current runtime. "
+ "Please use the standard orchestrator mode (without --headless flag). "
+ "If you need distributed deployment, consider using Ray backend or "
+ "other distributed serving solutions."
)
- engine_manager = None
- log_stats = bool(getattr(args, "log_stats", False))
- if getattr(args, "disable_log_stats", False):
- log_stats = False
-
- try:
- engine_manager = CoreEngineProcManager(
- local_engine_count=local_engine_count,
- start_index=dp_rank,
- local_start_index=0,
- vllm_config=vllm_config,
- local_client=False,
- handshake_address=handshake_address,
- executor_class=executor_class,
- log_stats=log_stats,
- )
- engine_manager.join_first()
- finally:
- logger.info("[Headless] Shutting down stage %d.", stage_id)
- if engine_manager is not None:
- engine_manager.shutdown()
- if coordinator is not None:
- coordinator.shutdown()
-
def cmd_init() -> list[CLISubcommand]:
return [OmniServeCommand()]
diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py
index 223c208af98..a3bfe98ce2c 100644
--- a/vllm_omni/entrypoints/omni.py
+++ b/vllm_omni/entrypoints/omni.py
@@ -66,13 +66,6 @@ def generate(
py_generator: bool = False,
use_tqdm: bool | Callable[..., tqdm] = True,
) -> Generator[OmniRequestOutput, None, None] | list[OmniRequestOutput]:
- # Expand sampling params for PD disaggregation (user may provide N-1 params)
- if (
- sampling_params_list is not None
- and isinstance(sampling_params_list, Sequence)
- and not isinstance(sampling_params_list, (str, bytes))
- ):
- sampling_params_list = self._maybe_expand_sampling_params(list(sampling_params_list))
sampling_params_list = self.resolve_sampling_params_list(sampling_params_list)
try:
if py_generator:
@@ -132,17 +125,10 @@ def _run_generation(
req_state.metrics = metrics
self.request_states[req_id] = req_state
- # PD disaggregation: modify stage-0 (prefill) sampling params per request
- req_sp_list = list(sampling_params_list)
- pd_pair = self._get_pd_separation_pair()
- if pd_pair is not None:
- p_id = pd_pair[0]
- req_sp_list[p_id] = self._prepare_prefill_sampling_params(req_id, req_sp_list[p_id])
-
self.engine.add_request(
request_id=req_id,
prompt=prompt,
- sampling_params_list=req_sp_list,
+ sampling_params_list=sampling_params_list,
final_stage_id=final_stage_id,
)
submit_ts = time.time()
@@ -166,8 +152,6 @@ def _run_generation(
logger.warning("[Omni] Received output for unknown/finished request_id=%s", req_id)
continue
- self._check_engine_output_error(msg, req_id, stage_id)
-
if req_state.metrics is None:
continue
output_to_yield = self._process_single_result(
diff --git a/vllm_omni/entrypoints/omni_base.py b/vllm_omni/entrypoints/omni_base.py
index 4147c802765..96df0591ea1 100644
--- a/vllm_omni/entrypoints/omni_base.py
+++ b/vllm_omni/entrypoints/omni_base.py
@@ -1,22 +1,20 @@
from __future__ import annotations
-import argparse
import os
-import sys
import time
import types
import weakref
from collections.abc import Sequence
+from pprint import pformat
from typing import TYPE_CHECKING, Any, Literal
import huggingface_hub
from vllm.logger import init_logger
-from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
+from vllm.v1.engine.exceptions import EngineDeadError
from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
from vllm_omni.entrypoints.client_request_state import ClientRequestState
-from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin
-from vllm_omni.entrypoints.utils import coerce_param_message_types, get_final_stage_id_for_e2e
+from vllm_omni.entrypoints.utils import get_final_stage_id_for_e2e
from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics
from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific
from vllm_omni.outputs import OmniRequestOutput
@@ -27,23 +25,6 @@
logger = init_logger(__name__)
-class OmniEngineDeadError(EngineDeadError):
- _DEFAULT_MESSAGE = EngineDeadError().args[0]
- error_stage_id: int | None
-
- def __init__(
- self,
- message: str | None = None,
- *,
- error_stage_id: int | None = None,
- suppress_context: bool = False,
- ) -> None:
- resolved_message = message or self._DEFAULT_MESSAGE
- Exception.__init__(self, resolved_message)
- self.__suppress_context__ = suppress_context
- self.error_stage_id = error_stage_id
-
-
def _weak_shutdown_engine(engine: AsyncOmniEngine) -> None:
"""Best-effort engine cleanup for GC finalization."""
try:
@@ -84,65 +65,28 @@ def omni_snapshot_download(model_id: str) -> str:
OutputMessageHandleResult = tuple[Literal[True], None, None, None] | tuple[Literal[False], str, int, ClientRequestState]
-class OmniBase(PDDisaggregationMixin):
+class OmniBase:
"""Shared runtime foundation for AsyncOmni and Omni."""
- @classmethod
- def from_cli_args(
- cls,
- args: argparse.Namespace,
- *,
- parser: argparse.ArgumentParser | None = None,
- **overrides: Any,
- ) -> OmniBase:
- """Build from argparse. If ``parser`` is passed and not yet nullified,
- un-typed engine fields are reset to ``None``."""
- kwargs: dict[str, Any] = {k: v for k, v in vars(args).items() if not k.startswith("_")}
-
- if parser is not None and not getattr(parser, "_omni_nullified", False):
- from vllm_omni.engine.arg_utils import (
- deploy_override_field_names,
- )
- from vllm_omni.entrypoints.utils import detect_explicit_cli_keys
-
- explicit = detect_explicit_cli_keys(sys.argv[1:], parser) or set()
- override_dests = deploy_override_field_names()
- for key in list(kwargs):
- if key in override_dests and key not in explicit:
- kwargs[key] = None
-
- kwargs.update(overrides)
- return cls(**kwargs)
-
def __init__(
self,
model: str,
**kwargs: Any,
) -> None:
engine_args: OmniEngineArgs | None = kwargs.pop("engine_args", None)
-
stage_init_timeout = kwargs.pop("stage_init_timeout", 300)
init_timeout = kwargs.pop("init_timeout", 600)
log_stats = kwargs.pop("log_stats", False)
- self._enable_ar_profiler = kwargs.pop("enable_ar_profiler", False)
- # NOTE: read-only lookup — must NOT pop. Popping here drops the key
- # before it reaches ``StageConfigFactory._create_from_registry``, so
- # ``--no-async-chunk`` (``async_chunk=False``) silently fails to
- # override the deploy YAML's ``async_chunk: true`` default.
- async_chunk = kwargs.get("async_chunk")
+ async_chunk = kwargs.pop("async_chunk", False)
output_modalities = kwargs.pop("output_modalities", None)
diffusion_batch_size: int = kwargs.pop("diffusion_batch_size", 1)
if "log_requests" in kwargs:
raise TypeError("`log_requests` has been removed in Omni/AsyncOmni. Use `log_stats`.")
model = omni_snapshot_download(model)
- self.__dict__["_name"] = self.__class__.__name__
self.model = model
self.log_stats = log_stats
- # Provisional value (mirrors the CLI/caller kwarg); the engine resolves
- # pipeline + deploy YAML + CLI precedence below and the final value is
- # re-assigned from ``self.engine.async_chunk`` after init.
- self.async_chunk = bool(async_chunk) if async_chunk is not None else False
+ self.async_chunk = async_chunk
self.output_modalities = output_modalities or []
self.tts_batch_max_items: int = kwargs.pop("tts_batch_max_items", 32)
@@ -160,11 +104,7 @@ def __init__(
self._weak_finalizer = weakref.finalize(self, _weak_shutdown_engine, self.engine)
et = time.time()
logger.info("[%s] AsyncOmniEngine initialized in %.2f seconds", self.__class__.__name__, et - st)
- # Authoritative: ``AsyncOmniEngine`` resolves (pipeline + deploy YAML +
- # CLI overrides) through ``StageConfigFactory`` and stores the final
- # value on ``engine.async_chunk``; mirror it here so ``--no-async-chunk``
- # (explicit ``False``) is not fallen-back-through by ``or``.
- self.async_chunk = bool(getattr(self.engine, "async_chunk", False))
+ self.async_chunk = bool(self.async_chunk or getattr(self.engine, "async_chunk", False))
self.request_states: dict[str, ClientRequestState] = {}
@@ -185,61 +125,24 @@ def __init__(
model,
)
- # PD disaggregation state (detects if a prefill/decode stage pair is configured)
- self._init_pd_state()
-
@property
def num_stages(self) -> int:
return self.engine.num_stages
- @property
- def stage_configs(self) -> list:
- """Expose engine stage configs for PD disaggregation detection and validation."""
- return self.engine.stage_configs
-
@property
def is_running(self) -> bool:
return self.engine.is_alive()
- @property
- def errored(self) -> bool:
- """Whether the engine is in a non-recoverable error state.
-
- True when the orchestrator thread is dead **or** any stage client
- has been marked dead (e.g. diffusion worker OOM / process death).
-
- Checks both ``_engine_dead`` (StageDiffusionClient) and
- ``resources.engine_dead`` (StageEngineCoreClient / AsyncMPClient)
- since the two client types store the flag differently.
- """
- if not self.engine.is_alive():
- return True
- for stage_client in self.engine.stage_clients:
- if getattr(stage_client, "_engine_dead", False):
- return True
- resources = getattr(stage_client, "resources", None)
- if resources is not None and getattr(resources, "engine_dead", False):
- return True
- return False
-
def check_health(self) -> None:
if not self.engine.is_alive():
raise EngineDeadError("Orchestrator process is not alive")
- for stage_client in self.engine.stage_clients:
- if hasattr(stage_client, "check_health"):
- stage_client.check_health()
def resolve_sampling_params_list(
self,
sampling_params_list: Sequence[Any] | Any | None,
- allow_delta_coercion: bool = False,
) -> Sequence[Any]:
if sampling_params_list is None:
normalized = self.default_sampling_params_list
- # Set the output kind to delta since no params were specified
- if allow_delta_coercion:
- normalized = coerce_param_message_types(list(normalized), is_streaming=True)
-
elif isinstance(sampling_params_list, Sequence) and not isinstance(sampling_params_list, (str, bytes)):
normalized = sampling_params_list
elif self.num_stages == 1:
@@ -255,6 +158,8 @@ def _log_summary_and_cleanup(self, request_id: str) -> None:
try:
if req_state is None or req_state.metrics is None:
return
+ summary = req_state.metrics.build_and_log_summary()
+ logger.info("[Summary] %s", pformat(summary, sort_dicts=False))
except Exception:
logger.exception(
"[%s] Failed to build/log summary for req=%s",
@@ -301,14 +206,7 @@ def _handle_output_message(
return True, None, None, None
if msg_type == "error":
- error_text = msg.get("error", "Orchestrator returned an error message")
- stage_id = msg.get("stage_id")
- if msg.get("fatal"):
- raise OmniEngineDeadError(
- error_text,
- error_stage_id=stage_id,
- )
- raise RuntimeError(error_text)
+ raise RuntimeError(msg.get("error", "Orchestrator returned an error message"))
if msg_type != "output":
logger.warning("[%s] got unexpected msg type: %s", self.__class__.__name__, msg_type)
@@ -337,37 +235,6 @@ def _handle_output_message(
return False, req_id, stage_id, req_state
- def _check_engine_output_error(
- self,
- result: dict[str, Any],
- request_id: str,
- stage_id: int,
- ) -> None:
- """Raise if ``engine_outputs`` carries an error field.
-
- Raises :class:`EngineDeadError` when ``self.errored`` indicates the
- engine is unrecoverable, otherwise raises :class:`EngineGenerateError`
- (recoverable, single-request failure).
- """
- engine_outputs = result.get("engine_outputs")
- error_text = getattr(engine_outputs, "error", None)
- if error_text is None:
- return
- logger.error(
- "[%s] Stage error for req=%s stage-%s: %s",
- self.__class__.__name__,
- request_id,
- stage_id,
- error_text,
- )
- # NOTE: O(n_stages) check for every error.
- if self.errored:
- raise OmniEngineDeadError(
- error_text,
- error_stage_id=stage_id,
- )
- raise EngineGenerateError(error_text)
-
def _process_single_result(
self,
result: dict[str, Any],
@@ -381,30 +248,6 @@ def _process_single_result(
engine_outputs = result.get("engine_outputs")
stage_durations = getattr(result["engine_outputs"], "stage_durations", {})
peak_memory_mb = getattr(result["engine_outputs"], "peak_memory_mb", 0.0)
-
- # Merge AR stage timing from OrchestratorAggregator.stage_events
- if self._enable_ar_profiler:
- ar_events = metrics.stage_events.get(str(req_id), [])
- for evt in ar_events:
- if evt.stage_id != stage_id:
- stage_durations[f"ar_stage_{evt.stage_id}"] = evt.stage_gen_time_ms / 1000.0
-
- # Merge pipeline timings from Orchestrator into stage_durations
- _m = result.get("metrics")
- if _m is not None and hasattr(_m, "pipeline_timings") and _m.pipeline_timings:
- for key, value in _m.pipeline_timings.items():
- if key not in stage_durations:
- stage_durations[key] = value
-
- # Merge per-stage gen times into stage_durations
- for evt in metrics.stage_events.get(str(req_id), []):
- key = f"stage_{evt.stage_id}_gen_ms"
- if key not in stage_durations:
- stage_durations[key] = evt.stage_gen_time_ms
- # Current stage gen time (not yet in stage_events at this point)
- if _m is not None:
- stage_durations.setdefault(f"stage_{stage_id}_gen_ms", _m.stage_gen_time_ms)
-
finished = engine_outputs.finished
submit_ts = result.get("stage_submit_ts")
@@ -439,11 +282,6 @@ def _process_single_result(
final_output_type=stage_meta["final_output_type"],
request_output=engine_outputs,
images=images,
- trajectory_latents=getattr(engine_outputs, "trajectory_latents", None),
- trajectory_timesteps=getattr(engine_outputs, "trajectory_timesteps", None),
- trajectory_log_probs=getattr(engine_outputs, "trajectory_log_probs", None),
- trajectory_decoded=getattr(engine_outputs, "trajectory_decoded", None),
- _custom_output=getattr(engine_outputs, "_custom_output", {}),
stage_durations=stage_durations,
peak_memory_mb=peak_memory_mb,
)
diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py
index 646bbd6f913..627174b20e9 100644
--- a/vllm_omni/entrypoints/openai/api_server.py
+++ b/vllm_omni/entrypoints/openai/api_server.py
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import base64
-import dataclasses
import io
import json
import multiprocessing
@@ -16,11 +15,9 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from http import HTTPStatus
-from numbers import Integral
from typing import Annotated, Any, Literal, cast
import httpx
-import numpy as np
import vllm.envs as envs
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, Request, UploadFile, WebSocket
from fastapi.responses import FileResponse, JSONResponse, Response, StreamingResponse
@@ -32,7 +29,7 @@
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
from vllm.entrypoints.chat_utils import load_chat_template
-from vllm.entrypoints.launcher import serve_http, terminate_if_errored
+from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.mcp.tool_server import DemoToolServer, MCPToolServer, ToolServer
from vllm.entrypoints.openai.api_server import build_app as build_openai_app
@@ -51,11 +48,11 @@
ModelCard,
ModelList,
ModelPermission,
- RequestResponseMetadata,
)
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.orca_metrics import metrics_header
+from vllm.entrypoints.openai.realtime.connection import RealtimeConnection
from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
from vllm.entrypoints.openai.server_utils import get_uvicorn_log_config
@@ -76,7 +73,6 @@
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.entrypoints.utils import (
- create_error_response,
load_aware_call,
process_lora_modules,
with_cancellation,
@@ -86,7 +82,6 @@
from vllm.tool_parsers import ToolParserManager
from vllm.utils import random_uuid
from vllm.utils.system_utils import decorate_logs
-from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.openai.errors import InvalidInputReferenceError
@@ -112,12 +107,10 @@
VideoListResponse,
VideoResponse,
)
-from vllm_omni.entrypoints.openai.realtime_connection import RealtimeConnection
from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech
from vllm_omni.entrypoints.openai.serving_speech_stream import OmniStreamingSpeechHandler
from vllm_omni.entrypoints.openai.serving_video import OmniOpenAIServingVideo, ReferenceImage
-from vllm_omni.entrypoints.openai.serving_video_stream import OmniStreamingVideoHandler
from vllm_omni.entrypoints.openai.storage import STORAGE_MANAGER
from vllm_omni.entrypoints.openai.stores import VIDEO_STORE, VIDEO_TASKS
from vllm_omni.entrypoints.openai.utils import get_stage_type, parse_lora_request
@@ -127,7 +120,6 @@
logger = init_logger(__name__)
router = APIRouter()
-MAX_UINT32_SEED = 2**32 - 1
profiler_router = APIRouter()
@@ -204,81 +196,6 @@ def _remove_route_from_app(app, path: str, methods: set[str] | None = None):
app.routes.remove(route)
-def _register_omni_exception_handlers(app) -> None:
- """Override upstream vLLM exception handlers with Omni-aware versions.
-
- The upstream ``engine_error_handler`` is designed for ``AsyncLLM`` (single
- EngineCore process). Omni uses a multi-stage orchestrator with different
- health semantics, so we register our own handlers that:
-
- - Log multi-stage diagnostic info (orchestrator liveness, per-stage health)
- when an ``EngineDeadError`` is caught.
- - Call ``terminate_if_errored``
- - Return an OpenAI-compatible error JSON response.
- """
-
- async def omni_engine_error_handler(
- req: Request,
- exc: EngineDeadError | EngineGenerateError,
- ):
- request_id = _get_request_id_from_request(req)
-
- if req.app.state.args.log_error_stack:
- logger.exception("Engine Exception caught. Request id: %s", request_id)
-
- return _create_engine_error_json_response(req, exc)
-
- app.exception_handler(EngineGenerateError)(omni_engine_error_handler)
- app.exception_handler(EngineDeadError)(omni_engine_error_handler)
-
-
-def _get_request_id_from_request(req: Request) -> str | None:
- return req.state.request_metadata.request_id if hasattr(req.state, "request_metadata") else None
-
-
-def _build_engine_error_payload(
- exc: EngineDeadError | EngineGenerateError,
- *,
- request_id: str | None,
-) -> tuple[dict[str, Any], int]:
- err = create_error_response(exc)
- payload = err.model_dump()
- error_body = payload.get("error", {})
-
- error_body["request_id"] = request_id
- error_body["error_stage_id"] = getattr(exc, "error_stage_id", None)
-
- return payload, err.error.code
-
-
-def _create_engine_error_json_response(
- req: Request,
- exc: EngineDeadError | EngineGenerateError,
-) -> JSONResponse:
- request_id = _get_request_id_from_request(req)
- error_stage_id = getattr(exc, "error_stage_id", None)
- engine = req.app.state.engine_client
-
- if isinstance(exc, EngineDeadError):
- # Log Omni-specific diagnostic information for dead engines.
- orchestrator_alive = engine.engine.is_alive() if hasattr(engine, "engine") else "N/A"
- logger.error(
- "EngineDeadError: orchestrator_alive=%s, errored=%s, request_id=%s, error_stage_id=%s",
- orchestrator_alive,
- engine.errored,
- request_id,
- error_stage_id,
- )
-
- terminate_if_errored(
- server=req.app.state.server,
- engine=engine,
- )
-
- payload, status_code = _build_engine_error_payload(exc, request_id=request_id)
- return JSONResponse(content=payload, status_code=status_code)
-
-
class _DiffusionServingModels:
"""Minimal OpenAIServingModels implementation for diffusion-only servers.
@@ -388,10 +305,6 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None,
_remove_route_from_app(app, "/v1/models", {"GET"}) # Remove upstream /v1/models to use omni's handler
app.include_router(router)
- # OMNI: Override upstream exception handlers with Omni-aware versions
- # that understand the multi-stage orchestrator lifecycle.
- _register_omni_exception_handlers(app)
-
await omni_init_app_state(engine_client, app.state, args)
# Conditionally register profiler endpoints based on stage YAML configs
@@ -440,10 +353,6 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None,
try:
await shutdown_task
finally:
- state = getattr(app, "state", None)
- serving_speech = getattr(state, "openai_serving_speech", None) if state is not None else None
- if serving_speech is not None:
- serving_speech.shutdown()
sock.close()
@@ -607,7 +516,6 @@ async def omni_init_app_state(
stage_configs=diffusion_stage_configs,
)
state.openai_streaming_speech = None
- state.openai_streaming_video = None
state.enable_server_load_tracking = getattr(args, "enable_server_load_tracking", False)
state.server_load_metrics = 0
@@ -909,14 +817,6 @@ async def omni_init_app_state(
state.openai_streaming_speech = OmniStreamingSpeechHandler(
speech_service=state.openai_serving_speech,
)
- state.openai_streaming_video = (
- OmniStreamingVideoHandler(
- chat_service=state.openai_serving_chat,
- engine_client=engine_client,
- )
- if state.openai_serving_chat is not None
- else None
- )
state.openai_serving_realtime = OpenAIServingRealtime(
engine_client=engine_client,
models=state.openai_serving_models,
@@ -931,7 +831,6 @@ async def omni_init_app_state(
state.enable_server_load_tracking = args.enable_server_load_tracking
state.server_load_metrics = 0
- state.sleeping_stages = set()
def Omnivideo(request: Request) -> OmniOpenAIServingVideo | None:
@@ -971,8 +870,6 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
return base_server.create_error_response(message="The model does not support Chat Completions API")
try:
generator = await handler.create_chat_completion(request, raw_request)
- except (EngineGenerateError, EngineDeadError) as exc:
- return _create_engine_error_json_response(raw_request, exc)
except Exception as e:
logger.exception("Chat completion failed: %s", e)
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e
@@ -1068,8 +965,6 @@ async def create_speech(request: OpenAICreateSpeechRequest, raw_request: Request
status_code=result.error.code if result.error else 400,
)
return result
- except (EngineGenerateError, EngineDeadError) as exc:
- return _create_engine_error_json_response(raw_request, exc)
except Exception as e:
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e
@@ -1104,8 +999,6 @@ async def create_speech_batch(request: BatchSpeechRequest, raw_request: Request)
status_code=result.error.code if result.error else 400,
)
return JSONResponse(content=result.model_dump())
- except (EngineGenerateError, EngineDeadError) as exc:
- return _create_engine_error_json_response(raw_request, exc)
except ValueError as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e)) from e
except Exception as e:
@@ -1302,42 +1195,9 @@ async def streaming_speech(websocket: WebSocket):
await handler.handle_session(websocket)
-@router.websocket("/v1/video/chat/stream")
-async def streaming_video_chat(websocket: WebSocket):
- """WebSocket endpoint for streaming video input chat."""
- handler = getattr(websocket.app.state, "openai_streaming_video", None)
- if handler is None:
- await websocket.accept()
- await websocket.send_json(
- {
- "type": "error",
- "message": "Streaming video chat is not available",
- }
- )
- await websocket.close()
- return
- await handler.handle_session(websocket)
-
-
@router.websocket("/v1/realtime")
async def realtime_websocket(websocket: WebSocket):
"""WebSocket endpoint for OpenAI-style realtime interactions."""
- engine_client = getattr(websocket.app.state, "engine_client", None)
- if engine_client is not None and getattr(engine_client, "async_chunk", False):
- await websocket.accept()
- await websocket.send_json(
- {
- "type": "error",
- "error": (
- "The /v1/realtime API is not supported when async_chunk is enabled on the server. "
- "Use a stage configuration with async_chunk disabled and restart the server before using "
- "this endpoint."
- ),
- "code": "unsupported",
- }
- )
- await websocket.close()
- return
serving = getattr(websocket.app.state, "openai_serving_realtime", None)
if serving is None:
await websocket.accept()
@@ -1360,26 +1220,31 @@ async def realtime_websocket(websocket: WebSocket):
async def health(raw_request: Request) -> JSONResponse:
"""Health check endpoint that works for both LLM and diffusion modes.
- Returns 200 OK if the server is healthy, 503 if the engine is dead.
- Mirrors vLLM upstream's /health which catches EngineDeadError -> 503.
+ Returns 200 OK if the server is healthy.
+ For LLM mode: delegates to engine_client health check
+ For diffusion mode: checks if diffusion_engine is running
"""
- engine_client = getattr(raw_request.app.state, "engine_client", None) or getattr(
- raw_request.app.state, "diffusion_engine", None
- )
- if engine_client is None:
+ # Check if we're in diffusion mode
+ diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None)
+ if diffusion_engine is not None:
+ # Diffusion mode health check
+ if hasattr(diffusion_engine, "is_running") and diffusion_engine.is_running:
+ return JSONResponse(content={"status": "healthy"})
return JSONResponse(
- content={"status": "unhealthy", "reason": "No engine initialized"},
+ content={"status": "unhealthy", "reason": "Diffusion engine is not running"},
status_code=HTTPStatus.SERVICE_UNAVAILABLE.value,
)
- try:
+ # LLM mode - delegate to engine_client
+ engine_client = getattr(raw_request.app.state, "engine_client", None)
+ if engine_client is not None:
await engine_client.check_health()
return JSONResponse(content={"status": "healthy"})
- except EngineDeadError:
- return JSONResponse(
- content={"status": "unhealthy"},
- status_code=HTTPStatus.SERVICE_UNAVAILABLE.value,
- )
+
+ return JSONResponse(
+ content={"status": "unhealthy", "reason": "No engine initialized"},
+ status_code=HTTPStatus.SERVICE_UNAVAILABLE.value,
+ )
# Remove existing models endpoint if present (from vllm imports)
@@ -1414,8 +1279,7 @@ async def show_available_models(raw_request: Request) -> JSONResponse:
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
-@with_cancellation
-async def generate_images(request: ImageGenerationRequest, raw_request: Request):
+async def generate_images(request: ImageGenerationRequest, raw_request: Request) -> ImageGenerationResponse:
"""Generate images from text prompts using diffusion models.
OpenAI DALL-E compatible endpoint for text-to-image generation.
@@ -1434,76 +1298,20 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
# Get engine client (AsyncOmni) from app state
engine_client, model_name, stage_configs = _get_engine_and_model(raw_request)
+ # Validate model field (warn if mismatch, don't error)
if request.model is not None and request.model != model_name:
- raise HTTPException(
- status_code=HTTPStatus.BAD_REQUEST.value,
- detail=(f"Model mismatch: request specifies '{request.model}' but server is running '{model_name}'."),
+ logger.warning(
+ f"Model mismatch: request specifies '{request.model}' but "
+ f"server is running '{model_name}'. Using server model."
)
try:
- # Unify request construction for any multi-stage pipeline to avoid
- # divergence between /v1/images and /v1/chat/completions.
- if len(stage_configs) > 1:
- chat_handler = getattr(raw_request.app.state, "openai_serving_chat", None)
- if chat_handler is None:
- logger.warning("openai_serving_chat is not initialized for multi-stage /v1/images/generations")
- raise HTTPException(
- status_code=HTTPStatus.SERVICE_UNAVAILABLE.value,
- detail="openai_serving_chat is not initialized for multi-stage image generation.",
- )
-
- effective_seed = request.seed if request.seed is not None else random.randint(0, MAX_UINT32_SEED)
- extra_body: dict[str, Any] = {
- "seed": effective_seed,
- "num_outputs_per_prompt": request.n,
- }
- if request.size is not None:
- parse_size(request.size)
- width, height = parse_size(request.size)
- app_state_args = getattr(raw_request.app.state, "args", None)
- _check_max_generated_image_size(app_state_args, width, height)
- extra_body["size"] = request.size
- if request.negative_prompt is not None:
- extra_body["negative_prompt"] = request.negative_prompt
- if request.num_inference_steps is not None:
- extra_body["num_inference_steps"] = request.num_inference_steps
- if request.guidance_scale is not None:
- extra_body["guidance_scale"] = request.guidance_scale
- if request.true_cfg_scale is not None:
- extra_body["true_cfg_scale"] = request.true_cfg_scale
- if request.generator_device is not None:
- extra_body["generator_device"] = request.generator_device
- if request.lora is not None:
- # Keep /images validation semantics: invalid LoRA should fail with 400.
- _parse_lora_request(request.lora)
- extra_body["lora"] = request.lora
-
- generation_result = await chat_handler.generate_diffusion_images(
- prompt=request.prompt,
- extra_body=extra_body,
- request_id=f"img_gen-{random_uuid()}",
- )
- if isinstance(generation_result, ErrorResponse):
- return JSONResponse(
- status_code=generation_result.error.code if generation_result.error else 400,
- content=generation_result.model_dump(),
- )
- flat_images, _, _ = generation_result
- image_data = [ImageData(b64_json=encode_image_base64(img), revised_prompt=None) for img in flat_images]
- return ImageGenerationResponse(created=int(time.time()), data=image_data)
-
# Build params - pass through user values directly
prompt: OmniTextPrompt = {"prompt": request.prompt}
if request.negative_prompt is not None:
prompt["negative_prompt"] = request.negative_prompt
gen_params = OmniDiffusionSamplingParams(num_outputs_per_prompt=request.n)
- extra_args = {}
- if request.use_system_prompt is not None:
- extra_args["use_system_prompt"] = request.use_system_prompt
- if request.system_prompt is not None:
- extra_args["system_prompt"] = request.system_prompt
- if extra_args:
- gen_params.extra_args = extra_args
+
# Parse per-request LoRA (compatible with chat's extra_body.lora shape).
lora_request, lora_scale = _parse_lora_request(request.lora)
_update_if_not_none(gen_params, "lora_request", lora_request)
@@ -1516,20 +1324,6 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
size_str = f"{width}x{height}"
else:
size_str = "model default"
-
- # Keep AR stage target grid in sync with requested output size.
- # GLM-Image consumes target_h/target_w via mm_processor_kwargs.
- if width is not None and height is not None:
- prompt["mm_processor_kwargs"] = {
- "target_h": height,
- "target_w": width,
- }
- # Backward-compatible fallback for processors reading top-level fields.
- prompt["height"] = height
- prompt["width"] = width
- app_state_args = getattr(raw_request.app.state, "args", None)
- _check_max_generated_image_size(app_state_args, width, height)
-
_update_if_not_none(gen_params, "width", width)
_update_if_not_none(gen_params, "height", height)
@@ -1542,13 +1336,12 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
# This fixes issues where using the default global generator
# might produce blurry images in some environments.
_update_if_not_none(
- gen_params, "seed", request.seed if request.seed is not None else random.randint(0, MAX_UINT32_SEED)
+ gen_params, "seed", request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
)
_update_if_not_none(gen_params, "generator_device", request.generator_device)
_update_if_not_none(gen_params, "layers", request.layers)
request_id = f"img_gen-{random_uuid()}"
- raw_request.state.request_metadata = RequestResponseMetadata(request_id=request_id)
logger.info(f"Generating {request.n} image(s) {size_str}")
@@ -1572,27 +1365,16 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
logger.info(f"Successfully generated {len(images)} image(s)")
- # Determine output format (default to png)
- output_format = _choose_output_format(request.output_format or "png", None)
-
- # Encode images to base64 with the specified format
- image_data = [
- ImageData(b64_json=_encode_image_base64_with_compression(img, format=output_format), revised_prompt=None)
- for img in images
- ]
+ # Encode images to base64
+ image_data = [ImageData(b64_json=encode_image_base64(img), revised_prompt=None) for img in images]
- response_kwargs = {
- "created": int(time.time()),
- "data": image_data,
- "output_format": output_format,
- }
- if request.size:
- response_kwargs["size"] = size_str
- return ImageGenerationResponse(**response_kwargs)
+ return ImageGenerationResponse(
+ created=int(time.time()),
+ data=image_data,
+ )
- except (EngineGenerateError, EngineDeadError) as exc:
- return _create_engine_error_json_response(raw_request, exc)
except HTTPException:
+ # Re-raise HTTPExceptions as-is
raise
except ValueError as e:
logger.error(f"Validation error: {e}")
@@ -1628,14 +1410,10 @@ async def edit_images(
background: str | None = Form("auto"),
output_compression: Annotated[int, Form(ge=0, le=100)] = 100,
user: str | None = Form(None), # unused now
- # vllm-omni extensions for image editing
- mask_image: str | UploadFile | None = None,
- reference_image: str | UploadFile | None = None,
# vllm-omni extensions for diffusion control
negative_prompt: str | None = Form(None),
num_inference_steps: int | None = Form(None),
guidance_scale: float | None = Form(None),
- strength: float | None = Form(None),
true_cfg_scale: float | None = Form(None),
seed: int | None = Form(None),
generator_device: str | None = Form(None),
@@ -1651,9 +1429,8 @@ async def edit_images(
# 1. get engine and model
engine_client, model_name, stage_configs = _get_engine_and_model(raw_request)
if model is not None and model != model_name:
- raise HTTPException(
- status_code=HTTPStatus.BAD_REQUEST.value,
- detail=(f"Model mismatch: request specifies '{model}' but server is running '{model_name}'."),
+ logger.warning(
+ f"Model mismatch: request specifies '{model}' but server is running '{model_name}'. Using server model."
)
# 2. get output format & compression
output_format = _choose_output_format(output_format, background)
@@ -1676,35 +1453,15 @@ async def edit_images(
input_images_list.extend(urls)
if not input_images_list:
raise HTTPException(status_code=422, detail="Field 'image' or 'url' is required")
- # Reject oversized multi-image edit requests before fetching or decoding
- # any inputs. This keeps over-limit URL requests from burning network,
- # CPU, and memory on work that will be rejected anyway.
- max_input_images = _get_max_edit_input_images(raw_request, engine_client)
- if max_input_images is not None and len(input_images_list) > max_input_images:
- detail = (
- "Received multiple input images. Only a single image is supported by this model."
- if max_input_images == 1
- else (
- f"Received {len(input_images_list)} input images. "
- f"At most {max_input_images} images are supported by this model."
- )
- )
+ pil_images = await _load_input_images(input_images_list)
+ if len(pil_images) > 1 and not _supports_multimodal_image_inputs(raw_request, engine_client):
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
- detail=detail,
+ detail="Received multiple input images. Only a single image is supported by this model.",
)
- pil_images = await _load_input_images(input_images_list)
prompt["multi_modal_data"] = {}
prompt["multi_modal_data"]["image"] = pil_images
- if mask_image is not None:
- loaded = await _load_input_images([mask_image])
- prompt["multi_modal_data"]["mask_image"] = loaded[0]
-
- if reference_image is not None:
- loaded = await _load_input_images([reference_image])
- prompt["multi_modal_data"]["reference_image"] = loaded[0]
-
# 3 Build sample params
gen_params = OmniDiffusionSamplingParams()
# 3.0 Init with system default values
@@ -1752,6 +1509,7 @@ async def edit_images(
)
# 3.3 Parse and add size if provided
+ max_generated_image_size = getattr(app_state_args, "max_generated_image_size", None)
width, height = None, None
if size.lower() == "auto":
if resolution is None:
@@ -1761,117 +1519,54 @@ async def edit_images(
else:
width, height = parse_size(size)
- _check_max_generated_image_size(app_state_args, width, height, resolution)
+ # Check max_generated_image_size
+ if max_generated_image_size is not None:
+ if width is not None and height is not None:
+ if width * height > max_generated_image_size:
+ raise HTTPException(
+ status_code=HTTPStatus.BAD_REQUEST.value,
+ detail=f"Requested image size {width}x{height} exceeds the maximum allowed "
+ f"size of {max_generated_image_size} pixels.",
+ )
+ elif resolution is not None:
+ # When resolution is set, the output size is resolution * resolution
+ if resolution * resolution > max_generated_image_size:
+ raise HTTPException(
+ status_code=HTTPStatus.BAD_REQUEST.value,
+ detail=f"Requested resolution {resolution} (max {resolution}x{resolution} pixels) "
+ f"exceeds the maximum allowed size of {max_generated_image_size} pixels.",
+ )
size_str = f"{width}x{height}" if width is not None and height is not None else "auto"
-
- # Keep AR stage target grid in sync with requested output size.
- # GLM-Image consumes target_h/target_w via mm_processor_kwargs.
- if width is not None and height is not None:
- prompt["mm_processor_kwargs"] = {
- "target_h": height,
- "target_w": width,
- }
- # Backward-compatible fallback for processors reading top-level fields.
- prompt["height"] = height
- prompt["width"] = width
-
_update_if_not_none(gen_params, "width", width)
_update_if_not_none(gen_params, "height", height)
# 3.4 Add optional parameters ONLY if provided
_update_if_not_none(gen_params, "num_inference_steps", num_inference_steps)
_update_if_not_none(gen_params, "guidance_scale", guidance_scale)
- _update_if_not_none(gen_params, "strength", strength)
_update_if_not_none(gen_params, "true_cfg_scale", true_cfg_scale)
# If seed is not provided, generate a random one to ensure
# a proper generator is initialized in the backend.
# This fixes issues where using the default global generator
# might produce blurry images in some environments.
- _update_if_not_none(gen_params, "seed", seed if seed is not None else random.randint(0, MAX_UINT32_SEED))
+ _update_if_not_none(gen_params, "seed", seed if seed is not None else random.randint(0, 2**32 - 1))
_update_if_not_none(gen_params, "generator_device", generator_device)
_update_if_not_none(gen_params, "layers", layers)
_update_if_not_none(gen_params, "resolution", resolution)
- # 4. Generate images
+ # 4. Generate images using AsyncOmni (multi-stage mode)
request_id = f"img_edit-{random_uuid()}"
- raw_request.state.request_metadata = RequestResponseMetadata(request_id=request_id)
logger.info(f"Generating {n} image(s) {size_str}")
+ result = await _generate_with_async_omni(
+ engine_client=engine_client,
+ gen_params=gen_params,
+ stage_configs=stage_configs,
+ prompt=prompt,
+ request_id=request_id,
+ )
- if len(stage_configs) > 1:
- # Multi-stage pipeline (e.g. GLM-Image AR+Diffusion): route through
- # the chat handler so the AR stage gets correct max_tokens and
- # target_h/w (same path as /v1/images/generations).
- chat_handler = getattr(raw_request.app.state, "openai_serving_chat", None)
- if chat_handler is None:
- raise HTTPException(
- status_code=HTTPStatus.SERVICE_UNAVAILABLE.value,
- detail="openai_serving_chat is not initialized for multi-stage image editing.",
- )
-
- # Encode input images to base64 for generate_diffusion_images.
- import base64
- import io as _io
-
- ref_b64_list: list[str] = []
- for _img in pil_images:
- buf = _io.BytesIO()
- _img.save(buf, format="PNG")
- ref_b64_list.append(base64.b64encode(buf.getvalue()).decode())
-
- effective_seed = seed if seed is not None else random.randint(0, MAX_UINT32_SEED)
- extra_body: dict[str, Any] = {
- "seed": effective_seed,
- "num_outputs_per_prompt": n,
- }
- if width is not None:
- extra_body["width"] = width
- if height is not None:
- extra_body["height"] = height
- if negative_prompt is not None:
- extra_body["negative_prompt"] = negative_prompt
- if num_inference_steps is not None:
- extra_body["num_inference_steps"] = num_inference_steps
- if guidance_scale is not None:
- extra_body["guidance_scale"] = guidance_scale
- if strength is not None:
- extra_body["strength"] = strength
- if true_cfg_scale is not None:
- extra_body["true_cfg_scale"] = true_cfg_scale
- if layers is not None:
- extra_body["layers"] = layers
- if resolution is not None:
- extra_body["resolution"] = resolution
- if lora is not None:
- # Validate LoRA, then pass through.
- lora_dict = _get_lora_from_json_str(lora)
- _parse_lora_request(lora_dict)
- extra_body["lora"] = lora_dict
-
- prompt_text = prompt.get("prompt", "")
- generation_result = await chat_handler.generate_diffusion_images(
- prompt=prompt_text,
- extra_body=extra_body,
- reference_images=ref_b64_list,
- request_id=request_id,
- )
- if isinstance(generation_result, ErrorResponse):
- raise HTTPException(
- status_code=generation_result.error.code if generation_result.error else 400,
- detail=generation_result.message,
- )
- images, _, _ = generation_result
- else:
- # Single-stage diffusion: use the direct path.
- result = await _generate_with_async_omni(
- engine_client=engine_client,
- gen_params=gen_params,
- stage_configs=stage_configs,
- prompt=prompt,
- request_id=request_id,
- )
- images = _extract_images_from_result(result)
-
+ # 5. Extract images from result
+ images = _extract_images_from_result(result)
logger.info(f"Successfully generated {len(images)} image(s)")
# Encode images to base64
@@ -1892,9 +1587,8 @@ async def edit_images(
size=size_str,
)
- except (EngineGenerateError, EngineDeadError) as exc:
- return _create_engine_error_json_response(raw_request, exc)
except HTTPException:
+ # Re-raise HTTPExceptions as-is
raise
except ValueError as e:
logger.error(f"Validation error: {e}")
@@ -1945,39 +1639,18 @@ def _get_engine_and_model(raw_request: Request):
return engine_client, model_name, normalized_stage_configs
-def _get_diffusion_od_config(raw_request: Request, engine_client: Any) -> Any:
+def _supports_multimodal_image_inputs(raw_request: Request, engine_client: Any) -> bool:
diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None) or engine_client
get_diffusion_od_config = getattr(diffusion_engine, "get_diffusion_od_config", None)
- return (
+ od_config = (
get_diffusion_od_config() if callable(get_diffusion_od_config) else getattr(diffusion_engine, "od_config", None)
)
-
-def _get_max_edit_input_images(raw_request: Request, engine_client: Any) -> int | None:
- od_config = _get_diffusion_od_config(raw_request, engine_client)
if od_config is None:
# Preserve the existing compatibility behavior when the diffusion
# config is not exposed on the serving surface.
- return None
-
- supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", None)
- if not isinstance(supports_multimodal_inputs, bool):
- # Older serving surfaces and mocked engines may expose a placeholder
- # object instead of a real diffusion config. Treat that as "unknown"
- # so existing single-image flows keep working.
- return None
-
- if not supports_multimodal_inputs:
- return 1
-
- max_input_images = getattr(od_config, "max_multimodal_image_inputs", None)
- if max_input_images is None:
- return None
- if isinstance(max_input_images, bool) or not isinstance(max_input_images, Integral):
- return None
- if max_input_images < 1:
- return None
- return int(max_input_images)
+ return True
+ return bool(getattr(od_config, "supports_multimodal_inputs", False))
def _get_lora_from_json_str(lora_body):
@@ -2061,67 +1734,11 @@ async def _generate_with_async_omni(
return result
-def _check_max_generated_image_size(
- app_state_args: Any,
- width: int | None,
- height: int | None,
- resolution: int | None = None,
-) -> None:
- """Raise 400 if the requested image size exceeds --max-generated-image-size."""
- max_generated_image_size = getattr(app_state_args, "max_generated_image_size", None)
- # Check max_generated_image_size
- if max_generated_image_size is None:
- return
- if width is not None and height is not None:
- if width * height > max_generated_image_size:
- raise HTTPException(
- status_code=HTTPStatus.BAD_REQUEST.value,
- detail=f"Requested image size {width}x{height} exceeds the maximum allowed "
- f"size of {max_generated_image_size} pixels.",
- )
- elif resolution is not None:
- # When resolution is set, the output size is resolution * resolution
- if resolution * resolution > max_generated_image_size:
- raise HTTPException(
- status_code=HTTPStatus.BAD_REQUEST.value,
- detail=f"Requested resolution {resolution} (max {resolution}x{resolution} pixels) "
- f"exceeds the maximum allowed size of {max_generated_image_size} pixels.",
- )
-
-
def _update_if_not_none(object: Any, key: str, val: Any) -> None:
if val is not None:
setattr(object, key, val)
-def _normalize_image(image: Any) -> Any:
- """Normalize a single image output to a PIL-compatible format."""
- if isinstance(image, Image.Image):
- return image
- if not isinstance(image, np.ndarray):
- raise ValueError(f"Unsupported image type: {type(image)}")
- if not np.issubdtype(image.dtype, np.integer) and not np.issubdtype(image.dtype, np.floating):
- raise ValueError(f"Unsupported dtype: {image.dtype}")
- if isinstance(image, np.ndarray):
- while image.ndim > 3:
- image = image[0]
- if image.min() < 0:
- if image.min() < -1.01 or image.max() > 1.01:
- logger.warning(
- f"Image float range [{image.min():.2f}, {image.max():.2f}] outside expected [-1, 1]. "
- f"Clipping to [-1, 1] before normalization."
- )
- image = np.clip(image, -1.0, 1.0) * 0.5 + 0.5
- elif image.max() > 1.01:
- logger.warning(
- f"Image float range [{image.min():.2f}, {image.max():.2f}] outside expected [0, 1]. "
- f"Clipping to [0, 1] before normalization."
- )
- image = (np.clip(image, 0.0, 1.0) * 255).astype(np.uint8)
- image = Image.fromarray(image)
- return image
-
-
def _extract_images_from_result(result: Any) -> list[Any]:
images = []
if hasattr(result, "images") and result.images:
@@ -2132,10 +1749,6 @@ def _extract_images_from_result(result: Any) -> list[Any]:
images = request_output["images"]
elif hasattr(request_output, "images") and request_output.images:
images = request_output.images
- # Handle when generate more than one image
- if images and isinstance(images[0], np.ndarray) and images[0].shape[0] > 1 and images[0].ndim == 5:
- # Unwrap batch: (N, T, H, W, C) -> [img1, img2, ...]
- images = list(images[0])
# Flatten nested lists (e.g., from layered models like Qwen-Image-Layered).
# Note: This only flattens one level deep. Deeper nesting is not supported.
flattened = []
@@ -2144,7 +1757,7 @@ def _extract_images_from_result(result: Any) -> list[Any]:
flattened.extend(img)
else:
flattened.append(img)
- return [_normalize_image(img) for img in flattened]
+ return flattened
async def _load_input_images(
@@ -2314,46 +1927,16 @@ def video_response_from_request(model_name: str, req: VideoGenerationRequest) ->
return resp
-def _status_code_for_video_failure(error: VideoError | None) -> int:
- if error is None:
- return HTTPStatus.INTERNAL_SERVER_ERROR.value
-
- if isinstance(error.code, int):
- if 400 <= error.code < 600:
- return error.code
- return HTTPStatus.INTERNAL_SERVER_ERROR.value
-
- if error.code == "HTTPException":
- status_text, _, _ = error.message.partition(":")
- try:
- status_code = int(status_text)
- except ValueError:
- return HTTPStatus.INTERNAL_SERVER_ERROR.value
- if 400 <= status_code < 600:
- return status_code
- return HTTPStatus.INTERNAL_SERVER_ERROR.value
-
- if error.code == "EngineDeadError":
- return HTTPStatus.INTERNAL_SERVER_ERROR.value
- if error.code == "EngineGenerateError":
- return HTTPStatus.INTERNAL_SERVER_ERROR.value
+async def decode_and_save_video_output(output: Any, file_name: str) -> str:
+ if not output.b64_json:
+ raise RuntimeError(f"Video output for {file_name} did not include b64_json content.")
- return HTTPStatus.INTERNAL_SERVER_ERROR.value
-
-
-def _video_error_from_exception(exc: Exception) -> VideoError:
- if isinstance(exc, HTTPException):
- message = str(exc.detail) if exc.detail else str(exc)
- return VideoError(code=exc.status_code, message=message)
-
- if isinstance(exc, (EngineGenerateError, EngineDeadError)):
- err = create_error_response(exc)
- return VideoError(code=err.error.code, message=err.error.message)
+ try:
+ video_bytes = base64.b64decode(output.b64_json)
+ except Exception as decode_exc:
+ raise RuntimeError(f"Failed to decode generated video payload for {file_name}") from decode_exc
- return VideoError(
- code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
- message=str(exc),
- )
+ return await STORAGE_MANAGER.save(video_bytes, file_name)
def _cleanup_video(video_id: str, output_path: str | None):
@@ -2369,7 +1952,6 @@ async def _run_video_generation_job(
request: VideoGenerationRequest,
video_id: str,
reference_image: ReferenceImage | None = None,
- app_state: Any | None = None,
) -> None:
job = await VIDEO_STORE.get(video_id)
if job is None:
@@ -2380,12 +1962,15 @@ async def _run_video_generation_job(
started_at = time.perf_counter()
output_path = None
try:
- video_bytes, stage_durations, peak_memory_mb = await handler.generate_video_bytes(
- request, video_id, reference_image=reference_image
- )
+ response = await handler.generate_videos(request, video_id, reference_image=reference_image)
+ if not response.data:
+ raise RuntimeError("Video generation completed but returned no outputs.")
+
+ if (video_count := len(response.data)) > 1:
+ logger.warning("Video request %s generated %s outputs but we only expected one.", video_id, video_count)
file_name = f"{video_id}.{job.file_extension}"
- output_path = await STORAGE_MANAGER.save(video_bytes, file_name)
+ output_path = await decode_and_save_video_output(response.data[0], file_name)
logger.info("Video request %s persisted %s output file.", video_id, output_path)
await VIDEO_STORE.update_fields(
@@ -2396,40 +1981,19 @@ async def _run_video_generation_job(
"file_name": file_name,
"completed_at": int(time.time()),
"inference_time_s": time.perf_counter() - started_at,
- "stage_durations": stage_durations,
- "peak_memory_mb": peak_memory_mb,
- },
- )
- except (EngineGenerateError, EngineDeadError) as exc:
- logger.exception("Video generation failed (engine error) for id=%s", video_id)
-
- _cleanup_video(video_id, output_path)
- await VIDEO_STORE.update_fields(
- video_id,
- {
- "status": VideoGenerationStatus.FAILED,
- "completed_at": int(time.time()),
- "error": _video_error_from_exception(exc),
- "inference_time_s": time.perf_counter() - started_at,
},
)
- # Background tasks can't propagate exceptions to FastAPI handlers.
- # Actively signal shutdown when the engine is dead.
- if app_state is not None and isinstance(exc, EngineDeadError):
- terminate_if_errored(
- server=app_state.server,
- engine=app_state.engine_client,
- )
except Exception as exc:
logger.exception("Video generation failed for id=%s", video_id)
_cleanup_video(video_id, output_path)
+ # TODO: It would be better to have a finite collection of errors to return rather than the exception name
await VIDEO_STORE.update_fields(
video_id,
{
"status": VideoGenerationStatus.FAILED,
"completed_at": int(time.time()),
- "error": _video_error_from_exception(exc),
+ "error": VideoError(code=type(exc).__name__, message=str(exc)),
"inference_time_s": time.perf_counter() - started_at,
},
)
@@ -2463,10 +2027,6 @@ async def _parse_video_form(
true_cfg_scale: float | None = Form(default=None),
seed: int | None = Form(default=None),
negative_prompt: str | None = Form(default=None),
- enable_frame_interpolation: bool | None = Form(default=None),
- frame_interpolation_exp: int | None = Form(default=None, ge=1),
- frame_interpolation_scale: float | None = Form(default=None, gt=0.0),
- frame_interpolation_model_path: str | None = Form(default=None),
lora: str | None = Form(default=None),
extra_params: str | None = Form(default=None),
) -> tuple[VideoGenerationRequest, "OmniOpenAIServingVideo", str, ReferenceImage | None]:
@@ -2503,10 +2063,6 @@ async def _parse_video_form(
"true_cfg_scale": true_cfg_scale,
"seed": seed,
"negative_prompt": negative_prompt,
- "enable_frame_interpolation": enable_frame_interpolation,
- "frame_interpolation_exp": frame_interpolation_exp,
- "frame_interpolation_scale": frame_interpolation_scale,
- "frame_interpolation_model_path": frame_interpolation_model_path,
"lora": _parse_form_json(lora, expected_type=dict),
"extra_params": _parse_form_json(extra_params, expected_type=dict),
}
@@ -2524,12 +2080,10 @@ async def _parse_video_form(
app_model_name, app_stage_configs = _resolve_video_runtime_context(raw_request)
effective_model_name = handler.model_name or app_model_name or request.model or "unknown"
if request.model is not None and effective_model_name is not None and request.model != effective_model_name:
- raise HTTPException(
- status_code=HTTPStatus.BAD_REQUEST.value,
- detail=(
- f"Model mismatch: request specifies '{request.model}' but server is running "
- f"'{effective_model_name}'."
- ),
+ logger.warning(
+ "Model mismatch: request specifies '%s' but server is running '%s'. Using server model.",
+ request.model,
+ effective_model_name,
)
handler.set_stage_configs_if_missing(app_stage_configs)
except HTTPException:
@@ -2560,7 +2114,6 @@ async def _parse_video_form(
},
)
async def create_video(
- raw_request: Request,
ctx: tuple[VideoGenerationRequest, OmniOpenAIServingVideo, str, ReferenceImage | None] = Depends(_parse_video_form),
) -> VideoResponse:
"""Create an asynchronous video generation job.
@@ -2571,9 +2124,7 @@ async def create_video(
request, handler, effective_model_name, reference_image = ctx
ref = video_response_from_request(effective_model_name, request)
await VIDEO_STORE.upsert(ref.id, ref)
- task = asyncio.create_task(
- _run_video_generation_job(handler, request, ref.id, reference_image, app_state=raw_request.app.state)
- )
+ task = asyncio.create_task(_run_video_generation_job(handler, request, ref.id, reference_image))
await VIDEO_TASKS.upsert(ref.id, task)
return ref
@@ -2588,7 +2139,6 @@ async def create_video(
},
)
async def create_video_sync(
- raw_request: Request,
ctx: tuple[VideoGenerationRequest, OmniOpenAIServingVideo, str, ReferenceImage | None] = Depends(_parse_video_form),
) -> Response:
"""Synchronous video generation endpoint.
@@ -2602,10 +2152,9 @@ async def create_video_sync(
"""
request, handler, effective_model_name, reference_image = ctx
request_id = f"video_sync-{random_uuid()}"
- raw_request.state.request_metadata = RequestResponseMetadata(request_id=request_id)
started_at = time.perf_counter()
try:
- video_bytes, stage_durations, peak_memory_mb = await asyncio.wait_for(
+ video_bytes = await asyncio.wait_for(
handler.generate_video_bytes(request, request_id, reference_image=reference_image),
timeout=VIDEO_SYNC_TIMEOUT_S,
)
@@ -2614,8 +2163,6 @@ async def create_video_sync(
status_code=HTTPStatus.GATEWAY_TIMEOUT.value,
detail=f"Video generation timed out after {VIDEO_SYNC_TIMEOUT_S}s.",
)
- except (EngineGenerateError, EngineDeadError) as exc:
- return _create_engine_error_json_response(raw_request, exc)
except HTTPException:
raise
except Exception as exc:
@@ -2633,8 +2180,6 @@ async def create_video_sync(
"X-Request-Id": request_id,
"X-Model": effective_model_name,
"X-Inference-Time-S": f"{inference_time_s:.3f}",
- "X-Stage-Durations": json.dumps(stage_durations, separators=(",", ":")),
- "X-Peak-Memory-MB": f"{peak_memory_mb:.3f}",
},
)
@@ -2676,8 +2221,8 @@ async def list_videos(
return VideoListResponse(data=jobs, has_more=has_more, first_id=first_id, last_id=last_id)
-@router.get("/v1/videos/{video_id}", response_model=None)
-async def retrieve_video(video_id: str) -> VideoResponse | JSONResponse:
+@router.get("/v1/videos/{video_id}")
+async def retrieve_video(video_id: str) -> VideoResponse:
"""Retrieve metadata for a previously created video job.
Args:
@@ -2692,15 +2237,6 @@ async def retrieve_video(video_id: str) -> VideoResponse | JSONResponse:
job = await VIDEO_STORE.get(video_id)
if job is None:
raise HTTPException(status_code=404, detail="Video not found")
- if job.status == VideoGenerationStatus.FAILED:
- status_code = _status_code_for_video_failure(job.error)
- content = job.model_dump(mode="json")
- if content.get("error") is not None:
- content["error"]["code"] = status_code
- return JSONResponse(
- content=content,
- status_code=status_code,
- )
return job
@@ -2837,67 +2373,3 @@ async def stop_profile(raw_request: Request, request: ProfileRequest | None = No
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=f"Failed to stop profiler: {str(e)}"
)
-
-
-class OmniSleepRequest(BaseModel):
- stage_ids: list[int]
- level: int = 2
-
-
-class OmniWakeupRequest(BaseModel):
- stage_ids: list[int]
-
-
-@router.post("/v1/omni/sleep")
-async def omni_sleep(request: OmniSleepRequest, raw_request: Request):
- engine_client = raw_request.app.state.engine_client
- sleeping_set = raw_request.app.state.sleeping_stages
- if not hasattr(engine_client, "sleep"):
- raise HTTPException(status_code=501, detail="Engine does not support sleep")
- acks = await engine_client.sleep(stage_ids=request.stage_ids, level=request.level)
- for sid in request.stage_ids:
- sleeping_set.add(sid)
- return {"status": "SUCCESS", "acks": [dataclasses.asdict(a) if dataclasses.is_dataclass(a) else a for a in acks]}
-
-
-@router.post("/v1/omni/wakeup")
-async def omni_wakeup(request: OmniWakeupRequest, raw_request: Request):
- engine_client = raw_request.app.state.engine_client
- sleeping_set = raw_request.app.state.sleeping_stages
- if not any(sid in sleeping_set for sid in request.stage_ids):
- return {"status": "SKIPPED", "reason": "Target stages are not sleeping."}
- if not hasattr(engine_client, "wake_up"):
- raise HTTPException(status_code=501, detail="Engine does not support wake_up")
- acks = await engine_client.wake_up(stage_ids=request.stage_ids)
- for sid in request.stage_ids:
- if sid in sleeping_set:
- sleeping_set.remove(sid)
- return {"status": "SUCCESS", "acks": [dataclasses.asdict(a) if dataclasses.is_dataclass(a) else a for a in acks]}
-
-
-if __name__ == "__main__":
- import argparse
- import asyncio
-
- from vllm.entrypoints.openai.cli_args import make_arg_parser
-
- from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults
-
- parser = argparse.ArgumentParser(description="vLLM-Omni OpenAI-Compatible REST API server")
- parser = make_arg_parser(parser)
- registered_flags = set()
- for action in parser._actions:
- registered_flags.update(action.option_strings)
- if "--omni" not in registered_flags:
- parser.add_argument("--omni", action="store_true", default=False, help="Enable vLLM-Omni mode.")
- if "--enable-sleep-mode" not in registered_flags:
- parser.add_argument(
- "--enable-sleep-mode", action="store_true", default=False, help="Enable GPU memory pool for sleep mode."
- )
- nullify_stage_engine_defaults(parser)
- args = parser.parse_args()
- if not hasattr(args, "model_tag"):
- setattr(args, "model_tag", args.model)
- if hasattr(args, "model_tag") and args.model_tag is None:
- args.model_tag = args.model
- asyncio.run(omni_run_server(args))
diff --git a/vllm_omni/entrypoints/openai/audio_utils_mixin.py b/vllm_omni/entrypoints/openai/audio_utils_mixin.py
index b626f7eeb20..13df32ebe00 100644
--- a/vllm_omni/entrypoints/openai/audio_utils_mixin.py
+++ b/vllm_omni/entrypoints/openai/audio_utils_mixin.py
@@ -1,8 +1,6 @@
from io import BytesIO
import numpy as np
-import torch
-import torchaudio
from vllm.logger import init_logger
from vllm_omni.entrypoints.openai.protocol.audio import AudioResponse, CreateAudio
@@ -12,6 +10,11 @@
except ImportError:
soundfile = None
+try:
+ import librosa
+except ImportError:
+ librosa = None
+
logger = init_logger(__name__)
@@ -71,53 +74,20 @@ def create_audio(self, audio_obj: CreateAudio) -> AudioResponse:
return AudioResponse(audio_data=audio_data, media_type=media_type)
def _apply_speed_adjustment(self, audio_tensor: np.ndarray, speed: float, sample_rate: int):
- """Apply speed adjustment to the audio tensor while preserving pitch.
-
- Uses torchaudio's phase vocoder (Spectrogram → TimeStretch →
- InverseSpectrogram) to stretch/compress audio in time without
- changing pitch.
- """
+ """Apply speed adjustment to the audio tensor while preserving pitch."""
if speed == 1.0:
return audio_tensor, sample_rate
+ if librosa is None:
+ raise ImportError("librosa is required for speed adjustment. Please install it with: pip install librosa")
+
try:
+ # librosa.effects.time_stretch requires a float audio tensor.
if not np.issubdtype(audio_tensor.dtype, np.floating):
audio_tensor = audio_tensor.astype(np.float32)
- # Stereo numpy arrays use channels-last (T, C);
- # torch expects channels-first (C, T).
- channels_last = audio_tensor.ndim == 2
- if channels_last:
- waveform = torch.from_numpy(audio_tensor.T)
- else:
- waveform = torch.from_numpy(audio_tensor).unsqueeze(0)
-
- # Match librosa.stft defaults: n_fft=2048, hop_length=n_fft//4
- n_fft = 2048
- hop_length = n_fft // 4
- to_spec = torchaudio.transforms.Spectrogram(
- n_fft=n_fft,
- hop_length=hop_length,
- power=None,
- )
- stretch = torchaudio.transforms.TimeStretch(
- n_freq=n_fft // 2 + 1,
- hop_length=hop_length,
- )
- to_wave = torchaudio.transforms.InverseSpectrogram(
- n_fft=n_fft,
- hop_length=hop_length,
- )
-
- spec = to_spec(waveform)
- stretched = stretch(spec, speed)
- expected_length = int(audio_tensor.shape[0] / speed)
- result = to_wave(stretched, length=expected_length)
-
- result = result.squeeze(0).numpy()
- if channels_last:
- result = result.T
- return result, sample_rate
+ stretched_audio = librosa.effects.time_stretch(y=audio_tensor, rate=speed)
+ return stretched_audio, sample_rate
except Exception as e:
logger.error(f"An error occurred during speed adjustment: {e}")
raise ValueError("Failed to apply speed adjustment.") from e
diff --git a/vllm_omni/entrypoints/openai/protocol/audio.py b/vllm_omni/entrypoints/openai/protocol/audio.py
index 59b5777a874..89d2dc02f65 100644
--- a/vllm_omni/entrypoints/openai/protocol/audio.py
+++ b/vllm_omni/entrypoints/openai/protocol/audio.py
@@ -1,8 +1,8 @@
import math
-from typing import Any, Literal
+from typing import Literal
import numpy as np
-from pydantic import AliasChoices, BaseModel, Field, field_validator, model_validator
+from pydantic import BaseModel, Field, field_validator, model_validator
_MAX_EMBEDDING_DIM = 8192
@@ -10,12 +10,8 @@
class OpenAICreateSpeechRequest(BaseModel):
input: str
model: str | None = None
- # Accept both "voice" (OpenAI convention) and "speaker" (model/internal
- # convention) as input keys. Intentionally global — all TTS backends
- # (Qwen3-TTS, Voxtral, Fish Speech) use this field for the speaker name.
voice: str | None = Field(
default=None,
- validation_alias=AliasChoices("voice", "speaker"),
description="Speaker/voice to use. For Qwen3-TTS: vivian, ryan, aiden, etc.",
)
instructions: str | None = Field(
@@ -74,10 +70,6 @@ class OpenAICreateSpeechRequest(BaseModel):
ge=0,
description="Per-request initial chunk size override. If null, computed dynamically based on server load.",
)
- extra_params: dict[str, Any] | None = Field(
- default=None,
- description=("Optional model-specific parameters passed directly to the model's extra_args."),
- )
@field_validator("stream_format")
@classmethod
diff --git a/vllm_omni/entrypoints/openai/protocol/images.py b/vllm_omni/entrypoints/openai/protocol/images.py
index 0fb22a548cf..5f76bbd6b8e 100644
--- a/vllm_omni/entrypoints/openai/protocol/images.py
+++ b/vllm_omni/entrypoints/openai/protocol/images.py
@@ -81,24 +81,6 @@ def validate_layers(cls, v):
# vllm-omni extensions for diffusion control
negative_prompt: str | None = Field(default=None, description="Text describing what to avoid in the image")
- system_prompt: str | None = Field(
- default=None, description="Custom system prompt. Used when --use_system_prompt is custom"
- )
- use_system_prompt: str | None = Field(
- default=None,
- description="System prompt type. Options: None, dynamic, en_vanilla, "
- "en_recaption, en_think_recaption, en_unified, custom",
- )
-
- @field_validator("use_system_prompt")
- @classmethod
- def validate_use_system_prompt(cls, v):
- """Validate system prompt type."""
- valid_types = [None, "dynamic", "en_vanilla", "en_recaption", "en_think_recaption", "en_unified", "custom"]
- if v not in valid_types:
- raise ValueError(f"Invalid use_system_prompt type: {v}. Must be one of: {valid_types[1:] + [None]}")
- return v
-
num_inference_steps: int | None = Field(
default=None,
ge=1,
@@ -139,12 +121,6 @@ def validate_use_system_prompt(cls, v):
vae_use_slicing: bool | None = Field(default=False, description="Enable VAE slicing")
vae_use_tiling: bool | None = Field(default=False, description="Enable VAE tiling")
- # Output format for generated images
- output_format: str | None = Field(
- default=None,
- description="Output image format: 'png', 'jpeg', or 'webp'. Defaults to 'png'.",
- )
-
class ImageData(BaseModel):
"""Single generated image data"""
diff --git a/vllm_omni/entrypoints/openai/protocol/videos.py b/vllm_omni/entrypoints/openai/protocol/videos.py
index d46c8d43d6b..e180bef2292 100644
--- a/vllm_omni/entrypoints/openai/protocol/videos.py
+++ b/vllm_omni/entrypoints/openai/protocol/videos.py
@@ -150,29 +150,6 @@ class VideoGenerationRequest(BaseModel):
)
seed: int | None = Field(default=None, description="Random seed for reproducibility")
- # vllm-omni extensions for post-generation frame interpolation.
- enable_frame_interpolation: bool = Field(
- default=False,
- description="Enable post-generation RIFE frame interpolation before MP4 encoding.",
- )
- frame_interpolation_exp: int = Field(
- default=1,
- ge=1,
- description="Interpolation exponent: 1=2x temporal resolution, 2=4x, etc.",
- )
- frame_interpolation_scale: float = Field(
- default=1.0,
- gt=0.0,
- description="RIFE inference scale. Use 0.5 for high-resolution inputs to save memory.",
- )
- frame_interpolation_model_path: str | None = Field(
- default=None,
- description=(
- "Local directory or Hugging Face repo ID containing RIFE flownet.pkl weights. "
- "Defaults to elfgum/RIFE-4.22.lite."
- ),
- )
-
# vllm-omni extension for per-request LoRA.
lora: dict[str, Any] | None = Field(
default=None,
@@ -224,18 +201,10 @@ class VideoGenerationResponse(BaseModel):
created: int = Field(..., description="Unix timestamp of when the generation completed")
data: list[VideoData] = Field(..., description="Array of generated videos")
- stage_durations: dict[str, float] = Field(
- default_factory=dict,
- description="Profiler stage durations reported by the diffusion pipeline.",
- )
- peak_memory_mb: float = Field(
- default=0.0,
- description="Peak device memory usage in MB reported by the diffusion pipeline.",
- )
class VideoError(BaseModel):
- code: int | str = Field(..., description="A machine-readable error code that was returned.")
+ code: str = Field(..., description="A machine-readable error code that was returned.")
message: str = Field(..., description="A human-readable description of the error that was returned.")
@@ -281,14 +250,6 @@ class VideoResponse(BaseModel):
description="Filename of the saved output video files for this job.",
)
inference_time_s: float | None = Field(default=None, description="End-to-end inference time in seconds.")
- stage_durations: dict[str, float] = Field(
- default_factory=dict,
- description="Profiler stage durations reported by the diffusion pipeline.",
- )
- peak_memory_mb: float = Field(
- default=0.0,
- description="Peak device memory usage in MB reported by the diffusion pipeline.",
- )
@property
def file_extension(self) -> str:
diff --git a/vllm_omni/entrypoints/openai/realtime_connection.py b/vllm_omni/entrypoints/openai/realtime_connection.py
deleted file mode 100644
index 9fc2a1ee3a0..00000000000
--- a/vllm_omni/entrypoints/openai/realtime_connection.py
+++ /dev/null
@@ -1,203 +0,0 @@
-from __future__ import annotations
-
-import asyncio
-import base64
-import json
-from collections.abc import AsyncGenerator
-from typing import cast
-from uuid import uuid4
-
-import numpy as np
-from vllm.entrypoints.openai.engine.protocol import UsageInfo
-from vllm.entrypoints.openai.realtime.connection import RealtimeConnection as VllmRealtimeConnection
-from vllm.entrypoints.openai.realtime.protocol import TranscriptionDelta, TranscriptionDone
-from vllm.logger import init_logger
-
-from vllm_omni.entrypoints.async_omni import AsyncOmni
-from vllm_omni.entrypoints.utils import coerce_param_message_types
-
-logger = init_logger(__name__)
-
-
-class RealtimeConnection(VllmRealtimeConnection):
- """Omni realtime connection with audio-only server events.
-
- Reuses upstream vLLM websocket/session lifecycle and only customizes
- generation output handling to emit audio deltas.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.engine = cast(AsyncOmni, self.serving.engine_client)
- self._realtime_audio_ref: np.ndarray | None = None
-
- async def start_generation(self):
- await super().start_generation()
-
- @staticmethod
- def _tensor_to_numpy(value) -> np.ndarray | None:
- if value is None:
- return None
- if isinstance(value, np.ndarray):
- arr = value
- elif hasattr(value, "detach"):
- arr = value.detach().float().cpu().numpy()
- else:
- try:
- arr = np.asarray(value)
- except Exception:
- return None
- if arr.ndim > 1:
- arr = arr.reshape(-1)
- return arr.astype(np.float32, copy=False)
-
- @staticmethod
- def _numpy_audio_prefix_match(prev: np.ndarray, curr: np.ndarray) -> bool:
- n = prev.shape[0]
- if n == 0:
- return True
- if curr.shape[0] < n:
- return False
- return bool(np.allclose(curr[:n], prev, rtol=1e-3, atol=2e-4))
-
- def _raw_waveform_to_deltas(self, arr: np.ndarray) -> list[np.ndarray]:
- """Convert one streaming PCM f32 chunk into incremental piece(s) for the client.
-
- Some engine paths emit a growing cumulative waveform each step; others emit
- true per-step deltas. We support both without duplicating audio on the client.
- """
- if arr.size == 0:
- return []
- ref = self._realtime_audio_ref
- if ref is None:
- self._realtime_audio_ref = arr.copy()
- return [arr]
- if self._numpy_audio_prefix_match(ref, arr):
- delta = arr[ref.shape[0] :]
- self._realtime_audio_ref = arr.copy()
- return [delta] if delta.size > 0 else []
- # True per-step delta (not a prefix extension of what we have seen).
- self._realtime_audio_ref = np.concatenate([ref, arr])
- return [arr]
-
- def _extract_audio_chunks(self, output) -> tuple[list[np.ndarray], int]:
- mm = getattr(output, "multimodal_output", None)
- if not isinstance(mm, dict):
- return [], 24000
-
- sr = mm.get("sr") or mm.get("sample_rate") or mm.get("audio_sample_rate") or 24000
- key = "audio" if "audio" in mm else ("model_outputs" if "model_outputs" in mm else None)
- if key is None:
- return [], int(sr)
-
- raw_audio = mm.get(key)
- chunks: list[np.ndarray] = []
- if isinstance(raw_audio, (list, tuple)):
- if len(raw_audio) > 0:
- arr = self._tensor_to_numpy(raw_audio[-1])
- if arr is not None and arr.size > 0:
- chunks.extend(self._raw_waveform_to_deltas(arr))
- else:
- arr = self._tensor_to_numpy(raw_audio)
- if arr is not None and arr.size > 0:
- chunks.extend(self._raw_waveform_to_deltas(arr))
- return chunks, int(sr)
-
- @staticmethod
- def _pcm16_b64(audio_f32: np.ndarray) -> str:
- clipped = np.clip(audio_f32, -1.0, 1.0)
- pcm16 = (clipped * 32767.0).astype(np.int16)
- return base64.b64encode(pcm16.tobytes()).decode("utf-8")
-
- async def _run_generation(
- self,
- streaming_input_gen: AsyncGenerator,
- input_stream: asyncio.Queue[list[int]],
- ):
- request_id = f"rt-{self.connection_id}-{uuid4()}"
- sent_audio = False
- audio_done_sent = False
- full_text = ""
- prompt_token_ids_len = 0
- completion_tokens_len = 0
- self._realtime_audio_ref = None
-
- # Coerce cumulative outputs to delta outputs; this ensures
- # we don't emit redundant MM data & drain after emitting.
- sampling_params_list = list(self.engine.default_sampling_params_list)
- sampling_params_list = coerce_param_message_types(
- sampling_params_list,
- is_streaming=True,
- )
-
- try:
- result_gen = self.engine.generate(
- prompt=streaming_input_gen,
- request_id=request_id,
- sampling_params_list=sampling_params_list,
- )
-
- async for output in result_gen:
- # Handle delta texts; this is very similar to the client from vLLM
- if output.outputs and len(output.outputs) > 0:
- first_output = output.outputs[0]
- new_token_ids = list(first_output.token_ids)
- new_tokens_len = len(new_token_ids)
-
- if not prompt_token_ids_len and output.prompt_token_ids:
- prompt_token_ids_len = len(output.prompt_token_ids)
-
- if new_tokens_len:
- input_stream.put_nowait(new_token_ids)
-
- delta_text = first_output.text or ""
- full_text += delta_text
-
- # append output to input if there was any delta text
- if delta_text:
- await self.send(TranscriptionDelta(delta=delta_text))
-
- completion_tokens_len += new_tokens_len
-
- # Handle audio chunking; this is Omni specific
- audio_chunks, sample_rate = self._extract_audio_chunks(output)
-
- for chunk in audio_chunks:
- sent_audio = True
- await self.send_json(
- {
- "type": "response.audio.delta",
- "audio": self._pcm16_b64(chunk),
- "format": "pcm16",
- "sample_rate_hz": sample_rate,
- }
- )
-
- if not self._is_connected:
- break
-
- usage = UsageInfo(
- prompt_tokens=prompt_token_ids_len,
- completion_tokens=completion_tokens_len,
- total_tokens=prompt_token_ids_len + completion_tokens_len,
- )
- await self.send(TranscriptionDone(text=full_text, usage=usage))
-
- if sent_audio:
- await self.send_json({"type": "response.audio.done", "has_audio": True})
- audio_done_sent = True
- except Exception as e:
- logger.exception("Error in generation: %s", e)
- await self.send_error(str(e), "processing_error")
- finally:
- # Always send terminal event so clients don't hang forever.
- if self._is_connected and not audio_done_sent:
- try:
- await self.send_json({"type": "response.audio.done", "has_audio": sent_audio})
- except Exception:
- logger.exception("Failed to send response.audio.done")
- while not self.audio_queue.empty():
- self.audio_queue.get_nowait()
-
- async def send_json(self, payload: dict):
- await self.websocket.send_text(json.dumps(payload))
diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py
index 42d421db87b..e84a49aac2e 100644
--- a/vllm_omni/entrypoints/openai/serving_chat.py
+++ b/vllm_omni/entrypoints/openai/serving_chat.py
@@ -16,7 +16,6 @@
from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.openai.protocol.chat_completion import OmniChatCompletionResponse
-from vllm_omni.entrypoints.utils import coerce_param_message_types
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt
try:
@@ -33,7 +32,6 @@
get_history_tool_calls_cnt,
make_tool_call_id,
)
-from vllm.entrypoints.launcher import terminate_if_errored
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
@@ -82,18 +80,12 @@
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.utils.collection_utils import as_list
-from vllm.v1.engine.exceptions import EngineDeadError
from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin
from vllm_omni.entrypoints.openai.image_api_utils import validate_layered_layers
from vllm_omni.entrypoints.openai.protocol import OmniChatCompletionStreamResponse
from vllm_omni.entrypoints.openai.protocol.audio import AudioResponse, CreateAudio
-from vllm_omni.entrypoints.openai.utils import (
- get_stage_type,
- get_supported_speakers_from_hf_config,
- parse_lora_request,
- validate_requested_speaker,
-)
+from vllm_omni.entrypoints.openai.utils import parse_lora_request
from vllm_omni.lora.request import LoRARequest
from vllm_omni.outputs import OmniRequestOutput
@@ -114,7 +106,6 @@ class OmniOpenAIServingChat(OpenAIServingChat, AudioMixin):
_diffusion_mode: bool = False
_diffusion_engine: AsyncOmni | None = None
_diffusion_model_name: str = ""
- _supported_speakers: set[str] | None = None
@classmethod
def for_diffusion(
@@ -141,18 +132,6 @@ def for_diffusion(
instance._diffusion_model_name = model_name
return instance
- def _get_supported_speakers(self) -> set[str]:
- """Load supported speakers from model config (cached)."""
- if self._supported_speakers is not None:
- return self._supported_speakers
- try:
- self._supported_speakers = get_supported_speakers_from_hf_config(self.model_config.hf_config)
- return self._supported_speakers
- except Exception as e:
- logger.warning("Could not load speakers from model config: %s", e)
- self._supported_speakers = set()
- return self._supported_speakers
-
async def create_chat_completion(
self,
request: ChatCompletionRequest,
@@ -281,10 +260,7 @@ async def create_chat_completion(
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
- message = str(e)
- if e.__cause__ is not None:
- message = f"{message} {e.__cause__}"
- return self.create_error_response(message)
+ return self.create_error_response(f"{e} {e.__cause__}")
request_id = f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}"
@@ -298,8 +274,6 @@ async def create_chat_completion(
)
num_inference_steps = None
- cfg_text_scale = None
- cfg_img_scale = None
# Omni multistage image generation: Stage-0 (AR) should receive a clean
# text prompt (and optional conditioning image/size) so the model's own
# processor can construct the correct inputs.
@@ -307,9 +281,20 @@ async def create_chat_completion(
# effectively unconditioned and produce nonsense images.
if request.modalities and ("image" in request.modalities):
try:
- extracted_prompt, reference_images = self._extract_diffusion_prompt_and_images_from_messages(
- request.messages
- )
+ messages_as_dicts: list[dict[str, Any]] = []
+ for msg in request.messages:
+ if hasattr(msg, "model_dump"):
+ messages_as_dicts.append(msg.model_dump())
+ elif isinstance(msg, dict):
+ messages_as_dicts.append(msg)
+ else:
+ messages_as_dicts.append(
+ {
+ "role": getattr(msg, "role", "user"),
+ "content": getattr(msg, "content", ""),
+ }
+ )
+ extracted_prompt, reference_images = self._extract_diffusion_prompt_and_images(messages_as_dicts)
if not extracted_prompt:
return self.create_error_response("No text prompt found in messages")
@@ -320,33 +305,39 @@ async def create_chat_completion(
extra_body = getattr(request, "extra_body", None)
if not extra_body:
extra_body = request.model_extra or {}
-
- height, width = self._resolve_height_width_from_extra_body(extra_body)
-
+ height = extra_body.get("height")
+ width = extra_body.get("width")
num_inference_steps = extra_body.get("num_inference_steps")
if num_inference_steps is not None:
try:
num_inference_steps = int(num_inference_steps)
except Exception:
num_inference_steps = None
-
+ if "size" in extra_body:
+ try:
+ size_str = extra_body["size"]
+ if isinstance(size_str, str) and "x" in size_str.lower():
+ w, h = size_str.lower().split("x")
+ width, height = int(w), int(h)
+ except Exception:
+ pass
negative_prompt = extra_body.get("negative_prompt")
- cfg_text_scale = extra_body.get("cfg_text_scale")
- cfg_img_scale = extra_body.get("cfg_img_scale")
engine_prompt_image: dict[str, Any] | None = None
+ is_img2img = False
if reference_images:
# Best-effort decode first reference image for i2i.
try:
img_bytes = base64.b64decode(reference_images[0])
img = Image.open(BytesIO(img_bytes))
engine_prompt_image = {"img2img": img}
+ is_img2img = True
except Exception:
engine_prompt_image = None
# Override the prompts produced by chat-template preprocessing.
tprompt: OmniTextPrompt = {"prompt": extracted_prompt}
- if engine_prompt_image:
+ if is_img2img:
tprompt["modalities"] = ["img2img"]
else:
tprompt["modalities"] = ["image"]
@@ -362,13 +353,6 @@ async def create_chat_completion(
tprompt["mm_processor_kwargs"] = mm_processor_kwargs
if engine_prompt_image is not None:
tprompt["multi_modal_data"] = engine_prompt_image
- # Provide multi_modal_uuids so that newer vLLM versions
- # can validate multi_modal_data / multi_modal_uuids
- # consistency. After the multimodal processor consumes
- # the image data, the uuids remain as a stable reference.
- tprompt["multi_modal_uuids"] = {
- k: [f"{request_id}-{k}-{i}"] for i, k in enumerate(engine_prompt_image)
- }
engine_prompts = [tprompt]
# Store height/width for applying to diffusion stage sampling params later
@@ -392,24 +376,15 @@ async def create_chat_completion(
# Use standard OpenAI API parameters for comprehension stage
sampling_params_list = self._build_sampling_params_list_from_request(request)
- # If this is a streaming (output) request, coerce cumulative outputs
- # to delta to ensure emitted outputs are correctly drained. Otherwise
- # convert cumulative to Final Only to ensure the output is correct.
- sampling_params_list = coerce_param_message_types(sampling_params_list, request.stream)
-
# Apply user-specified overrides to diffusion stage(s) for image generation
- for idx, sp in enumerate(sampling_params_list):
- if hasattr(sp, "height") and _image_gen_height is not None:
- sp.height = _image_gen_height
- if hasattr(sp, "width") and _image_gen_width is not None:
- sp.width = _image_gen_width
- if hasattr(sp, "num_inference_steps") and num_inference_steps is not None:
- sp.num_inference_steps = num_inference_steps
- if hasattr(sp, "extra_args") and sp.extra_args is not None:
- if cfg_text_scale is not None:
- sp.extra_args["cfg_text_scale"] = cfg_text_scale
- if cfg_img_scale is not None:
- sp.extra_args["cfg_img_scale"] = cfg_img_scale
+ if _image_gen_height is not None or _image_gen_width is not None or num_inference_steps is not None:
+ for idx, sp in enumerate(sampling_params_list):
+ if hasattr(sp, "height") and _image_gen_height is not None:
+ sp.height = _image_gen_height
+ if hasattr(sp, "width") and _image_gen_width is not None:
+ sp.width = _image_gen_width
+ if hasattr(sp, "num_inference_steps") and num_inference_steps is not None:
+ sp.num_inference_steps = num_inference_steps
self._log_inputs(
request_id,
@@ -443,7 +418,6 @@ async def create_chat_completion(
tokenizer,
request_metadata,
reasoning_parser,
- raw_request=raw_request,
)
try:
@@ -541,7 +515,20 @@ async def _preprocess_chat(
# containing image tokens.
req_modalities = getattr(request, "modalities", [])
if req_modalities and ("image" in req_modalities):
- extracted_prompt, _ = self._extract_diffusion_prompt_and_images_from_messages(messages)
+ messages_as_dicts: list[dict[str, Any]] = []
+ for msg in messages:
+ if hasattr(msg, "model_dump"):
+ messages_as_dicts.append(msg.model_dump())
+ elif isinstance(msg, dict):
+ messages_as_dicts.append(msg)
+ else:
+ messages_as_dicts.append(
+ {
+ "role": getattr(msg, "role", "user"),
+ "content": getattr(msg, "content", ""),
+ }
+ )
+ extracted_prompt, _ = self._extract_diffusion_prompt_and_images(messages_as_dicts)
if extracted_prompt:
engine_prompt["prompt"] = extracted_prompt
@@ -553,11 +540,10 @@ async def _preprocess_chat(
engine_prompt["cache_salt"] = request.cache_salt
speaker = getattr(request, "speaker", None)
- normalized = validate_requested_speaker(speaker, self._get_supported_speakers())
- if normalized is not None:
+ if speaker is not None and isinstance(speaker, str) and speaker.strip():
if "additional_information" not in engine_prompt or engine_prompt["additional_information"] is None:
engine_prompt["additional_information"] = {}
- engine_prompt["additional_information"]["speaker"] = [normalized]
+ engine_prompt["additional_information"]["speaker"] = [speaker.lower().strip()]
language = getattr(request, "language", None)
if language is not None and isinstance(language, str) and language.strip():
@@ -565,16 +551,6 @@ async def _preprocess_chat(
engine_prompt["additional_information"] = {}
engine_prompt["additional_information"]["language"] = [language.strip()]
- # Style instruction — used by Ming-flash-omni instruct TTS path
- # (ming_task="instruct"). For the omni speech path the thinker2talker
- # bridge drops this field to match upstream omni_audio_generation
- # which hardcodes instruction=None.
- instructions = getattr(request, "instructions", None)
- if instructions is not None and isinstance(instructions, str) and instructions.strip():
- if "additional_information" not in engine_prompt or engine_prompt["additional_information"] is None:
- engine_prompt["additional_information"] = {}
- engine_prompt["additional_information"]["instruction"] = instructions.strip()
-
return conversation, [engine_prompt]
async def _inject_audio_from_video_urls(
@@ -711,10 +687,6 @@ def _apply_request_overrides(
Starts with YAML defaults and only overrides fields that the user
explicitly provided (non-None values) in the request.
- For models needing spatial metadata (e.g. GLM-Image), target_h/w is
- injected into extra_args so the runner can build M-RoPE position grids.
- max_tokens is NOT computed dynamically — it uses the deploy YAML default.
-
Args:
default_params: Default SamplingParams from stage config YAML.
request: The chat completion request containing user-provided values.
@@ -724,41 +696,13 @@ def _apply_request_overrides(
"""
params = default_params.clone()
- # Only apply fields explicitly provided by user, not protocol defaults.
- # Pydantic v2 uses `model_fields_set`; keep v1 fallback for compatibility.
- explicit_fields = getattr(request, "model_fields_set", None)
- if explicit_fields is None:
- explicit_fields = getattr(request, "__fields_set__", set())
-
for field_name in self._OPENAI_SAMPLING_FIELDS:
- if field_name not in explicit_fields:
- continue
-
value = getattr(request, field_name, None)
- if (value is not None and not isinstance(value, list)) or (isinstance(value, list) and len(value) > 0):
+ if value is not None:
setattr(params, field_name, value)
- # For GLM-Image: compute max_tokens from height/width with mode-aware
- # budgeting (t2i vs i2i).
- extra_body = getattr(request, "extra_body", {}) or {}
- height, width = self._resolve_height_width_from_extra_body(extra_body)
-
- if height is not None and width is not None:
- # Keep target size in stage-0 sampling params so runner/model can
- # build deterministic M-RoPE grids for t2i (no MM features).
- extra_args = dict(getattr(params, "extra_args", {}) or {})
- extra_args["target_h"] = int(height)
- extra_args["target_w"] = int(width)
- params.extra_args = extra_args
-
return params
- @staticmethod
- def _set_if_supported(obj: Any, **kwargs: Any) -> None:
- for key, value in kwargs.items():
- if value is not None and hasattr(obj, key):
- setattr(obj, key, value)
-
def _build_sampling_params_list_from_request(
self,
request: ChatCompletionRequest,
@@ -822,7 +766,6 @@ async def chat_completion_stream_generator(
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
reasoning_parser: ReasoningParser | None = None,
- raw_request: Request | None = None,
):
created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk"
@@ -1013,7 +956,12 @@ async def chat_completion_stream_generator(
cur_channel = harmony_parser.current_channel
cur_recipient = harmony_parser.current_recipient
else:
- delta_text = output.text or ""
+ # output.text is cumulative, extract only the delta portion
+ previous_text = previous_texts[i] if previous_texts else ""
+ if output.text is not None:
+ delta_text = output.text[len(previous_text) :]
+ else:
+ delta_text = ""
if not delta_text and not output.token_ids and not previous_num_tokens[i]:
# Chunked prefill case, don't return empty chunks
@@ -1530,21 +1478,6 @@ async def chat_completion_stream_generator(
delta=False,
)
- except EngineDeadError as e:
- logger.error(
- "EngineDeadError during streaming for request %s: %s",
- request_id,
- e,
- )
- data = self.create_streaming_error_response(e)
- yield f"data: {data}\n\n"
- # Actively signal shutdown instead of waiting for the watchdog
- # (5s polling interval).
- if raw_request is not None:
- terminate_if_errored(
- server=raw_request.app.state.server,
- engine=self.engine_client,
- )
except Exception as e:
logger.exception("Error in chat completion stream generator.")
data = self.create_streaming_error_response(e)
@@ -1616,7 +1549,6 @@ async def chat_completion_full_generator(
role,
reasoning_parser,
)
- final_res = omni_outputs.request_output
elif omni_outputs.final_output_type == "audio":
choices_data = self._create_audio_choice(omni_outputs, role, request, stream=False)
elif omni_outputs.final_output_type == "image":
@@ -1625,13 +1557,7 @@ async def chat_completion_full_generator(
logger.warning(f"Unsupported final output type: {omni_outputs.final_output_type}")
continue
if omni_outputs.metrics:
- response_metrics = dict(omni_outputs.metrics)
- if omni_outputs.final_output_type == "image":
- # Expose diffusion profiler metrics on the top-level response for benchmarks / clients.
- if response_metrics is None:
- response_metrics = {}
- response_metrics.setdefault("stage_durations", omni_outputs.stage_durations or {})
- response_metrics.setdefault("peak_memory_mb", float(omni_outputs.peak_memory_mb or 0.0))
+ response_metrics = omni_outputs.metrics
choices.extend(choices_data)
response = OmniChatCompletionResponse(
@@ -1925,8 +1851,7 @@ def _create_audio_choice(
final_res = omni_outputs.request_output
# OMNI: Access multimodal_output from CompletionOutput (outputs[0]), not from RequestOutput
# Reference: examples/offline_inference/qwen3_omni/end2end.py line 421
- mm_output = final_res.outputs[0].multimodal_output
- audio_data = mm_output.get("audio")
+ audio_data = final_res.outputs[0].multimodal_output.get("audio")
if isinstance(audio_data, list):
if stream:
audio_tensor = audio_data[-1]
@@ -1940,20 +1865,9 @@ def _create_audio_choice(
if audio_tensor.ndim > 1:
audio_tensor = audio_tensor.flatten()
- # Prefer the talker-reported sample rate when present. Qwen3-Omni
- # omits "sr" and runs at 24kHz; Ming-flash-omni surfaces a 44.1kHz
- # AudioVAE rate via multimodal_output["sr"].
- sr_raw = mm_output.get("sr")
- if sr_raw is None:
- sample_rate = 24000
- elif hasattr(sr_raw, "item"):
- sample_rate = int(sr_raw.item())
- else:
- sample_rate = int(sr_raw)
-
audio_obj = CreateAudio(
audio_tensor=audio_tensor,
- sample_rate=sample_rate,
+ sample_rate=24000,
response_format="wav",
speed=1.0,
stream_format="audio",
@@ -2113,261 +2027,6 @@ def _create_image_choice(
return choices
# ==================== Diffusion Mode Methods ====================
- def _build_multistage_generation_inputs(
- self,
- *,
- engine: AsyncOmni,
- prompt: str,
- extra_body: dict[str, Any],
- reference_images: list[Image.Image],
- gen_params: OmniDiffusionSamplingParams,
- ) -> tuple[OmniTextPrompt, list[Any]]:
- """Build the shared multistage generation prompt and stage params."""
- stage_configs = getattr(engine, "stage_configs", None) or []
- default_params_list = list(getattr(engine, "default_sampling_params_list", []) or [])
-
- height = gen_params.height
- width = gen_params.width
- seed = gen_params.seed
- generator_device = gen_params.generator_device
- num_outputs_per_prompt = gen_params.num_outputs_per_prompt
- num_inference_steps = extra_body.get("num_inference_steps")
- guidance_scale = extra_body.get("guidance_scale")
- true_cfg_scale = extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale")
- negative_prompt = extra_body.get("negative_prompt")
- num_frames = extra_body.get("num_frames")
- guidance_scale_2 = extra_body.get("guidance_scale_2")
- lora_body = extra_body.get("lora")
- layers = extra_body.get("layers")
- resolution = extra_body.get("resolution")
-
- engine_prompt_data: dict[str, Any] | None = None
- modalities = ["image"]
- if reference_images:
- if len(reference_images) == 1:
- engine_prompt_data = {"img2img": reference_images[0]}
- modalities = ["img2img"]
- else:
- engine_prompt_data = {"image": reference_images}
-
- engine_prompt: OmniTextPrompt = {"prompt": prompt}
- engine_prompt["modalities"] = modalities
- if negative_prompt is not None:
- engine_prompt["negative_prompt"] = negative_prompt
-
- mm_processor_kwargs: dict[str, Any] = {}
- if height is not None:
- mm_processor_kwargs["target_h"] = height
- if width is not None:
- mm_processor_kwargs["target_w"] = width
- if mm_processor_kwargs:
- engine_prompt["mm_processor_kwargs"] = mm_processor_kwargs
- if engine_prompt_data is not None:
- engine_prompt["multi_modal_data"] = engine_prompt_data
- # Provide multi_modal_uuids so that newer vLLM versions can
- # validate multi_modal_data / multi_modal_uuids consistency.
- engine_prompt["multi_modal_uuids"] = {k: [f"img-{k}-{i}"] for i, k in enumerate(engine_prompt_data)}
-
- comprehension_idx = None
- for idx, stage in enumerate(stage_configs):
- if getattr(stage, "is_comprehension", False):
- comprehension_idx = idx
- break
-
- sampling_params_list: list[Any] = []
- for idx, stage_cfg in enumerate(stage_configs):
- stage_type = get_stage_type(stage_cfg)
- if idx < len(default_params_list):
- default_stage_params = default_params_list[idx]
- if hasattr(default_stage_params, "clone"):
- try:
- default_stage_params = default_stage_params.clone()
- except Exception:
- pass
- elif stage_type == "diffusion":
- default_stage_params = gen_params.clone()
- else:
- default_stage_params = SamplingParams()
-
- if (
- comprehension_idx is not None
- and idx == comprehension_idx
- and seed is not None
- and hasattr(default_stage_params, "seed")
- ):
- default_stage_params.seed = seed
-
- # Inject target_h/w into comprehension (AR) stage sampling params
- # for models that need M-RoPE position pre-computation (e.g.
- # GLM-Image). max_tokens is handled via the deploy YAML default
- # (upper-bound ceiling) rather than computed dynamically here.
- if comprehension_idx is not None and idx == comprehension_idx and height is not None and width is not None:
- extra_args = getattr(default_stage_params, "extra_args", None)
- if extra_args is None:
- extra_args = {}
- default_stage_params.extra_args = extra_args
- extra_args["target_h"] = int(height)
- extra_args["target_w"] = int(width)
-
- if stage_type == "diffusion":
- self._set_if_supported(
- default_stage_params,
- height=height,
- width=width,
- seed=seed,
- generator_device=generator_device,
- num_outputs_per_prompt=num_outputs_per_prompt,
- num_inference_steps=num_inference_steps,
- guidance_scale=guidance_scale,
- true_cfg_scale=true_cfg_scale,
- num_frames=num_frames,
- guidance_scale_2=guidance_scale_2,
- layers=layers,
- resolution=resolution,
- )
- if lora_body and isinstance(lora_body, dict):
- try:
- lora_req, lora_scale = parse_lora_request(lora_body)
- if lora_req is not None:
- default_stage_params.lora_request = lora_req
- if lora_scale is not None:
- default_stage_params.lora_scale = lora_scale
- except Exception as e: # pragma: no cover - safeguard
- logger.warning("Failed to parse LoRA request: %s", e)
-
- sampling_params_list.append(default_stage_params)
-
- return engine_prompt, sampling_params_list
-
- async def generate_diffusion_images(
- self,
- *,
- prompt: str,
- extra_body: dict[str, Any] | None = None,
- reference_images: list[str] | None = None,
- request_id: str | None = None,
- ) -> tuple[list[Image.Image], dict[str, Any], float] | ErrorResponse:
- """Generate diffusion images and return raw images plus generation stats."""
- if request_id is None:
- request_id = f"chatcmpl-{uuid.uuid4().hex[:16]}"
- if extra_body is None:
- extra_body = {}
- if reference_images is None:
- reference_images = []
-
- engine = self._diffusion_engine if self._diffusion_engine is not None else self.engine_client
-
- height, width = self._resolve_height_width_from_extra_body(extra_body)
-
- seed = extra_body.get("seed")
- generator_device = extra_body.get("generator_device")
- negative_prompt = extra_body.get("negative_prompt")
- num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt", 1)
- lora_body = extra_body.get("lora")
-
- pil_images: list[Image.Image] = []
- for img_b64 in reference_images:
- try:
- img_bytes = base64.b64decode(img_b64)
- pil_images.append(Image.open(BytesIO(img_bytes)))
- except Exception as e:
- logger.warning("Failed to decode reference image: %s", e)
-
- gen_params = OmniDiffusionSamplingParams(
- height=height,
- width=width,
- num_outputs_per_prompt=num_outputs_per_prompt,
- seed=seed,
- )
- self._set_if_supported(
- gen_params,
- generator_device=generator_device,
- num_inference_steps=extra_body.get("num_inference_steps"),
- guidance_scale=extra_body.get("guidance_scale"),
- true_cfg_scale=extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale"),
- num_frames=extra_body.get("num_frames"),
- guidance_scale_2=extra_body.get("guidance_scale_2"),
- layers=extra_body.get("layers"),
- resolution=extra_body.get("resolution"),
- strength=extra_body.get("strength"),
- )
-
- if lora_body and isinstance(lora_body, dict):
- try:
- lora_req, lora_scale = parse_lora_request(lora_body)
- if lora_req is not None:
- gen_params.lora_request = lora_req
- if lora_scale is not None:
- gen_params.lora_scale = lora_scale
- except Exception as e: # pragma: no cover - safeguard
- logger.warning("Failed to parse LoRA request: %s", e)
-
- gen_prompt: OmniTextPrompt = {
- "prompt": prompt,
- "negative_prompt": negative_prompt,
- }
- if pil_images:
- if len(pil_images) == 1:
- gen_prompt["multi_modal_data"] = {"image": pil_images[0]}
- else:
- od_config = getattr(engine, "od_config", None)
- supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False)
- if od_config is None:
- supports_multimodal_inputs = True
- if supports_multimodal_inputs:
- gen_prompt["multi_modal_data"] = {"image": pil_images}
- else:
- return self._create_error_response(
- "Multiple input images are not supported by the current diffusion model. "
- "For multi-image editing, start the server with Qwen-Image-Edit-2509 "
- "and send multiple images in the user message content.",
- status_code=400,
- )
-
- if isinstance(engine, AsyncOmni):
- diffusion_engine = cast(AsyncOmni, engine)
- stage_configs = getattr(diffusion_engine, "stage_configs", None) or []
- if len(stage_configs) > 1:
- engine_prompt, sampling_params_list = self._build_multistage_generation_inputs(
- engine=diffusion_engine,
- prompt=prompt,
- extra_body=extra_body,
- reference_images=pil_images,
- gen_params=gen_params,
- )
- else:
- engine_prompt = gen_prompt
- sampling_params_list = [gen_params]
-
- result = None
- async for output in diffusion_engine.generate(
- prompt=engine_prompt,
- sampling_params_list=sampling_params_list,
- request_id=request_id,
- ):
- result = output
- if result is None:
- return self._create_error_response("No output generated from AsyncOmni", status_code=500)
- else:
- result = await engine.generate(
- prompt=gen_prompt,
- sampling_params=gen_params,
- request_id=request_id,
- )
-
- images = getattr(result.request_output, "images", [])
- stage_durations = result.stage_durations
- peak_memory_mb = result.peak_memory_mb
-
- flat_images: list[Image.Image] = []
- for item in images:
- if isinstance(item, list):
- flat_images.extend(item)
- else:
- flat_images.append(item)
-
- return flat_images, stage_durations, peak_memory_mb
-
async def _create_diffusion_chat_completion(
self,
request: ChatCompletionRequest,
@@ -2410,7 +2069,16 @@ async def _create_diffusion_chat_completion(
extra_body = request.model_extra or {}
# Parse size if provided (supports "1024x1024" format)
- height, width = self._resolve_height_width_from_extra_body(extra_body)
+ height = extra_body.get("height")
+ width = extra_body.get("width")
+ if "size" in extra_body:
+ try:
+ size_str = extra_body["size"]
+ if isinstance(size_str, str) and "x" in size_str.lower():
+ w, h = size_str.lower().split("x")
+ width, height = int(w), int(h)
+ except ValueError:
+ logger.warning("Invalid size format: %s", extra_body.get("size"))
# Get request parameters from extra_body.
# Avoid hardcoded defaults here — let each pipeline's forward()
@@ -2419,8 +2087,6 @@ async def _create_diffusion_chat_completion(
num_inference_steps = extra_body.get("num_inference_steps")
guidance_scale = extra_body.get("guidance_scale")
true_cfg_scale = extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale")
- cfg_text_scale = extra_body.get("cfg_text_scale")
- cfg_img_scale = extra_body.get("cfg_img_scale")
seed = extra_body.get("seed")
negative_prompt = extra_body.get("negative_prompt")
num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt", 1)
@@ -2475,10 +2141,6 @@ async def _create_diffusion_chat_completion(
gen_params.guidance_scale = guidance_scale
if true_cfg_scale is not None:
gen_params.true_cfg_scale = true_cfg_scale
- if cfg_text_scale is not None:
- gen_params.extra_args["cfg_text_scale"] = cfg_text_scale
- if cfg_img_scale is not None:
- gen_params.extra_args["cfg_img_scale"] = cfg_img_scale
if num_frames is not None:
gen_params.num_frames = num_frames
if guidance_scale_2 is not None:
@@ -2499,7 +2161,7 @@ async def _create_diffusion_chat_completion(
except Exception as e: # pragma: no cover - safeguard
logger.warning("Failed to parse LoRA request: %s", e)
- # Add reference image if provided (from messages content)
+ # Add reference image if provided
if pil_images:
if len(pil_images) == 1:
gen_prompt["multi_modal_data"] = {}
@@ -2523,30 +2185,10 @@ async def _create_diffusion_chat_completion(
# Generate image
diffusion_engine = cast(AsyncOmni, self._diffusion_engine)
- stage_configs = list(getattr(diffusion_engine, "stage_configs", []) or [])
- default_params_list = list(getattr(diffusion_engine, "default_sampling_params_list", []) or [])
-
- sampling_params_list: list[Any] = []
- for idx, stage_cfg in enumerate(stage_configs):
- if get_stage_type(stage_cfg) == "diffusion":
- sampling_params_list.append(gen_params)
- continue
-
- default_stage_params = default_params_list[idx] if idx < len(default_params_list) else SamplingParams()
- if hasattr(default_stage_params, "clone"):
- try:
- default_stage_params = default_stage_params.clone()
- except Exception as e:
- logger.warning("Failed to clone default params for stage %d: %s", idx, e)
- sampling_params_list.append(default_stage_params)
-
- if not sampling_params_list:
- sampling_params_list = [gen_params]
-
result = None
async for output in diffusion_engine.generate(
prompt=gen_prompt,
- sampling_params_list=sampling_params_list,
+ sampling_params_list=[gen_params], # Pass as single-stage params
request_id=request_id,
):
result = output
@@ -2690,48 +2332,6 @@ def _extract_diffusion_prompt_and_images(
prompt = " ".join(prompt_parts).strip()
return prompt, images
- def _extract_diffusion_prompt_and_images_from_messages(
- self,
- messages: list[Any],
- ) -> tuple[str, list[str]]:
- """Normalize mixed message types and extract prompt + reference images once."""
- return self._extract_diffusion_prompt_and_images(self._messages_to_dicts(messages))
-
- @staticmethod
- def _messages_to_dicts(messages: list[Any]) -> list[dict[str, Any]]:
- """Normalize request messages to plain dicts."""
- out: list[dict[str, Any]] = []
- for msg in messages:
- if hasattr(msg, "model_dump"):
- out.append(msg.model_dump())
- elif isinstance(msg, dict):
- out.append(msg)
- else:
- out.append(
- {
- "role": getattr(msg, "role", "user"),
- "content": getattr(msg, "content", ""),
- }
- )
- return out
-
- @staticmethod
- def _resolve_height_width_from_extra_body(extra_body: dict[str, Any]) -> tuple[Any, Any]:
- """Extract generation height/width with optional size string fallback."""
- height = extra_body.get("height")
- width = extra_body.get("width")
-
- if "size" in extra_body and (height is None or width is None):
- try:
- size_str = extra_body["size"]
- if isinstance(size_str, str) and "x" in size_str.lower():
- w, h = size_str.lower().split("x")
- width, height = int(w), int(h)
- except ValueError:
- logger.warning("Invalid size format: %s", extra_body.get("size"))
-
- return height, width
-
def _create_error_response(
self,
message: str,
diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py
index eb8a12c66ef..f051268824b 100644
--- a/vllm_omni/entrypoints/openai/serving_speech.py
+++ b/vllm_omni/entrypoints/openai/serving_speech.py
@@ -6,29 +6,22 @@
import os
import re
import struct
+import tempfile
import time
-from concurrent.futures import ThreadPoolExecutor
-from http import HTTPStatus
from pathlib import Path
from typing import Any
import numpy as np
import soundfile as sf
import torch
-from fastapi import HTTPException, Request, UploadFile
+from fastapi import Request, UploadFile
from fastapi.responses import Response, StreamingResponse
from transformers.utils.hub import cached_file
-from vllm.entrypoints.launcher import terminate_if_errored
-from vllm.entrypoints.openai.engine.protocol import (
- ErrorResponse,
- RequestResponseMetadata,
-)
+from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.logger import init_logger
from vllm.multimodal.media import MediaConnector
from vllm.utils import random_uuid
-from vllm.utils.async_utils import make_async
-from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin
from vllm_omni.entrypoints.openai.protocol.audio import (
@@ -40,18 +33,11 @@
SpeechBatchItem,
SpeechBatchItemResult,
)
-from vllm_omni.entrypoints.utils import coerce_param_message_types
from vllm_omni.model_executor.models.fish_speech.prompt_utils import (
build_fish_text_only_prompt_ids,
estimate_fish_voice_clone_prompt_len_from_normalized,
normalize_fish_voice_clone_texts,
)
-from vllm_omni.model_executor.models.ming_flash_omni.prompt_utils import (
- DEFAULT_PROMPT as MING_DEFAULT_PROMPT,
-)
-from vllm_omni.model_executor.models.ming_flash_omni.prompt_utils import (
- create_instruction as ming_create_instruction,
-)
from vllm_omni.outputs import OmniRequestOutput
logger = init_logger(__name__)
@@ -62,18 +48,12 @@
_FISH_TTS_MODEL_STAGES = {"fish_speech_slow_ar"}
_COSYVOICE3_TTS_MODEL_STAGES = {"cosyvoice3_talker"}
_OMNIVOICE_TTS_MODEL_STAGES = {"omnivoice_generator"}
-_VOXCPM_TTS_MODEL_STAGES = {"latent_generator", "vae"}
-_VOXCPM2_TTS_MODEL_STAGES = {"latent_generator"}
-_MING_TTS_MODEL_STAGES = {"ming_tts"}
_TTS_MODEL_STAGES: set[str] = (
_VOXTRAL_TTS_MODEL_STAGES
| _QWEN3_TTS_MODEL_STAGES
| _FISH_TTS_MODEL_STAGES
| _COSYVOICE3_TTS_MODEL_STAGES
| _OMNIVOICE_TTS_MODEL_STAGES
- | _VOXCPM_TTS_MODEL_STAGES
- | _VOXCPM2_TTS_MODEL_STAGES
- | _MING_TTS_MODEL_STAGES
)
_TTS_LANGUAGES: set[str] = {
"Auto",
@@ -173,7 +153,6 @@ def _validate_path_within_directory(file_path: Path, directory: Path) -> bool:
class OmniOpenAIServingSpeech(OpenAIServing, AudioMixin):
_diffusion_mode: bool = False
- _tts_executor: ThreadPoolExecutor | None = None
@classmethod
def for_diffusion(
@@ -231,8 +210,6 @@ def __init__(self, *args, **kwargs):
"Re-upload voices after each restart if needed."
)
self._tts_tokenizer = None
- self._voxcpm2_tokenizer = None
- self._voxcpm2_split_map: dict[int, list[int]] = {}
logger.info(f"Loaded {len(self.supported_speakers)} supported speakers: {sorted(self.supported_speakers)}")
@@ -242,14 +219,6 @@ def __init__(self, *args, **kwargs):
# Load speech tokenizer codec parameters for prompt length estimation
self._codec_frame_rate: float | None = self._load_codec_frame_rate()
- # Shared thread pool executor for blocking TTS preprocessing
- # operations. max_workers=1 serializes tokenizer access to avoid
- # Rust RefCell "Already borrowed" errors from concurrent use.
- self._tts_executor = ThreadPoolExecutor(max_workers=1)
- self._build_voxtral_prompt_async = make_async(self._build_voxtral_prompt, executor=self._tts_executor)
- self._build_fish_speech_prompt_async = make_async(self._build_fish_speech_prompt, executor=self._tts_executor)
- self._estimate_prompt_len_async = make_async(self._estimate_prompt_len, executor=self._tts_executor)
-
def _load_codec_frame_rate(self) -> float | None:
"""Load codec frame rate from speech tokenizer config for prompt length estimation."""
try:
@@ -283,12 +252,6 @@ def _load_codec_frame_rate(self) -> float | None:
pass
return None
- def shutdown(self) -> None:
- """Shut down the TTS thread pool executor."""
- if self._tts_executor is not None:
- self._tts_executor.shutdown(wait=False, cancel_futures=True)
- self._tts_executor = None
-
def _find_tts_stage(self):
"""Find and return the TTS stage config, or None if not found."""
for stage in self.engine_client.stage_configs:
@@ -301,11 +264,6 @@ def _detect_tts_model_type(self) -> str | None:
if self._tts_stage is None:
return None
model_stage = getattr(self._tts_stage.engine_args, "model_stage", None)
- model_arch = getattr(self._tts_stage.engine_args, "model_arch", None)
- if model_arch == "VoxCPM2TalkerForConditionalGeneration":
- return "voxcpm2"
- if model_arch == "VoxCPMForConditionalGeneration":
- return "voxcpm"
if model_stage in _QWEN3_TTS_MODEL_STAGES:
return "qwen3_tts"
if model_stage in _VOXTRAL_TTS_MODEL_STAGES:
@@ -316,14 +274,6 @@ def _detect_tts_model_type(self) -> str | None:
return "cosyvoice3"
if model_stage in _OMNIVOICE_TTS_MODEL_STAGES:
return "omnivoice"
- if model_stage in (_VOXCPM_TTS_MODEL_STAGES | _VOXCPM2_TTS_MODEL_STAGES):
- has_vae_stage = any(
- getattr(getattr(stage, "engine_args", None), "model_stage", None) == "vae"
- for stage in self.engine_client.stage_configs
- )
- return "voxcpm" if has_vae_stage or model_stage == "vae" else "voxcpm2"
- if model_stage in _MING_TTS_MODEL_STAGES:
- return "ming_flash_omni_tts"
return None
def _compute_max_instructions_length(self) -> int:
@@ -347,14 +297,7 @@ def _compute_max_instructions_length(self) -> int:
def _load_supported_speakers(self) -> set[str]:
"""Load supported speakers (case-insensitive) from the model configuration."""
- if self._tts_model_type == "ming_flash_omni_tts":
- # Ming-flash-omni drives speaker selection via the caption JSON
- # (audio_sequence[0]["说话人"]) rather than a spk_id table, so there
- # is no static speaker list to surface here.
- return set()
try:
- if self._tts_model_type == "voxcpm":
- return set()
if self._tts_model_type == "voxtral_tts":
config = self.engine_client.model_config.hf_config.audio_config
else:
@@ -414,8 +357,6 @@ def _estimate_ref_code_len(self, ref_audio: object) -> int | None:
def _estimate_prompt_len(self, tts_params: dict[str, Any]) -> int:
"""Estimate prompt length so the placeholder matches model-side embeddings."""
try:
- if self._tts_model_type == "voxcpm":
- return 1
from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import (
Qwen3TTSTalkerForConditionalGeneration,
)
@@ -479,25 +420,6 @@ def _estimate_fish_prompt_len(self, text: str, ref_text: str, ref_audio: object)
logger.warning("Failed to estimate Fish Speech prompt length, using fallback 2048: %s", e)
return 2048
- async def _build_voxcpm2_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]:
- """Build prefill prompt for VoxCPM2 TTS (`prompt_token_ids` padded to full prefill length)."""
- from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import build_voxcpm2_prompt
-
- self._voxcpm2_encode("") # lazy-init tokenizer + split_map
- ref_audio = None
- ref_sr = None
- if request.ref_audio is not None:
- ref_audio, ref_sr = await self._resolve_ref_audio(request.ref_audio)
- return build_voxcpm2_prompt(
- hf_config=self.engine_client.model_config.hf_config,
- tokenizer=self._voxcpm2_tokenizer,
- split_map=self._voxcpm2_split_map,
- text=request.input,
- ref_audio=ref_audio,
- ref_sr=ref_sr,
- ref_text=request.ref_text,
- )
-
def _get_uploaded_audio_data(self, voice_name: str) -> str | None:
"""Get base64 encoded audio data for uploaded voice."""
voice_name_lower = voice_name.lower()
@@ -528,48 +450,6 @@ def _get_uploaded_audio_data(self, voice_name: str) -> str | None:
logger.error(f"Could not read audio file for voice {voice_name}: {e}")
return None
- def _get_uploaded_speaker_embedding(self, voice_name: str) -> list[float] | None:
- """Load pre-computed speaker embedding for an uploaded voice.
-
- Returns the embedding as a list of floats, or None if the voice
- was not uploaded with an embedding (i.e. it has audio instead).
- """
- voice_name_lower = voice_name.lower()
- if voice_name_lower not in self.uploaded_speakers:
- return None
-
- speaker_info = self.uploaded_speakers[voice_name_lower]
- if speaker_info.get("embedding_source") != "direct":
- return None
-
- cache_file = speaker_info.get("cache_file")
- if not cache_file or not Path(cache_file).exists():
- logger.warning("Embedding file not found for voice %s: %s", voice_name, cache_file)
- return None
-
- if not _validate_path_within_directory(Path(cache_file), self.uploaded_speakers_dir):
- logger.error("Cache file path traversal detected for voice %s: %s", voice_name, cache_file)
- return None
-
- try:
- from safetensors.torch import load_file
- except ImportError:
- logger.error(
- "The 'safetensors' package is required to load speaker embeddings. "
- "Install it with: pip install safetensors"
- )
- return None
-
- try:
- tensors = load_file(cache_file)
- if "speaker_embedding" not in tensors:
- logger.warning("Key 'speaker_embedding' not found in %s for voice %s", cache_file, voice_name)
- return None
- return tensors["speaker_embedding"].squeeze().tolist()
- except Exception as e:
- logger.error("Could not load embedding for voice %s: %s", voice_name, e)
- return None
-
async def upload_voice(
self,
audio_file: UploadFile,
@@ -849,68 +729,8 @@ def _validate_tts_request(self, request: OpenAICreateSpeechRequest) -> str | Non
return self._validate_fish_tts_request(request)
if self._tts_model_type == "cosyvoice3":
return self._validate_cosyvoice3_request(request)
- if self._tts_model_type == "voxcpm":
- return self._validate_voxcpm_request(request)
- if self._tts_model_type == "voxcpm2":
- return None # VoxCPM2 accepts any text input
- if self._tts_model_type == "ming_flash_omni_tts":
- return self._validate_ming_tts_request(request)
return self._validate_qwen_tts_request(request)
- def _voxcpm2_encode(self, text: str) -> list[int]:
- """Tokenize text for VoxCPM2, splitting multichar Chinese tokens."""
- from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import (
- build_cjk_split_map,
- split_multichar_chinese,
- )
-
- if self._voxcpm2_tokenizer is None:
- from transformers import AutoTokenizer
-
- model_name = self.engine_client.model_config.model
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
- self._voxcpm2_split_map = build_cjk_split_map(tokenizer)
- self._voxcpm2_tokenizer = tokenizer
- logger.info("VoxCPM2 serving: built multichar split map (%d entries)", len(self._voxcpm2_split_map))
-
- ids = self._voxcpm2_tokenizer.encode(text, add_special_tokens=True)
- return split_multichar_chinese(ids, self._voxcpm2_split_map)
-
- def _validate_ming_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None:
- """Validate Ming-flash-omni standalone-talker request parameters."""
- if not request.input or not request.input.strip():
- return "Input text cannot be empty"
- if request.instructions is not None:
- if not isinstance(request.instructions, str):
- return "instructions must be a string"
- if len(request.instructions) > self._max_instructions_length:
- return f"instructions exceeds max length {self._max_instructions_length}"
-
- if request.task_type is not None:
- return "'task_type' is not supported for Ming-flash-omni TTS"
- if request.language is not None:
- return "'language' is not supported for Ming-flash-omni TTS (language is inferred from input text)"
- if request.x_vector_only_mode is not None:
- return "'x_vector_only_mode' is not supported for Ming-flash-omni TTS"
- if request.initial_codec_chunk_frames is not None:
- return "'initial_codec_chunk_frames' is not supported for Ming-flash-omni TTS"
-
- # Per-request voice cloning from raw audio is not yet wired up: Ming
- # extracts spk_emb / prompt_wav_lat / prompt_wav_emb model-side via
- # register_prompt_wav() at engine init. For ad-hoc cloning, callers
- # should pre-compute speaker_embedding and pass it directly.
- if request.ref_audio is not None:
- return (
- "'ref_audio' is not yet supported for Ming-flash-omni TTS; "
- "use a preset 'voice' or 'speaker_embedding' instead"
- )
- if request.ref_text is not None:
- return "'ref_text' is not yet supported for Ming-flash-omni TTS"
-
- if request.max_new_tokens is not None and request.max_new_tokens <= 0:
- return "'max_new_tokens' must be a positive integer"
- return None
-
def _validate_ref_audio_format(self, ref_audio: str) -> str | None:
"""Validate ref_audio is a supported URI format. Returns error or None."""
if not (
@@ -948,43 +768,6 @@ def _validate_voxtral_tts_request(self, request: OpenAICreateSpeechRequest) -> s
return None
- def _validate_voxcpm_request(self, request: OpenAICreateSpeechRequest) -> str | None:
- """Validate VoxCPM request parameters. Returns error message or None."""
- if not request.input or not request.input.strip():
- return "Input text cannot be empty"
-
- if request.voice is not None:
- return "'voice' is not supported for VoxCPM"
- if request.instructions is not None:
- return "'instructions' is not supported for VoxCPM"
- if request.language is not None:
- return "'language' is not supported for VoxCPM"
- if request.task_type not in (None, "Base"):
- return "VoxCPM only supports plain TTS or voice cloning with ref_audio/ref_text"
- if request.x_vector_only_mode is not None:
- return "'x_vector_only_mode' is not supported for VoxCPM"
- if request.speaker_embedding is not None:
- return "'speaker_embedding' is not supported for VoxCPM"
- if request.initial_codec_chunk_frames is not None:
- return "'initial_codec_chunk_frames' is not supported for VoxCPM"
-
- if request.ref_audio is not None:
- fmt_err = self._validate_ref_audio_format(request.ref_audio)
- if fmt_err:
- return fmt_err
- if not request.ref_text or not request.ref_text.strip():
- return "Voice cloning requires 'ref_text' (transcript of the reference audio)"
- elif request.ref_text is not None:
- return "'ref_text' requires 'ref_audio' for VoxCPM voice cloning"
-
- if request.max_new_tokens is not None:
- if request.max_new_tokens < _TTS_MAX_NEW_TOKENS_MIN:
- return f"max_new_tokens must be at least {_TTS_MAX_NEW_TOKENS_MIN}"
- if request.max_new_tokens > _TTS_MAX_NEW_TOKENS_MAX:
- return f"max_new_tokens cannot exceed {_TTS_MAX_NEW_TOKENS_MAX}"
-
- return None
-
def _validate_qwen_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None:
"""Validate Qwen TTS request parameters. Returns error message or None."""
# Infer Base task when ref_audio or ref_text is provided without explicit task_type.
@@ -1058,17 +841,11 @@ def _validate_qwen_tts_request(self, request: OpenAICreateSpeechRequest) -> str
# voice is not None
voice_lower = request.voice.lower()
if voice_lower in self.uploaded_speakers:
- # Check if data file exists for uploaded speaker
+ # Check if audio file exists for uploaded speaker
speaker_info = self.uploaded_speakers[voice_lower]
file_path = Path(speaker_info["file_path"])
if not file_path.exists():
- return f"Data file for uploaded speaker '{request.voice}' not found on disk"
- # For embedding-uploaded voices, verify the cache is ready
- if speaker_info.get("embedding_source") == "direct":
- cache_file = speaker_info.get("cache_file")
- if not cache_file or not Path(cache_file).exists():
- status = speaker_info.get("cache_status", "unknown")
- return f"Speaker embedding for '{request.voice}' is not yet ready (cache_status='{status}')"
+ return f"Audio file for uploaded speaker '{request.voice}' not found on disk"
else:
# need ref_audio for built-in speaker
if request.ref_audio is None:
@@ -1078,13 +855,6 @@ def _validate_qwen_tts_request(self, request: OpenAICreateSpeechRequest) -> str
fmt_err = self._validate_ref_audio_format(request.ref_audio)
if fmt_err:
return fmt_err
- if not getattr(request, "x_vector_only_mode", False) and (
- not request.ref_text or not request.ref_text.strip()
- ):
- return (
- "Base task requires non-empty 'ref_text' (transcript of "
- "the reference audio) unless 'x_vector_only_mode' is enabled"
- )
# Validate cross-parameter dependencies
if task_type != "Base":
@@ -1111,32 +881,10 @@ def _validate_qwen_tts_request(self, request: OpenAICreateSpeechRequest) -> str
return None
def _validate_fish_tts_request(self, request: OpenAICreateSpeechRequest) -> str | None:
- """Validate Fish Speech request parameters. Returns error message or None.
-
- Side effect: if request.voice references an uploaded speaker, resolves
- it to request.ref_audio and request.ref_text for voice cloning.
- """
+ """Validate Fish Speech request parameters. Returns error message or None."""
if not request.input or not request.input.strip():
return "Input text cannot be empty"
- # Support uploaded voices: auto-resolve voice → ref_audio + ref_text.
- if request.voice is not None and request.ref_audio is None:
- voice_lower = request.voice.lower()
- if voice_lower in self.uploaded_speakers:
- speaker_info = self.uploaded_speakers[voice_lower]
- file_path = Path(speaker_info["file_path"])
- if not file_path.exists():
- return f"Audio file for uploaded voice '{request.voice}' not found on disk"
- audio_data_url = self._get_uploaded_audio_data(voice_lower)
- if audio_data_url is None:
- return f"Could not load audio for uploaded voice '{request.voice}'"
- request.ref_audio = audio_data_url
- # Use ref_text from upload metadata if not provided in request.
- if not request.ref_text or not request.ref_text.strip():
- upload_ref_text = speaker_info.get("ref_text")
- if upload_ref_text and upload_ref_text.strip():
- request.ref_text = upload_ref_text
-
if request.ref_audio is not None:
fmt_err = self._validate_ref_audio_format(request.ref_audio)
if fmt_err:
@@ -1183,15 +931,11 @@ async def _resolve_ref_audio(self, ref_audio_str: str) -> tuple[list[float], int
URLs, ``data:`` base64 URIs, and ``file:`` local paths (the latter
gated by ``--allowed-local-media-path``).
"""
- # In diffusion mode, model_config may not be available
- if self._diffusion_mode:
- connector = MediaConnector()
- else:
- model_config = self.model_config
- connector = MediaConnector(
- allowed_local_media_path=model_config.allowed_local_media_path,
- allowed_media_domains=model_config.allowed_media_domains,
- )
+ model_config = self.model_config
+ connector = MediaConnector(
+ allowed_local_media_path=model_config.allowed_local_media_path,
+ allowed_media_domains=model_config.allowed_media_domains,
+ )
wav_np, sr = await connector.fetch_audio_async(ref_audio_str)
wav_np = np.asarray(wav_np, dtype=np.float32)
if wav_np.ndim > 1:
@@ -1210,13 +954,7 @@ async def _resolve_ref_audio(self, ref_audio_str: str) -> tuple[list[float], int
)
return wav_np.tolist(), sr
- async def _generate_audio_chunks(
- self,
- generator,
- request_id: str,
- response_format: str = "pcm",
- raw_request: Request | None = None,
- ):
+ async def _generate_audio_chunks(self, generator, request_id: str, response_format: str = "pcm"):
"""Generate audio chunks for streaming response.
Handles two audio output modes from the engine:
@@ -1291,19 +1029,6 @@ async def _generate_audio_chunks(
except asyncio.CancelledError:
logger.info("Streaming request %s cancelled by client", request_id)
raise
- except EngineDeadError as e:
- logger.error(
- "EngineDeadError during streaming speech for %s: %s",
- request_id,
- e,
- )
- # Actively signal shutdown rather than relying on the watchdog.
- if raw_request is not None:
- terminate_if_errored(
- server=raw_request.app.state.server,
- engine=self.engine_client,
- )
- raise
except Exception as e:
logger.exception("Streaming speech generation failed for %s: %s", request_id, e)
raise
@@ -1316,20 +1041,9 @@ def _extract_audio_output(res) -> tuple[dict | None, str | None]:
streaming needs per-chunk delta slicing; non-streaming needs full concatenation.
"""
mm = getattr(res, "multimodal_output", None)
- ro = None
if not mm:
ro = getattr(res, "request_output", None)
mm = getattr(ro, "multimodal_output", None) if ro else None
- if not mm:
- if ro is None:
- ro = getattr(res, "request_output", None)
- outputs = getattr(ro, "outputs", None) if ro else None
- if outputs:
- for completion_output in outputs:
- completion_mm = getattr(completion_output, "multimodal_output", None)
- if completion_mm:
- mm = completion_mm
- break
if not mm:
return None, None
key = "audio" if "audio" in mm else ("model_outputs" if "model_outputs" in mm else None)
@@ -1341,18 +1055,6 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any
Processes each parameter if present, skips if not.
Values are wrapped in lists as required by the model.
"""
- if self._tts_model_type == "voxcpm":
- params: dict[str, Any] = {
- "text": [request.input],
- "cfg_value": [2.0],
- "inference_timesteps": [10],
- "min_len": [2],
- "max_new_tokens": [request.max_new_tokens or 4096],
- }
- if request.ref_text is not None:
- params["ref_text"] = [request.ref_text]
- return params
-
params: dict[str, Any] = {}
# Text content (always required)
@@ -1377,32 +1079,20 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any
# Uploaded voices use task_type="Base" (CustomVoice requires built-in spk_id).
# If ref_text was provided at upload time, use in-context cloning; otherwise x_vector only.
if request.voice.lower() in self.uploaded_speakers and request.ref_audio is None:
+ audio_data = self._get_uploaded_audio_data(request.voice)
+ if not audio_data:
+ raise ValueError(f"Audio file for uploaded voice '{request.voice}' is missing or corrupted")
speaker_info = self.uploaded_speakers[request.voice.lower()]
-
- # Check if this voice was uploaded with a pre-computed embedding.
- # Populate request.speaker_embedding so the existing code path
- # (below) handles voice_clone_prompt and x_vector_only_mode.
- embedding = self._get_uploaded_speaker_embedding(request.voice)
- if embedding is not None:
- request.speaker_embedding = embedding
- params["task_type"] = ["Base"]
- logger.info("Auto-set speaker_embedding for uploaded voice: %s", request.voice)
+ stored_ref_text = speaker_info.get("ref_text")
+ params["ref_audio"] = [audio_data]
+ params["task_type"] = ["Base"]
+ params["voice_created_at"] = [speaker_info.get("created_at", 0)]
+ if stored_ref_text:
+ params["ref_text"] = [stored_ref_text]
+ params["x_vector_only_mode"] = [False]
else:
- audio_data = self._get_uploaded_audio_data(request.voice)
- if not audio_data:
- raise ValueError(f"Audio file for uploaded voice '{request.voice}' is missing or corrupted")
- stored_ref_text = speaker_info.get("ref_text")
- params["ref_audio"] = [audio_data]
- params["task_type"] = ["Base"]
- params["voice_created_at"] = [speaker_info.get("created_at", 0)]
- if stored_ref_text:
- params["ref_text"] = [stored_ref_text]
- params["x_vector_only_mode"] = [False]
- else:
- params["x_vector_only_mode"] = [True]
- logger.info(
- "Auto-set ref_audio for uploaded voice: %s (icl=%s)", request.voice, bool(stored_ref_text)
- )
+ params["x_vector_only_mode"] = [True]
+ logger.info("Auto-set ref_audio for uploaded voice: %s (icl=%s)", request.voice, bool(stored_ref_text))
elif params["task_type"][0] == "CustomVoice":
params["speaker"] = ["Vivian"] # Default for CustomVoice
@@ -1448,7 +1138,7 @@ def _build_tts_params(self, request: OpenAICreateSpeechRequest) -> dict[str, Any
# ---- Voxtral TTS helpers ----
- def _build_voxtral_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]:
+ async def _build_voxtral_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]:
"""Build Voxtral TTS engine prompt from shared TTS parameters."""
from mistral_common.protocol.speech.request import SpeechRequest
@@ -1523,22 +1213,20 @@ def _build_fish_speech_prompt(
wav_samples, sr = ref_audio_data
normalized_text, normalized_ref_text = normalize_fish_voice_clone_texts(request.input, request.ref_text)
ph_len = self._estimate_fish_prompt_len(normalized_text, normalized_ref_text, ref_audio_data)
-
- # Structured clone: scalars (not list-wrapped) because model-side
- # preprocess() consumes per-request fields directly.
- additional_information: dict[str, Any] = {
+ with tempfile.NamedTemporaryFile(prefix="fish_ref_", suffix=".npy", delete=False) as f:
+ np.save(f, np.asarray(wav_samples, dtype=np.float32))
+ ref_audio_path = f.name
+
+ # Structured clone metadata is consumed directly by
+ # FishSpeechSlowARForConditionalGeneration.preprocess(), so keep these
+ # values as scalars instead of the list-wrapped prompt-dict convention.
+ additional_information = {
"text": normalized_text,
"ref_text": normalized_ref_text,
- "ref_audio_wav": torch.from_numpy(np.asarray(wav_samples, dtype=np.float32)),
+ "ref_audio_path": ref_audio_path,
"ref_audio_sr": int(sr),
"fish_structured_voice_clone": True,
}
- # Pass voice identity for model-side DAC code caching.
- if request.voice is not None:
- voice_lower = request.voice.lower()
- if voice_lower in self.uploaded_speakers:
- additional_information["voice_name"] = voice_lower
- additional_information["voice_created_at"] = self.uploaded_speakers[voice_lower].get("created_at", 0)
if request.max_new_tokens is not None:
additional_information["max_new_tokens"] = request.max_new_tokens
return {
@@ -1573,69 +1261,15 @@ async def _build_cosyvoice3_prompt(
},
}
- # ---- Ming-flash-omni standalone-talker (TTS) helpers ----
-
- def _build_ming_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]:
- # request.instructions accepts two forms:
- # 1. Plain text: mapped to the caption's 风格 (style) field
- # 2. JSON object: parsed and splatted into the caption. Unlocks
- # Unknown keys are dropped by `ming_create_instruction`.
- caption_fields: dict[str, Any] = {}
- if request.instructions:
- stripped = request.instructions.strip()
- if stripped.startswith("{"):
- try:
- parsed = json.loads(stripped)
- except json.JSONDecodeError:
- parsed = None
- if isinstance(parsed, dict):
- caption_fields.update(parsed)
- else:
- caption_fields["风格"] = request.instructions
- else:
- caption_fields["风格"] = request.instructions
-
- has_spk_emb = request.speaker_embedding is not None
-
- # TTS path applies ming task type `instruct`.
- # voice_name enables talker-side voice preset resolution (e.g. "DB30").
- additional_information: dict[str, Any] = {
- "ming_task": "instruct",
- "prompt": MING_DEFAULT_PROMPT,
- "text": request.input,
- "instruction": ming_create_instruction(caption_fields),
- "voice_name": request.voice or None,
- "use_zero_spk_emb": not has_spk_emb,
- "max_decode_steps": request.max_new_tokens or _TTS_MAX_NEW_TOKENS_MAX,
- "cfg": 2.0,
- "sigma": 0.25,
- "temperature": 0.0,
- }
- if has_spk_emb:
- # Passed as plain float list
- additional_information["spk_emb"] = list(request.speaker_embedding)
- return {
- "prompt_token_ids": [0],
- "additional_information": additional_information,
- }
-
# ---- Common speech generation helpers ----
async def _prepare_speech_generation(
self,
request: OpenAICreateSpeechRequest,
- request_id: str | None = None,
) -> tuple[str, Any, dict[str, Any]]:
if self.engine_client.errored:
raise self.engine_client.dead_error
- # If this is a streaming request, we need to coerce
- # cumulative outputs to delta outputs; this ensures
- # we don't emit redundant MM data & drain after emitting.
- # list() makes a copy to avoid mutating the params.
- sampling_params_list = list(self.engine_client.default_sampling_params_list)
- sampling_params_list = coerce_param_message_types(sampling_params_list, request.stream)
-
if self._is_fish_speech:
validation_error = self._validate_fish_tts_request(request)
if validation_error:
@@ -1644,53 +1278,22 @@ async def _prepare_speech_generation(
if request.ref_audio is not None:
wav_list, sr = await self._resolve_ref_audio(request.ref_audio)
ref_audio_data = (wav_list, sr)
- prompt = await self._build_fish_speech_prompt_async(request, ref_audio_data=ref_audio_data)
+ prompt = self._build_fish_speech_prompt(request, ref_audio_data=ref_audio_data)
tts_params = {}
elif self._tts_model_type == "omnivoice":
- if not request.input or not request.input.strip():
- raise ValueError("Input text cannot be empty")
- tts_params = {}
- prompt: dict[str, Any] = {"input": request.input}
- # Resolve ref_audio: explicit request param or uploaded voice
- ref_src = request.ref_audio
- if not ref_src and request.voice:
- vl = request.voice.lower()
- if vl in self.uploaded_speakers:
- sp = self.uploaded_speakers[vl]
- if sp.get("embedding_source") == "audio":
- ref_src = self._get_uploaded_audio_data(request.voice)
- if not ref_src:
- raise ValueError(f"Audio for voice '{request.voice}' missing")
- prompt["ref_text"] = sp.get("ref_text")
- if ref_src:
- fmt_err = self._validate_ref_audio_format(ref_src)
- if fmt_err:
- raise ValueError(fmt_err)
- wav, sr = await self._resolve_ref_audio(ref_src)
- prompt["ref_audio"] = (np.asarray(wav, dtype=np.float32), sr)
- if request.ref_text:
- prompt["ref_text"] = request.ref_text
- if request.language:
- prompt["lang"] = request.language
- if request.instructions:
- prompt["instruct"] = request.instructions
- elif self._tts_model_type == "voxcpm2":
- prompt = await self._build_voxcpm2_prompt(request)
tts_params = {}
+ prompt = request.input # Diffusion engine takes raw text
elif self._is_tts:
validation_error = self._validate_tts_request(request)
if validation_error:
raise ValueError(validation_error)
if self._tts_model_type == "voxtral_tts":
- prompt = await self._build_voxtral_prompt_async(request)
+ prompt = await self._build_voxtral_prompt(request)
tts_params = {}
elif self._tts_model_type == "cosyvoice3":
prompt = await self._build_cosyvoice3_prompt(request)
tts_params = {}
- elif self._tts_model_type == "ming_flash_omni_tts":
- prompt = self._build_ming_prompt(request)
- tts_params = {}
else:
tts_params = self._build_tts_params(request)
# Resolve ref_audio (explicit or auto-set for uploaded voices)
@@ -1703,43 +1306,19 @@ async def _prepare_speech_generation(
wav_list, sr = await self._resolve_ref_audio(ref_audio_source)
tts_params["ref_audio"] = [[wav_list, sr]]
- ph_len = await self._estimate_prompt_len_async(tts_params)
+ ph_len = self._estimate_prompt_len(tts_params)
prompt = {"prompt_token_ids": [1] * ph_len, "additional_information": tts_params}
else:
- # Qwen omni models (Qwen3-Omni, Qwen2.5-Omni) use a "talker"
- # stage whose preprocess requires chat-templated tokens. The
- # async-chunk orchestrator prewarms the talker via
- # compute_talker_prompt_ids_length(), which scans for Qwen
- # chat-template markers (im_start_token_id 151644). A raw-text
- # prompt produces a 1-token placeholder that crashes the talker's
- # prefill/decode handoff. Reject early with an actionable message.
- stage_names = {
- getattr(getattr(s, "engine_args", None), "model_stage", None) for s in self.engine_client.stage_configs
- }
- if "talker" in stage_names:
- raise ValueError(
- "The /v1/audio/speech endpoint is only supported for "
- "dedicated TTS models (e.g., Qwen3-TTS, Voxtral, Fish "
- "Speech, CosyVoice3, OmniVoice, VoxCPM2). For omni "
- "models like Qwen3-Omni, use /v1/chat/completions with "
- '\'"modalities": ["audio"]\' instead.'
- )
tts_params = {}
prompt = {"prompt": request.input}
- request_id = request_id or f"speech-{random_uuid()}"
+ request_id = f"speech-{random_uuid()}"
if self._is_fish_speech:
model_type = "fish_speech"
elif self._tts_model_type == "voxtral_tts":
model_type = "voxtral_tts"
elif self._tts_model_type == "cosyvoice3":
model_type = "cosyvoice3"
- elif self._tts_model_type == "voxcpm":
- model_type = "voxcpm"
- elif self._tts_model_type == "voxcpm2":
- model_type = "voxcpm2"
- elif self._tts_model_type == "ming_flash_omni_tts":
- model_type = "ming_flash_omni_tts"
elif self._is_tts:
model_type = tts_params.get("task_type", ["unknown"])[0]
else:
@@ -1751,43 +1330,7 @@ async def _prepare_speech_generation(
model_type,
)
- # CosyVoice3: set dynamic min/max tokens based on text length.
- # The official model requires min_token_text_ratio to prevent early
- # EOS and max_token_text_ratio to cap generation length.
- if self._tts_model_type == "cosyvoice3" and sampling_params_list:
- import copy
-
- sampling_params_list = copy.deepcopy(sampling_params_list)
- text_len = len(request.input) # rough char-level estimate
- # Use the model's configured ratios (defaults: min=2, max=20)
- hf_cfg = self.model_config.hf_config
- min_ratio = getattr(hf_cfg, "min_token_text_ratio", 2)
- max_ratio = getattr(hf_cfg, "max_token_text_ratio", 20)
- min_tokens = max(1, int(text_len * min_ratio))
- max_tokens = min(2048, int(text_len * max_ratio))
- sampling_params_list[0].min_tokens = min_tokens
- sampling_params_list[0].max_tokens = max_tokens
- logger.info(
- "CosyVoice3 dynamic tokens: text_len=%d, min_tokens=%d, max_tokens=%d",
- text_len,
- min_tokens,
- max_tokens,
- )
-
- # Apply model-specific extra parameters
- if request.extra_params is not None and sampling_params_list:
- if not isinstance(request.extra_params, dict):
- raise HTTPException(
- status_code=HTTPStatus.BAD_REQUEST.value,
- detail="extra_params must be a JSON object/dict.",
- )
- import copy
-
- sampling_params_list = copy.deepcopy(sampling_params_list)
- if sampling_params_list[0].extra_args is None:
- sampling_params_list[0].extra_args = {}
- sampling_params_list[0].extra_args.update(request.extra_params)
- logger.info("Applied extra_params: %s", request.extra_params)
+ sampling_params_list = self.engine_client.default_sampling_params_list
# Fish defaults come from stage_configs YAML. Only override when the caller
# explicitly requests a different generation length.
@@ -1805,15 +1348,6 @@ async def _prepare_speech_generation(
)
return request_id, generator, tts_params
- async def _generate_pcm_chunks(self, generator, request_id: str):
- """Yield raw PCM byte chunks from the engine generator.
-
- Delegates to ``_generate_audio_chunks`` with ``response_format="pcm"``.
- Used by the WebSocket streaming handler and ``_iter_pcm_audio_bytes``.
- """
- async for chunk in self._generate_audio_chunks(generator, request_id, response_format="pcm"):
- yield chunk
-
async def _iter_pcm_audio_bytes(self, request: OpenAICreateSpeechRequest):
"""Yield raw PCM bytes for a speech request as soon as chunks are decoded."""
request_id, generator, _ = await self._prepare_speech_generation(request)
@@ -1824,9 +1358,8 @@ async def _generate_audio_bytes(
self,
request: OpenAICreateSpeechRequest,
base64_encode: bool = False,
- request_id: str | None = None,
) -> tuple[bytes | str, str]:
- request_id, generator, _ = await self._prepare_speech_generation(request, request_id=request_id)
+ request_id, generator, _ = await self._prepare_speech_generation(request)
final_output: OmniRequestOutput | None = None
async for res in generator:
@@ -1884,26 +1417,13 @@ async def _create_diffusion_speech(
from vllm_omni.outputs import OmniRequestOutput
try:
- if not request.input or not request.input.strip():
- raise ValueError("Input text cannot be empty")
-
request_id = f"speech-{random_uuid()}"
- prompt: dict[str, Any] = {"input": request.input}
- if request.ref_audio:
- wav, sr = await self._resolve_ref_audio(request.ref_audio)
- prompt["ref_audio"] = (np.asarray(wav, dtype=np.float32), sr)
- if request.ref_text:
- prompt["ref_text"] = request.ref_text
- if request.language:
- prompt["lang"] = request.language
- if request.instructions:
- prompt["instruct"] = request.instructions
+ prompt = request.input
logger.info(
- "Diffusion TTS speech request %s: text=%r, voice_clone=%s",
+ "Diffusion TTS speech request %s: text=%r",
request_id,
- request.input[:50] + "..." if len(request.input) > 50 else request.input,
- "ref_audio" in prompt,
+ prompt[:50] + "..." if len(prompt) > 50 else prompt,
)
generator = self._diffusion_engine.generate(
@@ -1950,8 +1470,6 @@ async def _create_diffusion_speech(
except asyncio.CancelledError:
return self._diffusion_error_response("Client disconnected")
- except (EngineGenerateError, EngineDeadError):
- raise # Propagate to the global Omni exception handler
except ValueError as e:
return self._diffusion_error_response(str(e))
except Exception as e:
@@ -1997,12 +1515,6 @@ async def create_speech(
logger.error("Error with model %s", error_check_ret)
return error_check_ret
- request_id = f"speech-{random_uuid()}"
- if raw_request:
- raw_request.state.request_metadata = RequestResponseMetadata(
- request_id=request_id,
- )
-
try:
if request.stream:
# Determine response format and media type for streaming
@@ -2023,24 +1535,17 @@ async def create_speech(
)
media_type = "audio/wav" if response_format == "wav" else "audio/pcm"
- _, generator, _ = await self._prepare_speech_generation(request, request_id=request_id)
+ request_id, generator, _ = await self._prepare_speech_generation(request)
return StreamingResponse(
- self._generate_audio_chunks(
- generator,
- request_id,
- response_format,
- raw_request=raw_request,
- ),
+ self._generate_audio_chunks(generator, request_id, response_format),
media_type=media_type,
)
- audio_bytes, media_type = await self._generate_audio_bytes(request, request_id=request_id)
+ audio_bytes, media_type = await self._generate_audio_bytes(request)
return Response(content=audio_bytes, media_type=media_type)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
- except (EngineGenerateError, EngineDeadError):
- raise # Propagate to the global Omni exception handler
except ValueError as e:
return self.create_error_response(e)
except Exception as e:
diff --git a/vllm_omni/entrypoints/openai/serving_video.py b/vllm_omni/entrypoints/openai/serving_video.py
index a4be330eb47..2987c81fba7 100644
--- a/vllm_omni/entrypoints/openai/serving_video.py
+++ b/vllm_omni/entrypoints/openai/serving_video.py
@@ -3,7 +3,6 @@
from __future__ import annotations
-import copy
import time
from dataclasses import dataclass
from http import HTTPStatus
@@ -34,18 +33,6 @@ class ReferenceImage:
data: Image.Image
-@dataclass
-class VideoGenerationArtifacts:
- """Normalized outputs and profiler metadata extracted from one request."""
-
- videos: list[Any]
- audios: list[Any | None]
- audio_sample_rate: int
- output_fps: int
- stage_durations: dict[str, float]
- peak_memory_mb: float
-
-
class OmniOpenAIServingVideo:
"""OpenAI-style video generation handler for omni diffusion models."""
@@ -90,13 +77,17 @@ async def _run_and_extract(
reference_id: str,
*,
reference_image: ReferenceImage | None = None,
- ) -> VideoGenerationArtifacts:
- """Run the generation pipeline and extract video/audio/profiler outputs."""
+ ) -> tuple[list[Any], list[Any | None], int, int]:
+ """Run the generation pipeline and extract video/audio outputs.
+
+ Returns:
+ Tuple of (videos, audios, audio_sample_rate, output_fps).
+ """
prompt: OmniTextPrompt = OmniTextPrompt(prompt=request.prompt)
if request.negative_prompt is not None:
prompt["negative_prompt"] = request.negative_prompt
- gen_params = self._resolve_default_sampling_params()
+ gen_params = OmniDiffusionSamplingParams()
input_image = None if reference_image is None else reference_image.data
vp = request.resolve_video_params()
@@ -114,27 +105,18 @@ async def _run_and_extract(
if vp.fps is not None:
gen_params.fps = vp.fps
gen_params.frame_rate = float(vp.fps)
- provided_fields = request.model_fields_set
- if "enable_frame_interpolation" in provided_fields:
- gen_params.enable_frame_interpolation = request.enable_frame_interpolation
- if "frame_interpolation_exp" in provided_fields:
- gen_params.frame_interpolation_exp = request.frame_interpolation_exp
- if "frame_interpolation_scale" in provided_fields:
- gen_params.frame_interpolation_scale = request.frame_interpolation_scale
- if "frame_interpolation_model_path" in provided_fields:
- gen_params.frame_interpolation_model_path = request.frame_interpolation_model_path
-
- if "num_inference_steps" in provided_fields and request.num_inference_steps is not None:
+
+ if request.num_inference_steps is not None:
gen_params.num_inference_steps = request.num_inference_steps
- if "guidance_scale" in provided_fields and request.guidance_scale is not None:
+ if request.guidance_scale is not None:
gen_params.guidance_scale = request.guidance_scale
- if "guidance_scale_2" in provided_fields and request.guidance_scale_2 is not None:
+ if request.guidance_scale_2 is not None:
gen_params.guidance_scale_2 = request.guidance_scale_2
- if "true_cfg_scale" in provided_fields and request.true_cfg_scale is not None:
+ if request.true_cfg_scale is not None:
gen_params.true_cfg_scale = request.true_cfg_scale
- if "seed" in provided_fields and request.seed is not None:
+ if request.seed is not None:
gen_params.seed = request.seed
- if "boundary_ratio" in provided_fields and request.boundary_ratio is not None:
+ if request.boundary_ratio is not None:
gen_params.boundary_ratio = request.boundary_ratio
logger.info(
@@ -142,7 +124,7 @@ async def _run_and_extract(
request.boundary_ratio,
gen_params.boundary_ratio,
)
- if "flow_shift" in provided_fields and request.flow_shift is not None:
+ if request.flow_shift is not None:
gen_params.extra_args["flow_shift"] = request.flow_shift
# Apply model-specific extra parameters
@@ -170,15 +152,8 @@ async def _run_and_extract(
videos = self._extract_video_outputs(result)
audios = self._extract_audio_outputs(result, expected_count=len(videos))
audio_sample_rate = self._resolve_audio_sample_rate(result)
- output_fps = (vp.fps or self._resolve_fps(result) or 24) * self._resolve_video_fps_multiplier(result)
- return VideoGenerationArtifacts(
- videos=videos,
- audios=audios,
- audio_sample_rate=audio_sample_rate,
- output_fps=output_fps,
- stage_durations=self._extract_stage_durations(result),
- peak_memory_mb=self._extract_peak_memory_mb(result),
- )
+ output_fps = vp.fps or 24
+ return videos, audios, audio_sample_rate, output_fps
async def generate_videos(
self,
@@ -187,38 +162,28 @@ async def generate_videos(
*,
reference_image: ReferenceImage | None = None,
) -> VideoGenerationResponse:
- artifacts = await self._run_and_extract(request, reference_id, reference_image=reference_image)
-
- video_codec_options = {"preset": "ultrafast", "threads": "0"}
- if request.extra_params is not None and isinstance(request.extra_params, dict):
- if "video_codec_options" in request.extra_params:
- video_codec_options = request.extra_params["video_codec_options"]
-
+ videos, audios, audio_sample_rate, output_fps = await self._run_and_extract(
+ request, reference_id, reference_image=reference_image
+ )
_t_encode_start = time.perf_counter()
video_data = [
VideoData(
b64_json=(
- encode_video_base64(video, fps=artifacts.output_fps, video_codec_options=video_codec_options)
- if artifacts.audios[idx] is None
+ encode_video_base64(video, fps=output_fps)
+ if audios[idx] is None
else encode_video_base64(
video,
- fps=artifacts.output_fps,
- audio=artifacts.audios[idx],
- audio_sample_rate=artifacts.audio_sample_rate,
- video_codec_options=video_codec_options,
+ fps=output_fps,
+ audio=audios[idx],
+ audio_sample_rate=audio_sample_rate,
)
)
)
- for idx, video in enumerate(artifacts.videos)
+ for idx, video in enumerate(videos)
]
_t_encode_ms = (time.perf_counter() - _t_encode_start) * 1000
logger.info("Video response encoding (MP4+base64): %.2f ms", _t_encode_ms)
- return VideoGenerationResponse(
- created=int(time.time()),
- data=video_data,
- stage_durations=artifacts.stage_durations,
- peak_memory_mb=artifacts.peak_memory_mb,
- )
+ return VideoGenerationResponse(created=int(time.time()), data=video_data)
async def generate_video_bytes(
self,
@@ -226,59 +191,25 @@ async def generate_video_bytes(
reference_id: str,
*,
reference_image: ReferenceImage | None = None,
- ) -> tuple[bytes, dict[str, float], float]:
+ ) -> bytes:
"""Generate a video and return raw MP4 bytes, bypassing base64 encoding."""
- artifacts = await self._run_and_extract(request, reference_id, reference_image=reference_image)
- if len(artifacts.videos) > 1:
+ videos, audios, audio_sample_rate, output_fps = await self._run_and_extract(
+ request, reference_id, reference_image=reference_image
+ )
+ if len(videos) > 1:
logger.warning(
- "Video request %s generated %d outputs; returning only the first.",
- reference_id,
- len(artifacts.videos),
+ "Video request %s generated %d outputs; returning only the first.", reference_id, len(videos)
)
- audio = artifacts.audios[0]
-
- video_codec_options = {"preset": "ultrafast", "threads": "0"}
- if request.extra_params is not None and isinstance(request.extra_params, dict):
- if "video_codec_options" in request.extra_params:
- video_codec_options = request.extra_params["video_codec_options"]
-
+ audio = audios[0]
_t_encode_start = time.perf_counter()
video_bytes = _encode_video_bytes(
- artifacts.videos[0],
- fps=artifacts.output_fps,
- **({"audio": audio, "audio_sample_rate": artifacts.audio_sample_rate} if audio is not None else {}),
- video_codec_options=video_codec_options,
+ videos[0],
+ fps=output_fps,
+ **({"audio": audio, "audio_sample_rate": audio_sample_rate} if audio is not None else {}),
)
_t_encode_ms = (time.perf_counter() - _t_encode_start) * 1000
logger.info("Video response encoding (MP4 bytes): %.2f ms", _t_encode_ms)
- return video_bytes, artifacts.stage_durations, artifacts.peak_memory_mb
-
- @staticmethod
- def _resolve_video_fps_multiplier(result: Any) -> int:
- custom_output = getattr(result, "custom_output", None)
- if isinstance(custom_output, dict):
- multiplier = custom_output.get("video_fps_multiplier")
- if multiplier is not None:
- return int(multiplier)
- request_output = getattr(result, "request_output", None)
- if request_output is not None:
- custom_output = getattr(request_output, "custom_output", None)
- if isinstance(custom_output, dict):
- multiplier = custom_output.get("video_fps_multiplier")
- if multiplier is not None:
- return int(multiplier)
- return 1
-
- def _resolve_default_sampling_params(self) -> OmniDiffusionSamplingParams:
- default_sampling_params_list = getattr(self._engine_client, "default_sampling_params_list", None)
- if default_sampling_params_list:
- for params in default_sampling_params_list:
- if isinstance(params, OmniDiffusionSamplingParams):
- # Requests mutate sampling params in-place, including
- # nested dict fields like extra_args. Deep-copy the stage
- # defaults so one request cannot leak state into another.
- return copy.deepcopy(params)
- return OmniDiffusionSamplingParams()
+ return video_bytes
@staticmethod
def _apply_lora(lora_body: Any, gen_params: OmniDiffusionSamplingParams) -> None:
@@ -434,46 +365,6 @@ def _resolve_audio_sample_rate(self, result: Any) -> int:
return 24000
- @staticmethod
- def _resolve_fps(result: Any) -> int | None:
- """Extract fps from multimodal_output if the model reported it."""
- multimodal_output = getattr(result, "multimodal_output", None)
- if isinstance(multimodal_output, dict):
- fps = multimodal_output.get("fps")
- if fps is not None:
- try:
- fps_val = fps.item() if hasattr(fps, "item") else int(fps)
- if fps_val > 0:
- return fps_val
- except (TypeError, ValueError):
- pass
-
- request_output = getattr(result, "request_output", None)
- if isinstance(request_output, dict):
- mm = request_output.get("multimodal_output") or {}
- if isinstance(mm, dict):
- fps = mm.get("fps")
- if fps is not None:
- try:
- fps_val = fps.item() if hasattr(fps, "item") else int(fps)
- if fps_val > 0:
- return fps_val
- except (TypeError, ValueError):
- pass
- elif hasattr(request_output, "multimodal_output"):
- mm = getattr(request_output, "multimodal_output", None)
- if isinstance(mm, dict):
- fps = mm.get("fps")
- if fps is not None:
- try:
- fps_val = fps.item() if hasattr(fps, "item") else int(fps)
- if fps_val > 0:
- return fps_val
- except (TypeError, ValueError):
- pass
-
- return None
-
@classmethod
def _extract_audio_sample_rate_from_result(cls, result: Any) -> int | None:
multimodal_output = getattr(result, "multimodal_output", None)
@@ -552,16 +443,3 @@ def _coerce_audio_sample_rate(value: Any) -> int | None:
return None
return sample_rate if sample_rate > 0 else None
-
- @staticmethod
- def _extract_stage_durations(result: Any) -> dict[str, float]:
- stage_durations = getattr(result, "stage_durations", None)
- return stage_durations if isinstance(stage_durations, dict) else {}
-
- @staticmethod
- def _extract_peak_memory_mb(result: Any) -> float:
- peak_memory_mb = getattr(result, "peak_memory_mb", 0.0)
- try:
- return float(peak_memory_mb or 0.0)
- except (TypeError, ValueError):
- return 0.0
diff --git a/vllm_omni/entrypoints/openai/serving_video_stream.py b/vllm_omni/entrypoints/openai/serving_video_stream.py
deleted file mode 100644
index a76b241c55b..00000000000
--- a/vllm_omni/entrypoints/openai/serving_video_stream.py
+++ /dev/null
@@ -1,971 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""WebSocket handler for streaming video input understanding.
-
-Accepts video frames incrementally via WebSocket, buffers them, and
-generates text + optional audio responses using the existing Qwen3-Omni
-multi-stage pipeline (thinker -> talker -> code2wav).
-
-Protocol:
- Client -> Server:
- {"type": "session.config", ...} # Session config (sent once)
- {"type": "video.frame", "data": "..."} # base64 JPEG/PNG frame
- {"type": "audio.chunk", "data": "..."} # base64 PCM16 16kHz mono
- {"type": "video.query", "text": "..."} # Submit query about buffered frames
- {"type": "video.done"} # End of session
-
- Server -> Client:
- {"type": "response.start"}
- {"type": "response.text.delta", "delta": "..."}
- {"type": "response.text.done", "text": "..."}
- {"type": "response.audio.delta", "data": "...", "format": "wav"}
- {"type": "response.audio.done"}
- {"type": "session.done"}
- {"type": "error", "message": "..."}
-"""
-
-import asyncio
-import base64
-import hashlib
-import io
-import json
-import time as _time
-import uuid
-import wave
-from typing import Any
-
-import torch
-from fastapi import WebSocket, WebSocketDisconnect
-from PIL import Image
-from pydantic import BaseModel, Field, ValidationError
-from vllm.logger import init_logger
-
-from vllm_omni.entrypoints.openai import video_stream_envs
-from vllm_omni.entrypoints.openai.video_frame_filter import FrameSimilarityFilter
-from vllm_omni.entrypoints.openai.video_stream_context import (
- text_only_message,
-)
-from vllm_omni.outputs import OmniRequestOutput
-
-logger = init_logger(__name__)
-
-_DEFAULT_IDLE_TIMEOUT = 60.0
-_DEFAULT_CONFIG_TIMEOUT = 10.0
-_MAX_FRAME_SIZE = 10 * 1024 * 1024 # 10MB per frame
-_MAX_BUFFER_FRAMES = 64
-_MAX_AUDIO_BUFFER_BYTES = 4 * 1024 * 1024
-_MAX_MSG_QUEUE = 200
-_CODEC_FRAME_SAMPLES = 1920 # CausalConv leading-edge artifact length
-_BAD_FRAME = object()
-
-
-def _decode_frame_bytes(raw_bytes: bytes) -> Any:
- return Image.open(io.BytesIO(raw_bytes)).convert("RGB")
-
-
-class StreamingVideoSessionConfig(BaseModel):
- """Configuration sent as the first WebSocket message."""
-
- model: str | None = None
- modalities: list[str] = Field(
- default_factory=lambda: ["text", "audio"],
- description="Output modalities: 'text', 'audio', or both.",
- )
- num_frames: int = Field(
- default=4,
- ge=1,
- le=128,
- description="Max frames to sample from buffer for the model.",
- )
- max_frames: int = Field(
- default=50,
- ge=1,
- le=256,
- description="Max frames to keep in the buffer.",
- )
- system_prompt: str | None = Field(
- default=None,
- description="Custom system prompt.",
- )
- use_audio_in_video: bool = Field(
- default=True,
- description="Interleave audio chunks with video frames when audio input is present.",
- )
- sampling_params_list: list[dict[str, Any]] | None = Field(
- default=None,
- description="Per-stage sampling params [thinker, talker, code2wav].",
- )
- enable_frame_filter: bool = Field(
- default=True,
- description="EVS pixel-similarity pre-filter to drop near-duplicate frames.",
- )
- frame_filter_threshold: float = Field(
- default=0.95,
- ge=0.0,
- le=1.0,
- description="EVS similarity threshold (higher = keep more frames).",
- )
-
-
-class OmniStreamingVideoHandler:
- """Handles WebSocket sessions for streaming video input.
-
- Supports:
- - Concurrent frame reception during query processing (reader/processor split)
- - PCM audio input (``audio.chunk``)
- - Async-chunk incremental audio output via ``engine_client.generate()``
- - Multi-turn conversation history
- - Soft interrupt (new query cancels current generation)
- """
-
- def __init__(
- self,
- chat_service: Any,
- idle_timeout: float = _DEFAULT_IDLE_TIMEOUT,
- config_timeout: float = _DEFAULT_CONFIG_TIMEOUT,
- engine_client: Any | None = None,
- ) -> None:
- self._chat_service = chat_service
- self._idle_timeout = idle_timeout
- self._config_timeout = config_timeout
- self._engine_client = engine_client
-
- async def handle_session(self, websocket: WebSocket) -> None:
- """Main session loop for a single WebSocket connection."""
- await websocket.accept()
-
- try:
- config = await self._receive_config(websocket)
- if config is None:
- return
-
- frame_buffer: list[str] = [] # base64-encoded JPEG frames
- # Per-frame PIL cache + uuid for mm_hash reuse. Aligned with frame_buffer by index.
- frame_pil_cache: dict[str, tuple[Any, str] | object] = {} # b64 -> (PIL.Image, uuid) or _BAD_FRAME
- frame_filter = (
- FrameSimilarityFilter(threshold=config.frame_filter_threshold) if config.enable_frame_filter else None
- )
- audio_buffer = bytearray() # raw PCM16 16kHz mono
- message_history: list[dict[str, Any]] = []
- active_request_id: str | None = None
- prev_request_id: str | None = None # abort target iff prev was interrupted
- prev_was_interrupted: bool = False
- interrupt_event = asyncio.Event()
- prewarm_tasks: set[asyncio.Task[Any]] = set()
- query_task: asyncio.Task[Any] | None = None
-
- msg_queue: asyncio.Queue[dict[str, Any] | None] = asyncio.Queue(maxsize=_MAX_MSG_QUEUE)
-
- async def _reader() -> None:
- """Receive WebSocket messages and enqueue them."""
- try:
- while True:
- try:
- raw = await asyncio.wait_for(
- websocket.receive_text(),
- timeout=self._idle_timeout,
- )
- except asyncio.TimeoutError:
- await self._send_error(websocket, "Idle timeout")
- await msg_queue.put(None)
- return
-
- try:
- msg = json.loads(raw)
- except json.JSONDecodeError:
- await self._send_error(websocket, "Invalid JSON")
- continue
-
- if not isinstance(msg, dict):
- await self._send_error(websocket, "Messages must be JSON objects")
- continue
-
- msg_type = str(msg.get("type", ""))
- if msg_type.startswith("_internal."):
- await self._send_error(websocket, f"Unknown type: {msg_type}")
- continue
-
- await msg_queue.put(msg)
- if msg.get("type") == "video.done":
- return
- except WebSocketDisconnect:
- await msg_queue.put(None)
- except Exception:
- await msg_queue.put(None)
- raise
-
- async def _cancel_active_query(*, abort_now: bool = False) -> None:
- """Signal soft interrupt for the active query."""
- nonlocal active_request_id, prev_was_interrupted, query_task
- if active_request_id is not None:
- interrupt_event.set()
- prev_was_interrupted = True
- logger.info("Interrupt signaled for %s", active_request_id)
- if abort_now and self._engine_client:
- try:
- await self._engine_client.abort(active_request_id)
- except Exception:
- logger.debug("Abort failed for %s", active_request_id, exc_info=True)
- if query_task is not None and not query_task.done():
- query_task.cancel()
- await asyncio.gather(query_task, return_exceptions=True)
- query_task = None
-
- async def _processor() -> None:
- """Process enqueued messages."""
- nonlocal active_request_id, prev_request_id, prev_was_interrupted, query_task
-
- while True:
- msg = await msg_queue.get()
- if msg is None:
- await _cancel_active_query(abort_now=True)
- return
-
- msg_type = msg.get("type")
-
- if msg_type == "_internal.frame_decode_failed":
- frame_data = msg.get("b64", "")
- removed = frame_data in frame_buffer
- if removed:
- frame_buffer[:] = [f for f in frame_buffer if f != frame_data]
- if frame_pil_cache.get(frame_data) is _BAD_FRAME:
- frame_pil_cache.pop(frame_data, None)
- if removed:
- await self._send_error(websocket, "Frame decode failed")
-
- elif msg_type == "video.frame":
- frame_data = msg.get("data", "")
- if not frame_data:
- continue
- if len(frame_data) > _MAX_FRAME_SIZE:
- await self._send_error(websocket, "Frame too large")
- continue
- try:
- raw_bytes = base64.b64decode(frame_data, validate=True)
- except Exception:
- await self._send_error(websocket, "Invalid image data")
- continue
- if frame_filter is not None:
- try:
- if not frame_filter.should_retain(raw_bytes):
- continue
- except Exception:
- await self._send_error(websocket, "Invalid image data")
- continue
- max_buf = config.max_frames
- if len(frame_buffer) >= max_buf:
- dropped = frame_buffer.pop(0)
- frame_pil_cache.pop(dropped, None)
- frame_buffer.append(frame_data)
- # Prewarm: decode PIL off the event loop so query-time chat_template
- # can skip base64+Image.open. uuid=md5 lets mm_cache dedupe identical frames.
- if frame_data not in frame_pil_cache:
- mm_uuid = hashlib.md5(raw_bytes, usedforsecurity=False).hexdigest()
-
- async def _prewarm(b64: str, b: bytes, u: str) -> None:
- try:
- pil = await asyncio.to_thread(_decode_frame_bytes, b)
- frame_pil_cache[b64] = (pil, u)
- except Exception:
- frame_pil_cache[b64] = _BAD_FRAME
- logger.warning("prewarm decode failed for frame (len=%d)", len(b))
- try:
- msg_queue.put_nowait({"type": "_internal.frame_decode_failed", "b64": b64})
- except asyncio.QueueFull:
- logger.warning(
- "frame decode failure event dropped because message queue is full"
- )
-
- task = asyncio.create_task(_prewarm(frame_data, raw_bytes, mm_uuid))
- prewarm_tasks.add(task)
- task.add_done_callback(prewarm_tasks.discard)
-
- elif msg_type == "audio.chunk":
- data_b64 = msg.get("data", "")
- try:
- pcm_bytes = base64.b64decode(data_b64)
- except Exception:
- continue
- if len(audio_buffer) + len(pcm_bytes) > _MAX_AUDIO_BUFFER_BYTES:
- await self._send_error(websocket, "Audio buffer overflow")
- audio_buffer.clear()
- continue
- audio_buffer.extend(pcm_bytes)
-
- elif msg_type == "video.query":
- await _cancel_active_query()
-
- query_text = msg.get("text", "")
- audio_data_b64 = msg.get("audio_data")
- if audio_data_b64:
- try:
- decoded = base64.b64decode(audio_data_b64)
- if len(audio_buffer) + len(decoded) <= _MAX_AUDIO_BUFFER_BYTES:
- audio_buffer.extend(decoded)
- else:
- await self._send_error(websocket, "Audio buffer overflow")
- audio_buffer.clear()
- except Exception:
- pass
-
- if not frame_buffer:
- await self._send_error(websocket, "No frames buffered")
- continue
-
- # Abort only if the previous turn was interrupted mid-flight.
- # A naturally-finished request is already released by the scheduler;
- # aborting it again can race with stage-1/2 tear-down and has been
- # observed to crash flash_attn with a mixed prefill+decode batch
- # (scheduler_metadata shape mismatch) under longer sessions.
- if prev_was_interrupted and prev_request_id and self._engine_client:
- try:
- await self._engine_client.abort(prev_request_id)
- except Exception:
- pass
- await asyncio.sleep(0.1)
- prev_was_interrupted = False
-
- request_id = f"video-{uuid.uuid4().hex[:12]}"
- active_request_id = request_id
- interrupt_event.clear()
- query_frames = list(frame_buffer)
- query_audio_buffer = bytearray(audio_buffer)
- audio_buffer.clear()
- query_prewarmed_frames = dict(frame_pil_cache)
-
- async def _run_query() -> None:
- nonlocal active_request_id, prev_request_id
- try:
- await self._process_query(
- websocket,
- config,
- query_frames,
- query_audio_buffer,
- message_history,
- query_text,
- request_id,
- interrupt_event,
- query_prewarmed_frames,
- )
- finally:
- if active_request_id == request_id:
- prev_request_id = request_id
- active_request_id = None
-
- query_task = asyncio.create_task(_run_query())
-
- elif msg_type == "video.done":
- if query_task is not None and not query_task.done():
- await asyncio.gather(query_task, return_exceptions=True)
- query_task = None
- await websocket.send_json({"type": "session.done"})
- return
-
- elif msg_type == "ping":
- try:
- await websocket.send_json({"type": "pong"})
- except Exception:
- pass
-
- else:
- await self._send_error(websocket, f"Unknown type: {msg_type}")
-
- reader_task = asyncio.create_task(_reader())
- try:
- await _processor()
- finally:
- reader_task.cancel()
- try:
- await reader_task
- except (asyncio.CancelledError, Exception):
- pass
- for t in list(prewarm_tasks):
- t.cancel()
- if prewarm_tasks:
- await asyncio.gather(*prewarm_tasks, return_exceptions=True)
- if query_task is not None and not query_task.done():
- await _cancel_active_query(abort_now=True)
-
- except WebSocketDisconnect:
- logger.info("Streaming video: client disconnected")
- except Exception as e:
- logger.exception("Streaming video session error: %s", e)
- try:
- await self._send_error(websocket, f"Internal error: {e}")
- except Exception:
- pass
-
- async def _receive_config(self, websocket: WebSocket) -> StreamingVideoSessionConfig | None:
- """Wait for and validate the session.config message."""
- try:
- raw = await asyncio.wait_for(
- websocket.receive_text(),
- timeout=self._config_timeout,
- )
- except asyncio.TimeoutError:
- await self._send_error(websocket, "Timeout waiting for session.config")
- return None
-
- try:
- msg = json.loads(raw)
- except json.JSONDecodeError:
- await self._send_error(websocket, "Invalid JSON in session.config")
- return None
-
- if not isinstance(msg, dict) or msg.get("type") != "session.config":
- await self._send_error(
- websocket,
- f"Expected session.config, got: {msg.get('type') if isinstance(msg, dict) else type(msg).__name__}",
- )
- return None
-
- config_data = {k: v for k, v in msg.items() if k != "type"}
- alias_map = {
- "num_sample_frames": "num_frames",
- "evs_enabled": "enable_frame_filter",
- "evs_threshold": "frame_filter_threshold",
- }
- for old_key, new_key in alias_map.items():
- if old_key in config_data and new_key not in config_data:
- config_data[new_key] = config_data[old_key]
-
- try:
- config = StreamingVideoSessionConfig(**config_data)
- except ValidationError as e:
- await self._send_error(websocket, f"Invalid session config: {e}")
- return None
-
- return config
-
- async def _process_query(
- self,
- websocket: WebSocket,
- config: StreamingVideoSessionConfig,
- frame_buffer: list[str],
- audio_buffer: bytearray,
- message_history: list[dict[str, Any]],
- query_text: str,
- request_id: str,
- interrupt_event: asyncio.Event,
- prewarmed_frames: dict[str, tuple[Any, str]],
- ) -> None:
- """Build prompt, run inference, stream text + audio response."""
-
- if self._engine_client is None:
- await self._send_error(websocket, "Streaming video requires an engine client")
- return
-
- await self._process_query_engine(
- websocket,
- config,
- frame_buffer,
- audio_buffer,
- message_history,
- query_text,
- request_id,
- interrupt_event,
- prewarmed_frames,
- )
-
- # ------------------------------------------------------------------
- # Engine-client path (async_chunk audio streaming)
- # ------------------------------------------------------------------
-
- async def _process_query_engine(
- self,
- websocket: WebSocket,
- config: StreamingVideoSessionConfig,
- frame_buffer: list[str],
- audio_buffer: bytearray,
- message_history: list[dict[str, Any]],
- query_text: str,
- request_id: str,
- interrupt_event: asyncio.Event,
- prewarmed_frames: dict[str, tuple[Any, str]],
- ) -> None:
- """Direct engine_client.generate() path for async_chunk audio."""
- from vllm.entrypoints.openai.chat_completion.protocol import (
- ChatCompletionRequest,
- )
-
- messages, user_message = self._build_messages(
- config,
- frame_buffer,
- audio_buffer,
- message_history,
- query_text,
- prewarmed_frames,
- )
-
- request_kwargs: dict[str, Any] = {
- "model": config.model or "default",
- "messages": messages,
- "stream": True,
- "modalities": config.modalities,
- "add_generation_prompt": True,
- "continue_final_message": False,
- "add_special_tokens": False,
- }
- if config.use_audio_in_video and len(audio_buffer) > 0:
- request_kwargs["mm_processor_kwargs"] = {
- "use_audio_in_video": True,
- }
- if config.sampling_params_list:
- request_kwargs["sampling_params_list"] = config.sampling_params_list
-
- try:
- chat_request = ChatCompletionRequest(**request_kwargs)
- except Exception as e:
- await self._send_error(websocket, f"Failed to build request: {e}")
- return
-
- try:
- engine_prompt = await self._preprocess_to_engine_prompt(chat_request)
- except Exception as e:
- await self._send_error(websocket, f"Preprocess failed: {e}")
- return
-
- await websocket.send_json({"type": "response.start"})
- text_parts: list[str] = []
- text_done_sent = False
- audio_chunk_count = 0
- # Number of per-step tensors in OmniRequestOutput.audio_data already
- # drained. Used by the fast path to skip already-emitted history.
- audio_chunks_drained = 0
- previous_text = ""
- interrupted = False
- t_start = _time.monotonic()
- t_first_text = None
- t_first_audio = None
-
- # Wire-level async-chunk switch. "off" means
- # buffer all deltas server-side and flush once at the end; the engine
- # pipeline still overlaps internally.
- async_chunk_mode = video_stream_envs.VLLM_VIDEO_ASYNC_CHUNK
- streaming = async_chunk_mode == "on"
- audio_tail_tensors: list[Any] = []
-
- try:
- result_gen = self._engine_client.generate(
- prompt=engine_prompt,
- request_id=request_id,
- output_modalities=config.modalities,
- )
-
- async for output in result_gen:
- # Soft interrupt: drain without sending
- if interrupt_event.is_set():
- if not interrupted:
- logger.info("Generation interrupted — draining")
- interrupted = True
- continue
-
- if not isinstance(output, OmniRequestOutput):
- continue
-
- out_type = getattr(output, "final_output_type", "text")
-
- if out_type == "audio":
- if streaming and not text_done_sent:
- full_text = "".join(text_parts)
- await websocket.send_json({"type": "response.text.done", "text": full_text})
- text_done_sent = True
-
- if t_first_audio is None:
- t_first_audio = _time.monotonic()
- audio_chunk_count += 1
- if streaming:
- b64, audio_chunks_drained = self._extract_audio_delta_b64(
- output,
- audio_chunks_drained,
- )
- if b64:
- await websocket.send_json(
- {
- "type": "response.audio.delta",
- "data": b64,
- "format": "wav",
- }
- )
- else:
- audio_data = self._get_audio_data(output)
- if audio_data is not None:
- if isinstance(audio_data, list):
- audio_tail_tensors = list(audio_data)
- else:
- audio_tail_tensors = [audio_data]
- else:
- delta_text, previous_text = self._extract_text_delta(
- output,
- previous_text,
- )
- if delta_text:
- if t_first_text is None:
- t_first_text = _time.monotonic()
- text_parts.append(delta_text)
- if streaming:
- await websocket.send_json({"type": "response.text.delta", "delta": delta_text})
-
- if not text_done_sent:
- full_text = "".join(text_parts)
- await websocket.send_json({"type": "response.text.done", "text": full_text})
- text_done_sent = True
-
- if not streaming and audio_tail_tensors:
- try:
- coalesced = (
- audio_tail_tensors[0] if len(audio_tail_tensors) == 1 else torch.cat(audio_tail_tensors, dim=-1)
- )
- tail_np = self._tensor_to_1d_np(coalesced)
- b64, _ = self._encode_tail(
- tail_np,
- 0,
- new_drained=len(audio_tail_tensors),
- is_first=True,
- )
- if b64:
- await websocket.send_json(
- {
- "type": "response.audio.delta",
- "data": b64,
- "format": "wav",
- }
- )
- except Exception:
- logger.exception("Failed to coalesce off-path audio")
-
- if audio_chunk_count > 0:
- await websocket.send_json({"type": "response.audio.done"})
-
- response_text = "".join(text_parts)
- message_history.append(user_message)
- message_history.append({"role": "assistant", "content": response_text})
-
- t_end = _time.monotonic()
- logger.info(
- "[TIMING] mode=%s total=%.2fs first_text=%.2fs first_audio=%.2fs audio_chunks=%d",
- async_chunk_mode,
- t_end - t_start,
- (t_first_text - t_start) if t_first_text else -1,
- (t_first_audio - t_start) if t_first_audio else -1,
- audio_chunk_count,
- )
-
- except Exception:
- logger.exception("Engine query failed")
- await self._send_error(websocket, "Query processing failed")
-
- if not text_done_sent:
- full_text = "".join(text_parts)
- await websocket.send_json({"type": "response.text.done", "text": full_text})
-
- # ------------------------------------------------------------------
- # Message building
- # ------------------------------------------------------------------
-
- def _build_messages(
- self,
- config: StreamingVideoSessionConfig,
- frame_buffer: list[str],
- audio_buffer: bytearray,
- message_history: list[dict[str, Any]],
- query_text: str,
- prewarmed_frames: dict[str, tuple[Any, str]],
- ) -> tuple[list[dict[str, Any]], dict[str, Any]]:
- """Build OpenAI-style messages list and the current user message.
-
- Returns (messages, user_message).
- """
- # Stride sampling (index 0 anchor, last slot = newest). Covers full buffer + stable mm_hash.
- n_buf = len(frame_buffer)
- if n_buf <= config.num_frames:
- frames = list(frame_buffer)
- else:
- stride = max(1, n_buf // config.num_frames)
- idx = [i * stride for i in range(config.num_frames - 1)] + [n_buf - 1]
- frames = [frame_buffer[i] for i in idx]
-
- # Prefer prewarmed PIL + uuid so mm_cache can dedupe by hash.
- prewarmed = prewarmed_frames or {}
- user_content: list[dict] = []
- for frame_b64 in frames:
- cached = prewarmed.get(frame_b64)
- if cached is _BAD_FRAME:
- continue
- if cached is not None:
- pil, pil_uuid = cached
- user_content.append(
- {
- "type": "image_pil",
- "image_pil": pil,
- "uuid": pil_uuid,
- }
- )
- else:
- user_content.append(
- {
- "type": "image_url",
- "image_url": {"url": f"data:image/jpeg;base64,{frame_b64}"},
- }
- )
-
- if len(audio_buffer) > 0:
- wav_b64 = self._pcm_to_wav_b64(bytes(audio_buffer))
- user_content.append(
- {
- "type": "input_audio",
- "input_audio": {
- "data": wav_b64,
- "format": "wav",
- },
- }
- )
-
- if query_text:
- user_content.append({"type": "text", "text": query_text})
-
- user_message: dict[str, Any] = {"role": "user", "content": user_content}
-
- messages: list[dict[str, Any]] = []
- if config.system_prompt:
- messages.append({"role": "system", "content": config.system_prompt})
-
- # Add text-only history (strip images/audio from old turns).
- # Keep only the last turn (2 messages) to keep prompt short
- # enough for single-step mm_encoder scheduling. When prompt
- # exceeds ~50 tokens, the V1 scheduler splits mm_encoder and
- # text prefill, causing incomplete thinker embeddings and
- # garbled audio.
- recent_history = message_history[-2:] if len(message_history) > 2 else message_history
- for hist_msg in recent_history:
- messages.append(self._text_only_message(hist_msg))
-
- messages.append(user_message)
-
- return messages, user_message
-
- # ------------------------------------------------------------------
- # Audio helpers
- # ------------------------------------------------------------------
-
- @staticmethod
- def _pcm_to_wav_b64(pcm_data: bytes, sample_rate: int = 16000) -> str:
- """Wrap raw PCM16 mono in a WAV container and return base64."""
- buf = io.BytesIO()
- with wave.open(buf, "wb") as wf:
- wf.setnchannels(1)
- wf.setsampwidth(2)
- wf.setframerate(sample_rate)
- wf.writeframes(pcm_data)
- return base64.b64encode(buf.getvalue()).decode()
-
- @classmethod
- def _extract_audio_delta_b64(
- cls,
- result: OmniRequestOutput,
- chunks_drained: int,
- ) -> tuple[str | None, int]:
- """Return (base64 WAV of new samples, updated chunks_drained).
-
- `chunks_drained` is the number of per-step tensors in
- ``audio_data`` that have already been emitted. Each engine step appends
- one tensor, so new samples are ``audio_data[chunks_drained:]`` — no
- matter how many steps accumulated between reads (handles backpressure
- cleanly, unlike a simple ``audio_data[-1]``).
-
- Two paths, selected at runtime by ``VLLM_VIDEO_AUDIO_DELTA_MODE``:
- * fast — only D2H the new tail. Per-call cost ∝ new chunks.
- * slow — full cat + D2H each call. Per-call cost ∝ total history.
- Retained for A/B; remove once downstream callers confirm.
- """
- audio_data = cls._get_audio_data(result)
- if audio_data is None:
- return None, chunks_drained
-
- if video_stream_envs.VLLM_VIDEO_AUDIO_DELTA_MODE == "slow":
- return cls._delta_slow(audio_data, chunks_drained)
- return cls._delta_fast(audio_data, chunks_drained)
-
- @staticmethod
- def _get_audio_data(result: OmniRequestOutput):
- """Navigate OmniRequestOutput → multimodal_output['audio']. None on miss."""
- request_output = getattr(result, "request_output", None)
- if request_output is None:
- return None
- outputs = getattr(request_output, "outputs", None)
- if not isinstance(outputs, list) or not outputs:
- return None
- mm_output = getattr(outputs[0], "multimodal_output", None)
- if not isinstance(mm_output, dict):
- return None
- return mm_output.get("audio")
-
- @classmethod
- def _delta_fast(
- cls,
- audio_data,
- chunks_drained: int,
- ) -> tuple[str | None, int]:
- """Emit only tensors appended since the last call."""
- # Single tensor: output_processor hands us one tensor before it becomes a
- # list (see output_processor.py:89). Treat it as chunk #0.
- if not isinstance(audio_data, list):
- if chunks_drained >= 1:
- return None, chunks_drained
- tail_np = cls._tensor_to_1d_np(audio_data)
- return cls._encode_tail(tail_np, chunks_drained, new_drained=1, is_first=True)
-
- n = len(audio_data)
- if n <= chunks_drained:
- return None, chunks_drained
-
- new_chunks = audio_data[chunks_drained:]
- tail = new_chunks[0] if len(new_chunks) == 1 else torch.cat(new_chunks, dim=-1)
- tail_np = cls._tensor_to_1d_np(tail)
- return cls._encode_tail(tail_np, chunks_drained, new_drained=n, is_first=(chunks_drained == 0))
-
- @classmethod
- def _delta_slow(
- cls,
- audio_data,
- chunks_drained: int,
- ) -> tuple[str | None, int]:
- """Pre-fix behaviour: concat everything each call and slice on CPU."""
- if isinstance(audio_data, list):
- if not audio_data:
- return None, chunks_drained
- audio_tensor = torch.cat(audio_data, dim=-1)
- new_drained = len(audio_data)
- else:
- audio_tensor = audio_data
- new_drained = 1
-
- full_np = cls._tensor_to_1d_np(audio_tensor)
- if full_np is None:
- return None, chunks_drained
- # chunks_drained doesn't map directly to sample offset without tracking
- # per-chunk lengths, so we re-derive: replay the tail that corresponds
- # to chunks appended since last call by slicing off the part produced
- # by the already-drained prefix. For slow path this is intentionally
- # wasteful — the point is to reproduce the pre-fix hot loop.
- if chunks_drained == 0:
- tail_np = full_np
- else:
- # Recover prefix length by re-concatenating the already-drained
- # prefix tensors (cost intentionally identical to the baseline
- # implementation this was lifted from).
- if isinstance(audio_data, list) and chunks_drained < len(audio_data):
- prefix_len = sum(int(t.shape[-1]) for t in audio_data[:chunks_drained])
- tail_np = full_np[prefix_len:]
- else:
- tail_np = full_np[0:0]
- return cls._encode_tail(tail_np, chunks_drained, new_drained=new_drained, is_first=(chunks_drained == 0))
-
- @classmethod
- def _encode_tail(
- cls,
- tail_np,
- old_drained: int,
- *,
- new_drained: int,
- is_first: bool,
- ) -> tuple[str | None, int]:
- """Strip the CausalConv leading artifact on first emit, then b64-encode."""
- if tail_np is None or len(tail_np) == 0:
- return None, new_drained
- if is_first and len(tail_np) > _CODEC_FRAME_SAMPLES * 2:
- tail_np = tail_np[_CODEC_FRAME_SAMPLES:]
- if len(tail_np) == 0:
- return None, new_drained
- try:
- return cls._encode_audio_wav_b64(tail_np), new_drained
- except Exception:
- logger.exception("Failed to encode audio delta WAV")
- return None, old_drained
-
- @staticmethod
- def _tensor_to_1d_np(t):
- """Tensor → flat float32 numpy on CPU. None on failure."""
- if t is None or not hasattr(t, "float"):
- return None
- arr = t.float().detach().cpu().numpy()
- if arr.ndim > 1:
- arr = arr.flatten()
- return arr
-
- @staticmethod
- def _encode_audio_wav_b64(audio_np) -> str:
- """Encode numpy float32 audio to base64 WAV (24kHz)."""
- from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin
- from vllm_omni.entrypoints.openai.protocol.audio import CreateAudio
-
- audio_obj = CreateAudio(
- audio_tensor=audio_np,
- sample_rate=24000,
- response_format="wav",
- speed=1.0,
- stream_format="audio",
- base64_encode=True,
- )
- mixin = AudioMixin()
- resp = mixin.create_audio(audio_obj)
- return resp.audio_data
-
- @staticmethod
- def _extract_text_delta(
- result: OmniRequestOutput,
- previous_text: str,
- ) -> tuple[str, str]:
- """Extract incremental text delta from OmniRequestOutput."""
- if result.final_output_type != "text":
- return "", previous_text
-
- request_output = getattr(result, "request_output", None)
- if request_output is None:
- return "", previous_text
-
- outputs = getattr(request_output, "outputs", None)
- if not isinstance(outputs, list) or not outputs:
- return "", previous_text
-
- text = getattr(outputs[0], "text", None)
- if not isinstance(text, str) or not text:
- return "", previous_text
-
- if text.startswith(previous_text):
- return text[len(previous_text) :], text
- return text, text
-
- # ------------------------------------------------------------------
- # Preprocessing
- # ------------------------------------------------------------------
-
- async def _preprocess_to_engine_prompt(self, request) -> Any:
- """Use the chat handler's preprocessing to build an engine prompt."""
- handler = self._chat_service
- renderer = handler.renderer
-
- _conversation, engine_prompts = await handler._preprocess_chat(
- request,
- request.messages,
- default_template=getattr(request, "chat_template", None) or handler.chat_template,
- default_template_content_format=handler.chat_template_content_format,
- renderer=renderer,
- add_generation_prompt=request.add_generation_prompt,
- continue_final_message=request.continue_final_message,
- add_special_tokens=request.add_special_tokens,
- )
- return engine_prompts[0]
-
- # ------------------------------------------------------------------
- # Utilities
- # ------------------------------------------------------------------
-
- _text_only_message = staticmethod(text_only_message)
-
- async def _send_error(self, websocket: WebSocket, message: str) -> None:
- """Send an error message to the client."""
- try:
- await websocket.send_json({"type": "error", "message": message})
- except Exception:
- pass
diff --git a/vllm_omni/entrypoints/openai/utils.py b/vllm_omni/entrypoints/openai/utils.py
index f411526fdb2..84b28ef5b19 100644
--- a/vllm_omni/entrypoints/openai/utils.py
+++ b/vllm_omni/entrypoints/openai/utils.py
@@ -53,33 +53,3 @@ def parse_lora_request(lora_body: Any) -> tuple[LoRARequest | None, float | None
scale = float(lora_scale) if lora_scale is not None else None
return LoRARequest(str(lora_name), int(lora_int_id), str(lora_path)), scale
-
-
-def get_supported_speakers_from_hf_config(hf_config: Any) -> set[str]:
- """Extract supported speaker names from a model hf_config."""
- config = (
- hf_config.get("talker_config") if isinstance(hf_config, dict) else getattr(hf_config, "talker_config", None)
- )
- if config is None:
- return set()
-
- for spk_attr in ("speaker_id", "spk_id"):
- speakers_dict = config.get(spk_attr) if isinstance(config, dict) else getattr(config, spk_attr, None)
- if speakers_dict and isinstance(speakers_dict, dict):
- return {speaker.lower() for speaker in speakers_dict}
- return set()
-
-
-def validate_requested_speaker(speaker: str | None, supported_speakers: set[str]) -> str | None:
- """Normalize and validate an optional speaker value.
-
- Returns the normalized speaker string when provided, otherwise ``None``.
- Raises ``ValueError`` when the speaker is not in the supported list.
- """
- if not isinstance(speaker, str) or not speaker.strip():
- return None
-
- normalized = speaker.lower().strip()
- if supported_speakers and normalized not in supported_speakers:
- raise ValueError(f"Invalid speaker '{speaker}'. Supported: {', '.join(sorted(supported_speakers))}")
- return normalized
diff --git a/vllm_omni/entrypoints/openai/video_api_utils.py b/vllm_omni/entrypoints/openai/video_api_utils.py
index 3fb991225c0..2ed1fd3de6d 100644
--- a/vllm_omni/entrypoints/openai/video_api_utils.py
+++ b/vllm_omni/entrypoints/openai/video_api_utils.py
@@ -8,6 +8,8 @@
import base64
import binascii
+import os
+import tempfile
from io import BytesIO
from typing import Any
@@ -158,7 +160,7 @@ def _normalize_frames(frames: list[Any]) -> list[np.ndarray]:
def _coerce_video_to_frames(video: Any) -> list[np.ndarray]:
- """Convert a video payload into a list of normalized float32 frames."""
+ """Convert a video payload into a list of frames for export_to_video."""
if isinstance(video, torch.Tensor):
video_array = _normalize_video_tensor(video)
return list(video_array)
@@ -184,72 +186,84 @@ def _coerce_video_to_frames(video: Any) -> list[np.ndarray]:
raise ValueError(f"Unsupported video payload type: {type(video)}")
-def _coerce_audio_to_numpy(audio: Any) -> np.ndarray:
- """Convert an audio payload into a float32 numpy array for muxing."""
+def _coerce_audio_to_waveform(audio: Any) -> torch.Tensor:
+ """Convert an audio payload into a 2-channel CPU float tensor for LTX2 export."""
if isinstance(audio, torch.Tensor):
- arr = audio.detach().cpu().float().numpy()
+ waveform = audio.detach().cpu()
elif isinstance(audio, np.ndarray):
- arr = audio
+ waveform = torch.from_numpy(audio)
elif isinstance(audio, list):
- arr = np.array(audio)
+ waveform = torch.tensor(audio)
else:
raise ValueError(f"Unsupported audio payload type: {type(audio)}")
- arr = np.squeeze(arr)
- if arr.ndim == 0:
+ waveform = waveform.squeeze()
+
+ if waveform.ndim == 0:
raise ValueError("Audio payload must contain at least one sample.")
- return arr.astype(np.float32)
+ if waveform.ndim == 1:
+ waveform = waveform.unsqueeze(0)
+ elif waveform.ndim == 2:
+ if waveform.shape[0] in (1, 2):
+ pass
+ elif waveform.shape[1] in (1, 2):
+ waveform = waveform.transpose(0, 1)
+ else:
+ raise ValueError(f"Unsupported audio payload shape: {tuple(waveform.shape)}")
+ else:
+ raise ValueError(f"Unsupported audio payload rank: {waveform.ndim}")
+
+ if waveform.shape[0] == 1:
+ waveform = waveform.repeat(2, 1)
+ elif waveform.shape[0] != 2:
+ raise ValueError(f"Expected mono or stereo audio, got shape {tuple(waveform.shape)}")
+
+ return waveform.float().contiguous()
-def _encode_video_bytes(
- video: Any,
- fps: int,
- audio: Any | None = None,
- audio_sample_rate: int | None = None,
- video_codec_options: dict[str, str] | None = None,
-) -> bytes:
+def _encode_video_bytes(video: Any, fps: int, audio: Any | None = None, audio_sample_rate: int | None = None) -> bytes:
"""Encode a video payload into MP4 bytes, optionally muxing audio."""
- from vllm_omni.diffusion.utils.media_utils import mux_video_audio_bytes
+ try:
+ from diffusers.utils import export_to_video
+ except ImportError as exc: # pragma: no cover - optional dependency
+ raise ImportError("diffusers is required for export_to_video.") from exc
frames = _coerce_video_to_frames(video)
if not frames:
raise ValueError("No frames found to encode.")
- frames_np = np.stack(frames, axis=0)
- if frames_np.ndim == 4 and frames_np.shape[-1] == 4:
- frames_np = frames_np[..., :3]
+ tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
+ tmp_file.close()
+ try:
+ if audio is not None:
+ from diffusers.pipelines.ltx2.export_utils import encode_video as encode_ltx2_video
+
+ frames_np = np.stack(frames, axis=0)
+ if frames_np.ndim == 4 and frames_np.shape[-1] == 4:
+ frames_np = frames_np[..., :3]
+ frames_np = np.clip(frames_np, 0.0, 1.0)
+ frames_u8 = (frames_np * 255).round().clip(0, 255).astype("uint8")
+ video_tensor = torch.from_numpy(frames_u8)
+ encode_ltx2_video(
+ video_tensor,
+ fps=fps,
+ audio=_coerce_audio_to_waveform(audio),
+ audio_sample_rate=audio_sample_rate,
+ output_path=tmp_file.name,
+ )
+ else:
+ export_to_video(frames, tmp_file.name, fps=fps)
+ with open(tmp_file.name, "rb") as f:
+ return f.read()
+ finally:
+ try:
+ os.remove(tmp_file.name)
+ except OSError:
+ pass
- if frames_np.dtype == np.uint8:
- frames_u8 = frames_np
- else:
- frames_np = np.clip(frames_np, 0.0, 1.0)
- frames_np *= 255.0
- frames_u8 = np.round(frames_np).astype(np.uint8)
-
- # Ensure contiguous memory layout for faster PyAV muxing
- frames_u8 = np.ascontiguousarray(frames_u8)
-
- audio_np = _coerce_audio_to_numpy(audio) if audio is not None else None
-
- return mux_video_audio_bytes(
- frames_u8,
- audio_np,
- fps=float(fps),
- audio_sample_rate=audio_sample_rate or 24000,
- video_codec_options=video_codec_options,
- )
-
-
-def encode_video_base64(
- video: Any,
- fps: int,
- audio: Any | None = None,
- audio_sample_rate: int | None = None,
- video_codec_options: dict[str, str] | None = None,
-) -> str:
+
+def encode_video_base64(video: Any, fps: int, audio: Any | None = None, audio_sample_rate: int | None = None) -> str:
"""Encode a video (frames/array/tensor) to base64 MP4."""
- video_bytes = _encode_video_bytes(
- video, fps=fps, audio=audio, audio_sample_rate=audio_sample_rate, video_codec_options=video_codec_options
- )
+ video_bytes = _encode_video_bytes(video, fps=fps, audio=audio, audio_sample_rate=audio_sample_rate)
return base64.b64encode(video_bytes).decode("utf-8")
diff --git a/vllm_omni/entrypoints/openai/video_frame_filter.py b/vllm_omni/entrypoints/openai/video_frame_filter.py
deleted file mode 100644
index 4ea427835aa..00000000000
--- a/vllm_omni/entrypoints/openai/video_frame_filter.py
+++ /dev/null
@@ -1,116 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""EVS (Efficient Video Sampling) frame pre-filter for streaming video input.
-
-Lightweight pixel-level similarity filter that runs before frames reach the
-vision encoder. For static or slow-moving scenes this can reduce the number
-of frames by 2-5x, proportionally cutting encoder compute and KV-cache usage.
-
-Usage:
- filter = FrameSimilarityFilter(threshold=0.95)
- for jpeg_bytes in incoming_frames:
- if filter.should_retain(jpeg_bytes):
- buffer.append(jpeg_bytes)
-"""
-
-from __future__ import annotations
-
-import io
-
-import numpy as np
-from PIL import Image
-
-_DEFAULT_THRESHOLD = 0.95
-_DEFAULT_THUMBNAIL_SIZE = 64
-
-
-class FrameSimilarityFilter:
- """Drop near-duplicate frames based on pixel-level similarity.
-
- Each incoming JPEG frame is down-scaled to a small thumbnail and compared
- against the last *retained* frame. If the normalised similarity exceeds
- ``threshold`` the frame is considered redundant and dropped.
-
- Args:
- threshold: Similarity score in [0, 1] above which a frame is dropped.
- Higher values keep more frames (less aggressive filtering).
- Default 0.95 works well for typical webcam / surveillance feeds.
- thumbnail_size: Edge length of the square thumbnail used for
- comparison. Larger values are more accurate but slower.
- """
-
- def __init__(
- self,
- threshold: float = _DEFAULT_THRESHOLD,
- thumbnail_size: int = _DEFAULT_THUMBNAIL_SIZE,
- ) -> None:
- if not 0.0 <= threshold <= 1.0:
- raise ValueError(f"threshold must be in [0, 1], got {threshold}")
- if thumbnail_size < 1:
- raise ValueError(f"thumbnail_size must be >= 1, got {thumbnail_size}")
-
- self._threshold = threshold
- self._thumbnail_size = thumbnail_size
- self._last_retained: np.ndarray | None = None
- self._retained_count = 0
- self._dropped_count = 0
-
- # ------------------------------------------------------------------
- # Public API
- # ------------------------------------------------------------------
-
- def should_retain(self, frame_jpeg: bytes) -> bool:
- """Return ``True`` if *frame_jpeg* is sufficiently different from the
- last retained frame and should be kept in the buffer."""
- current = self._decode_and_resize(frame_jpeg)
-
- if self._last_retained is None:
- self._last_retained = current
- self._retained_count += 1
- return True
-
- similarity = self._compute_similarity(self._last_retained, current)
- if similarity >= self._threshold:
- self._dropped_count += 1
- return False
-
- self._last_retained = current
- self._retained_count += 1
- return True
-
- def reset(self) -> None:
- """Clear internal state so the next frame is always retained."""
- self._last_retained = None
- self._retained_count = 0
- self._dropped_count = 0
-
- @property
- def stats(self) -> dict[str, int | float]:
- """Return filtering statistics."""
- total = self._retained_count + self._dropped_count
- return {
- "retained_count": self._retained_count,
- "dropped_count": self._dropped_count,
- "total_count": total,
- "drop_rate": self._dropped_count / total if total > 0 else 0.0,
- }
-
- # ------------------------------------------------------------------
- # Internals
- # ------------------------------------------------------------------
-
- @staticmethod
- def _compute_similarity(a: np.ndarray, b: np.ndarray) -> float:
- """Normalised pixel similarity in [0, 1]. 1 = identical."""
- mse = float(np.mean((a.astype(np.float32) - b.astype(np.float32)) ** 2))
- return 1.0 - mse / (255.0 * 255.0)
-
- def _decode_and_resize(self, jpeg_bytes: bytes) -> np.ndarray:
- """Decode JPEG and resize to a small square thumbnail for fast
- comparison."""
- img = Image.open(io.BytesIO(jpeg_bytes))
- img = img.resize(
- (self._thumbnail_size, self._thumbnail_size),
- Image.Resampling.BILINEAR,
- ).convert("RGB")
- return np.asarray(img, dtype=np.uint8)
diff --git a/vllm_omni/entrypoints/openai/video_stream_context.py b/vllm_omni/entrypoints/openai/video_stream_context.py
deleted file mode 100644
index e72dfcd41d0..00000000000
--- a/vllm_omni/entrypoints/openai/video_stream_context.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Small helpers shared by streaming video handlers."""
-
-from __future__ import annotations
-
-from typing import Any
-
-
-def text_only_message(message: dict[str, Any]) -> dict[str, Any]:
- """Return a history message with multimodal content stripped out."""
- role = message.get("role", "user")
- content = message.get("content", "")
-
- if isinstance(content, str):
- return {"role": role, "content": content}
-
- if not isinstance(content, list):
- return {"role": role, "content": ""}
-
- text_parts: list[str] = []
- for part in content:
- if isinstance(part, dict) and part.get("type") == "text":
- text = part.get("text")
- if isinstance(text, str):
- text_parts.append(text)
-
- return {"role": role, "content": "".join(text_parts)}
diff --git a/vllm_omni/entrypoints/openai/video_stream_envs.py b/vllm_omni/entrypoints/openai/video_stream_envs.py
deleted file mode 100644
index f8571e7f94b..00000000000
--- a/vllm_omni/entrypoints/openai/video_stream_envs.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Environment variables for the streaming video OpenAI entrypoint."""
-
-import logging
-import os
-from collections.abc import Callable
-from typing import TYPE_CHECKING, Literal
-
-if TYPE_CHECKING:
- VLLM_VIDEO_AUDIO_DELTA_MODE: Literal["fast", "slow"] = "fast"
- VLLM_VIDEO_ASYNC_CHUNK: Literal["on", "off"] = "on"
-
-logger = logging.getLogger(__name__)
-_warned_invalid_envs: set[tuple[str, str]] = set()
-_VIDEO_AUDIO_DELTA_MODE = "VLLM_VIDEO_AUDIO_DELTA_MODE"
-_VIDEO_ASYNC_CHUNK = "VLLM_VIDEO_ASYNC_CHUNK"
-
-
-def _choice_env(
- name: str,
- default: str,
- allowed: tuple[str, ...],
-) -> str:
- value = os.getenv(name, default).strip().lower()
- if value in allowed:
- return value
- warning_key = (name, value)
- if warning_key not in _warned_invalid_envs:
- logger.warning("%s=%s not recognized; falling back to %r", name, value, default)
- _warned_invalid_envs.add(warning_key)
- return default
-
-
-environment_variables: dict[str, Callable[[], str]] = {
- _VIDEO_AUDIO_DELTA_MODE: lambda: _choice_env(
- _VIDEO_AUDIO_DELTA_MODE,
- "fast",
- ("fast", "slow"),
- ),
- _VIDEO_ASYNC_CHUNK: lambda: _choice_env(
- _VIDEO_ASYNC_CHUNK,
- "on",
- ("on", "off"),
- ),
-}
-
-
-def __getattr__(name: str) -> str:
- if name in environment_variables:
- return environment_variables[name]()
- raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
-
-
-def __dir__() -> list[str]:
- return list(environment_variables.keys())
-
-
-def is_set(name: str) -> bool:
- if name in environment_variables:
- return name in os.environ
- raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/vllm_omni/entrypoints/openai/video_stream_session.py b/vllm_omni/entrypoints/openai/video_stream_session.py
deleted file mode 100644
index ead3b1ec50c..00000000000
--- a/vllm_omni/entrypoints/openai/video_stream_session.py
+++ /dev/null
@@ -1,439 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Video streaming session manager and WebSocket handler.
-
-Provides ``VideoStreamConfig``, ``VideoStreamSession`` (frame/audio buffer
-with EVS filtering), and ``VideoStreamHandler`` (WebSocket session loop).
-Builds standard ``ChatCompletionRequest`` so ``OmniOpenAIServingChat`` is
-reused with zero changes.
-"""
-
-from __future__ import annotations
-
-import asyncio
-import base64
-import json
-from collections import deque
-from typing import Any
-
-from fastapi import WebSocket, WebSocketDisconnect
-from vllm.entrypoints.openai.chat_completion.protocol import (
- ChatCompletionRequest,
-)
-from vllm.entrypoints.openai.engine.protocol import ErrorResponse
-from vllm.logger import init_logger
-
-from vllm_omni.entrypoints.openai.video_frame_filter import FrameSimilarityFilter
-
-logger = init_logger(__name__)
-
-_DEFAULT_IDLE_TIMEOUT = 60.0
-_DEFAULT_CONFIG_TIMEOUT = 10.0
-_MAX_FRAME_BYTES = 10 * 1024 * 1024 # 10 MB per frame
-
-
-class VideoStreamConfig:
- """Per-session configuration sent by the client in ``session.config``."""
-
- __slots__ = (
- "model",
- "modalities",
- "max_frames",
- "num_sample_frames",
- "evs_enabled",
- "evs_threshold",
- )
-
- def __init__(
- self,
- model: str = "",
- modalities: list[str] | None = None,
- max_frames: int = 64,
- num_sample_frames: int = 16,
- evs_enabled: bool = True,
- evs_threshold: float = 0.95,
- ) -> None:
- self.model = model
- self.modalities = modalities if modalities is not None else ["text"]
- self.max_frames = max_frames
- self.num_sample_frames = num_sample_frames
- self.evs_enabled = evs_enabled
- self.evs_threshold = evs_threshold
-
- # JSON doesn't distinguish int/float, so float fields accept both.
- _FIELD_TYPES: dict[str, tuple[type, ...]] = {
- "model": (str,),
- "modalities": (list,),
- "max_frames": (int,),
- "num_sample_frames": (int,),
- "evs_enabled": (bool,),
- "evs_threshold": (int, float),
- }
-
- @classmethod
- def from_dict(cls, data: dict[str, Any]) -> VideoStreamConfig:
- known = set(cls.__slots__)
- filtered = {k: v for k, v in data.items() if k in known}
-
- for field, accepted in cls._FIELD_TYPES.items():
- if field in filtered and not isinstance(filtered[field], accepted):
- names = "/".join(t.__name__ for t in accepted)
- raise TypeError(f"Invalid type for '{field}': expected {names}, got {type(filtered[field]).__name__}")
-
- return cls(**filtered)
-
-
-class VideoStreamSession:
- """Manages frame and audio buffers for a single streaming session.
-
- Frames are optionally filtered through ``FrameSimilarityFilter`` (EVS)
- before being stored in a fixed-size ring buffer. ``build_chat_request``
- merges buffers into a ``ChatCompletionRequest`` with ``image_url`` (+
- ``audio_url``) content blocks.
- """
-
- def __init__(self, config: VideoStreamConfig) -> None:
- self._config = config
- self._frame_filter: FrameSimilarityFilter | None = (
- FrameSimilarityFilter(threshold=config.evs_threshold) if config.evs_enabled else None
- )
- self._frames: deque[bytes] = deque(maxlen=config.max_frames)
- self._audio_chunks: list[bytes] = []
-
- def add_frame(self, jpeg_bytes: bytes) -> bool:
- """Add a JPEG frame after EVS filtering. Returns ``True`` if kept."""
- if len(jpeg_bytes) > _MAX_FRAME_BYTES:
- raise ValueError(f"Frame too large: {len(jpeg_bytes)} bytes (limit {_MAX_FRAME_BYTES})")
-
- if self._frame_filter and not self._frame_filter.should_retain(jpeg_bytes):
- return False
-
- self._frames.append(jpeg_bytes)
- return True
-
- def sample_frames(self) -> list[bytes]:
- """Return up to ``num_sample_frames`` uniformly sampled frames."""
- n = len(self._frames)
- k = min(n, self._config.num_sample_frames)
- if k == 0:
- return []
- if k == 1:
- return [self._frames[-1]]
- if k >= n:
- return list(self._frames)
- indices = [int(i * (n - 1) / (k - 1)) for i in range(k)]
- return [self._frames[i] for i in indices]
-
- @property
- def frame_count(self) -> int:
- return len(self._frames)
-
- def add_audio_chunk(self, pcm_bytes: bytes) -> None:
- """Append raw PCM 16 kHz audio bytes."""
- self._audio_chunks.append(pcm_bytes)
-
- def clear_audio(self) -> None:
- """Clear audio buffer after a query is submitted."""
- self._audio_chunks.clear()
-
- @property
- def has_audio(self) -> bool:
- return bool(self._audio_chunks)
-
- def build_chat_request(self, query_text: str) -> ChatCompletionRequest:
- """Build a ``ChatCompletionRequest`` from the current buffers."""
- sampled = self.sample_frames()
- content_parts: list[dict[str, Any]] = []
-
- for frame in sampled:
- frame_b64 = base64.b64encode(frame).decode()
- content_parts.append(
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:image/jpeg;base64,{frame_b64}",
- },
- }
- )
-
- if self.has_audio:
- combined_pcm = b"".join(self._audio_chunks)
- audio_b64 = base64.b64encode(combined_pcm).decode()
- content_parts.append(
- {
- "type": "audio_url",
- "audio_url": {
- "url": f"data:audio/L16;rate=16000;base64,{audio_b64}",
- },
- }
- )
-
- content_parts.append({"type": "text", "text": query_text})
-
- messages: list[dict[str, Any]] = [
- {"role": "user", "content": content_parts},
- ]
-
- request_dict: dict[str, Any] = {
- "model": self._config.model,
- "messages": messages,
- "stream": True,
- }
-
- if self.has_audio:
- request_dict["mm_processor_kwargs"] = {
- "use_audio_in_video": True,
- }
-
- return ChatCompletionRequest(**request_dict)
-
- @property
- def evs_stats(self) -> dict[str, int | float] | None:
- return self._frame_filter.stats if self._frame_filter else None
-
-
-class VideoStreamHandler:
- """Drives a ``VideoStreamSession`` over a FastAPI ``WebSocket``.
-
- Instantiate once at server startup and call ``handle_session`` for each
- incoming connection. Uses two concurrent tasks (_reader + _processor)
- so frames keep arriving while a query is being processed.
- """
-
- def __init__(
- self,
- chat_handler: Any,
- idle_timeout: float = _DEFAULT_IDLE_TIMEOUT,
- config_timeout: float = _DEFAULT_CONFIG_TIMEOUT,
- ) -> None:
- self._chat_handler = chat_handler
- self._idle_timeout = idle_timeout
- self._config_timeout = config_timeout
-
- async def handle_session(self, websocket: WebSocket) -> None:
- """Main session loop for one WebSocket connection."""
- await websocket.accept()
-
- try:
- config = await self._receive_config(websocket)
- if config is None:
- return
-
- session = VideoStreamSession(config)
- logger.info(
- "Video stream session started: model=%s modalities=%s max_frames=%d evs=%s",
- config.model,
- config.modalities,
- config.max_frames,
- config.evs_enabled,
- )
-
- msg_queue: asyncio.Queue[dict[str, Any] | None] = asyncio.Queue()
-
- async def _reader() -> None:
- try:
- while True:
- try:
- raw = await asyncio.wait_for(
- websocket.receive_text(),
- timeout=self._idle_timeout,
- )
- except asyncio.TimeoutError:
- await self._send_error(websocket, "Idle timeout: no message received")
- await msg_queue.put(None)
- return
-
- try:
- msg = json.loads(raw)
- except json.JSONDecodeError:
- await self._send_error(websocket, "Invalid JSON")
- continue
-
- if not isinstance(msg, dict):
- await self._send_error(websocket, "Messages must be JSON objects")
- continue
-
- await msg_queue.put(msg)
-
- if msg.get("type") == "video.done":
- return
- except WebSocketDisconnect:
- await msg_queue.put(None)
- except Exception:
- await msg_queue.put(None)
- raise
-
- async def _processor() -> None:
- while True:
- msg = await msg_queue.get()
- if msg is None:
- return
-
- msg_type = msg.get("type")
-
- if msg_type == "video.frame":
- data_b64 = msg.get("data", "")
- try:
- jpeg_bytes = base64.b64decode(data_b64)
- except Exception:
- await self._send_error(websocket, "Invalid base64 in video.frame")
- continue
- try:
- session.add_frame(jpeg_bytes)
- except ValueError as exc:
- await self._send_error(websocket, str(exc))
- continue
-
- elif msg_type == "audio.chunk":
- data_b64 = msg.get("data", "")
- try:
- pcm_bytes = base64.b64decode(data_b64)
- except Exception:
- await self._send_error(websocket, "Invalid base64 in audio.chunk")
- continue
- session.add_audio_chunk(pcm_bytes)
-
- elif msg_type == "video.query":
- query_text = msg.get("text", "")
- if not query_text:
- await self._send_error(
- websocket,
- "video.query requires a non-empty 'text' field",
- )
- continue
- await self._handle_query(websocket, session, query_text)
-
- elif msg_type == "video.done":
- evs = session.evs_stats
- if evs is not None:
- await websocket.send_json({"type": "response.evs_stats", **evs})
- await websocket.send_json({"type": "session.done"})
- return
-
- else:
- await self._send_error(websocket, f"Unknown message type: {msg_type}")
-
- reader_task = asyncio.create_task(_reader())
- try:
- await _processor()
- finally:
- reader_task.cancel()
- try:
- await reader_task
- except asyncio.CancelledError:
- pass
- except Exception:
- logger.debug(
- "Reader task raised an exception",
- exc_info=True,
- )
-
- except WebSocketDisconnect:
- logger.info("Video stream: client disconnected")
- except Exception:
- logger.exception("Video stream session error")
- try:
- await self._send_error(websocket, "Internal server error")
- except Exception:
- pass
-
- async def _receive_config(self, websocket: WebSocket) -> VideoStreamConfig | None:
- try:
- raw = await asyncio.wait_for(
- websocket.receive_text(),
- timeout=self._config_timeout,
- )
- except asyncio.TimeoutError:
- await self._send_error(websocket, "Timeout waiting for session.config")
- return None
-
- try:
- msg = json.loads(raw)
- except json.JSONDecodeError:
- await self._send_error(websocket, "Invalid JSON in session.config")
- return None
-
- if not isinstance(msg, dict) or msg.get("type") != "session.config":
- await self._send_error(
- websocket,
- f"Expected session.config, got: {msg.get('type') if isinstance(msg, dict) else type(msg).__name__}",
- )
- return None
-
- try:
- return VideoStreamConfig.from_dict(msg)
- except TypeError as exc:
- await self._send_error(websocket, str(exc))
- return None
-
- async def _handle_query(
- self,
- websocket: WebSocket,
- session: VideoStreamSession,
- query_text: str,
- ) -> None:
- """Build a ChatCompletionRequest, run it, and stream the response."""
- if session.frame_count == 0:
- await self._send_error(
- websocket,
- "No frames available. Send video.frame before video.query.",
- )
- return
-
- request = session.build_chat_request(query_text)
- await websocket.send_json({"type": "response.start"})
-
- # After response.start, the protocol contract requires us to always
- # send response.text.done — even on error / exception paths.
- text_parts: list[str] = []
- try:
- # raw_request=None: serving_chat.py guards with `if raw_request:`
- generator = await self._chat_handler.create_chat_completion(request, raw_request=None)
-
- if isinstance(generator, ErrorResponse):
- error_msg = generator.error.message if generator.error else "Unknown error"
- await self._send_error(websocket, error_msg)
- else:
- async for chunk in generator:
- if isinstance(chunk, str):
- for delta in self._parse_sse_deltas(chunk):
- text_parts.append(delta)
- await websocket.send_json(
- {
- "type": "response.text.delta",
- "delta": delta,
- }
- )
-
- except Exception:
- logger.exception("Query failed")
- await self._send_error(websocket, "Query processing failed")
-
- full_text = "".join(text_parts)
- await websocket.send_json({"type": "response.text.done", "text": full_text})
-
- session.clear_audio()
-
- @staticmethod
- def _parse_sse_deltas(chunk: str) -> list[str]:
- """Extract content deltas from SSE-formatted chunk."""
- deltas: list[str] = []
- for line in chunk.split("\n"):
- line = line.strip()
- if not line or not line.startswith("data: ") or line == "data: [DONE]":
- continue
- try:
- data = json.loads(line[6:])
- delta = data.get("choices", [{}])[0].get("delta", {}).get("content", "")
- if delta:
- deltas.append(delta)
- except (json.JSONDecodeError, IndexError, AttributeError):
- pass
- return deltas
-
- @staticmethod
- async def _send_error(websocket: WebSocket, message: str) -> None:
- try:
- await websocket.send_json({"type": "error", "message": message})
- except Exception:
- pass
diff --git a/vllm_omni/entrypoints/pd_utils.py b/vllm_omni/entrypoints/pd_utils.py
index 413d5d6b448..0e3d65f5537 100644
--- a/vllm_omni/entrypoints/pd_utils.py
+++ b/vllm_omni/entrypoints/pd_utils.py
@@ -23,19 +23,9 @@
class PDDisaggregationMixin:
"""Mixin supplying PD disaggregation helpers to OmniBase."""
- def _get_pd_separation_pair(self) -> tuple[int, int] | None:
- """PD prefill/decode indices when ``_init_pd_state`` ran; else ``None``.
-
- Partial test doubles may skip ``OmniBase.__init__``; treat missing state as
- no PD disaggregation instead of raising ``AttributeError``.
- """
- return getattr(self, "_pd_separation_pair", None)
-
def _init_pd_state(self) -> None:
"""Initialise PD disaggregation state."""
- self._pd_separation_pair: tuple[int, int] | None = self.detect_pd_separation_from_stage_configs(
- self.stage_configs
- )
+ self._pd_separation_pair: tuple[int, int] | None = self._detect_pd_separation()
self._pd_connector_info: dict[str, Any] | None = None
self._pd_kv_params_by_req: dict[str, dict[str, Any]] = {}
self._pd_kv_params_lock = threading.Lock()
@@ -50,19 +40,11 @@ def _init_pd_state(self) -> None:
d_id,
)
- @staticmethod
- def detect_pd_separation_from_stage_configs(stage_configs: list[Any]) -> tuple[int, int] | None:
- """Scan stage configs for a prefill/decode pair.
-
- Returns:
- (prefill_idx, decode_idx) if one pair exists, None if not found.
-
- Raises:
- ValueError: if multiple candidate PD pairs are found.
- """
+ def _detect_pd_separation(self) -> tuple[int, int] | None:
+ """Scan stage_list for a prefill/decode pair. Returns (p_id, d_id) or None."""
prefill_by_id: dict[int, int] = {}
decode_indices: list[int] = []
- for i, stage in enumerate(stage_configs):
+ for i, stage in enumerate(self.stage_list):
if getattr(stage, "is_prefill_only", False):
prefill_by_id[i] = i
sid = getattr(stage, "stage_id", i)
@@ -73,7 +55,7 @@ def detect_pd_separation_from_stage_configs(stage_configs: list[Any]) -> tuple[i
pd_pairs: list[tuple[int, int]] = []
for j in decode_indices:
- source_ids = getattr(stage_configs[j], "engine_input_source", [])
+ source_ids = getattr(self.stage_list[j], "engine_input_source", [])
for src in source_ids:
if src in prefill_by_id:
pd_pairs.append((prefill_by_id[src], j))
@@ -125,11 +107,10 @@ def _normalize_kv_transfer_params(self, kv_params: Any) -> dict[str, Any] | None
def _validate_pd_separation_config(self) -> None:
"""Validate PD stage configurations are consistent."""
- pair = self._get_pd_separation_pair()
- assert pair is not None
- p_id, d_id = pair
- p_stage = self.stage_configs[p_id]
- d_stage = self.stage_configs[d_id]
+ assert self._pd_separation_pair is not None
+ p_id, d_id = self._pd_separation_pair
+ p_stage = self.stage_list[p_id]
+ d_stage = self.stage_list[d_id]
def _get_kv_cfg(stage: "OmniStage") -> dict[str, Any]:
ea = stage.engine_args
@@ -177,12 +158,11 @@ def _get_kv_cfg(stage: "OmniStage") -> dict[str, Any]:
def _get_pd_connector_info(self) -> dict[str, Any] | None:
"""Extract prefill engine KV connector info."""
- pair = self._get_pd_separation_pair()
- if pair is None:
+ if self._pd_separation_pair is None:
return None
- p_id, _ = pair
- p_stage = self.stage_configs[p_id]
+ p_id, _ = self._pd_separation_pair
+ p_stage = self.stage_list[p_id]
ea = p_stage.engine_args
kv_cfg = getattr(ea, "kv_transfer_config", None)
@@ -261,17 +241,18 @@ def _extract_kv_transfer_params(self, engine_outputs: Any) -> dict[str, Any] | N
def _is_pd_routing(self, stage_id: int, next_stage_id: int) -> bool:
"""True when edge stage_id → next_stage_id is the prefill→decode boundary."""
- pair = self._get_pd_separation_pair()
- return pair is not None and pair == (stage_id, next_stage_id)
+ return self._pd_separation_pair is not None and self._pd_separation_pair == (
+ stage_id,
+ next_stage_id,
+ )
def _maybe_expand_sampling_params(self, sampling_params_list: list) -> list:
"""Auto-duplicate thinker SP for decode stage when user provides N-1 params."""
- pair = self._get_pd_separation_pair()
- if pair is None:
+ if self._pd_separation_pair is None:
return sampling_params_list
- if len(sampling_params_list) != len(self.stage_configs) - 1:
+ if len(sampling_params_list) != len(self.stage_list) - 1:
return sampling_params_list
- p_id, d_id = pair
+ p_id, d_id = self._pd_separation_pair
sp_list = list(sampling_params_list)
sp_list.insert(d_id, sp_list[p_id])
return sp_list
diff --git a/vllm_omni/entrypoints/stage_utils.py b/vllm_omni/entrypoints/stage_utils.py
index 7b725f469eb..8674d3c33da 100644
--- a/vllm_omni/entrypoints/stage_utils.py
+++ b/vllm_omni/entrypoints/stage_utils.py
@@ -78,7 +78,7 @@ def _parse_device_list(devices: str | int) -> list[str]:
def _map_device_list(stage_id: int, device_list: list[str], visible_device_list: list[str]) -> list[str]:
- """Map logical stage devices onto the currently available device pool.
+ """Maps logical to physical devices if we have enough visible devices available.
Args:
stage_id: The stage ID currently configuring devices.
@@ -87,42 +87,22 @@ def _map_device_list(stage_id: int, device_list: list[str], visible_device_list:
visible_device_list: List of physical devices available.
"""
num_visible = len(visible_device_list)
+ num_logical = len(device_list)
+ if num_visible < num_logical:
+ raise ValueError(f"Stage {stage_id} requires {num_logical} devices, but only {num_visible} devices are visible")
# Ensure that the logical IDs are actually in range to avoid index errors;
- # if some requested ids exceed the available pool, we will fall back to the
- # subset that can be mapped and leave the final capacity check to the later
- # parallel-config validation path.
+ # If the check above passes and those below fail, the logical devices are wrong,
+ # i.e., not actually 0, 1, ..., n
if not all(device.isdigit() for device in device_list):
raise ValueError("Logical devices must be non-negative integers")
logical_ids = [int(device) for device in device_list]
- mapped_devices = [visible_device_list[idx] for idx in logical_ids if idx < num_visible]
- mapping_pairs = [
- f"{logical_id}->{visible_device_list[logical_id]}" for logical_id in logical_ids if logical_id < num_visible
- ]
- if not mapped_devices:
+ if max(logical_ids) >= num_visible:
raise ValueError(
- f"Stage {stage_id} has logical IDs {device_list}, none of which map to the visible devices "
- f"{visible_device_list}"
+ f"Stage {stage_id} has logical IDs {device_list}, one or more of which exceed the number of visible devices"
)
- if len(mapped_devices) < len(logical_ids):
- logger.warning(
- "Stage %s requested logical devices %s, but only %d device(s) are currently available: %s. "
- "Resolved logical-to-physical mapping: %s. Falling back to mapped subset %s",
- stage_id,
- device_list,
- num_visible,
- visible_device_list,
- ", ".join(mapping_pairs) if mapping_pairs else "(none)",
- mapped_devices,
- )
- else:
- logger.info(
- "Stage %s logical-to-physical device mapping: %s",
- stage_id,
- ", ".join(mapping_pairs),
- )
- return mapped_devices
+ return [visible_device_list[idx] for idx in logical_ids]
def serialize_obj(obj: Any) -> bytes:
diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py
index d728e76417c..e29e9eea1c2 100644
--- a/vllm_omni/entrypoints/utils.py
+++ b/vllm_omni/entrypoints/utils.py
@@ -1,4 +1,3 @@
-import argparse
import os
import types
from collections import Counter
@@ -6,16 +5,12 @@
from pathlib import Path
from typing import Any, get_args, get_origin
-import yaml
from vllm.logger import init_logger
-from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.config import get_config, get_hf_file_to_dict
from vllm.transformers_utils.repo_utils import file_or_path_exists
-from vllm_omni.config.stage_config import StageConfigFactory
from vllm_omni.config.yaml_util import create_config, load_yaml_config, merge_configs
from vllm_omni.entrypoints.stage_utils import _to_dict
-from vllm_omni.inputs.data import OmniSamplingParams
from vllm_omni.platforms import current_omni_platform
# Get the project root directory (2 levels up from this file)
@@ -23,82 +18,11 @@
logger = init_logger(__name__)
-
-def _warn_deprecated_explicit_keys(kwargs: dict[str, Any]) -> None:
- if "cli_explicit_keys" in kwargs:
- import warnings
-
- warnings.warn(
- "cli_explicit_keys= is deprecated and ignored. Remove the kwarg.",
- DeprecationWarning,
- stacklevel=3,
- )
-
-
_DIFFUSERS_CLASS_TO_CONFIG: dict[str, str] = {
"GlmImagePipeline": "glm_image",
}
-def detect_explicit_cli_keys(
- argv: list[str],
- parser: argparse.ArgumentParser | None = None,
-) -> set[str]:
- """Walk ``argv`` and return the set of ``dest`` attribute names the user
- explicitly provided (e.g. ``--max-num-seqs 64`` → ``max_num_seqs``).
-
- Used to distinguish user-typed CLI args from argparse default values so
- deploy YAMLs are not silently overridden by parser defaults. Shared
- across online (``vllm serve``) and offline (scripts, examples, tests,
- CI) entry points — offline callers that parse CLI args via argparse
- should invoke this on ``sys.argv[1:]`` and pass the result through to
- ``AsyncOmni`` / ``Omni`` via the ``_cli_explicit_keys`` kwarg.
-
- When ``parser`` is provided, each token is looked up in the parser's
- action table to find its real ``dest``. This correctly handles flags
- with ``dest=`` overrides, alias flags (e.g. ``--usp`` /
- ``--ulysses-degree`` both mapping to ``ulysses_degree``), and
- ``--disable-foo`` / ``store_false`` patterns that map to a differently
- named dest. Callers with access to an ``argparse.ArgumentParser`` should
- always pass it.
-
- When ``parser`` is ``None``, a name-based heuristic is used as a
- fallback (hyphens → underscores, plus a ``no_`` prefix strip for
- ``argparse.BooleanOptionalAction``). This is correct for simple flags
- but silently misidentifies ``--disable-X``-style flags and explicit
- ``dest=`` overrides, so prefer the parser-aware form.
- """
- if parser is not None:
- dest_map: dict[str, str] = {}
- for action in parser._actions:
- for opt in action.option_strings:
- dest_map[opt] = action.dest
- explicit: set[str] = set()
- for tok in argv:
- if not tok.startswith("--"):
- continue
- flag = tok.split("=", 1)[0]
- dest = dest_map.get(flag)
- if dest is not None:
- explicit.add(dest)
- return explicit
-
- # Fallback: name-based heuristic (legacy path for callers without a parser).
- explicit = set()
- for tok in argv:
- if not tok.startswith("--"):
- continue
- name = tok[2:].split("=", 1)[0]
- if not name:
- continue
- attr = name.replace("-", "_")
- explicit.add(attr)
- # BooleanOptionalAction: --no-foo records as dest `foo`, not `no_foo`.
- if attr.startswith("no_"):
- explicit.add(attr[3:])
- return explicit
-
-
def inject_omni_kv_config(stage: Any, omni_conn_cfg: dict[str, Any], omni_from: str, omni_to: str) -> None:
"""Inject connector configuration into stage engine arguments."""
# Prepare omni_kv_config dict
@@ -258,31 +182,6 @@ def _convert_dataclasses_to_dict(obj: Any) -> Any:
return obj
-def _try_resolve_omni_model_type(model: str) -> str | None:
- """Try to resolve model_type for omni models with empty config.json.
-
- Searches both the legacy ``stage_configs/*.yaml`` directory and the
- migrated ``deploy/*.yaml`` directory for a stem that substring-matches
- the model path (e.g. ``cosyvoice3`` in
- ``FunAudioLLM/Fun-CosyVoice3-0.5B-2512``). The longest match wins so
- ``cosyvoice3`` beats ``cosyvoice`` and ``bagel_single_stage`` beats
- ``bagel``.
- """
- model_lower = model.lower().replace("-", "").replace("_", "")
- best_match: str | None = None
- best_len = 0
- for subdir in ("model_executor/stage_configs", "deploy"):
- config_dir = PROJECT_ROOT / "vllm_omni" / subdir
- if not config_dir.exists():
- continue
- for config_file in sorted(config_dir.glob("*.yaml")):
- candidate = config_file.stem.replace("-", "").replace("_", "")
- if candidate and candidate in model_lower and len(candidate) > best_len:
- best_match = config_file.stem
- best_len = len(candidate)
- return best_match
-
-
def resolve_model_config_path(model: str) -> str:
"""Resolve the stage config file path from the model name.
@@ -321,11 +220,7 @@ def resolve_model_config_path(model: str) -> str:
if config_dict and "model_type" in config_dict:
model_type = config_dict["model_type"]
else:
- # For models with empty config.json (e.g. CosyVoice3),
- # try matching against registered omni stage configs.
- model_type = _try_resolve_omni_model_type(model)
- if model_type is None:
- raise ValueError(f"config.json found but missing 'model_type' for model: {model}")
+ raise ValueError(f"config.json found but missing 'model_type' for model: {model}")
except Exception as e:
raise ValueError(f"Failed to read config.json for model: {model}. Error: {e}") from e
else:
@@ -345,10 +240,6 @@ def resolve_model_config_path(model: str) -> str:
if os.path.exists(complete_config_path):
return str(complete_config_path)
- deploy_config_path = PROJECT_ROOT / "vllm_omni" / "deploy" / model_type_str
- if os.path.exists(deploy_config_path):
- return str(deploy_config_path)
-
stage_config_file = f"vllm_omni/model_executor/stage_configs/{normalized_model_type}.yaml"
stage_config_path = PROJECT_ROOT / stage_config_file
if not os.path.exists(stage_config_path):
@@ -356,76 +247,44 @@ def resolve_model_config_path(model: str) -> str:
return str(stage_config_path)
-def load_stage_configs_from_model(
- model: str,
- base_engine_args: dict | None = None,
- deploy_config_path: str | None = None,
- stage_overrides: dict[str, dict[str, Any]] | None = None,
- **deprecated_kwargs: Any,
-) -> list:
+def load_stage_configs_from_model(model: str, base_engine_args: dict | None = None) -> list:
"""Load stage configurations from model's default config file.
- For models registered in the pipeline registry (new path), uses
- ``StageConfigFactory.create_from_model()`` which merges
- PipelineConfig + DeployConfig + CLI overrides.
+ .. deprecated::
+ This is the legacy OmegaConf-based loading path. New code should use
+ ``StageConfigFactory.create_from_model()`` instead. This function will
+ be removed once all callers are migrated (see PR series [2/N]).
- For other models (legacy path), loads stage configs from YAML.
+ Loads stage configurations based on the model type and device type.
+ First tries to load a device-specific YAML file from stage_configs/{device_type}/
+ directory. If not found, falls back to the default config file.
Args:
model: Model name or path (used to determine model_type)
- base_engine_args: Base engine args to merge as CLI overrides.
- deploy_config_path: Optional explicit deploy config path.
- stage_overrides: Per-stage overrides from --stage-overrides.
Returns:
List of stage configuration dictionaries
- """
- _warn_deprecated_explicit_keys(deprecated_kwargs)
+ Raises:
+ FileNotFoundError: If no stage config file exists for the model type
+ """
if base_engine_args is None:
base_engine_args = {}
-
- cli_overrides = _convert_dataclasses_to_dict(dict(base_engine_args))
- if stage_overrides:
- for stage_id_str, overrides in stage_overrides.items():
- for key, val in overrides.items():
- cli_overrides[f"stage_{stage_id_str}_{key}"] = val
-
- stages = StageConfigFactory.create_from_model(
- model,
- cli_overrides=cli_overrides,
- deploy_config_path=deploy_config_path,
- )
- if stages is not None:
- # Convert StageConfig objects to OmegaConf for backward compat
- return [stage.to_omegaconf() for stage in stages]
-
- # Legacy fallback: load from YAML
stage_config_path = resolve_model_config_path(model)
if stage_config_path is None:
return []
- stage_configs = load_stage_configs_from_yaml(
- config_path=stage_config_path,
- base_engine_args=base_engine_args,
- prefer_stage_engine_args=True,
- )
+ stage_configs = load_stage_configs_from_yaml(config_path=stage_config_path, base_engine_args=base_engine_args)
return stage_configs
-def load_stage_configs_from_yaml(
- config_path: str,
- base_engine_args: dict | None = None,
- prefer_stage_engine_args: bool = True,
-) -> list:
- """Load stage configurations from a YAML file (legacy OmegaConf path).
+def load_stage_configs_from_yaml(config_path: str, base_engine_args: dict | None = None) -> list:
+ """Load stage configurations from a YAML file.
- TODO(@lishunyang12): remove once all models use PipelineConfig + DeployConfig.
+ .. deprecated::
+ Legacy OmegaConf-based loader. Will be removed in PR series [2/N].
Args:
config_path: Path to the YAML configuration file
- base_engine_args: Engine args supplied by the caller.
- prefer_stage_engine_args: When True, YAML stage args override caller
- engine args. When False, caller engine args override YAML defaults.
Returns:
List of stage configuration dictionaries from the file's stage_args
@@ -442,11 +301,7 @@ def load_stage_configs_from_yaml(
base_engine_args_tmp = base_engine_args.copy()
# Update base_engine_args with stage-specific engine_args if they exist
if hasattr(stage_arg, "engine_args") and stage_arg.engine_args is not None:
- if prefer_stage_engine_args:
- merged_engine_args = merge_configs(base_engine_args_tmp, stage_arg.engine_args)
- else:
- merged_engine_args = merge_configs(stage_arg.engine_args, base_engine_args_tmp)
- base_engine_args_tmp = create_config(merged_engine_args)
+ base_engine_args_tmp = create_config(merge_configs(base_engine_args_tmp, stage_arg.engine_args))
stage_type = getattr(stage_arg, "stage_type", "llm")
if hasattr(stage_arg, "runtime") and stage_arg.runtime is not None and stage_type != "diffusion":
base_engine_args_tmp.async_chunk = global_async_chunk
@@ -553,75 +408,22 @@ def load_and_resolve_stage_configs(
stage_configs_path: str | None,
kwargs: dict | None,
default_stage_cfg_factory: Any = None,
- deploy_config_path: str | None = None,
- stage_overrides: dict[str, dict[str, Any]] | None = None,
- **deprecated_kwargs: Any,
) -> tuple[str, list]:
"""Load stage configurations from model or YAML file with fallback to defaults.
Args:
model: Model name or path
- stage_configs_path: Optional path to legacy YAML (stage_args format)
+ stage_configs_path: Optional path to YAML file containing stage configurations
kwargs: Engine arguments to merge with stage configs
default_stage_cfg_factory: Optional callable that takes no args and returns
default stage config list when no configs are found
- deploy_config_path: Optional path to deploy YAML (new format).
- Mutually exclusive with ``stage_configs_path``.
- stage_overrides: Per-stage overrides from ``--stage-overrides`` JSON.
- Keys are stage_id strings, values are dicts of overrides.
Returns:
Tuple of (config_path, stage_configs)
"""
- if stage_configs_path is not None and deploy_config_path is not None:
- raise ValueError(
- "--stage-configs-path and --deploy-config are mutually exclusive: "
- "they use different path resolution rules and loading paths. "
- "Use --deploy-config for new-format YAMLs (preferred); "
- "--stage-configs-path is kept only for the legacy `stage_args` format "
- "and will be removed in a future release."
- )
- if stage_configs_path is not None and deploy_config_path is None:
- if not os.path.exists(stage_configs_path):
- raise FileNotFoundError(
- f"--stage-configs-path {stage_configs_path!r} does not exist. "
- "Legacy `stage_configs/` yamls were replaced by `vllm_omni/deploy/.yaml`; "
- "use --deploy-config. See docs/configuration/stage_configs.md."
- )
- with open(stage_configs_path, encoding="utf-8") as f:
- _peek = yaml.safe_load(f) or {}
- if "stages" in _peek and "stage_args" not in _peek:
- deploy_config_path = stage_configs_path
- stage_configs_path = None
- else:
- logger.warning(
- "--stage-configs-path is deprecated; migrate %r and use --deploy-config.",
- stage_configs_path,
- )
-
- _warn_deprecated_explicit_keys(deprecated_kwargs)
-
- if deploy_config_path is not None:
- config_path = deploy_config_path
- stage_configs = load_stage_configs_from_model(
- model,
- base_engine_args=kwargs,
- deploy_config_path=deploy_config_path,
- stage_overrides=stage_overrides,
- )
- if not stage_configs:
- if default_stage_cfg_factory is not None:
- default_stage_cfg = default_stage_cfg_factory()
- stage_configs = create_config(default_stage_cfg)
- else:
- stage_configs = []
- elif stage_configs_path is None:
+ if stage_configs_path is None:
config_path = resolve_model_config_path(model)
- stage_configs = load_stage_configs_from_model(
- model,
- base_engine_args=kwargs,
- stage_overrides=stage_overrides,
- )
+ stage_configs = load_stage_configs_from_model(model, base_engine_args=kwargs)
if not stage_configs:
if default_stage_cfg_factory is not None:
default_stage_cfg = default_stage_cfg_factory()
@@ -852,40 +654,3 @@ def detect_pid_host() -> bool:
return True
return has_pid_host()
-
-
-### Helpers for handling delta messages
-def coerce_param_message_types(params: list[OmniSamplingParams], is_streaming: bool):
- """Iterate over the sampling params and convert to the message types
- to DELTA messages, if streaming is enabled, or FINAL_ONLY if
- it's disabled, while respecting `.skip_clone` on the params.
-
- This is needed to avoid emitting redundant multimodal data.
- """
- # Coerce vLLM's default output kinds as needed to handle streaming
- # (i.e., DELTA output kind). Note that this is only applied to non
- # Diffusion sampling params.
- #
- # NOTE: Hidden states will still be passed between stages.
- for idx, sp in enumerate(params):
- # For OmniDiffusionParams don't set output kind
- if isinstance(sp, SamplingParams):
- params[idx] = maybe_coerce_to_message_type(sp, is_streaming)
- return params
-
-
-def maybe_coerce_to_message_type(params: SamplingParams, is_streaming: bool):
- """If this is a CUMULATIVE message, coerce it to DELTA if streaming, otherwise FINAL_ONLY."""
- target_type = RequestOutputKind.DELTA if is_streaming else RequestOutputKind.FINAL_ONLY
- if params.output_kind == target_type:
- return params
- elif is_streaming and params.output_kind == RequestOutputKind.FINAL_ONLY:
- logger.warning("Request appears to be streaming, but got request type final only!")
- elif not is_streaming and params.output_kind == RequestOutputKind.DELTA:
- logger.warning("Request appears to not be streaming, but got request type delta!")
-
- if not params.skip_clone:
- params = params.clone()
- params.skip_clone = True
- params.output_kind = target_type
- return params
diff --git a/vllm_omni/inputs/data.py b/vllm_omni/inputs/data.py
index e4c33a58c20..7824e7092dc 100644
--- a/vllm_omni/inputs/data.py
+++ b/vllm_omni/inputs/data.py
@@ -227,10 +227,6 @@ class OmniDiffusionSamplingParams:
frame_rate: float | None = None # Floating-point rate used by the diffusion model when it differs from `fps`.
height_not_provided: bool = False
width_not_provided: bool = False
- enable_frame_interpolation: bool = False
- frame_interpolation_exp: int = 1
- frame_interpolation_scale: float = 1.0
- frame_interpolation_model_path: str | None = None
# Timesteps
timesteps: torch.Tensor | None = None
@@ -245,7 +241,6 @@ class OmniDiffusionSamplingParams:
guidance_scale_provided: bool = False
guidance_scale_2: float | None = None
guidance_rescale: float = 0.0
- strength: float | None = None # I2I: Z-Image specific now, uses to control denoising start timestep
decode_timestep: float | list[float] | None = None
decode_noise_scale: float | list[float] | None = None
eta: float = 0.0
@@ -267,10 +262,6 @@ class OmniDiffusionSamplingParams:
cfg_text_kv_metadata: dict[str, Any] | None = None
cfg_img_kv_metadata: dict[str, Any] | None = None
cfg_kv_request_ids: dict[str, str] | None = None
- cfg_active_branch: str | None = None
- cfg_branch_roles: list[str] | None = None
- cfg_branch_past_key_values: dict[str, Any] | None = None
- cfg_branch_kv_metadata: dict[str, dict[str, Any]] | None = None
# Component modules
modules: dict[str, Any] = field(default_factory=dict)
diff --git a/vllm_omni/inputs/preprocess.py b/vllm_omni/inputs/preprocess.py
index cca6ce56870..c6dffd05426 100644
--- a/vllm_omni/inputs/preprocess.py
+++ b/vllm_omni/inputs/preprocess.py
@@ -29,8 +29,6 @@ def _process_text(
self,
parsed_content: OmniTextPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
- *,
- mm_uuids: Any | None = None,
) -> OmniTokenInputs | MultiModalInput:
"""Process text prompts with support for mm_processor_kwargs.
@@ -40,10 +38,6 @@ def _process_text(
"""
prompt_text = parsed_content["prompt"]
mm_processor_kwargs = parsed_content.get("mm_processor_kwargs") or {}
- # When the deprecated raw-prompt path is used, process_inputs does
- # not pass mm_uuids to preprocess(). Fall back to reading it from
- # the prompt dict so the Renderer's _validate_mm_uuids can see it.
- effective_mm_uuids = mm_uuids or parsed_content.get("multi_modal_uuids")
inputs: OmniTokenInputs | MultiModalInput
if multi_modal_data := parsed_content.get("multi_modal_data"):
@@ -52,7 +46,6 @@ def _process_text(
multi_modal_data,
mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
- mm_uuids=effective_mm_uuids,
)
prompt_embeds = parsed_content.get("prompt_embeds")
if prompt_embeds is not None:
@@ -66,7 +59,6 @@ def _process_text(
{},
mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
- mm_uuids=effective_mm_uuids,
)
else:
prompt_token_ids = self._tokenize_prompt(
@@ -150,8 +142,6 @@ def _prompt_to_llm_inputs(
self,
prompt: SingletonDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
- *,
- mm_uuids: Any | None = None,
) -> SingletonInput:
"""
Extract the singleton inputs from a prompt.
@@ -176,7 +166,6 @@ def _prompt_to_llm_inputs(
return self._process_text(
prompt, # type: ignore[arg-type]
tokenization_kwargs=tokenization_kwargs,
- mm_uuids=mm_uuids,
)
assert_never(prompt) # type: ignore[arg-type]
diff --git a/vllm_omni/metrics/stats.py b/vllm_omni/metrics/stats.py
index 4245deb5453..46fec6fcc1f 100644
--- a/vllm_omni/metrics/stats.py
+++ b/vllm_omni/metrics/stats.py
@@ -41,7 +41,6 @@ class StageRequestStats:
postprocess_time_ms: float = 0.0
diffusion_metrics: dict[str, int] = None
audio_generated_frames: int = 0
- pipeline_timings: dict[str, float] | None = None
@property
def rx_mbps(self) -> float:
@@ -102,7 +101,6 @@ def e2e_tpt(self) -> float:
"rx_decode_time_ms",
"rx_in_flight_time_ms",
"final_output_type",
- "pipeline_timings",
}
TRANSFER_EXCLUDE = {"from_stage", "to_stage", "request_id", "used_shm"}
E2E_EXCLUDE = {"request_id"}
@@ -489,15 +487,6 @@ def build_and_log_summary(self) -> dict[str, Any]:
"e2e_avg_time_per_request_ms": float(e2e_avg_req),
"e2e_avg_tokens_per_s": float(e2e_avg_tok),
}
- # Add average input preprocess time across all requests
- preprocess_times = []
- for _rid, evts in self.stage_events.items():
- for evt in evts:
- if evt.pipeline_timings and "preprocess_ms" in evt.pipeline_timings:
- preprocess_times.append(evt.pipeline_timings["preprocess_ms"])
- break # only once per request
- if preprocess_times:
- overall_summary["input_preprocess_time_ms"] = sum(preprocess_times) / len(preprocess_times)
# Add stage_wall_time_ms as separate fields for each stage
for idx, wall_time in enumerate(stage_wall_time_ms):
overall_summary[f"e2e_stage_{idx}_wall_time_ms"] = wall_time
@@ -545,41 +534,11 @@ def build_and_log_summary(self) -> dict[str, Any]:
),
)
- # === [OmniTiming] concise per-request summary ===
+ # === Stage table (columns = stage_id) ===
stage_evts = sorted(
self.stage_events.get(rid, []),
key=lambda e: e.stage_id if e.stage_id is not None else -1,
)
- pt = {}
- if stage_evts:
- pt = stage_evts[-1].pipeline_timings or {}
- if pt or e2e_evt:
- parts = [f"req={rid}"]
- if e2e_evt:
- parts.append(f"total={e2e_evt.e2e_total_ms / 1000.0:.2f}s")
- if "preprocess_ms" in pt:
- parts.append(f"preprocess={pt['preprocess_ms'] / 1000.0:.2f}s")
- if e2e_evt:
- engine_ms = e2e_evt.e2e_total_ms - pt.get("preprocess_ms", 0.0)
- parts.append(f"engine={engine_ms / 1000.0:.2f}s")
- stage_parts = []
- for evt in stage_evts:
- sid = evt.stage_id if evt.stage_id is not None else "?"
- t = evt.stage_gen_time_ms / 1000.0
- stage_parts.append(f"{sid}:{t:.2f}s")
- if stage_parts:
- parts.append(f"stages=[{','.join(stage_parts)}]")
- transfer_parts = []
- for te in self.transfer_events.values():
- if te.request_id == rid:
- transfer_parts.append(f"{te.from_stage}->{te.to_stage}={te.tx_time_ms:.2f}ms")
- if transfer_parts:
- parts.append(f"transfers=[{','.join(transfer_parts)}]")
- if "ar2diffusion_ms" in pt:
- parts.append(f"ar2diffusion={pt['ar2diffusion_ms']:.2f}ms")
- logger.info("[OmniTiming] %s", " ".join(parts))
-
- # === Stage table (columns = stage_id) ===
# if any stage has diffusion_metrics, remove postprocess_time_ms field
# because it is already included in diffusion_metrics
local_exclude = STAGE_EXCLUDE.copy()
diff --git a/vllm_omni/model_executor/models/bagel/bagel.py b/vllm_omni/model_executor/models/bagel/bagel.py
index 3714836f254..934f434e64e 100644
--- a/vllm_omni/model_executor/models/bagel/bagel.py
+++ b/vllm_omni/model_executor/models/bagel/bagel.py
@@ -1,3 +1,4 @@
+from collections import deque
from collections.abc import Iterable, Mapping, Sequence
from math import isqrt
from typing import Any
@@ -7,7 +8,7 @@
from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
-from vllm.inputs import ModalityData, MultiModalDataDict
+from vllm.inputs import MultiModalDataDict
from vllm.model_executor.layers.layernorm import RMSNorm as VllmRMSNorm
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
@@ -26,6 +27,7 @@
from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
+ ModalityData,
ModalityDataItems,
MultiModalDataItems,
MultiModalDataParser,
@@ -203,13 +205,6 @@ def _get_subparsers(self):
class OmniBagelMultiModalProcessor(BaseMultiModalProcessor[OmniBagelProcessingInfo]):
IMG2IMG_PLACEHOLDER = "<|fim_middle|>"
- @staticmethod
- def _mm_kwargs_for_bagel_img2img_hf(mm_kwargs: Mapping[str, object]) -> dict[str, object]:
- # OpenAI / GLM-style serving may pass target_h/target_w for output grid sizing.
- # BagelProcessor does not accept these in img2img mode; strip here so callers
- # (e.g. serving_chat) can stay model-agnostic.
- return {k: v for k, v in mm_kwargs.items() if k not in ("target_h", "target_w")}
-
def _cached_apply_hf_processor(self, inputs, timing_ctx):
# img2img: prompt text must be modified based on mm data presence,
# so text and mm data cannot be tokenized separately — bypass cache.
@@ -255,7 +250,7 @@ def _call_hf_processor(
if "images" in img2img_data:
del img2img_data["images"]
img2img_data["images"] = img2img_data.pop("pixel_values_img2img")
- kwargs_img2img = self._mm_kwargs_for_bagel_img2img_hf(mm_kwargs)
+ kwargs_img2img = dict(mm_kwargs)
kwargs_img2img["is_img2img"] = True
out_img2img = super()._call_hf_processor(prompt, img2img_data, kwargs_img2img, tok_kwargs)
if "pixel_values" in out_img2img:
@@ -269,7 +264,7 @@ def _call_hf_processor(
elif has_img2img:
mm_data = dict(mm_data)
mm_data["images"] = mm_data.pop("pixel_values_img2img")
- mm_kwargs = self._mm_kwargs_for_bagel_img2img_hf(mm_kwargs)
+ mm_kwargs = dict(mm_kwargs)
mm_kwargs["is_img2img"] = True
outputs = super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs)
if "pixel_values" in outputs:
@@ -413,22 +408,6 @@ class OmniBagelForConditionalGeneration(BagelForConditionalGeneration):
the DiT's denoising loop.
"""
- # LoRA packed→sublayer mapping for both standard Qwen2 projections
- # and the MoE generation-mode projections added by _install_mot_modules().
- packed_modules_mapping = {
- "qkv_proj": ["q_proj", "k_proj", "v_proj"],
- "gate_up_proj": ["gate_proj", "up_proj"],
- "qkv_proj_moe_gen": [
- "q_proj_moe_gen",
- "k_proj_moe_gen",
- "v_proj_moe_gen",
- ],
- "mlp_moe_gen.gate_up_proj": [
- "mlp_moe_gen.gate_proj",
- "mlp_moe_gen.up_proj",
- ],
- }
-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
@@ -448,7 +427,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._pending_img2img_info: list[tuple[int, int, int, int]] = []
self._ropes_pending: list[dict[str, Any]] = []
self._ropes_metadata: dict[str, dict[str, Any]] = {}
- self._last_img2img_info: tuple[int, int, int, int] | None = None
+ self._cfg_companion_queue: deque[tuple[tuple[int, int, int, int], int]] = deque()
+
+ # Per-request position offset for decode after img2img prefill.
+ # Prefill rewrites positions (VAE→0, ViT→1, text→2..N) but the model
+ # runner assigns decode positions starting from prefill_len, not N+1.
+ # offset = rope - prefill_len (a negative number).
+ self._pending_decode_offsets: list[int] = []
+ self._decode_position_offsets: dict[str, int] = {}
from transformers import AutoTokenizer
@@ -460,6 +446,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._start_of_image_id = int(_tok.convert_tokens_to_ids("<|vision_start|>"))
self._end_of_image_id = int(_tok.convert_tokens_to_ids("<|vision_end|>"))
self._img2img_token_id = int(_tok.convert_tokens_to_ids("<|fim_middle|>"))
+
self._vae_token_mask: torch.Tensor | None = None
self.device = get_local_device()
self._install_mot_modules(config)
@@ -538,7 +525,9 @@ def _clear_warmup_state(self):
self._ropes_pending.clear()
self._ropes_metadata.clear()
self._pending_img2img_info.clear()
- self._last_img2img_info = None
+ self._cfg_companion_queue.clear()
+ self._pending_decode_offsets.clear()
+ self._decode_position_offsets.clear()
self._vae_token_mask = None
def get_kv_transfer_metadata(
@@ -550,10 +539,12 @@ def get_kv_transfer_metadata(
meta = self._ropes_metadata.pop(req_id, None)
if meta is None:
return None
- if num_computed_tokens is not None and "image_shape" in meta:
- prefill_rope = meta["ropes"][0] if meta.get("ropes") else 0
- if num_computed_tokens > prefill_rope:
- meta["ropes"] = [num_computed_tokens]
+ # In think-mode img2img the prefill rope doesn't account for decoded
+ # thinking tokens; correct it to num_computed_tokens + offset.
+ # Skip correction when num_computed_tokens is unavailable (None).
+ offset = self._decode_position_offsets.pop(req_id, 0)
+ if offset != 0 and "ropes" in meta and num_computed_tokens is not None:
+ meta["ropes"] = [num_computed_tokens + offset]
return meta
def prepare_runner_inputs(
@@ -566,32 +557,48 @@ def prepare_runner_inputs(
num_scheduled_tokens: list[int],
input_ids_buffer: torch.Tensor | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
- """Restore input_ids so _adjust_positions_for_img2img can locate
- the <|fim_middle|> placeholder for thinking-mode pre_text_len
- detection."""
+ """Model-runner hook: adjust inputs before ``forward()``.
+
+ Returns ``(input_ids, positions)`` — possibly modified.
+
+ Two adjustments for BAGEL img2img:
+
+ 1. **Restore input_ids** when ``inputs_embeds`` is present so that
+ ``_adjust_positions_for_img2img`` can locate the
+ ``<|fim_middle|>`` placeholder.
+ 2. **Decode position offset**: prefill rewrites positions to a
+ compact scheme (rope ≪ prefill_len). The runner assigns decode
+ positions from ``num_computed_tokens``, which is far too large;
+ apply the stored per-request offset.
+ """
if inputs_embeds is not None and input_ids is None and input_ids_buffer is not None:
input_ids = input_ids_buffer
+
+ if self._decode_position_offsets and positions is not None:
+ token_start = 0
+ for i, rid in enumerate(req_ids):
+ sched = num_scheduled_tokens[i]
+ offset = self._decode_position_offsets.get(rid, 0)
+ if offset != 0 and num_computed_tokens[i] > 0:
+ positions[token_start : token_start + sched] += offset
+ token_start += sched
+
return input_ids, positions
def flush_pending_metadata(self, req_ids: list[str]) -> None:
- """Map pending metadata (batch order) to req_ids after forward().
-
- Guard: if a request already has metadata with ``image_shape``
- (written during img2img prefill), don't overwrite it with
- decode-step metadata that lacks ``image_shape``.
- """
+ """Map pending metadata (batch order) to req_ids after forward()."""
pending = self._ropes_pending
self._ropes_pending = []
for i, meta in enumerate(pending):
if i < len(req_ids):
- rid = req_ids[i]
- existing = self._ropes_metadata.get(rid)
- if existing and "image_shape" in existing and "image_shape" not in meta:
- continue
- ropes = meta.get("ropes")
- if ropes:
- meta["ropes"] = [int(r.item()) if isinstance(r, torch.Tensor) else r for r in ropes]
- self._ropes_metadata[rid] = meta
+ if req_ids[i] not in self._ropes_metadata:
+ self._ropes_metadata[req_ids[i]] = meta
+
+ pending_offsets = self._pending_decode_offsets
+ self._pending_decode_offsets = []
+ for i, offset in enumerate(pending_offsets):
+ if i < len(req_ids) and offset != 0:
+ self._decode_position_offsets[req_ids[i]] = offset
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {}
@@ -705,7 +712,16 @@ def _process_img2img_input(self, multimodal_input):
num_vit = vit_emb.shape[0] + 2
info = (num_vae, num_vit, int(H), int(W))
self._pending_img2img_info.append(info)
- self._last_img2img_info = info
+ # Only the gen (main) request should add a companion queue entry.
+ # Companion requests (cfg_text, cfg_img) also call this method with
+ # the same image, so guard by checking whether this exact info
+ # tuple is already enqueued. For batched img2img with multiple
+ # concurrent gen requests this correctly adds one entry per unique
+ # image; images with identical (num_vae, num_vit, H, W) that arrive
+ # in the same batch are indistinguishable here and will share one
+ # entry, but that is an uncommon edge case.
+ if not any(entry[0] == info for entry in self._cfg_companion_queue):
+ self._cfg_companion_queue.append((info, 2)) # cfg_text + cfg_img
return tuple(results)
@@ -724,18 +740,31 @@ def forward(
positions = self._adjust_positions_for_img2img(positions, input_ids)
use_mot = True
- elif self._last_img2img_info is not None:
- info = self._last_img2img_info
- num_vae, num_vit, _, _ = info
- num_img2img = num_vae + 1 + num_vit
-
- if seq_len >= num_img2img:
- self._pending_img2img_info = [info]
- positions = self._adjust_positions_for_img2img(positions, input_ids)
- use_mot = True
+ elif self._cfg_companion_queue:
+ # Guard: if this looks like a pure decode step (small token count,
+ # no multimodal embeddings), the queue has stale entries from a
+ # previous prefill cycle — clear them instead of consuming.
+ if inputs_embeds is None and seq_len <= 2:
+ self._cfg_companion_queue.clear()
else:
- rope = positions[seq_len - 1] + 1
- self._ropes_pending.append({"ropes": [rope]})
+ cached, remaining = self._cfg_companion_queue[0]
+ remaining -= 1
+ num_vae, num_vit, img_H, img_W = cached
+ num_img2img = num_vae + 1 + num_vit # +1 separator
+ seq_len = inputs_embeds.shape[0] if inputs_embeds is not None else positions.shape[0]
+
+ if inputs_embeds is not None and seq_len >= num_img2img:
+ self._pending_img2img_info = [cached]
+ positions = self._adjust_positions_for_img2img(positions, input_ids)
+ use_mot = True
+ else:
+ rope = int(positions[seq_len - 1].item()) + 1
+ self._ropes_pending.append({"ropes": [rope]})
+
+ if remaining == 0:
+ self._cfg_companion_queue.popleft()
+ else:
+ self._cfg_companion_queue[0] = (cached, remaining)
if use_mot:
return self._mot_forward(input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs)
@@ -746,18 +775,27 @@ def _adjust_positions_for_img2img(
positions: torch.Tensor,
input_ids: torch.Tensor | None = None,
) -> torch.Tensor:
- """Rewrite position IDs for img2img.
+ """Rewrite position IDs to match the original BAGEL position scheme:
+
+ If there are ``pre_text_len`` text tokens before the img2img block::
+
+ pre_text → 0, 1, ..., M-1
+ VAE → M (all share)
+ separator→ M
+ ViT → M+1 (all share)
+ post_text→ M+2, M+3, ...
- Supports an optional ``pre_text_len`` prefix (thinking-mode) detected
- via the ``<|fim_middle|>`` token in *input_ids*:
+ When no text precedes the img2img block (M=0), this reduces to the
+ simpler scheme: VAE→0, ViT→1, text→2, 3, ...
- pre_text -> 0 .. M-1
- VAE -> M (all share)
- separator-> M
- ViT -> M+1 (all share)
- post_text-> M+2, M+3, ...
+ Also computes ``self._vae_token_mask`` (bool tensor, True for actual
+ VAE latent patches that should use gen-mode weights) and pushes
+ per-request ropes + image_shape to the FIFO consumed by
+ ``get_kv_transfer_metadata``.
- When M=0 (standard img2img) this reduces to VAE->0, ViT->1, text->2..
+ For img2img requests, also stores a decode position offset so that
+ subsequent autoregressive decode steps use positions that continue
+ from the rewritten scheme rather than from the original prefill length.
"""
info_list = self._pending_img2img_info
self._pending_img2img_info = []
@@ -783,64 +821,70 @@ def _adjust_positions_for_img2img(
req_len = end - start
if img2img_idx < len(info_list):
- cur_info = info_list[img2img_idx]
- elif self._last_img2img_info is not None:
- cur_info = self._last_img2img_info
- else:
- cur_info = None
-
- if cur_info is not None:
- num_vae, num_vit, img_H, img_W = cur_info
+ num_vae, num_vit, img_H, img_W = info_list[img2img_idx]
num_img2img = num_vae + 1 + num_vit # +1 separator
if req_len >= num_img2img:
+ # Detect offset of img2img tokens within this request
+ # by searching for the img2img placeholder token ID.
pre_text_len = 0
if input_ids is not None:
- req_ids_slice = input_ids[start:end]
- indices = (req_ids_slice == self._img2img_token_id).nonzero(as_tuple=True)[0]
+ req_ids = input_ids[start:end]
+ mask = req_ids == self._img2img_token_id
+ indices = mask.nonzero(as_tuple=True)[0]
if indices.numel() > 0:
pre_text_len = int(indices[0].item())
- M = pre_text_len
- img_start = start + M
+ img_start = start + pre_text_len
post_text_start = img_start + num_img2img
+ # pre_text_pos: position base for image tokens
+ pre_text_pos = pre_text_len
- if M > 0:
+ # Pre-image text: sequential positions 0..pre_text_pos-1
+ if pre_text_len > 0:
new_positions[start:img_start] = torch.arange(
- 0, M, device=positions.device, dtype=positions.dtype
+ 0, pre_text_pos, device=positions.device, dtype=positions.dtype
)
- new_positions[img_start : img_start + num_vae] = M
- new_positions[img_start + num_vae] = M # separator
+ # VAE tokens: all share position pre_text_pos
+ new_positions[img_start : img_start + num_vae] = pre_text_pos
+ # Separator: position pre_text_pos
+ new_positions[img_start + num_vae] = pre_text_pos
+ # ViT tokens: all share position pre_text_pos+1
vit_start = img_start + num_vae + 1
- new_positions[vit_start : vit_start + num_vit] = M + 1
+ new_positions[vit_start : vit_start + num_vit] = pre_text_pos + 1
+ # Post-image text: sequential positions pre_text_pos+2, pre_text_pos+3, ...
num_post_text = end - post_text_start
if num_post_text > 0:
new_positions[post_text_start:end] = torch.arange(
- M + 2,
- M + 2 + num_post_text,
+ pre_text_pos + 2,
+ pre_text_pos + 2 + num_post_text,
device=positions.device,
dtype=positions.dtype,
)
- vae_patches_start = img_start + 1
- vae_patches_end = img_start + num_vae - 1
+ # VAE gen-mode mask: only actual VAE latent patches (not markers)
+ vae_patches_start = img_start + 1 # skip start_marker
+ vae_patches_end = img_start + num_vae - 1 # before end_marker
if vae_patches_end > vae_patches_start:
vae_mask[vae_patches_start:vae_patches_end] = True
- rope = M + 2 + num_post_text
+ rope = pre_text_pos + 2 + num_post_text
self._ropes_pending.append(
{
"ropes": [rope],
"image_shape": [img_H, img_W],
}
)
+ decode_offset = rope - req_len
+ self._pending_decode_offsets.append(decode_offset)
img2img_idx += 1
continue
rope = int(new_positions[end - 1].item()) + 1
self._ropes_pending.append({"ropes": [rope]})
+ self._pending_decode_offsets.append(0)
self._vae_token_mask = vae_mask if vae_mask.any() else None
return new_positions
diff --git a/vllm_omni/model_executor/models/bagel/pipeline.py b/vllm_omni/model_executor/models/bagel/pipeline.py
deleted file mode 100644
index c68a531c294..00000000000
--- a/vllm_omni/model_executor/models/bagel/pipeline.py
+++ /dev/null
@@ -1,111 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""BAGEL-7B-MoT pipeline topologies (frozen).
-
-Two-stage (default):
- Stage 0: Thinker — multimodal understanding + text generation (AR)
- Stage 1: DiT — diffusion image generation
-
-Two-stage think:
- Same as two-stage but the Thinker decodes ... tokens before
- KV transfer. Uses expand_cfg_prompts_think (companion max_tokens=1) and
- omits kv_transfer_criteria so transfer happens after EOS, not after prefill.
-
-Single-stage:
- Stage 0: DiT — self-contained diffusion stage that handles all modalities
- (text2img, img2img, img2text, text2text, think) internally via its
- own LLM, ViT, VAE, and tokenizer.
-"""
-
-from vllm_omni.config.stage_config import (
- PipelineConfig,
- StageExecutionType,
- StagePipelineConfig,
-)
-
-_PROC = "vllm_omni.model_executor.stage_input_processors.bagel"
-
-BAGEL_PIPELINE = PipelineConfig(
- model_type="bagel",
- model_arch="OmniBagelForConditionalGeneration",
- hf_architectures=("BagelForConditionalGeneration",),
- stages=(
- StagePipelineConfig(
- stage_id=0,
- model_stage="thinker",
- execution_type=StageExecutionType.LLM_AR,
- input_sources=(),
- final_output=True,
- final_output_type="text",
- owns_tokenizer=True,
- requires_multimodal_data=True,
- model_arch="OmniBagelForConditionalGeneration",
- engine_output_type="text",
- prompt_expand_func=f"{_PROC}.expand_cfg_prompts",
- omni_kv_config={
- "need_send_cache": True,
- "kv_transfer_criteria": {"type": "prefill_finished"},
- },
- sampling_constraints={"detokenize": True},
- ),
- StagePipelineConfig(
- stage_id=1,
- model_stage="dit",
- execution_type=StageExecutionType.DIFFUSION,
- input_sources=(0,),
- final_output=True,
- final_output_type="image",
- cfg_kv_collect_func=f"{_PROC}.collect_cfg_kv_caches",
- omni_kv_config={"need_recv_cache": True},
- ),
- ),
-)
-
-BAGEL_THINK_PIPELINE = PipelineConfig(
- model_type="bagel_think",
- model_arch="OmniBagelForConditionalGeneration",
- hf_architectures=(),
- stages=(
- StagePipelineConfig(
- stage_id=0,
- model_stage="thinker",
- execution_type=StageExecutionType.LLM_AR,
- input_sources=(),
- final_output=True,
- final_output_type="text",
- owns_tokenizer=True,
- requires_multimodal_data=True,
- model_arch="OmniBagelForConditionalGeneration",
- engine_output_type="text",
- prompt_expand_func=f"{_PROC}.expand_cfg_prompts_think",
- omni_kv_config={"need_send_cache": True},
- sampling_constraints={"detokenize": True},
- ),
- StagePipelineConfig(
- stage_id=1,
- model_stage="dit",
- execution_type=StageExecutionType.DIFFUSION,
- input_sources=(0,),
- final_output=True,
- final_output_type="image",
- cfg_kv_collect_func=f"{_PROC}.collect_cfg_kv_caches",
- omni_kv_config={"need_recv_cache": True},
- ),
- ),
-)
-
-BAGEL_SINGLE_STAGE_PIPELINE = PipelineConfig(
- model_type="bagel_single_stage",
- model_arch="BagelForConditionalGeneration",
- hf_architectures=(),
- stages=(
- StagePipelineConfig(
- stage_id=0,
- model_stage="dit",
- execution_type=StageExecutionType.DIFFUSION,
- input_sources=(),
- final_output=True,
- final_output_type="image",
- ),
- ),
-)
diff --git a/vllm_omni/model_executor/models/common/__init__.py b/vllm_omni/model_executor/models/common/__init__.py
deleted file mode 100644
index e69de29bb2d..00000000000
diff --git a/vllm_omni/model_executor/models/common/qwen3_code_predictor.py b/vllm_omni/model_executor/models/common/qwen3_code_predictor.py
deleted file mode 100644
index ecfeec38174..00000000000
--- a/vllm_omni/model_executor/models/common/qwen3_code_predictor.py
+++ /dev/null
@@ -1,753 +0,0 @@
-"""Qwen3 Code Predictor -- optimized re-prefill, no KV cache.
-
-Shared by Qwen3-Omni and Qwen3-TTS talker models.
-
-* SDPA attention (F.scaled_dot_product_attention) with native GQA support
-* HF-compatible numerics (float32 RMSNorm, float32 RoPE, separate linear layers)
-* Per-call embedding buffer to avoid cross-request aliasing
-* Pre-allocated position_ids (read-only, safe to persist)
-* torch.compile (epilogue_fusion=False) on inner transformer by default
-* Optional manual CUDA graph capture per batch-size bucket
-* Inline sampling (top-k + top-p) -- no custom op overhead
-"""
-
-from __future__ import annotations
-
-import dataclasses
-from collections.abc import Iterable
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from vllm.config import VllmConfig
-from vllm.logger import init_logger
-from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
-from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-
-from vllm_omni.platforms import current_omni_platform
-
-logger = init_logger(__name__)
-
-
-# ===================================================================
-# HF-numerics-compatible layers for code predictor
-# ===================================================================
-#
-# These use plain PyTorch ops (nn.Linear, manual RMSNorm in float32,
-# rotate_half RoPE) to produce outputs numerically identical to the
-# HuggingFace reference. vLLM's fused kernels (RMSNorm, QKVParallel,
-# get_rope) introduce small precision differences that compound across
-# the autoregressive steps of the code predictor, causing severe
-# audio quality degradation.
-#
-# See: https://github.com/vllm-project/vllm-omni/issues/2274
-
-
-class _RMSNorm(nn.Module):
- """RMSNorm matching HuggingFace's implementation exactly.
-
- Computes variance in float32 to avoid bfloat16 precision loss.
- """
-
- def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-
-def _rotate_half(x: torch.Tensor) -> torch.Tensor:
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
-
-class _RotaryEmbedding(nn.Module):
- """RoPE matching HuggingFace's implementation exactly.
-
- Forces float32 computation for cos/sin, matching HF's torch.autocast(enabled=False).
- """
-
- def __init__(self, config) -> None:
- super().__init__()
- head_dim = getattr(
- config,
- "head_dim",
- config.hidden_size // config.num_attention_heads,
- )
- rope_theta = getattr(config, "rope_theta", 10000.0)
- inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
- self.register_buffer("inv_freq", inv_freq, persistent=False)
-
- def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
- # position_ids: [batch, seq_len]
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
-
- # Force float32 (matching HF)
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
-
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
-
-
-# ===================================================================
-# Attention
-# ===================================================================
-
-
-class CodePredictorAttention(nn.Module):
- """Multi-head self-attention for code predictor.
-
- Uses ``F.scaled_dot_product_attention`` with HF-compatible RoPE and RMSNorm.
- No KV cache -- the code predictor always re-prefills the full (short)
- sequence each AR step.
-
- Input : [B, seq_len, hidden_size]
- Output: [B, seq_len, hidden_size]
- """
-
- def __init__(self, config, *, prefix: str = "") -> None:
- super().__init__()
- self.num_heads = config.num_attention_heads
- self.num_kv_heads = config.num_key_value_heads
- assert self.num_heads % self.num_kv_heads == 0
- self.is_gqa = self.num_kv_heads != self.num_heads
- self.num_queries_per_kv = self.num_heads // self.num_kv_heads
- self.head_dim = getattr(
- config,
- "head_dim",
- config.hidden_size // config.num_attention_heads,
- )
- self.hidden_size = config.hidden_size
- self.scaling = self.head_dim**-0.5
- self.max_seq = int(config.num_code_groups) + 1
-
- # Separate q/k/v projections matching HF (no fused packing)
- bias = getattr(config, "attention_bias", False)
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=bias)
- self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=bias)
- self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=bias)
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
- self.q_norm = _RMSNorm(self.head_dim, eps=config.rms_norm_eps)
- self.k_norm = _RMSNorm(self.head_dim, eps=config.rms_norm_eps)
-
- if current_omni_platform.is_npu():
- if self.max_seq > 2048:
- raise ValueError(
- "Qwen3-TTS code predictor NPU fusion attention uses a fixed 2048x2048 "
- f"causal mask, but max_seq={self.max_seq} exceeds the mask size."
- )
- # Ascend SDPA is_causal migration example uses a fixed 2048x2048
- # compressed causal mask with sparse_mode=2.
- fusion_mask = torch.triu(
- torch.ones(2048, 2048, dtype=torch.bool),
- diagonal=1,
- )
- self.register_buffer("_fusion_causal_mask", fusion_mask, persistent=False)
-
- def _forward_npu_attention(
- self,
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- bsz: int,
- seq_len: int,
- ) -> torch.Tensor:
- import torch_npu
-
- q_f, k_f, v_f = q, k, v
- if self.is_gqa:
- k_f = (
- k[:, :, None, :, :]
- .expand(bsz, self.num_kv_heads, self.num_queries_per_kv, seq_len, self.head_dim)
- .reshape(bsz, self.num_heads, seq_len, self.head_dim)
- )
- v_f = (
- v[:, :, None, :, :]
- .expand(bsz, self.num_kv_heads, self.num_queries_per_kv, seq_len, self.head_dim)
- .reshape(bsz, self.num_heads, seq_len, self.head_dim)
- )
-
- mask = self._fusion_causal_mask
- mask = mask.contiguous()
- q_f = q_f.contiguous()
- k_f = k_f.contiguous()
- v_f = v_f.contiguous()
- return torch_npu.npu_fusion_attention(
- q_f,
- k_f,
- v_f,
- self.num_heads,
- "BNSD",
- pse=None,
- padding_mask=None,
- atten_mask=mask,
- scale=float(self.scaling),
- keep_prob=1.0,
- # Keep torch_npu's API spelling.
- pre_tockens=2147483647,
- next_tockens=2147483647,
- inner_precise=0,
- prefix=None,
- actual_seq_qlen=None,
- actual_seq_kvlen=None,
- # Ascend SDPA is_causal migration example uses sparse_mode=2.
- sparse_mode=2,
- gen_mask_parallel=True,
- # Keep sync=True for the NPU fused attention path.
- sync=True,
- )[0]
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- ) -> torch.Tensor:
- bsz, seq_len, _ = hidden_states.shape
- hidden_shape_q = (bsz, seq_len, self.num_heads, self.head_dim)
- hidden_shape_kv = (bsz, seq_len, self.num_kv_heads, self.head_dim)
-
- q = self.q_norm(self.q_proj(hidden_states).view(hidden_shape_q)).transpose(1, 2)
- k = self.k_norm(self.k_proj(hidden_states).view(hidden_shape_kv)).transpose(1, 2)
- v = self.v_proj(hidden_states).view(hidden_shape_kv).transpose(1, 2)
-
- cos, sin = position_embeddings
- # cos/sin are [batch, seq_len, head_dim], need unsqueeze at dim=1 for heads
- cos = cos.unsqueeze(1) # [batch, 1, seq_len, head_dim]
- sin = sin.unsqueeze(1)
- q = (q * cos) + (_rotate_half(q) * sin)
- k = (k * cos) + (_rotate_half(k) * sin)
-
- if not current_omni_platform.is_npu():
- attn_out = F.scaled_dot_product_attention(
- q,
- k,
- v,
- scale=self.scaling,
- is_causal=True,
- enable_gqa=self.is_gqa,
- )
- else:
- attn_out = self._forward_npu_attention(q, k, v, bsz, seq_len)
-
- attn_out = attn_out.transpose(1, 2).reshape(bsz, seq_len, -1)
- return self.o_proj(attn_out)
-
-
-# ===================================================================
-# MLP
-# ===================================================================
-
-
-class CodePredictorMLP(nn.Module):
- """SiLU-gated MLP for code predictor, matching HF's implementation."""
-
- def __init__(self, config, *, prefix: str = "") -> None:
- super().__init__()
- self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
- self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
- self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- return self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
-
-
-# ===================================================================
-# Decoder Layer
-# ===================================================================
-
-
-class CodePredictorDecoderLayer(nn.Module):
- """Transformer decoder layer (SDPA, no KV cache)."""
-
- def __init__(self, config, *, prefix: str = "") -> None:
- super().__init__()
- self.self_attn = CodePredictorAttention(config, prefix=f"{prefix}.self_attn")
- self.mlp = CodePredictorMLP(config, prefix=f"{prefix}.mlp")
- self.input_layernorm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- hidden_states = self.self_attn(hidden_states, position_embeddings)
- hidden_states = residual + hidden_states
-
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states
-
-
-# ===================================================================
-# Base Transformer Model (re-prefill, no KV cache)
-# ===================================================================
-
-
-class CodePredictorBaseModel(nn.Module):
- """Inner transformer for code predictor.
-
- Signature: ``forward(inputs_embeds, position_ids) -> hidden_states``
- """
-
- def __init__(
- self,
- config,
- *,
- embedding_dim: int | None = None,
- use_parallel_embedding: bool = False,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.config = config
-
- emb_dim = int(embedding_dim) if embedding_dim is not None else int(config.hidden_size)
- if use_parallel_embedding:
- self.codec_embedding = nn.ModuleList(
- [VocabParallelEmbedding(config.vocab_size, emb_dim) for _ in range(config.num_code_groups - 1)]
- )
- else:
- self.codec_embedding = nn.ModuleList(
- [nn.Embedding(config.vocab_size, emb_dim) for _ in range(config.num_code_groups - 1)]
- )
-
- self.layers = nn.ModuleList(
- [
- CodePredictorDecoderLayer(config, prefix=f"{prefix}.layers.{idx}")
- for idx in range(config.num_hidden_layers)
- ]
- )
- self.norm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.rotary_emb = _RotaryEmbedding(config)
-
- def get_input_embeddings(self) -> nn.ModuleList:
- return self.codec_embedding
-
- def forward(
- self,
- inputs_embeds: torch.Tensor,
- position_ids: torch.Tensor,
- ) -> torch.Tensor:
- hidden_states = inputs_embeds
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
- for layer in self.layers:
- hidden_states = layer(hidden_states, position_embeddings)
- hidden_states = self.norm(hidden_states)
- return hidden_states
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- params_dict = dict(self.named_parameters(remove_duplicate=False))
- loaded_params: set[str] = set()
- for name, loaded_weight in weights:
- if "rotary_emb.inv_freq" in name:
- continue
- param = params_dict.get(name)
- if param is None:
- continue
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
- return loaded_params
-
-
-# ===================================================================
-# Wrapper Configuration
-# ===================================================================
-
-
-@dataclasses.dataclass
-class CodePredictorWrapperConfig:
- """Controls behavioral differences between model-specific code predictors."""
-
- use_cuda_graphs: bool = False
- use_parallel_embedding: bool = False
- use_projection: bool = False
- return_proj_buf: bool = False
- sampling_mode: str = "stored"
-
-
-# ===================================================================
-# Code Predictor Wrapper (optimized re-prefill, persistent buffers)
-# ===================================================================
-
-
-class CodePredictorWrapper(nn.Module):
- """Optimized code predictor -- re-prefill approach, no KV cache.
-
- Each AR step forwards the full growing sequence (len 2 -> num_code_groups+1)
- through the transformer. The extra O(T^2) FLOPs are negligible for
- short sequences, and this avoids all KV-cache management overhead.
-
- Optimizations:
- 1. Per-call embedding buffer -- avoids cross-request aliasing.
- 2. Pre-allocated position_ids -- no torch.arange per step.
- 3. Cached module references -- bypass ModuleList indexing.
- 4. torch.compile on inner transformer.
- 5. Inline sampling (top-k + top-p) -- no custom op overhead.
- 6. Optional manual CUDA graph capture per batch-size bucket.
- """
-
- def __init__(
- self,
- *,
- vllm_config: VllmConfig,
- cp_config,
- wrapper_config: CodePredictorWrapperConfig,
- talker_hidden_size: int | None = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self._vllm_config = vllm_config
- self.config = cp_config
- self._wrapper_config = wrapper_config
- self.prefix = prefix
-
- self._num_groups = int(cp_config.num_code_groups)
- self._cp_hidden = int(cp_config.hidden_size)
-
- # For Omni backward compat (accessed by the talker)
- self.num_code_groups = self._num_groups
-
- # Determine embedding dimension
- _talker_hidden = int(talker_hidden_size) if talker_hidden_size is not None else self._cp_hidden
-
- self.model = CodePredictorBaseModel(
- cp_config,
- embedding_dim=_talker_hidden,
- use_parallel_embedding=wrapper_config.use_parallel_embedding,
- prefix=f"{prefix}.model" if prefix else "model",
- )
-
- self.lm_head = nn.ModuleList(
- [nn.Linear(cp_config.hidden_size, cp_config.vocab_size, bias=False) for _ in range(self._num_groups - 1)]
- )
-
- # Projection: Identity when hidden sizes match or not needed
- if wrapper_config.use_projection and _talker_hidden != self._cp_hidden:
- self.small_to_mtp_projection = nn.Linear(_talker_hidden, self._cp_hidden, bias=True)
- else:
- self.small_to_mtp_projection = nn.Identity()
-
- # Sampling defaults for "stored" mode
- self._top_k: int = 50
- self._top_p: float = 0.8
-
- # Lazily initialised state
- self._proj_buf: torch.Tensor | None = None
- self._model_dtype: torch.dtype | None = None
- self._compiled_model_fwd = None
- self._bucket_sizes: list[int] = []
- self._bucket_pos_ids: dict[int, torch.Tensor] = {}
- self._lm_heads_list: list[nn.Module] | None = None
- self._codec_embeds_list: list[nn.Module] | None = None
- self._device_graphs: dict[int, tuple] = {} # (graph, static_output) per bucket
-
- def get_input_embeddings(self) -> nn.ModuleList:
- return self.model.get_input_embeddings()
-
- def set_sampling_params(self, top_k: int = 50, top_p: float = 0.8) -> None:
- """Configure sampling parameters to maintain consistency with previous implementation."""
- self._top_k = top_k
- self._top_p = top_p
- logger.debug("Sampling parameters updated: top_k=%d, top_p=%.2f", top_k, top_p)
-
- # ------------------------------------------------------------------
- # Lazy-init helpers
- # ------------------------------------------------------------------
-
- def _ensure_buffers(self, device: torch.device, dtype: torch.dtype, bsz: int) -> None:
- """Ensure the projection buffer can hold at least *bsz* rows."""
- max_seq = self._num_groups + 1
- if (
- self._proj_buf is not None
- and self._proj_buf.device == device
- and self._proj_buf.dtype == dtype
- and self._proj_buf.shape[0] >= bsz
- ):
- return
- self._proj_buf = torch.zeros(bsz, max_seq, self._cp_hidden, dtype=dtype, device=device)
-
- def _setup_compile(self) -> None:
- """Lazily set up torch.compile with optional device graph capture."""
- if self._compiled_model_fwd is not None:
- return
-
- # Cache model parameter dtype so forward() doesn't need to query it
- # on every call. Also ensures warmup buffers match model precision
- # even when upstream modules produce a different dtype (#2385).
- self._model_dtype = next(self.model.parameters()).dtype
- self._lm_heads_list = list(self.lm_head)
- self._codec_embeds_list = list(self.model.codec_embedding)
-
- if not current_omni_platform.supports_torch_inductor():
- # NPU or other platforms without Inductor support
- self._compiled_model_fwd = self.model.forward
-
- if current_omni_platform.is_npu() and self._wrapper_config.use_cuda_graphs:
- # For NPU, use eager + NPU graphs (no torch.compile)
- self._warmup_buckets()
- self._capture_npu_graphs()
- logger.info("code_predictor: eager mode + NPU graphs")
- else:
- logger.warning_once("code_predictor: torch.compile disabled")
- return
-
- # torch.compile fuses RMSNorm/RoPE in ways that lose float32
- # precision, compounding across AR steps. Use epilogue_fusion=False
- # to disable the problematic fusions while still getting kernel
- # fusion benefits for the linear layers and SDPA.
- self._compiled_model_fwd = torch.compile(
- self.model.forward,
- dynamic=False,
- options={"epilogue_fusion": False},
- )
- self._warmup_buckets()
-
- if self._wrapper_config.use_cuda_graphs:
- self._capture_cuda_graphs()
- logger.info("code_predictor: torch.compile (no epilogue fusion) + CUDA graphs")
- else:
- logger.info("code_predictor: torch.compile (dynamic=False, no epilogue fusion)")
-
- def _padded_bsz(self, bsz: int) -> int:
- """Round batch size up to nearest power-of-2 bucket."""
- for bucket in self._bucket_sizes:
- if bsz <= bucket:
- return bucket
- return bsz
-
- def _warmup_buckets(self) -> None:
- """Warmup power-of-2 batch-size buckets to front-load Inductor compilation."""
- max_bsz = self._vllm_config.scheduler_config.max_num_seqs
- bucket_sizes = [1 << i for i in range(max_bsz.bit_length()) if (1 << i) <= max_bsz]
- if max_bsz not in bucket_sizes:
- bucket_sizes.append(max_bsz)
- self._bucket_sizes = sorted(bucket_sizes)
-
- max_seq = self._num_groups + 1
- device = next(self.model.parameters()).device
-
- # Ensure proj_buf matches model parameter dtype to avoid dtype
- # mismatch during warmup compilation (see #2385).
- self._ensure_buffers(device, self._model_dtype, max(self._bucket_sizes))
- proj_buf = self._proj_buf
-
- for bsz in self._bucket_sizes:
- pos_ids = torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(bsz, -1).contiguous()
- self._bucket_pos_ids[bsz] = pos_ids
- for _ in range(3):
- self._compiled_model_fwd(proj_buf[:bsz, :max_seq, :], pos_ids)
- logger.info("code_predictor: warmup done for buckets %s", self._bucket_sizes)
-
- def _capture_cuda_graphs(self) -> None:
- """Capture a CUDA graph per bucket using vLLM's global graph pool."""
- from vllm.platforms import current_platform
-
- pool = current_platform.get_global_graph_pool()
- max_seq = self._num_groups + 1
- proj_buf = self._proj_buf
-
- for bsz in self._bucket_sizes:
- static_input = proj_buf[:bsz, :max_seq, :]
- pos_ids = self._bucket_pos_ids[bsz]
-
- g = torch.cuda.CUDAGraph()
- with torch.cuda.graph(g, pool=pool):
- static_output = self._compiled_model_fwd(static_input, pos_ids)
-
- self._device_graphs[bsz] = (g, static_output)
-
- logger.info("code_predictor: captured CUDA graphs for buckets %s", self._bucket_sizes)
-
- def _capture_npu_graphs(self) -> None:
- """Capture an NPU graph per bucket using torch_npu's NPUGraph."""
- max_seq = self._num_groups + 1
- proj_buf = self._proj_buf
- pool = torch.npu.graph_pool_handle()
-
- for bsz in self._bucket_sizes:
- static_input = proj_buf[:bsz, :max_seq, :]
- pos_ids = self._bucket_pos_ids[bsz]
-
- g = torch.npu.NPUGraph()
- with torch.npu.graph(g, pool=pool):
- static_output = self._compiled_model_fwd(static_input, pos_ids)
-
- self._device_graphs[bsz] = (g, static_output)
-
- logger.info("code_predictor: captured NPU graphs for buckets %s", self._bucket_sizes)
-
- # ------------------------------------------------------------------
- # Forward -- re-prefill + inline sampling
- # ------------------------------------------------------------------
-
- @torch.inference_mode()
- def forward(
- self,
- layer0_code: torch.Tensor,
- layer0_embed: torch.Tensor,
- last_talker_hidden: torch.Tensor,
- do_sample: bool = True,
- temperature: float = 0.9,
- top_k: int = 50,
- top_p: float = 1.0,
- ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- """Predict residual codebooks 1..G-1 autoregressively via re-prefill."""
- bsz = int(layer0_code.shape[0])
- num_groups = self._num_groups
- device = layer0_code.device
-
- # _setup_compile caches _model_dtype on first call; use it for buffers
- # so they always match model weight precision (#2385).
- self._setup_compile()
- dtype = self._model_dtype
-
- padded_bsz = self._padded_bsz(bsz)
- self._ensure_buffers(device, dtype, padded_bsz)
-
- proj_buf = self._proj_buf
- max_seq = num_groups + 1
- projection = self.small_to_mtp_projection
- model_fwd = self._compiled_model_fwd
- lm_heads = self._lm_heads_list
- codec_embeds = self._codec_embeds_list
-
- # Zero the padded region of the buffer
- proj_buf[:padded_bsz].zero_()
-
- # Fill buffer positions 0 (talker hidden) & 1 (layer0 embed)
- proj_buf[:bsz, 0, :] = projection(last_talker_hidden.reshape(bsz, 1, -1).to(dtype)).reshape(bsz, -1)
- proj_buf[:bsz, 1, :] = projection(layer0_embed.reshape(bsz, 1, -1).to(dtype)).reshape(bsz, -1)
-
- # Get pre-computed pos_ids for this bucket
- full_pos_ids = self._bucket_pos_ids.get(padded_bsz)
- if full_pos_ids is None:
- full_pos_ids = (
- torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(padded_bsz, -1).contiguous()
- )
-
- # Use captured device graph if available, otherwise call compiled fn.
- device_graph_entry = self._device_graphs.get(padded_bsz)
-
- # Prepare sampling parameters
- stored_mode = self._wrapper_config.sampling_mode == "stored"
- if stored_mode:
- s_top_k = self._top_k
- s_top_p = self._top_p
- else:
- use_sampling = do_sample and temperature > 0
- inv_temperature = 1.0 / max(temperature, 1e-6) if use_sampling else 0.0
- if use_sampling and top_p != 1.0:
- raise NotImplementedError(
- "top_p sampling is not implemented for the vLLM-native code predictor; please set top_p=1.0."
- )
-
- # Output codes -- shape depends on return mode
- if self._wrapper_config.return_proj_buf:
- all_codes = torch.empty(bsz, num_groups, 1, dtype=torch.int64, device=device)
- all_codes[:, 0] = layer0_code.reshape(bsz, -1)[:, :1]
- else:
- all_codes = torch.empty(bsz, num_groups, dtype=torch.long, device=device)
- all_codes[:, 0] = layer0_code.reshape(bsz)
-
- # Autoregressive loop: predict layers 1..G-1
- for step in range(1, num_groups):
- # Run transformer (device graph replay or compiled forward)
- if device_graph_entry is not None:
- device_graph_entry[0].replay()
- hidden_out = device_graph_entry[1]
- else:
- hidden_out = model_fwd(proj_buf[:padded_bsz, :max_seq, :], full_pos_ids)
-
- logits = lm_heads[step - 1](hidden_out[:bsz, step, :])
-
- # Sample next code
- if stored_mode:
- # "stored" mode: top-k -> top-p -> softmax -> multinomial
- if s_top_k > 0:
- topk_vals, _ = logits.topk(s_top_k, dim=-1)
- logits = logits.masked_fill(logits < topk_vals[:, -1:], float("-inf"))
- if s_top_p < 1.0:
- sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
- sorted_probs = F.softmax(sorted_logits, dim=-1)
- cumulative_probs = sorted_probs.cumsum(dim=-1)
- remove_mask = (cumulative_probs - sorted_probs) >= s_top_p
- sorted_logits[remove_mask] = float("-inf")
- logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
- probs = F.softmax(logits, dim=-1)
- code = torch.multinomial(probs, num_samples=1)
- else:
- # "per_call" mode: temperature-scaled + top-k
- if use_sampling:
- scaled = logits * inv_temperature
- if top_k > 0:
- topk_vals, _ = scaled.topk(top_k, dim=-1)
- scaled = scaled.masked_fill(scaled < topk_vals[:, -1:], float("-inf"))
- probs = F.softmax(scaled, dim=-1)
- code = torch.multinomial(probs, num_samples=1)
- else:
- code = logits.argmax(dim=-1, keepdim=True)
-
- # Store code
- if self._wrapper_config.return_proj_buf:
- all_codes[:, step] = code
- else:
- all_codes[:, step] = code.reshape(bsz)
-
- # Embed predicted code -> project -> next buffer position
- if step < num_groups - 1 or self._wrapper_config.return_proj_buf:
- new_embed = codec_embeds[step - 1](code)
- proj_buf[:bsz, step + 1, :] = projection(new_embed.reshape(bsz, 1, -1)).reshape(bsz, -1)
-
- if self._wrapper_config.return_proj_buf:
- return all_codes, proj_buf[:bsz].clone()
- return all_codes
-
- # ------------------------------------------------------------------
- # Weight loading
- # ------------------------------------------------------------------
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- """Load weights directly (no fused projection remapping needed)."""
- loaded: set[str] = set()
- model_weights: list[tuple[str, torch.Tensor]] = []
- other_weights: list[tuple[str, torch.Tensor]] = []
-
- for name, w in weights:
- if "rotary_emb.inv_freq" in name:
- continue
- if name.startswith("model."):
- model_weights.append((name[len("model.") :], w))
- else:
- other_weights.append((name, w))
-
- loaded_model = self.model.load_weights(model_weights)
- loaded |= {f"model.{n}" for n in loaded_model}
-
- params = dict(self.named_parameters(remove_duplicate=False))
- for name, w in other_weights:
- param = params.get(name)
- if param is None:
- continue
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, w)
- loaded.add(name)
-
- return loaded
diff --git a/vllm_omni/model_executor/models/cosyvoice3/code2wav_core/cfm.py b/vllm_omni/model_executor/models/cosyvoice3/code2wav_core/cfm.py
index 36ff0d45659..7281cd81f97 100644
--- a/vllm_omni/model_executor/models/cosyvoice3/code2wav_core/cfm.py
+++ b/vllm_omni/model_executor/models/cosyvoice3/code2wav_core/cfm.py
@@ -174,7 +174,7 @@ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator:
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
@torch.inference_mode()
- def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming: bool = False):
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
@@ -277,9 +277,7 @@ def inference(
prompt_feat,
prompt_feat_len,
embedding,
- streaming: bool = True,
- finalize: bool = False,
- n_timesteps: int = 10,
+ finalize,
):
assert token.shape[0] == 1
# xvec projection
@@ -316,8 +314,7 @@ def inference(
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
- n_timesteps=max(1, int(n_timesteps)),
- streaming=streaming,
+ n_timesteps=10,
)
feat = feat[:, :, mel_len1:]
diff --git a/vllm_omni/model_executor/models/cosyvoice3/config.py b/vllm_omni/model_executor/models/cosyvoice3/config.py
index 518fe76b78a..0c9a2899797 100644
--- a/vllm_omni/model_executor/models/cosyvoice3/config.py
+++ b/vllm_omni/model_executor/models/cosyvoice3/config.py
@@ -7,10 +7,6 @@ class CosyVoice3Config(PretrainedConfig):
model_type = "cosyvoice3"
def __init__(self, **kwargs):
- # Set primary speech EOS so vLLM stops generation at the right token.
- # The official CosyVoice3 treats ALL tokens >= speech_token_size
- # (6561-6760) as stop signals; see stop_token_ids in the YAML configs.
- kwargs.setdefault("eos_token_id", 6562)
super().__init__(**kwargs)
self.sample_rate = 24000
self.llm_input_size = 896
diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py
index 2fba8fb8af1..bc04aae33c9 100644
--- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py
+++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3.py
@@ -2,16 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Iterable, Mapping, Sequence
-from dataclasses import replace
from functools import partial
-from threading import Lock
+import numpy as np
import torch
import torch.nn as nn
from transformers.feature_extraction_utils import BatchFeature
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
-from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.inputs import MultiModalDataDict
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsMultiModal
@@ -28,9 +26,6 @@
PromptUpdate,
)
from vllm.sequence import IntermediateTensors
-from vllm.v1.outputs import SamplerOutput
-from vllm.v1.sample.metadata import SamplingMetadata
-from vllm.v1.sample.sampler import Sampler
from vllm_omni.model_executor.models.cosyvoice3.config import CosyVoice3Config
from vllm_omni.model_executor.models.cosyvoice3.utils import (
@@ -272,8 +267,6 @@ class CosyVoice3Model(
supports_multimodal_raw_input_only = True
supports_multimodal = True
requires_raw_input_tokens = True
- prefer_model_sampler = True
- _sampling_eps = 1e-5
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -312,8 +305,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.code2wav = CosyVoice3Code2Wav(self.config)
self.model = self.code2wav.flow_model
self.hift = self.code2wav.hift
- # Keep additional information synchronized for async_chunk updates.
- self.enable_update_additional_information = True
# Expose streaming parameters
self.token_overlap_len = self.code2wav.token_overlap_len
@@ -322,9 +313,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.mel_cache_len = self.code2wav.mel_cache_len
self.source_cache_len = self.code2wav.source_cache_len
self.speech_window = self.code2wav.speech_window
- self._stream_audio_cache_by_req: dict[str, torch.Tensor] = {}
- self._stream_audio_cache_lock = Lock()
- self._stream_vocoder_cache_by_req: dict[str, dict[str, torch.Tensor]] = {}
else:
raise ValueError(f"Model stage not supported {self.model_stage}")
@@ -343,277 +331,19 @@ def _create_llm_vllm_config(self, parent_config: VllmConfig) -> VllmConfig:
# Use parent's cache config - critical for PagedAttention to work correctly
return parent_config.with_hf_config(qwen_hf_config, architectures=["Qwen2Model"])
- @staticmethod
- def _as_tensor(value: object) -> torch.Tensor | None:
- """Extract tensor payload from runtime info fields."""
- if isinstance(value, list):
- if not value:
- return None
- value = value[0]
- if isinstance(value, torch.Tensor):
- return value
- return None
-
- @staticmethod
- def _as_str(value: object) -> str | None:
- """Extract string payload from runtime info fields."""
- if isinstance(value, list):
- if not value:
- return None
- value = value[0]
- if value is None:
- return None
- return str(value)
-
- @staticmethod
- def _as_bool(value: object) -> bool:
- """Extract boolean payload from runtime info fields."""
- if isinstance(value, list):
- if not value:
- return False
- value = value[0]
- if isinstance(value, torch.Tensor):
- if value.numel() == 0:
- return False
- return bool(value.reshape(-1)[0].item())
- if value is None:
- return False
- return bool(value)
-
- @staticmethod
- def _cross_fade_audio(audio: torch.Tensor, prev_tail: torch.Tensor) -> torch.Tensor:
- """Blend previous chunk tail into current chunk head using a Hamming window.
-
- This mirrors upstream CosyVoice's `fade_in_out(...)` semantics:
- update the current head in-place using a 2*overlap window, then
- concatenate the unchanged remainder.
- """
- if audio.numel() == 0 or prev_tail.numel() == 0:
- return audio
- overlap = min(int(audio.numel()), int(prev_tail.numel()))
- if overlap <= 0:
- return audio
- window = torch.hamming_window(2 * overlap, periodic=False, dtype=audio.dtype, device=audio.device)
- fade_in = window[:overlap]
- fade_out = window[overlap:]
- blended = audio[:overlap] * fade_in + prev_tail[-overlap:].to(device=audio.device, dtype=audio.dtype) * fade_out
- if overlap == int(audio.numel()):
- return blended
- return torch.cat([blended, audio[overlap:]], dim=0)
-
- def _stitch_stream_audio(self, req_id: str | None, audio: torch.Tensor, stream_finished: bool) -> torch.Tensor:
- """Pass-through stitching for async_chunk.
-
- Chunk overlap is already removed in mel domain via token_offset_tokens.
- Applying an additional waveform-domain fade/cache step introduces either
- duplicated overlap (if no tail trim) or duration shrink (if tail trim).
- """
- if req_id is not None and stream_finished and hasattr(self, "_stream_audio_cache_by_req"):
- with self._stream_audio_cache_lock:
- self._stream_audio_cache_by_req.pop(req_id, None)
- if hasattr(self, "_stream_vocoder_cache_by_req"):
- self._stream_vocoder_cache_by_req.pop(req_id, None)
- return audio
-
- @staticmethod
- def _split_request_ids(ids: torch.Tensor, seq_token_counts: list[int] | None = None) -> list[torch.Tensor]:
- """Split concatenated input_ids into per-request segments."""
- if seq_token_counts is not None:
- boundaries = [0]
- for count in seq_token_counts:
- boundaries.append(boundaries[-1] + int(count))
- total = ids.numel()
- return [ids[boundaries[i] : min(boundaries[i + 1], total)] for i in range(len(seq_token_counts))]
-
- if is_forward_context_available():
- slices = get_forward_context().ubatch_slices
- if slices is not None and len(slices) > 1 and not any(hasattr(s, "token_slice") for s in slices):
- boundaries = [0]
- for s in slices:
- boundaries.append(boundaries[-1] + int(s))
- return [ids[boundaries[i] : boundaries[i + 1]] for i in range(len(boundaries) - 1)]
-
- return [ids]
-
- def _sanitize_codec_tokens(self, req_ids: torch.Tensor) -> torch.Tensor:
- """Filter non-code tokens before feeding flow token embedding."""
- vocab_size = int(self.code2wav.input_embedding.num_embeddings)
- valid_mask = (req_ids >= 0) & (req_ids < vocab_size)
- return req_ids[valid_mask]
-
- @staticmethod
- def _req_scalar(param: torch.Tensor | None, req_idx: int, default: float | int) -> float | int:
- if param is None or param.numel() == 0:
- return default
- index = min(req_idx, int(param.numel()) - 1)
- value = param.reshape(-1)[index].item()
- if isinstance(default, int):
- return int(value)
- return float(value)
-
- @staticmethod
- def _multinomial_sample(probs: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
- return torch.multinomial(probs, 1, replacement=True, generator=generator).reshape(())
-
- @classmethod
- def _nucleus_sample_one(
- cls,
- weighted_scores: torch.Tensor,
- *,
- top_p: float,
- top_k: int,
- generator: torch.Generator | None,
- ) -> int:
- probs = weighted_scores.softmax(dim=0)
- sorted_prob, sorted_idx = probs.sort(descending=True, stable=True)
- kept_probs: list[torch.Tensor] = []
- kept_indices: list[torch.Tensor] = []
- cum_prob = 0.0
- max_keep = len(sorted_idx) if top_k <= 0 else min(int(top_k), len(sorted_idx))
- for i in range(len(sorted_idx)):
- if cum_prob < top_p and len(kept_probs) < max_keep:
- cum_prob += float(sorted_prob[i].item())
- kept_probs.append(sorted_prob[i])
- kept_indices.append(sorted_idx[i])
- else:
- break
-
- if not kept_probs:
- return int(sorted_idx[0].item())
-
- sample_probs = torch.stack(kept_probs)
- sample_idx = cls._multinomial_sample(sample_probs, generator=generator)
- return int(torch.stack(kept_indices)[int(sample_idx.item())].item())
-
- @classmethod
- def _ras_sample_one(
- cls,
- weighted_scores: torch.Tensor,
- decoded_tokens: Sequence[int],
- *,
- top_p: float,
- top_k: int,
- win_size: int,
- tau_r: float,
- generator: torch.Generator | None,
- ) -> int:
- top_id = cls._nucleus_sample_one(
- weighted_scores,
- top_p=top_p,
- top_k=top_k,
- generator=generator,
- )
- if win_size > 0 and decoded_tokens:
- recent = torch.as_tensor(
- list(decoded_tokens[-win_size:]),
- device=weighted_scores.device,
- dtype=torch.long,
- )
- rep_num = int((recent == top_id).sum().item())
- if rep_num >= win_size * tau_r:
- weighted_scores = weighted_scores.clone()
- weighted_scores[top_id] = float("-inf")
- fallback_probs = weighted_scores.softmax(dim=0)
- top_id = int(cls._multinomial_sample(fallback_probs, generator=generator).item())
- return top_id
-
- def _cosyvoice3_ras_enabled(self, sampling_metadata: SamplingMetadata) -> bool:
- if self.model_stage != "cosyvoice3_talker":
- return False
- if sampling_metadata.max_num_logprobs is not None:
- return False
- if sampling_metadata.temperature is None:
- return False
- if bool(sampling_metadata.bad_words_token_ids):
- return False
- if torch.any(sampling_metadata.frequency_penalties != 0):
- return False
- if torch.any(sampling_metadata.presence_penalties != 0):
- return False
- return True
-
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> SamplerOutput | None:
- if logits is None or logits.numel() == 0:
- return None
- if self.model_stage != "cosyvoice3_talker":
- return None
-
- sampler = getattr(self, "_talker_sampler", None)
- if sampler is None:
- sampler = Sampler()
- self._talker_sampler = sampler
-
- if not self._cosyvoice3_ras_enabled(sampling_metadata):
- return sampler(logits=logits, sampling_metadata=sampling_metadata)
-
- logits = logits.to(torch.float32)
- sampling_for_processors = replace(sampling_metadata, no_penalties=True)
- logits = sampler.apply_logits_processors(logits, sampling_for_processors, predict_bonus_token=False)
-
- sampling_cfg = dict(self.config.llm.get("sampling", {}))
- default_top_p = float(sampling_cfg.get("top_p", 0.8))
- default_top_k = int(sampling_cfg.get("top_k", 25))
- win_size = int(sampling_cfg.get("win_size", 10))
- tau_r = float(sampling_cfg.get("tau_r", 0.1))
-
- sampled_ids: list[int] = []
- for req_idx in range(int(logits.shape[0])):
- row_logits = logits[req_idx]
-
- temperature = float(self._req_scalar(sampling_metadata.temperature, req_idx, 1.0))
- if temperature < self._sampling_eps:
- sampled_ids.append(int(torch.argmax(row_logits).item()))
- continue
-
- top_p = float(self._req_scalar(sampling_metadata.top_p, req_idx, default_top_p))
- top_k = int(self._req_scalar(sampling_metadata.top_k, req_idx, default_top_k))
- generator = sampling_metadata.generators.get(req_idx)
- weighted_scores = torch.log_softmax(row_logits / max(temperature, self._sampling_eps), dim=0)
- decoded_tokens = (
- sampling_metadata.output_token_ids[req_idx] if req_idx < len(sampling_metadata.output_token_ids) else []
- )
- sampled_ids.append(
- self._ras_sample_one(
- weighted_scores,
- decoded_tokens,
- top_p=top_p,
- top_k=top_k,
- win_size=win_size,
- tau_r=tau_r,
- generator=generator,
- )
- )
-
- sampled = torch.tensor(sampled_ids, device=logits.device, dtype=torch.int32)
- return SamplerOutput(sampled_token_ids=sampled.unsqueeze(-1), logprobs_tensors=None)
-
def compute_logits(self, hidden_states: torch.Tensor | OmniOutput) -> torch.Tensor | None:
if isinstance(hidden_states, OmniOutput):
hidden_states = hidden_states.text_hidden_states
if self.model_stage == "cosyvoice3_talker":
logits = self.model.llm_decoder(hidden_states)
- # The decoder outputs speech_token_size + 200 logits. The official
- # CosyVoice3 treats ALL tokens >= speech_token_size (the last 200)
- # as stop signals. Merge their probabilities into a single EOS
- # token (6562) via logsumexp so that vLLM's stop_token_ids=[6562]
- # fires with the correct aggregate stop probability.
- speech_token_size = self.config.llm["speech_token_size"]
- eos_idx = self.config.llm["eos_token_id"]
- stop_logits = logits[..., speech_token_size:] # last 200
- merged_stop = torch.logsumexp(stop_logits, dim=-1, keepdim=True)
- logits[..., speech_token_size:] = float("-inf") # mask all
- logits[..., eos_idx] = merged_stop.squeeze(-1) # restore merged
- # Pad to full vocab_size for vLLM token handling.
vocab_size = self.config.vocab_size
pad_size = vocab_size - logits.size(-1)
- if pad_size > 0:
- pad_shape = logits.shape[:-1] + (pad_size,)
- pad = logits.new_full(pad_shape, float("-inf"))
- logits = torch.cat([logits, pad], dim=-1)
+ pad_shape = logits.shape[:-1] + (pad_size,)
+ pad = logits.new_full(pad_shape, float("-inf"))
+ eos_token_val = logits[..., self.config.llm["eos_token_id"]].clone()
+ logits[..., -200:] = float("-inf")
+ logits[..., self.config.llm["eos_token_id"]] = eos_token_val
+ logits = torch.cat([logits, pad], dim=-1)
return logits
else:
raise RuntimeError(f"compute_logits is only valid for {self.model_stage}.")
@@ -650,7 +380,6 @@ def embed_input_ids(
hidden = int(self.config.hidden_size)
return torch.zeros(
(input_ids.shape[0], hidden),
- device=input_ids.device,
)
else:
raise RuntimeError(f"embed_input_ids is not valid for {self.model_stage}.")
@@ -683,116 +412,28 @@ def forward(
return OmniOutput(text_hidden_states=hidden_states, multimodal_outputs=multimodal_outputs)
elif self.model_stage == "cosyvoice3_code2wav":
- runtime_info = kwargs.get("model_intermediate_buffer")
- if runtime_info is None:
- runtime_info = kwargs.get("runtime_additional_information", [])
- if "runtime_additional_information" in kwargs and "model_intermediate_buffer" not in kwargs:
- logger.warning_once("runtime_additional_information is deprecated, use model_intermediate_buffer")
-
- seq_token_counts = kwargs.get("seq_token_counts")
- flat_ids = input_ids.reshape(-1).to(dtype=torch.long)
- request_ids_list = self._split_request_ids(flat_ids, seq_token_counts)
-
- num_reqs = max(1, len(request_ids_list))
- sample_rate = torch.tensor(int(self.config.sample_rate), dtype=torch.int32)
- empty_audio = torch.zeros((0,), dtype=torch.float32, device=input_ids.device)
- audios: list[torch.Tensor] = [empty_audio] * num_reqs
- srs: list[torch.Tensor] = [sample_rate] * num_reqs
- if not isinstance(runtime_info, list):
- runtime_info = []
-
- for idx, req_ids in enumerate(request_ids_list):
- info = runtime_info[idx] if idx < len(runtime_info) and isinstance(runtime_info[idx], dict) else {}
- req_id = self._as_str(info.get("req_id")) if info else None
- stream_finished = self._as_bool(info.get("stream_finished")) if info else False
- speech_token = self._as_tensor(info.get("speech_token")) if info else None
- speech_feat = self._as_tensor(info.get("speech_feat")) if info else None
- embedding = self._as_tensor(info.get("embedding")) if info else None
- if speech_token is None or speech_feat is None or embedding is None:
- if stream_finished and req_id is not None and hasattr(self, "_stream_vocoder_cache_by_req"):
- with self._stream_audio_cache_lock:
- self._stream_vocoder_cache_by_req.pop(req_id, None)
- audios[idx] = self._stitch_stream_audio(req_id, empty_audio, stream_finished)
- if (
- req_ids.numel() > 0
- and info
- and ("token_offset" in info or "left_context_size" in info or "generated_len" in info)
- ):
- info_keys = ",".join(sorted(info.keys())) if info else ""
- logger.warning_once(
- "CosyVoice3 code2wav missing prompt conditioning for non-empty codec tokens: "
- "raw_len=%d info_keys=%s",
- int(req_ids.numel()),
- info_keys,
- )
- continue
-
- token = self._sanitize_codec_tokens(req_ids)
- if token.numel() == 0:
- audios[idx] = self._stitch_stream_audio(req_id, empty_audio, stream_finished)
- if req_ids.numel() > 0:
- logger.warning_once(
- "CosyVoice3 code2wav received no valid codec tokens after filtering: "
- "raw_len=%d raw_range=[%d,%d] vocab_size=%d",
- req_ids.numel(),
- int(req_ids.min().item()),
- int(req_ids.max().item()),
- int(self.code2wav.input_embedding.num_embeddings),
- )
- continue
-
- # `generated_len` is injected for many models by the generic
- # runner, so only explicit chunk-routing fields should switch
- # code2wav into the streaming path.
- uses_streaming_decode = bool(info) and (
- "stream_finished" in info or "token_offset" in info or "left_context_size" in info
- )
- if uses_streaming_decode:
- token_offset = 0
- try:
- if info and "token_offset" in info:
- token_offset = max(0, int(info.get("token_offset", 0)))
- elif info:
- token_offset = max(0, int(info.get("left_context_size", 0)))
- except (TypeError, ValueError):
- token_offset = 0
-
- cache_state = None
- if req_id is not None and hasattr(self, "_stream_vocoder_cache_by_req"):
- with self._stream_audio_cache_lock:
- cache_state = self._stream_vocoder_cache_by_req.get(req_id)
-
- tts_speech, new_cache_state = self.code2wav.forward_streaming(
- token=token.unsqueeze(0),
- prompt_token=speech_token[:1],
- prompt_feat=speech_feat[:1],
- embedding=embedding[:1],
- cache_state=cache_state,
- n_timesteps=10,
- token_offset_tokens=token_offset,
- finalize=stream_finished,
- )
-
- if req_id is not None and hasattr(self, "_stream_vocoder_cache_by_req"):
- with self._stream_audio_cache_lock:
- if new_cache_state is None or stream_finished:
- self._stream_vocoder_cache_by_req.pop(req_id, None)
- else:
- self._stream_vocoder_cache_by_req[req_id] = new_cache_state
- else:
- tts_speech = self.code2wav.forward(
- token=token.unsqueeze(0),
- prompt_token=speech_token[:1],
- prompt_feat=speech_feat[:1],
- embedding=embedding[:1],
- n_timesteps=10,
- )
-
- audio = tts_speech.reshape(-1).to(dtype=torch.float32)
-
- audios[idx] = self._stitch_stream_audio(req_id, audio, stream_finished)
-
- return OmniOutput(text_hidden_states=None, multimodal_outputs={"audio": audios, "sr": srs})
+ runtime_info = kwargs.get("runtime_additional_information", [])
+ if not runtime_info:
+ length = 30 * 24000
+ audio = np.zeros((length,))
+ return OmniOutput(text_hidden_states=None, multimodal_outputs={"audio": audio})
+
+ # Remove the last eos token and add batch dimension
+ token = input_ids[..., :-1].unsqueeze(0)
+
+ # Generate audio using code2wav
+ tts_speech = self.code2wav(
+ token=token,
+ prompt_token=runtime_info[0]["speech_token"][:1],
+ prompt_feat=runtime_info[0]["speech_feat"][:1],
+ embedding=runtime_info[0]["embedding"][:1],
+ n_timesteps=10,
+ )
+
+ return OmniOutput(
+ text_hidden_states=None,
+ multimodal_outputs={"audio": tts_speech, "sr": 22050},
+ )
else:
raise ValueError(f"Unsupported model_stage: {self.model_stage}")
diff --git a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3_code2wav.py b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3_code2wav.py
index 3ad23cdb108..f5e0d04a8ae 100644
--- a/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3_code2wav.py
+++ b/vllm_omni/model_executor/models/cosyvoice3/cosyvoice3_code2wav.py
@@ -11,12 +11,11 @@
from __future__ import annotations
-from contextlib import nullcontext
-
import numpy as np
import torch
import torch.nn as nn
from omegaconf import DictConfig
+from torch.nn import functional as F
from vllm.logger import init_logger
from vllm_omni.diffusion.models.cosyvoice3_audio.cosyvoice3_dit import DiT
@@ -30,6 +29,7 @@
)
from vllm_omni.model_executor.models.cosyvoice3.code2wav_core.layers import PreLookaheadLayer
from vllm_omni.model_executor.models.cosyvoice3.config import CosyVoice3Config
+from vllm_omni.model_executor.models.cosyvoice3.utils import make_pad_mask
logger = init_logger(__name__)
@@ -151,160 +151,82 @@ def spk_embed_affine_layer(self) -> nn.Linear:
return self.flow_model.spk_embed_affine_layer
@torch.inference_mode()
- def _forward_mel(
+ def forward(
self,
token: torch.Tensor,
prompt_token: torch.Tensor,
prompt_feat: torch.Tensor,
embedding: torch.Tensor,
n_timesteps: int = 10,
- token_offset_tokens: int = 0,
- streaming: bool = True,
- finalize: bool = False,
- ) -> torch.Tensor:
- """Generate mel features via the upstream flow-model inference path."""
- flow_weight = next(self.flow_model.parameters())
- device = flow_weight.device
- dtype = flow_weight.dtype
-
- token = token.to(device=device, dtype=torch.int32)
- prompt_token = prompt_token.to(device=device, dtype=torch.int32)
- prompt_feat = prompt_feat.to(device=device, dtype=dtype)
- embedding = embedding.to(device=device, dtype=dtype)
- token_len = torch.tensor([token.shape[1]], device=device, dtype=torch.int32)
- prompt_token_len = torch.tensor([prompt_token.shape[1]], device=device, dtype=torch.int32)
- prompt_feat_len = torch.tensor([prompt_feat.shape[1]], device=device, dtype=torch.int32)
-
- with nullcontext():
- feat, _ = self.flow_model.inference(
- token=token,
- token_len=token_len,
- prompt_token=prompt_token,
- prompt_token_len=prompt_token_len,
- prompt_feat=prompt_feat,
- prompt_feat_len=prompt_feat_len,
- embedding=embedding,
- streaming=streaming,
- finalize=finalize,
- n_timesteps=n_timesteps,
- )
-
- trim_mel = max(0, int(token_offset_tokens)) * int(self.token_mel_ratio)
- if trim_mel > 0:
- feat = feat[:, :, trim_mel:]
-
- return feat
-
- @staticmethod
- def _fade_speech(
- speech: torch.Tensor,
- prev_speech: torch.Tensor,
) -> torch.Tensor:
- """Blend previous speech tail into current speech head."""
- if speech.numel() == 0 or prev_speech.numel() == 0:
- return speech
- overlap = min(int(speech.shape[-1]), int(prev_speech.shape[-1]))
- if overlap <= 0:
- return speech
- window = torch.hamming_window(2 * overlap, periodic=False, dtype=speech.dtype, device=speech.device)
- fade_in = window[:overlap].view(1, -1)
- fade_out = window[overlap:].view(1, -1)
- blended_head = (
- speech[:, :overlap] * fade_in
- + prev_speech[:, -overlap:].to(device=speech.device, dtype=speech.dtype) * fade_out
- )
- if overlap == int(speech.shape[-1]):
- return blended_head
- return torch.cat([blended_head, speech[:, overlap:]], dim=-1)
+ """Generate audio waveform from speech tokens.
- @torch.inference_mode()
- def forward_streaming(
- self,
- token: torch.Tensor,
- prompt_token: torch.Tensor,
- prompt_feat: torch.Tensor,
- embedding: torch.Tensor,
- *,
- cache_state: dict[str, torch.Tensor] | None = None,
- n_timesteps: int = 10,
- token_offset_tokens: int = 0,
- finalize: bool = False,
- ) -> tuple[torch.Tensor, dict[str, torch.Tensor] | None]:
- """Decode streaming audio using cumulative mel + emitted-speech offset.
-
- This mirrors upstream CosyVoice3 streaming semantics more closely than
- waveform-domain overlap-add: keep a cumulative mel history per request,
- re-run causal HiFT on the history, and emit only the newly grown speech
- suffix. That preserves causal look-right handling without double
- trimming or duplicated overlap at chunk boundaries.
+ Args:
+ token: Speech tokens from talker stage [batch, seq_len]
+ prompt_token: Prompt speech tokens [batch, prompt_len]
+ prompt_feat: Prompt mel features [batch, feat_len, mel_dim]
+ embedding: Speaker embedding [batch, spk_dim]
+ n_timesteps: Number of diffusion steps
+
+ Returns:
+ Audio waveform [batch, 1, audio_len]
"""
- with nullcontext():
- feat = self._forward_mel(
- token=token,
- prompt_token=prompt_token,
- prompt_feat=prompt_feat,
- embedding=embedding,
- n_timesteps=n_timesteps,
- token_offset_tokens=token_offset_tokens,
- streaming=True,
- finalize=finalize,
- )
- hift_weight = self.hift.m_source.l_linear.weight
- chunk_mel = feat.to(device=hift_weight.device, dtype=hift_weight.dtype)
-
- cached_mel = None if not cache_state else cache_state.get("mel")
- speech_offset_obj = None if not cache_state else cache_state.get("speech_offset")
- try:
- speech_offset = int(speech_offset_obj) if speech_offset_obj is not None else 0
- except (TypeError, ValueError):
- speech_offset = 0
-
- if isinstance(cached_mel, torch.Tensor) and cached_mel.numel() > 0:
- cached_mel = cached_mel.to(device=chunk_mel.device, dtype=chunk_mel.dtype)
- tts_mel = torch.cat([cached_mel, chunk_mel], dim=-1) if chunk_mel.numel() > 0 else cached_mel
- else:
- tts_mel = chunk_mel
+ device = token.device
+ dtype = next(self.flow_model.parameters()).dtype
- if tts_mel.shape[-1] == 0:
- tts_speech = torch.zeros((chunk_mel.shape[0], 1, 0), device=chunk_mel.device, dtype=chunk_mel.dtype)
- else:
- with nullcontext():
- tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
-
- tts_speech = tts_speech.reshape(tts_speech.shape[0], -1)
- speech_offset = max(0, min(speech_offset, int(tts_speech.shape[-1])))
- emitted_speech = tts_speech[:, speech_offset:]
-
- if finalize:
- return emitted_speech.reshape(emitted_speech.shape[0], 1, -1), None
-
- new_state = {
- "mel": tts_mel.detach().cpu().contiguous(),
- "speech_offset": int(tts_speech.shape[-1]),
- }
- return emitted_speech.reshape(emitted_speech.shape[0], 1, -1), new_state
-
- @torch.inference_mode()
- def forward(
- self,
- token: torch.Tensor,
- prompt_token: torch.Tensor,
- prompt_feat: torch.Tensor,
- embedding: torch.Tensor,
- n_timesteps: int = 10,
- ) -> torch.Tensor:
- """Generate audio waveform from speech tokens."""
- feat = self._forward_mel(
- token=token,
- prompt_token=prompt_token,
- prompt_feat=prompt_feat,
- embedding=embedding,
+ # Normalize and project speaker embedding
+ embedding = embedding.to(device=device, dtype=dtype)
+ embedding = F.normalize(embedding, dim=1)
+ embedding = self.spk_embed_affine_layer(embedding)
+
+ # Prepare tokens
+ prompt_token = prompt_token.to(device=device)
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
+ prompt_token_len = torch.tensor([token_len1], device=device, dtype=torch.int32)
+ token_len = torch.tensor([token_len2], device=device, dtype=torch.int32)
+
+ # Concatenate prompt and target tokens
+ full_token = torch.cat([prompt_token, token], dim=1)
+ full_token_len = prompt_token_len + token_len
+
+ # Create mask
+ mask = (~make_pad_mask(full_token_len)).unsqueeze(-1).to(embedding)
+
+ # Token embedding
+ token_emb = self.input_embedding(torch.clamp(full_token, min=0)) * mask
+
+ # Pre-lookahead processing
+ h = self.pre_lookahead_layer(token_emb)
+ h = h.repeat_interleave(self.token_mel_ratio, dim=1)
+
+ # Calculate mel lengths
+ mel_len1 = prompt_feat.shape[1]
+ mel_len2 = h.shape[1] - mel_len1
+
+ # Build conditioning
+ conds = torch.zeros(
+ [1, mel_len1 + mel_len2, self.output_size],
+ device=device,
+ dtype=h.dtype,
+ )
+ conds[:, :mel_len1] = prompt_feat
+ conds = conds.transpose(1, 2)
+
+ # Create mel mask
+ mel_mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
+
+ # Run flow matching decoder
+ feat, _ = self.decoder(
+ mu=h.transpose(1, 2).contiguous(),
+ mask=mel_mask.unsqueeze(1),
+ spks=embedding,
+ cond=conds,
n_timesteps=n_timesteps,
- token_offset_tokens=0,
- streaming=False,
- finalize=True,
)
+ # Extract generated portion (after prompt)
+ feat = feat[:, :, mel_len1:]
+
# Run vocoder
hift_weight = self.hift.m_source.l_linear.weight
tts_mel = feat.to(device=hift_weight.device, dtype=hift_weight.dtype)
@@ -316,7 +238,7 @@ def forward(
dtype=tts_mel.dtype,
)
else:
- tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=True)
+ tts_speech, _ = self.hift.inference(speech_feat=tts_mel)
return tts_speech
diff --git a/vllm_omni/model_executor/models/cosyvoice3/pipeline.py b/vllm_omni/model_executor/models/cosyvoice3/pipeline.py
deleted file mode 100644
index 4480a0dd831..00000000000
--- a/vllm_omni/model_executor/models/cosyvoice3/pipeline.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""CosyVoice3 pipeline topology (frozen).
-
-Stage 0: Talker — text prompt → speech tokens (LLM autoregressive).
-Stage 1: Code2Wav — flow-matching decoder → acoustic features → waveform.
- * ``sync_process_input_func`` runs when ``deploy.async_chunk=false``:
- stage 1 builds full-sequence flow input via ``text2flow``.
- * ``async_chunk_process_next_stage_input_func`` runs when
- ``deploy.async_chunk=true``: stage 0 streams codec chunks to stage 1
- through the shared-memory connector.
-"""
-
-from vllm_omni.config.stage_config import (
- PipelineConfig,
- StageExecutionType,
- StagePipelineConfig,
-)
-
-_PROC = "vllm_omni.model_executor.stage_input_processors.cosyvoice3"
-
-COSYVOICE3_PIPELINE = PipelineConfig(
- model_type="cosyvoice3",
- model_arch="CosyVoice3Model",
- stages=(
- StagePipelineConfig(
- stage_id=0,
- model_stage="cosyvoice3_talker",
- execution_type=StageExecutionType.LLM_AR,
- input_sources=(),
- owns_tokenizer=True,
- engine_output_type="latent",
- async_chunk_process_next_stage_input_func=(f"{_PROC}.talker2code2wav_async_chunk"),
- sampling_constraints={
- # merged speech stop token (logsumexp of all 200 stop logits)
- "stop_token_ids": [6562],
- },
- ),
- StagePipelineConfig(
- stage_id=1,
- model_stage="cosyvoice3_code2wav",
- execution_type=StageExecutionType.LLM_GENERATION,
- input_sources=(0,),
- final_output=True,
- final_output_type="audio",
- engine_output_type="latent",
- sync_process_input_func=f"{_PROC}.text2flow",
- ),
- ),
-)
diff --git a/vllm_omni/model_executor/models/cosyvoice3/utils.py b/vllm_omni/model_executor/models/cosyvoice3/utils.py
index 0bf0cccb163..52c52655e8d 100644
--- a/vllm_omni/model_executor/models/cosyvoice3/utils.py
+++ b/vllm_omni/model_executor/models/cosyvoice3/utils.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
+import os
from functools import cache, lru_cache
import numpy as np
@@ -8,8 +9,7 @@
import torch.nn.functional as F
import torchaudio
import torchaudio.compliance.kaldi as kaldi
-
-from vllm_omni.utils.audio import mel_filter_bank
+from librosa.filters import mel as librosa_mel_fn
logger = logging.getLogger(__name__)
@@ -34,13 +34,8 @@ def _get_mel_basis(
fmax: float | None,
device_str: str,
) -> torch.Tensor:
- return mel_filter_bank(
- sr=sampling_rate,
- n_fft=n_fft,
- n_mels=num_mels,
- fmin=fmin,
- fmax=fmax,
- ).to(torch.device(device_str))
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
+ return torch.from_numpy(mel).float().to(torch.device(device_str))
@lru_cache
@@ -127,8 +122,42 @@ def exact_div(x, y):
@cache
def mel_filters(device, n_mels: int) -> torch.Tensor:
- """Compute mel filterbank matrix for projecting STFT into a Mel spectrogram."""
- return mel_filter_bank(sr=16000, n_fft=400, n_mels=n_mels).to(device)
+ """
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
+ Allows decoupling librosa dependency; saved using:
+
+ np.savez_compressed(
+ "mel_filters.npz",
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
+ mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
+ )
+ """
+ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
+
+ filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
+ if not os.path.exists(filters_path):
+ source_url = "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/mel_filters.npz"
+ os.makedirs(os.path.dirname(filters_path), exist_ok=True)
+ try:
+ import urllib.request
+
+ with urllib.request.urlopen(source_url, timeout=30) as resp:
+ with open(filters_path, "wb") as f_out:
+ f_out.write(resp.read())
+ logger.info("Downloaded mel_filters.npz from %s", source_url)
+ except Exception as e:
+ raise FileNotFoundError(
+ "Missing CosyVoice3 mel filter asset:\n"
+ f" {filters_path}\n"
+ "Auto-download failed. Download it manually from:\n"
+ f" {source_url}\n"
+ "Example:\n"
+ f" mkdir -p {os.path.dirname(filters_path)} && "
+ f"curl -L {source_url} -o {filters_path}"
+ ) from e
+
+ with np.load(filters_path, allow_pickle=False) as f:
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(
diff --git a/vllm_omni/model_executor/models/dynin_omni/__init__.py b/vllm_omni/model_executor/models/dynin_omni/__init__.py
deleted file mode 100644
index 2a3bae8a9f4..00000000000
--- a/vllm_omni/model_executor/models/dynin_omni/__init__.py
+++ /dev/null
@@ -1,59 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Any
-
-from .dynin_omni import DyninOmniForConditionalGeneration
-from .dynin_omni_common import (
- get_dynin_magvit_attr,
- get_dynin_modeling_attr,
- get_dynin_sampling_attr,
-)
-
-if TYPE_CHECKING:
- from .dynin_omni_token2audio import DyninOmniToken2Audio
- from .dynin_omni_token2image import DyninOmniToken2Image
- from .dynin_omni_token2text import DyninOmniToken2Text
-
-
-_STAGE_EXPORTS = {
- "DyninOmniToken2Audio": (".dynin_omni_token2audio", "DyninOmniToken2Audio"),
- "DyninOmniToken2Image": (".dynin_omni_token2image", "DyninOmniToken2Image"),
- "DyninOmniToken2Text": (".dynin_omni_token2text", "DyninOmniToken2Text"),
-}
-
-_MODELING_EXPORTS = {"DyninOmniConfig", "DyninOmniModelLM", "VideoTokenMerger"}
-_MAGVIT_EXPORTS = {"VQGANEncoder", "VQGANDecoder", "LFQuantizer", "MAGVITv2"}
-
-
-def __getattr__(name: str) -> Any:
- if name in _STAGE_EXPORTS:
- module_name, attr_name = _STAGE_EXPORTS[name]
- module = __import__(module_name, globals(), locals(), [attr_name], 1)
- return getattr(module, attr_name)
-
- if name in _MODELING_EXPORTS:
- return get_dynin_modeling_attr(name)
-
- if name in _MAGVIT_EXPORTS:
- return get_dynin_magvit_attr(name)
-
- if name == "get_mask_schedule":
- return get_dynin_sampling_attr("get_mask_schedule")
-
- raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
-
-
-__all__ = [
- "DyninOmniForConditionalGeneration",
- "DyninOmniToken2Audio",
- "DyninOmniToken2Image",
- "DyninOmniToken2Text",
- "DyninOmniConfig",
- "DyninOmniModelLM",
- "VideoTokenMerger",
- "VQGANEncoder",
- "VQGANDecoder",
- "LFQuantizer",
- "MAGVITv2",
- "get_mask_schedule",
-]
diff --git a/vllm_omni/model_executor/models/dynin_omni/dynin_omni.py b/vllm_omni/model_executor/models/dynin_omni/dynin_omni.py
deleted file mode 100644
index 0caae158ef9..00000000000
--- a/vllm_omni/model_executor/models/dynin_omni/dynin_omni.py
+++ /dev/null
@@ -1,744 +0,0 @@
-from __future__ import annotations
-
-from collections.abc import Iterable, Mapping, Sequence
-from functools import cached_property
-from importlib import import_module
-from typing import Any
-
-import numpy as np
-import torch
-import torch.nn as nn
-from vllm.config import VllmConfig
-from vllm.config.multimodal import BaseDummyOptions
-from vllm.inputs import MultiModalDataDict
-from vllm.inputs import MultiModalInput as MultiModalInputs
-from vllm.model_executor.models.interfaces import SupportsMultiModal
-from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.multimodal.inputs import (
- MultiModalFieldConfig,
- MultiModalKwargsItems,
- PlaceholderRange,
-)
-from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
-from vllm.multimodal.processing import (
- BaseDummyInputsBuilder,
- BaseMultiModalProcessor,
- BaseProcessingInfo,
- ProcessorInputs,
- PromptUpdate,
- TimingContext,
-)
-from vllm.sequence import IntermediateTensors
-from vllm.v1.outputs import SamplerOutput
-from vllm.v1.sample.metadata import SamplingMetadata
-from vllm.v1.sample.sampler import Sampler
-
-from vllm_omni.model_executor.models.output_templates import OmniOutput
-
-from .dynin_omni_common import build_zero_input_embeddings
-
-try:
- from PIL import Image as PILImage
-except Exception: # pragma: no cover
- PILImage = None
-
-
-_MODALITY_ORDER = ("image", "video", "audio")
-
-_MODALITY_ALIASES = {
- "img2img": "image",
-}
-
-_MODALITY_INPUT_KEY_BY_NAME = {
- "image": "pixel_values",
- "video": "pixel_values_videos",
- "audio": "input_audio_features",
-}
-
-_MODALITY_PLACEHOLDER_BY_NAME = {
- "image": "<|soi|><|image|><|eoi|>",
- "video": "<|sov|><|video|><|eov|>",
- "audio": "<|soa|><|audio|><|eoa|>",
-}
-
-_MODALITY_INPUT_ALIASES = {
- "image": ("pixel_values", "image_embeds", "img2img"),
- "video": ("pixel_values_videos", "video_embeds"),
- "audio": ("input_audio_features", "audio_embeds"),
-}
-
-
-def _normalize_modality_name(modality: str) -> str:
- return _MODALITY_ALIASES.get(modality, modality)
-
-
-def _get_modality_count(mm_counts: Mapping[str, int], modality: str) -> int:
- canonical = _normalize_modality_name(modality)
- count = mm_counts.get(canonical, 0)
- for alias, target in _MODALITY_ALIASES.items():
- if target == canonical:
- count += mm_counts.get(alias, 0)
- return count
-
-
-def _normalize_mm_data_aliases(mm_data: MultiModalDataDict) -> MultiModalDataDict:
- normalized: dict[str, Any] = {}
- for modality, value in mm_data.items():
- canonical = _normalize_modality_name(modality)
- if canonical in normalized and normalized[canonical] is not None and value is not None:
- raise ValueError(
- "Dynin received duplicate multimodal inputs for "
- f"{canonical!r} via {modality!r}. "
- "Provide either the canonical modality or its alias, not both."
- )
- if canonical not in normalized or normalized[canonical] is None:
- normalized[canonical] = value
- return normalized
-
-
-def _get_placeholder_text(modality: str) -> str | None:
- modality = _normalize_modality_name(modality)
- for base_modality, placeholder in _MODALITY_PLACEHOLDER_BY_NAME.items():
- if modality.startswith(base_modality):
- return placeholder
- return None
-
-
-class DyninOmniProcessingInfo(BaseProcessingInfo):
- def get_data_parser(self) -> MultiModalDataParser:
- return DyninOmniMultiModalDataParser(
- expected_hidden_size=self._get_expected_hidden_size(),
- )
-
- def get_supported_mm_limits(self) -> Mapping[str, int | None]:
- limits = {modality: 1 for modality in _MODALITY_ORDER}
- for alias, target in _MODALITY_ALIASES.items():
- if target in limits:
- limits[alias] = limits[target]
- return limits
-
- def get_mm_max_tokens_per_item(
- self,
- seq_len: int,
- mm_counts: Mapping[str, int],
- ) -> Mapping[str, int] | None:
- del seq_len, mm_counts
- limits = {modality: 1 for modality in _MODALITY_ORDER}
- for alias, target in _MODALITY_ALIASES.items():
- if target in limits:
- limits[alias] = limits[target]
- return limits
-
-
-class DyninOmniDummyInputsBuilder(BaseDummyInputsBuilder[DyninOmniProcessingInfo]):
- def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
- chunks: list[str] = []
- for modality in _MODALITY_ORDER:
- placeholder = _get_placeholder_text(modality)
- if placeholder is None:
- continue
- chunks.extend([placeholder] * _get_modality_count(mm_counts, modality))
- return " ".join(chunks)
-
- def get_dummy_mm_data(
- self,
- seq_len: int,
- mm_counts: Mapping[str, int],
- mm_options: Mapping[str, BaseDummyOptions] | None = None,
- ) -> MultiModalDataDict:
- del seq_len
-
- mm_data: dict[str, Any] = {}
-
- num_images = _get_modality_count(mm_counts, "image")
- if num_images > 0:
- mm_data["image"] = self._get_dummy_images(
- width=224,
- height=224,
- num_images=num_images,
- overrides=mm_options.get("image") if mm_options else None,
- )
-
- num_videos = _get_modality_count(mm_counts, "video")
- if num_videos > 0:
- mm_data["video"] = self._get_dummy_videos(
- width=224,
- height=224,
- num_frames=8,
- num_videos=num_videos,
- overrides=mm_options.get("video") if mm_options else None,
- )
-
- num_audios = _get_modality_count(mm_counts, "audio")
- if num_audios > 0:
- mm_data["audio"] = self._get_dummy_audios(
- length=16000,
- num_audios=num_audios,
- overrides=mm_options.get("audio") if mm_options else None,
- )
-
- return mm_data
-
-
-class DyninOmniMultiModalDataParser(MultiModalDataParser):
- def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
- normalized = _normalize_mm_data_aliases(mm_data)
- mm_items = super().parse_mm_data(normalized)
-
- for alias, canonical in _MODALITY_ALIASES.items():
- if alias in mm_data and canonical in mm_items and alias not in mm_items:
- mm_items[alias] = mm_items[canonical]
-
- return mm_items
-
- def _get_audio_with_sr(self, audio: Any) -> tuple[np.ndarray, float | None]:
- audio_array, orig_sr = super()._get_audio_with_sr(audio)
- if self.audio_resampler.target_sr is None:
- return audio_array, None
- return audio_array, orig_sr
-
-
-class DyninOmniMultiModalProcessor(BaseMultiModalProcessor[DyninOmniProcessingInfo]):
- @staticmethod
- def _find_subsequence(
- haystack: list[int],
- needle: list[int],
- start: int,
- ) -> int | None:
- if not needle:
- return None
-
- max_start = len(haystack) - len(needle)
- if max_start < start:
- return None
-
- for idx in range(start, max_start + 1):
- if haystack[idx : idx + len(needle)] == needle:
- return idx
- return None
-
- @staticmethod
- def _make_disabled_embed_mask(length: int) -> torch.Tensor:
- return torch.zeros(length, dtype=torch.bool)
-
- @staticmethod
- def _encode_prompt_to_token_ids(
- prompt: str | list[int],
- tokenizer: Any | None,
- ) -> list[int]:
- if isinstance(prompt, str):
- if tokenizer is None:
- raise ValueError("Tokenizer is required to process string prompts for Dynin multimodal inputs.")
- return tokenizer.encode(prompt, add_special_tokens=False)
- return list(prompt)
-
- @staticmethod
- def _ensure_non_empty_prompt_ids(
- prompt_token_ids: list[int],
- tokenizer: Any | None,
- ) -> list[int]:
- if prompt_token_ids:
- return prompt_token_ids
-
- fallback_id = None
- if tokenizer is not None:
- fallback_id = getattr(tokenizer, "bos_token_id", None)
- if fallback_id is None:
- fallback_id = getattr(tokenizer, "eos_token_id", None)
- if fallback_id is None:
- fallback_id = getattr(tokenizer, "pad_token_id", None)
-
- return [0 if fallback_id is None else int(fallback_id)]
-
- @classmethod
- def _image_to_chw_float_tensor(cls, image: Any) -> torch.Tensor:
- if isinstance(image, torch.Tensor):
- tensor = image.detach()
- elif isinstance(image, np.ndarray):
- tensor = torch.from_numpy(image)
- elif PILImage is not None and isinstance(image, PILImage.Image):
- tensor = torch.from_numpy(np.asarray(image).copy())
- else:
- raise TypeError(f"Unsupported image item type: {type(image)!r}")
-
- if tensor.ndim == 2:
- tensor = tensor.unsqueeze(-1)
- if tensor.ndim != 3:
- raise ValueError(f"Expected 3D image tensor, got shape={tuple(tensor.shape)}")
-
- if tensor.shape[-1] in (1, 3, 4) and tensor.shape[0] not in (1, 3, 4):
- tensor = tensor.permute(2, 0, 1)
-
- if tensor.shape[0] == 1:
- tensor = tensor.repeat(3, 1, 1)
- if tensor.shape[0] == 4:
- tensor = tensor[:3]
-
- tensor = tensor.to(dtype=torch.float32)
- if tensor.numel() > 0 and torch.max(tensor) > 1.0:
- tensor = tensor / 255.0
- return tensor.contiguous()
-
- @classmethod
- def _video_to_tchw_float_tensor(cls, video: Any) -> torch.Tensor:
- if isinstance(video, (list, tuple)) and not isinstance(video, torch.Tensor):
- frames = [cls._image_to_chw_float_tensor(frame) for frame in video]
- if not frames:
- return torch.zeros((1, 3, 1, 1), dtype=torch.float32)
- return torch.stack(frames, dim=0).contiguous()
-
- if isinstance(video, torch.Tensor):
- tensor = video.detach()
- elif isinstance(video, np.ndarray):
- tensor = torch.from_numpy(video)
- else:
- raise TypeError(f"Unsupported video item type: {type(video)!r}")
-
- if tensor.ndim == 3:
- return cls._image_to_chw_float_tensor(tensor).unsqueeze(0).contiguous()
-
- if tensor.ndim != 4:
- raise ValueError(f"Expected 4D video tensor, got shape={tuple(tensor.shape)}")
-
- if tensor.shape[-1] in (1, 3, 4) and tensor.shape[1] not in (1, 3, 4):
- tensor = tensor.permute(0, 3, 1, 2)
-
- if tensor.shape[1] == 1:
- tensor = tensor.repeat(1, 3, 1, 1)
- if tensor.shape[1] == 4:
- tensor = tensor[:, :3]
-
- tensor = tensor.to(dtype=torch.float32)
- if tensor.numel() > 0 and torch.max(tensor) > 1.0:
- tensor = tensor / 255.0
- return tensor.contiguous()
-
- @staticmethod
- def _audio_to_float_tensor(audio: Any) -> torch.Tensor:
- if isinstance(audio, tuple) and len(audio) == 2:
- audio = audio[0]
-
- if isinstance(audio, torch.Tensor):
- tensor = audio.detach()
- elif isinstance(audio, np.ndarray):
- tensor = torch.from_numpy(audio)
- else:
- tensor = torch.as_tensor(audio)
-
- tensor = tensor.to(dtype=torch.float32).contiguous().view(-1)
- if tensor.numel() == 0:
- return torch.zeros((16000,), dtype=torch.float32)
-
- max_abs = torch.max(torch.abs(tensor))
- if max_abs > 1.0:
- tensor = tensor / max_abs
-
- return tensor.contiguous()
-
- @classmethod
- def _convert_modality_item(cls, modality: str, item: Any) -> torch.Tensor:
- if modality == "image":
- return cls._image_to_chw_float_tensor(item)
- if modality == "video":
- return cls._video_to_tchw_float_tensor(item)
- if modality == "audio":
- return cls._audio_to_float_tensor(item)
- raise ValueError(f"Unsupported modality for Dynin processor: {modality}")
-
- def _build_modality_kwargs(
- self,
- modality: str,
- modality_items: Sequence[Any],
- ) -> Sequence[Any]:
- modality = _normalize_modality_name(modality)
- input_key = _MODALITY_INPUT_KEY_BY_NAME[modality]
- tensor_items = [self._convert_modality_item(modality, item) for item in modality_items]
- mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
- {input_key: tensor_items},
- {input_key: MultiModalFieldConfig.batched(modality)},
- )
- return mm_kwargs[modality]
-
- def _build_placeholder_ranges(
- self,
- *,
- modality: str,
- item_count: int,
- prompt_token_ids: list[int],
- tokenizer: Any | None,
- search_start: int,
- ) -> tuple[list[PlaceholderRange], int]:
- ranges: list[PlaceholderRange] = []
-
- for _ in range(item_count):
- placeholder_text = _get_placeholder_text(modality)
- placeholder_token_ids: list[int] = []
-
- if placeholder_text and tokenizer is not None:
- placeholder_token_ids = tokenizer.encode(
- placeholder_text,
- add_special_tokens=False,
- )
-
- found_offset = None
- if placeholder_token_ids:
- found_offset = self._find_subsequence(
- prompt_token_ids,
- placeholder_token_ids,
- search_start,
- )
-
- if found_offset is None:
- found_offset = min(search_start, len(prompt_token_ids) - 1)
- placeholder_len = 1
- else:
- placeholder_len = len(placeholder_token_ids)
-
- ranges.append(
- PlaceholderRange(
- offset=found_offset,
- length=placeholder_len,
- is_embed=self._make_disabled_embed_mask(placeholder_len),
- )
- )
- search_start = found_offset + placeholder_len
-
- return ranges, search_start
-
- def _get_mm_fields_config(
- self,
- hf_inputs: Any,
- hf_processor_mm_kwargs: Mapping[str, object],
- ) -> Mapping[str, MultiModalFieldConfig]:
- del hf_inputs, hf_processor_mm_kwargs
- return {}
-
- def _get_prompt_updates(
- self,
- mm_items: MultiModalDataItems,
- hf_processor_mm_kwargs: Mapping[str, object],
- out_mm_kwargs: MultiModalKwargsItems,
- ) -> Sequence[PromptUpdate]:
- del mm_items, hf_processor_mm_kwargs, out_mm_kwargs
- return []
-
- def apply(
- self,
- inputs: ProcessorInputs,
- timing_ctx: TimingContext,
- ) -> MultiModalInputs:
- prompt = inputs.prompt
- mm_items = inputs.mm_data_items
-
- with timing_ctx.record("get_mm_hashes"):
- mm_hashes = inputs.get_mm_hashes(self.info.model_id)
-
- tokenizer = self.info.ctx.tokenizer
- prompt_token_ids = self._encode_prompt_to_token_ids(prompt, tokenizer)
- prompt_token_ids = self._ensure_non_empty_prompt_ids(prompt_token_ids, tokenizer)
-
- mm_kwargs_by_modality: dict[str, Sequence[Any]] = {}
- mm_placeholders: dict[str, list[PlaceholderRange]] = {}
- search_start = 0
- mm_counts = mm_items.get_all_counts()
-
- for modality in _MODALITY_ORDER:
- item_count = mm_counts.get(modality, 0)
- if item_count <= 0:
- continue
-
- modality_items = mm_items[modality].get_all()
- if len(modality_items) != item_count:
- raise RuntimeError(
- f"Parsed {len(modality_items)} items but expected {item_count} for modality={modality!r}"
- )
-
- mm_kwargs_by_modality[modality] = self._build_modality_kwargs(
- modality,
- modality_items,
- )
-
- placeholder_ranges, search_start = self._build_placeholder_ranges(
- modality=modality,
- item_count=item_count,
- prompt_token_ids=prompt_token_ids,
- tokenizer=tokenizer,
- search_start=search_start,
- )
- mm_placeholders[modality] = placeholder_ranges
-
- return MultiModalInputs(
- type="multimodal",
- prompt_token_ids=prompt_token_ids,
- mm_kwargs=MultiModalKwargsItems(mm_kwargs_by_modality),
- mm_hashes=mm_hashes,
- mm_placeholders=mm_placeholders,
- )
-
-
-class DyninOmniStageBase(nn.Module):
- stage_name = "Dynin stage"
-
- def make_empty_intermediate_tensors(
- self,
- batch_size: int,
- dtype: torch.dtype,
- device: torch.device,
- ) -> IntermediateTensors:
- del batch_size, dtype, device
- return IntermediateTensors({})
-
- def embed_input_ids(
- self,
- input_ids: torch.Tensor,
- multimodal_embeddings: Any = None,
- is_multimodal: torch.Tensor | None = None,
- **kwargs: Any,
- ) -> torch.Tensor:
- del multimodal_embeddings, is_multimodal, kwargs
- return build_zero_input_embeddings(
- input_ids=input_ids,
- hidden_size=self.hidden_size,
- stage_name=self.stage_name,
- )
-
- def load_weights(
- self,
- weights: Iterable[tuple[str, torch.Tensor]],
- ) -> set[str]:
- return {name for name, _ in weights}
-
- def compute_logits(
- self,
- hidden_states: torch.Tensor | OmniOutput,
- sampling_metadata: Any = None,
- ) -> torch.Tensor | None:
- del hidden_states, sampling_metadata
- return None
-
-
-@MULTIMODAL_REGISTRY.register_processor(
- DyninOmniMultiModalProcessor,
- info=DyninOmniProcessingInfo,
- dummy_inputs=DyninOmniDummyInputsBuilder,
-)
-class DyninOmniForConditionalGeneration(nn.Module, SupportsMultiModal):
- supports_multimodal_raw_input_only = True
- STAGE_ALIAS = {
- "tokenizer": "token2text",
- "token2token": "token2text",
- "detok_text": "token2text",
- "token2img": "token2image",
- "token2wav": "token2audio",
- "token2speech": "token2audio",
- }
-
- STAGE_IMPL = {
- "token2text": (".dynin_omni_token2text", "DyninOmniToken2Text"),
- "token2image": (".dynin_omni_token2image", "DyninOmniToken2Image"),
- "token2audio": (".dynin_omni_token2audio", "DyninOmniToken2Audio"),
- }
-
- _STAGE_IMPL_CACHE: dict[str, type[nn.Module]] = {}
-
- @classmethod
- def get_placeholder_str(cls, modality: str, i: int) -> str | None:
- del i
- return _get_placeholder_text(modality)
-
- @classmethod
- def _resolve_stage_impl_class(cls, model_stage: str) -> type[nn.Module]:
- impl = cls._STAGE_IMPL_CACHE.get(model_stage)
- if impl is not None:
- return impl
-
- module_name, class_name = cls.STAGE_IMPL[model_stage]
- module = import_module(module_name, package=__package__)
- impl = getattr(module, class_name)
- cls._STAGE_IMPL_CACHE[model_stage] = impl
- return impl
-
- @classmethod
- def _normalize_stage_name(cls, raw_stage: str) -> str:
- normalized = cls.STAGE_ALIAS.get(raw_stage, raw_stage)
- if normalized not in cls.STAGE_IMPL:
- raise ValueError(
- "Unsupported DYNIN omni model_stage: "
- f"{raw_stage} (normalized={normalized}). "
- f"Supported: {sorted(cls.STAGE_IMPL.keys())}"
- )
- return normalized
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
-
- raw_stage = str(getattr(vllm_config.model_config, "model_stage", "token2text")).lower()
- self.model_stage = self._normalize_stage_name(raw_stage)
-
- impl_cls = self._resolve_stage_impl_class(self.model_stage)
- self.impl = impl_cls(vllm_config=vllm_config, prefix=prefix)
- self.model = self.impl
-
- self.has_preprocess = False
- self.has_postprocess = False
- self.have_multimodal_outputs = getattr(self.impl, "have_multimodal_outputs", True)
- self.requires_raw_input_tokens = getattr(self.impl, "requires_raw_input_tokens", True)
- self.language_model = self._resolve_language_model()
-
- def _resolve_language_model(self) -> Any | None:
- if hasattr(self.impl, "get_language_model"):
- language_model = self.impl.get_language_model()
- if language_model is not None:
- return language_model
-
- if hasattr(self.impl, "language_model"):
- language_model = getattr(self.impl, "language_model")
- if language_model is not None:
- return language_model
-
- if self.model_stage == "token2text":
- return getattr(self.impl, "model", None)
-
- return None
-
- def get_language_model(self) -> Any | None:
- return self.language_model
-
- @cached_property
- def sampler(self):
- if hasattr(self.model, "sampler"):
- return self.model.sampler
- if self.language_model is not None and hasattr(self.language_model, "sampler"):
- return self.language_model.sampler
- return Sampler()
-
- def init_multi_modal(self, thinker_config: Any = None) -> None:
- if hasattr(self.model, "init_multi_modal"):
- self.model.init_multi_modal(thinker_config)
-
- def _collect_multimodal_inputs(self, **kwargs: Any) -> dict[str, Any]:
- mm_inputs: dict[str, Any] = {}
- for modality, aliases in _MODALITY_INPUT_ALIASES.items():
- for alias in aliases:
- if alias in kwargs and kwargs[alias] is not None:
- mm_inputs[modality] = kwargs[alias]
- break
- return mm_inputs
-
- def _normalize_loaded_weight_names(
- self,
- loaded: set[str],
- expected_param_names: set[str],
- ) -> set[str]:
- if self.model_stage != "token2text":
- return loaded
-
- normalized_loaded: set[str] = set()
- prefixes = ("", "impl.", "impl.model.")
-
- for name in loaded:
- for prefix in prefixes:
- candidate = f"{prefix}{name}" if prefix else name
- if candidate in expected_param_names:
- normalized_loaded.add(candidate)
- break
-
- if len(normalized_loaded) < len(expected_param_names):
- normalized_loaded.update(expected_param_names)
-
- return normalized_loaded
-
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- positions: torch.Tensor | None = None,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **kwargs: Any,
- ) -> OmniOutput:
- return self.model(
- input_ids=input_ids,
- positions=positions,
- intermediate_tensors=intermediate_tensors,
- inputs_embeds=inputs_embeds,
- **kwargs,
- )
-
- def make_empty_intermediate_tensors(
- self,
- batch_size: int,
- dtype: torch.dtype,
- device: torch.device,
- ) -> IntermediateTensors:
- return self.model.make_empty_intermediate_tensors(batch_size, dtype, device)
-
- def embed_input_ids(
- self,
- input_ids: torch.Tensor,
- multimodal_embeddings: Any = None,
- is_multimodal: torch.Tensor | None = None,
- **kwargs: Any,
- ) -> torch.Tensor:
- squeezed_batch = False
- staged_input_ids = input_ids
-
- if input_ids.ndim == 0:
- staged_input_ids = input_ids.view(1, 1)
- squeezed_batch = True
- elif input_ids.ndim == 1:
- staged_input_ids = input_ids.unsqueeze(0)
- squeezed_batch = True
-
- embeddings = self.model.embed_input_ids(
- staged_input_ids,
- multimodal_embeddings=multimodal_embeddings,
- is_multimodal=is_multimodal,
- **kwargs,
- )
-
- if squeezed_batch and isinstance(embeddings, torch.Tensor):
- if embeddings.ndim == 3 and embeddings.shape[0] == 1:
- return embeddings.squeeze(0)
- if embeddings.ndim == 2 and input_ids.ndim == 0 and embeddings.shape[0] == 1:
- return embeddings
-
- return embeddings
-
- def embed_multimodal(self, **kwargs: Any) -> Any:
- if hasattr(self.model, "embed_multimodal"):
- return self.model.embed_multimodal(**kwargs)
-
- self._collect_multimodal_inputs(**kwargs)
- return None
-
- def load_weights(
- self,
- weights: Iterable[tuple[str, torch.Tensor]],
- ) -> set[str]:
- loaded = self.model.load_weights(weights)
- if loaded is None:
- loaded = set()
-
- expected_param_names = {name for name, _ in self.named_parameters()}
- if not expected_param_names:
- return loaded
-
- return self._normalize_loaded_weight_names(loaded, expected_param_names)
-
- def compute_logits(
- self,
- hidden_states: torch.Tensor | OmniOutput,
- sampling_metadata: Any = None,
- ) -> torch.Tensor | None:
- return self.model.compute_logits(hidden_states, sampling_metadata=sampling_metadata)
-
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> SamplerOutput | None:
- if hasattr(self.model, "sample"):
- return self.model.sample(logits, sampling_metadata)
- if self.language_model is not None and hasattr(self.language_model, "sample"):
- return self.language_model.sample(logits, sampling_metadata)
- return None
diff --git a/vllm_omni/model_executor/models/dynin_omni/dynin_omni_common.py b/vllm_omni/model_executor/models/dynin_omni/dynin_omni_common.py
deleted file mode 100644
index 6166d8615c9..00000000000
--- a/vllm_omni/model_executor/models/dynin_omni/dynin_omni_common.py
+++ /dev/null
@@ -1,1241 +0,0 @@
-from __future__ import annotations
-
-import hashlib
-import importlib.util
-import os
-import sys
-import threading
-import types
-from collections.abc import Iterable
-from dataclasses import dataclass
-from enum import IntEnum
-from functools import lru_cache
-from pathlib import Path
-from typing import Any
-
-import torch
-from vllm.config import VllmConfig
-from vllm.logger import init_logger
-
-logger = init_logger(__name__)
-
-try:
- from huggingface_hub import snapshot_download
-except Exception: # pragma: no cover
- snapshot_download = None
-
-
-class DetokTarget(IntEnum):
- TEXT = 0
- AUDIO = 1
- IMAGE = 2
-
-
-TASK_TO_DETOK = {
- "mmu": DetokTarget.TEXT,
- "s2t": DetokTarget.TEXT,
- "mmu_fast": DetokTarget.TEXT,
- "mmu_fastdllm_v1": DetokTarget.TEXT,
- "v2t": DetokTarget.TEXT,
- "t2s": DetokTarget.AUDIO,
- "t2s_mmu_like": DetokTarget.AUDIO,
- "t2s_fixed": DetokTarget.AUDIO,
- "s2s": DetokTarget.AUDIO,
- "v2s": DetokTarget.AUDIO,
- "t2i": DetokTarget.IMAGE,
- "i2i": DetokTarget.IMAGE,
- "ti2ti": DetokTarget.IMAGE,
-}
-
-DEFAULT_VQ_IMAGE_SOURCE = "snu-aidas/magvitv2"
-DEFAULT_VQ_AUDIO_SOURCE = "snu-aidas/emova_speech_tokenizer_vllm"
-DEFAULT_MAGVIT_REMOTE_CODE_REPO = "snu-aidas/magvitv2"
-DEFAULT_DYNIN_REMOTE_CODE_REPO = "snu-aidas/Dynin-Omni"
-DYNIN_PROMPT_SOURCE_KEY = "dynin_prompt_source"
-DYNIN_PROMPT_SOURCE_OFFLINE_PREBUILT = "offline_prebuilt"
-
-DYNIN_TASK_DEFAULT_RUNTIME = {
- "t2t": ("mmu", "mmu", 0, "text"),
- "t2i": ("t2i", "t2i_gen", 2, "image"),
- "t2s": ("t2s_mmu_like", "t2s_gen", 1, "audio"),
- "i2i": ("i2i", "i2i", 2, "image"),
-}
-
-DYNIN_TASK_RUNTIME_FALLBACKS: dict[str, dict[str, Any]] = {
- "t2t": {
- "prompt_max_text_len": 1024,
- "max_new_tokens": 1024,
- "steps": 1024,
- "block_length": 16,
- "temperature": 0.0,
- "cfg_scale": 0.0,
- },
- "t2i": {
- "prompt_max_text_len": 128,
- "image_token_count": 1024,
- "mask_token_id": 126336,
- "codebook_size": 8192,
- "timesteps": 20,
- "guidance_scale": 3.5,
- "temperature": 1.0,
- },
- "i2i": {
- "prompt_max_text_len": 128,
- "mask_token_id": 126336,
- "codebook_size": 8192,
- "timesteps": 64,
- "guidance_scale": 3.5,
- "temperature": 1.0,
- "image_resolution": 336,
- "use_train_i2i_prompt": True,
- },
- "t2s": {
- "runtime_task": "t2s_mmu_like",
- "prompting_task": "t2s_gen",
- "prompt_max_text_len": 1024,
- "t2s_token_length": 512,
- "mask_token_id": 126336,
- "codebook_size": 8192,
- "audio_codebook_size": 4096,
- "steps": 512,
- "block_length": 128,
- "temperature": 1.0,
- "cfg_scale": 2.5,
- "t2s_condition": "gender-female_emotion-neutral_speed-normal_pitch-normal",
- },
-}
-
-DEFAULT_DYNIN_T2S_INSTRUCTION = "Please read the following text naturally."
-
-DYNIN_SPECIAL_TOKENS = (
- "<|soi|>",
- "<|eoi|>",
- "<|sov|>",
- "<|eov|>",
- "<|t2i|>",
- "<|mmu|>",
- "<|t2v|>",
- "<|v2v|>",
- "<|lvg|>",
- "<|i2i|>",
- "<|ti2ti|>",
- "<|v2t|>",
- "<|v2s|>",
- "<|s2t|>",
- "<|t2s|>",
- "<|s2s|>",
- "<|soa|>",
- "<|eoa|>",
-)
-
-_DYNIN_ONLINE_PROMPT_TOKEN_BY_TASK = {
- "t2i": "<|t2i|>",
- "i2i": "<|i2i|>",
- "t2s": "<|t2s|>",
-}
-
-_DYNIN_MODALITY_PLACEHOLDERS = (
- "<|soi|><|image|><|eoi|>",
- "<|sov|><|video|><|eov|>",
- "<|soa|><|audio|><|eoa|>",
-)
-
-_DYNIN_CONFIG_CANDIDATE_RELPATHS = (
- "configs/dynin_omni.yaml",
- "models/configs/dynin_omni.yaml",
- "vllm_omni/model_executor/models/dynin_omni/configs/dynin_omni.yaml",
- "vllm_omni/model_executor/stage_configs/dynin_omni.yaml",
- "dynin_omni.yaml",
-)
-
-_DYNIN_REMOTE_ALLOW_PATTERNS = ("*.py", "*.json", "*.yaml", "*.yml")
-
-_DYNIN_REMOTE_CACHE_LOCK = threading.Lock()
-_DYNIN_REMOTE_PACKAGE_BY_SNAPSHOT: dict[str, str] = {}
-_DYNIN_REMOTE_ATTR_CACHE: dict[tuple[str, str, str, str | None, bool], Any] = {}
-
-
-@dataclass(frozen=True)
-class DyninInferSources:
- model_source: str
- tokenizer_source: str
- vq_image_source: str
- vq_audio_source: str
- model_local_files_only: bool
- vq_image_local_files_only: bool
- vq_audio_local_files_only: bool
- config_path: str | None = None
-
- @property
- def local_files_only(self) -> bool:
- return self.model_local_files_only
-
-
-@dataclass(frozen=True)
-class RemoteCodeSettings:
- default_repo: str
- repo_env: str
- revision_env: str
- local_only_env: str
-
-
-DYNIN_REMOTE_SETTINGS = RemoteCodeSettings(
- default_repo=DEFAULT_DYNIN_REMOTE_CODE_REPO,
- repo_env="DYNIN_REMOTE_CODE_REPO_ID",
- revision_env="DYNIN_REMOTE_CODE_REVISION",
- local_only_env="DYNIN_REMOTE_CODE_LOCAL_FILES_ONLY",
-)
-
-MAGVIT_REMOTE_SETTINGS = RemoteCodeSettings(
- default_repo=DEFAULT_MAGVIT_REMOTE_CODE_REPO,
- repo_env="DYNIN_MAGVIT_REMOTE_CODE_REPO_ID",
- revision_env="DYNIN_MAGVIT_REMOTE_CODE_REVISION",
- local_only_env="DYNIN_MAGVIT_REMOTE_CODE_LOCAL_FILES_ONLY",
-)
-
-
-def unwrap_first_value(value: Any, default: Any = None) -> Any:
- if value is None:
- return default
- if isinstance(value, list):
- return default if not value else value[0]
- if isinstance(value, torch.Tensor):
- if value.numel() == 0:
- return default
- if value.numel() == 1:
- return value.item()
- return value
- return value
-
-
-def normalize_runtime_info(runtime_additional_information: Any) -> dict[str, Any]:
- if isinstance(runtime_additional_information, list):
- if not runtime_additional_information:
- return {}
- first = runtime_additional_information[0]
- return first if isinstance(first, dict) else {}
- if isinstance(runtime_additional_information, dict):
- return runtime_additional_information
- return {}
-
-
-def logical_dynin_task(task: Any) -> str:
- task_text = str(unwrap_first_value(task, "") or "").strip().lower()
- if task_text in ("t2s", "t2s_mmu_like", "t2s_fixed"):
- return "t2s"
- if task_text in ("t2i", "i2i"):
- return task_text
- return "t2t"
-
-
-def dynin_runtime_fallback(task: str, key: str, value: Any = None) -> Any:
- if isinstance(value, str):
- if value.strip() != "":
- return value
- elif value is not None:
- return value
- return DYNIN_TASK_RUNTIME_FALLBACKS.get(task, {}).get(key)
-
-
-def coerce_token_ids_1d(
- value: Any,
- ref_device: torch.device | None = None,
-) -> torch.Tensor:
- if isinstance(value, tuple):
- value = value[0]
-
- if isinstance(value, list):
- if not value:
- device = ref_device or torch.device("cpu")
- return torch.empty(0, dtype=torch.long, device=device)
- if isinstance(value[0], torch.Tensor):
- value = value[0]
- else:
- value = torch.tensor(
- value[0] if isinstance(value[0], list) else value,
- dtype=torch.long,
- )
-
- if not isinstance(value, torch.Tensor):
- value = torch.tensor(value, dtype=torch.long)
-
- if value.ndim == 0:
- value = value.unsqueeze(0)
- if value.ndim > 1:
- value = value[0]
-
- if ref_device is not None and value.device != ref_device:
- value = value.to(ref_device)
-
- return value.to(dtype=torch.long).contiguous()
-
-
-def _first_positive_int(value: Any) -> int | None:
- if value is None:
- return None
- if isinstance(value, torch.Tensor):
- if value.numel() != 1:
- return None
- value = value.item()
- try:
- value = int(value)
- except (TypeError, ValueError):
- return None
- return value if value > 0 else None
-
-
-def resolve_hidden_size(
- *,
- vllm_config: VllmConfig,
- model: Any | None = None,
- default: int = 1024,
-) -> int:
- if model is not None:
- try:
- embeddings = model.get_input_embeddings()
- weight = getattr(embeddings, "weight", None)
- if isinstance(weight, torch.Tensor) and weight.ndim >= 2:
- hidden_size = _first_positive_int(weight.shape[-1])
- if hidden_size is not None:
- return hidden_size
- except Exception:
- pass
-
- model_cfg = getattr(model, "config", None)
- for key in ("hidden_size", "d_model", "n_embd", "dim", "model_dim", "embed_dim"):
- hidden_size = _first_positive_int(getattr(model_cfg, key, None))
- if hidden_size is not None:
- return hidden_size
-
- for config_obj in (
- getattr(vllm_config.model_config, "hf_config", None),
- getattr(vllm_config.model_config, "hf_text_config", None),
- ):
- if config_obj is None:
- continue
- for key in ("hidden_size", "d_model", "n_embd", "dim", "model_dim", "embed_dim"):
- value = config_obj.get(key) if isinstance(config_obj, dict) else getattr(config_obj, key, None)
- hidden_size = _first_positive_int(value)
- if hidden_size is not None:
- return hidden_size
-
- return default
-
-
-def build_zero_input_embeddings(
- *,
- input_ids: torch.Tensor,
- hidden_size: int,
- stage_name: str,
- dtype: torch.dtype = torch.bfloat16,
-) -> torch.Tensor:
- if input_ids.ndim == 0:
- shape = (1, hidden_size)
- elif input_ids.ndim == 1:
- shape = (input_ids.shape[0], hidden_size)
- elif input_ids.ndim == 2:
- shape = (input_ids.shape[0], input_ids.shape[1], hidden_size)
- else:
- raise ValueError(f"Unsupported input_ids rank for {stage_name}: {input_ids.ndim}")
- return torch.zeros(shape, dtype=dtype, device=input_ids.device)
-
-
-def _to_bool(value: Any, default: bool = False) -> bool:
- if value is None:
- return default
- if isinstance(value, bool):
- return value
- if isinstance(value, (int, float)):
- return bool(value)
-
- text = str(value).strip().lower()
- if text in ("1", "true", "yes", "y", "on"):
- return True
- if text in ("0", "false", "no", "n", "off", "", "none", "null"):
- return False
- return default
-
-
-def _runtime_value(runtime_info: dict[str, Any], key: str) -> Any:
- return unwrap_first_value(runtime_info.get(key), None)
-
-
-def _runtime_first_value(runtime_info: dict[str, Any], keys: tuple[str, ...]) -> Any:
- for key in keys:
- value = _runtime_value(runtime_info, key)
- if value is not None:
- return value
- return None
-
-
-def _node_value(node: Any, key: str, default: Any = None) -> Any:
- if node is None:
- return default
- if isinstance(node, dict):
- return node.get(key, default)
- try:
- return node.get(key, default)
- except Exception:
- return getattr(node, key, default)
-
-
-def _looks_like_hf_repo_id(value: str | None) -> bool:
- if not isinstance(value, str):
- return False
- if value.count("/") != 1:
- return False
- org, name = value.split("/", 1)
- return bool(org and name)
-
-
-def _find_dynin_config_under_root(root: Path) -> Path | None:
- for rel_path in _DYNIN_CONFIG_CANDIDATE_RELPATHS:
- candidate = root.expanduser() / rel_path
- if candidate.exists():
- return candidate.resolve()
- return None
-
-
-@lru_cache(maxsize=16)
-def _resolve_dynin_config_from_hf_repo(repo_id: str) -> str | None:
- if not _looks_like_hf_repo_id(repo_id) or snapshot_download is None:
- return None
-
- try:
- snapshot_dir = (
- Path(
- snapshot_download(
- repo_id=repo_id,
- repo_type="model",
- allow_patterns=list(_DYNIN_CONFIG_CANDIDATE_RELPATHS),
- local_files_only=True,
- )
- )
- .expanduser()
- .resolve()
- )
- except Exception:
- return None
-
- found = _find_dynin_config_under_root(snapshot_dir)
- return str(found) if found is not None else None
-
-
-def _resolve_existing_path(path_like: Any, source_name: str) -> str | None:
- if path_like is None:
- return None
- text = str(path_like).strip()
- if not text:
- return None
-
- path = Path(text).expanduser()
- if path.is_file():
- return str(path.resolve())
-
- logger.warning(
- "DYNIN config path from %s does not exist: %s. Falling back to auto-discovery.",
- source_name,
- path,
- )
- return None
-
-
-def _resolve_config_path(vllm_config: VllmConfig, runtime_info: dict[str, Any]) -> str | None:
- for value, name in (
- (_runtime_value(runtime_info, "dynin_config_path"), "runtime_info.dynin_config_path"),
- (os.getenv("DYNIN_CONFIG_PATH"), "DYNIN_CONFIG_PATH"),
- (getattr(vllm_config.model_config, "dynin_config_path", None), "vllm_config.model_config.dynin_config_path"),
- ):
- resolved = _resolve_existing_path(value, name)
- if resolved:
- return resolved
-
- model_source = str(getattr(vllm_config.model_config, "model", "") or "")
- tokenizer_source = str(getattr(vllm_config.model_config, "tokenizer", "") or "")
- hf_config = getattr(vllm_config.model_config, "hf_config", None)
- hf_name_or_path = (
- hf_config.get("_name_or_path") if isinstance(hf_config, dict) else getattr(hf_config, "_name_or_path", None)
- )
-
- hf_repo_candidates: list[str] = []
- for source in (model_source, tokenizer_source, hf_name_or_path):
- if not _looks_like_hf_repo_id(source):
- continue
- source = str(source)
- if source not in hf_repo_candidates:
- hf_repo_candidates.append(source)
-
- for source in hf_repo_candidates:
- resolved = _resolve_dynin_config_from_hf_repo(source)
- if resolved is not None:
- logger.info("Resolved dynin config from Hugging Face cache for %s: %s", source, resolved)
- return resolved
-
- for source in (model_source, tokenizer_source):
- source_path = Path(source).expanduser()
- if source_path.is_dir():
- found = _find_dynin_config_under_root(source_path)
- if found is not None:
- return str(found)
-
- module_root = Path(__file__).resolve().parent
- for bundled in (
- module_root / "configs" / "dynin_omni.yaml",
- module_root / "models" / "configs" / "dynin_omni.yaml",
- module_root.parent / "stage_configs" / "dynin_omni.yaml",
- ):
- if bundled.exists():
- return str(bundled)
-
- return None
-
-
-@lru_cache(maxsize=16)
-def _load_omega_config(config_path: str) -> Any:
- try:
- from omegaconf import OmegaConf
- except ImportError as e:
- raise ImportError(
- f"omegaconf is required to load Dynin config files. Install it to read config: {config_path}"
- ) from e
- return OmegaConf.load(config_path)
-
-
-def resolve_dynin_infer_sources(
- *,
- vllm_config: VllmConfig,
- runtime_info: dict[str, Any] | None = None,
-) -> DyninInferSources:
- runtime_info = runtime_info or {}
-
- base_model_source = str(getattr(vllm_config.model_config, "model", ""))
- base_model_path = Path(base_model_source).expanduser()
- local_vllm_model_source = str(base_model_path) if base_model_path.is_dir() else None
-
- model_source = base_model_source
- tokenizer_source = model_source
- vq_image_source = DEFAULT_VQ_IMAGE_SOURCE
- vq_audio_source = DEFAULT_VQ_AUDIO_SOURCE
- model_local_files_only = False
- vq_image_local_files_only = False
- vq_audio_local_files_only = False
-
- resolver_source: str | None = base_model_source if base_model_source else None
- resolver_local_files_only: bool | None = True if base_model_path.is_dir() else None
- resolve_model_pretrained_source_fn = get_dynin_config_resolver_attr(
- "resolve_model_pretrained_source",
- source=resolver_source,
- local_files_only=resolver_local_files_only,
- )
- resolve_tokenizer_source_fn = get_dynin_config_resolver_attr(
- "resolve_tokenizer_source",
- source=resolver_source,
- local_files_only=resolver_local_files_only,
- )
- resolve_model_local_files_only_fn = get_dynin_config_resolver_attr(
- "resolve_model_local_files_only",
- source=resolver_source,
- local_files_only=resolver_local_files_only,
- )
- resolve_vq_cfg_block_fn = get_dynin_config_resolver_attr(
- "resolve_vq_cfg_block",
- source=resolver_source,
- local_files_only=resolver_local_files_only,
- )
- resolve_vq_repo_source_fn = get_dynin_config_resolver_attr(
- "resolve_vq_repo_source",
- source=resolver_source,
- local_files_only=resolver_local_files_only,
- )
-
- config_path = _resolve_config_path(vllm_config, runtime_info)
- if config_path:
- config_file = Path(config_path).expanduser()
- if config_file.exists():
- try:
- dynin_cfg = _load_omega_config(str(config_file))
- model_source = resolve_model_pretrained_source_fn(
- dynin_cfg,
- default=model_source,
- )
- tokenizer_source = resolve_tokenizer_source_fn(
- dynin_cfg,
- default=tokenizer_source,
- )
- model_local_files_only = resolve_model_local_files_only_fn(
- dynin_cfg,
- default=model_local_files_only,
- )
- vq_image_cfg = resolve_vq_cfg_block_fn(dynin_cfg, modality="image")
- vq_audio_cfg = resolve_vq_cfg_block_fn(dynin_cfg, modality="audio")
- vq_image_source = resolve_vq_repo_source_fn(
- vq_image_cfg,
- default=vq_image_source,
- )
- vq_audio_source = resolve_vq_repo_source_fn(
- vq_audio_cfg,
- default=vq_audio_source,
- )
- vq_image_local_files_only = _to_bool(
- _node_value(vq_image_cfg, "local_files_only", None),
- default=model_local_files_only,
- )
- vq_audio_local_files_only = _to_bool(
- _node_value(vq_audio_cfg, "local_files_only", None),
- default=model_local_files_only,
- )
- except Exception as e:
- logger.warning(
- "Failed to resolve DYNIN inference config from %s: %s",
- config_file,
- e,
- )
- else:
- logger.warning("DYNIN config path does not exist: %s", config_file)
-
- runtime_model_source = _runtime_value(runtime_info, "dynin_model_path")
- if runtime_model_source:
- model_source = str(runtime_model_source)
-
- runtime_tokenizer_source = _runtime_value(runtime_info, "tokenizer_path")
- if runtime_tokenizer_source:
- tokenizer_source = str(runtime_tokenizer_source)
-
- runtime_vq_image_source = _runtime_value(runtime_info, "vq_model_image_path")
- if runtime_vq_image_source is None:
- runtime_vq_image_source = _runtime_value(runtime_info, "vq_model_path_image")
- if runtime_vq_image_source:
- vq_image_source = str(runtime_vq_image_source)
-
- runtime_vq_audio_source = _runtime_value(runtime_info, "vq_model_audio_path")
- if runtime_vq_audio_source is None:
- runtime_vq_audio_source = _runtime_value(runtime_info, "vq_model_path_audio")
- if runtime_vq_audio_source:
- vq_audio_source = str(runtime_vq_audio_source)
-
- runtime_local_global = _runtime_value(runtime_info, "local_files_only")
- runtime_local_model = _runtime_first_value(
- runtime_info,
- ("model_local_files_only", "local_files_only_model"),
- )
- runtime_local_vq_image = _runtime_first_value(
- runtime_info,
- ("vq_model_image_local_files_only", "local_files_only_vq_image"),
- )
- runtime_local_vq_audio = _runtime_first_value(
- runtime_info,
- ("vq_model_audio_local_files_only", "local_files_only_vq_audio"),
- )
-
- if runtime_local_global is not None:
- global_local = _to_bool(runtime_local_global, default=False)
- if runtime_local_model is None:
- model_local_files_only = global_local
- if runtime_local_vq_image is None:
- vq_image_local_files_only = global_local
- if runtime_local_vq_audio is None:
- vq_audio_local_files_only = global_local
-
- if runtime_local_model is not None:
- model_local_files_only = _to_bool(
- runtime_local_model,
- default=model_local_files_only,
- )
- if runtime_local_vq_image is not None:
- vq_image_local_files_only = _to_bool(
- runtime_local_vq_image,
- default=vq_image_local_files_only,
- )
- if runtime_local_vq_audio is not None:
- vq_audio_local_files_only = _to_bool(
- runtime_local_vq_audio,
- default=vq_audio_local_files_only,
- )
-
- if runtime_local_global is None and runtime_local_model is None and local_vllm_model_source is not None:
- model_local_files_only = True
-
- if local_vllm_model_source is not None:
- if not runtime_model_source:
- if model_source != local_vllm_model_source:
- logger.info(
- "DYNIN infer model source overridden to local vLLM model path: %s (from %s)",
- local_vllm_model_source,
- model_source,
- )
- model_source = local_vllm_model_source
- if not runtime_tokenizer_source:
- tokenizer_source = local_vllm_model_source
-
- return DyninInferSources(
- model_source=model_source,
- tokenizer_source=tokenizer_source,
- vq_image_source=vq_image_source,
- vq_audio_source=vq_audio_source,
- model_local_files_only=model_local_files_only,
- vq_image_local_files_only=vq_image_local_files_only,
- vq_audio_local_files_only=vq_audio_local_files_only,
- config_path=config_path,
- )
-
-
-def _resolve_remote_source(source: str | None, settings: RemoteCodeSettings) -> str:
- if isinstance(source, str):
- stripped = source.strip()
- if stripped:
- source_path = Path(stripped).expanduser()
- if source_path.is_dir():
- return str(source_path.resolve())
- if _looks_like_hf_repo_id(stripped):
- return stripped
-
- env_repo = os.getenv(settings.repo_env)
- if _looks_like_hf_repo_id(env_repo):
- return str(env_repo).strip()
-
- return settings.default_repo
-
-
-def _resolve_remote_revision(revision: str | None, settings: RemoteCodeSettings) -> str | None:
- if isinstance(revision, str) and revision.strip():
- return revision.strip()
- env_revision = os.getenv(settings.revision_env)
- if isinstance(env_revision, str) and env_revision.strip():
- return env_revision.strip()
- return None
-
-
-def _resolve_remote_local_only(local_files_only: bool | None, settings: RemoteCodeSettings) -> bool:
- if local_files_only is not None:
- return bool(local_files_only)
- return _to_bool(os.getenv(settings.local_only_env), default=False)
-
-
-def _resolve_remote_snapshot_dir(
- *,
- source: str,
- revision: str | None,
- local_files_only: bool,
-) -> str:
- source_path = Path(source).expanduser()
- if source_path.is_dir():
- return str(source_path.resolve())
-
- if snapshot_download is None:
- raise RuntimeError("huggingface_hub is required to load remote code.")
-
- kwargs: dict[str, Any] = {
- "repo_id": source,
- "repo_type": "model",
- "allow_patterns": list(_DYNIN_REMOTE_ALLOW_PATTERNS),
- "local_files_only": bool(local_files_only),
- }
- if revision is not None:
- kwargs["revision"] = revision
-
- try:
- return str(snapshot_download(**kwargs))
- except TypeError:
- kwargs.pop("local_files_only", None)
- return str(snapshot_download(**kwargs))
-
-
-def _ensure_remote_package(snapshot_dir: str) -> str:
- with _DYNIN_REMOTE_CACHE_LOCK:
- existing = _DYNIN_REMOTE_PACKAGE_BY_SNAPSHOT.get(snapshot_dir)
- if existing is not None:
- return existing
-
- digest = hashlib.sha1(snapshot_dir.encode("utf-8")).hexdigest()[:12]
- package_name = f"_dynin_hf_remote_{digest}"
-
- package = types.ModuleType(package_name)
- package.__path__ = [snapshot_dir] # type: ignore[attr-defined]
- package.__file__ = str(Path(snapshot_dir) / "__init__.py")
-
- sys.modules.setdefault(package_name, package)
- _DYNIN_REMOTE_PACKAGE_BY_SNAPSHOT[snapshot_dir] = package_name
- return package_name
-
-
-def _load_remote_module(
- *,
- module_name: str,
- source: str,
- revision: str | None,
- local_files_only: bool,
-):
- snapshot_dir = _resolve_remote_snapshot_dir(
- source=source,
- revision=revision,
- local_files_only=local_files_only,
- )
-
- module_path = Path(snapshot_dir) / f"{module_name}.py"
- if not module_path.is_file():
- raise ImportError(f"Remote code module '{module_name}.py' not found under '{snapshot_dir}'. source={source!r}")
-
- package_name = _ensure_remote_package(snapshot_dir)
- full_name = f"{package_name}.{module_name}"
-
- existing = sys.modules.get(full_name)
- if existing is not None:
- return existing
-
- spec = importlib.util.spec_from_file_location(full_name, module_path)
- if spec is None or spec.loader is None:
- raise ImportError(f"Failed to create import spec for '{module_path}'.")
-
- module = importlib.util.module_from_spec(spec)
- module.__package__ = package_name
- sys.modules[full_name] = module
- try:
- spec.loader.exec_module(module)
- except Exception:
- sys.modules.pop(full_name, None)
- raise
- return module
-
-
-def resolve_remote_attr(
- attr_name: str,
- *,
- module_name: str,
- settings: RemoteCodeSettings,
- source: str | None = None,
- revision: str | None = None,
- local_files_only: bool | None = None,
- fallback_module_names: Iterable[str] = (),
- optional: bool = False,
-) -> Any | None:
- resolved_source = _resolve_remote_source(source, settings)
- resolved_revision = _resolve_remote_revision(revision, settings)
- resolved_local_only = _resolve_remote_local_only(local_files_only, settings)
-
- module_candidates = [module_name, *[m for m in fallback_module_names if m and m != module_name]]
- last_error: Exception | None = None
-
- for candidate in module_candidates:
- cache_key = (attr_name, candidate, resolved_source, resolved_revision, resolved_local_only)
- cached = _DYNIN_REMOTE_ATTR_CACHE.get(cache_key)
- if cached is not None:
- return cached
-
- try:
- module = _load_remote_module(
- module_name=candidate,
- source=resolved_source,
- revision=resolved_revision,
- local_files_only=resolved_local_only,
- )
- if hasattr(module, attr_name):
- value = getattr(module, attr_name)
- _DYNIN_REMOTE_ATTR_CACHE[cache_key] = value
- return value
- except Exception as e:
- last_error = e
-
- if optional:
- if last_error is not None:
- logger.debug(
- "Optional remote attr not found: attr=%s source=%s revision=%s err=%s",
- attr_name,
- resolved_source,
- resolved_revision,
- last_error,
- )
- return None
-
- raise ImportError(
- f"Failed to resolve '{attr_name}' from remote code "
- f"(source={resolved_source!r}, revision={resolved_revision!r}, modules={module_candidates})."
- ) from last_error
-
-
-_DYNIN_MODELING_REMOTE_EXPORTS = {
- "DyninOmniConfig": "DyninOmniConfig",
- "DyninOmniModelLM": "DyninOmniModelLM",
- "VideoTokenMerger": "VideoTokenMerger",
-}
-
-_DYNIN_SAMPLING_REMOTE_EXPORTS = {
- "log": "log",
- "gumbel_noise": "gumbel_noise",
- "gumbel_sample": "gumbel_sample",
- "top_k": "top_k",
- "mask_by_random_topk": "mask_by_random_topk",
- "cosine_schedule": "cosine_schedule",
- "linear_schedule": "linear_schedule",
- "pow": "pow",
- "sigmoid_schedule": "sigmoid_schedule",
- "get_mask_schedule": "get_mask_schedule",
- "top_k_top_p_filtering": "top_k_top_p_filtering",
-}
-
-_DYNIN_CONFIG_RESOLVER_REMOTE_EXPORTS = {
- "resolve_model_pretrained_source": "resolve_model_pretrained_source",
- "resolve_tokenizer_source": "resolve_tokenizer_source",
- "resolve_model_local_files_only": "resolve_model_local_files_only",
- "resolve_vq_cfg_block": "resolve_vq_cfg_block",
- "resolve_vq_repo_source": "resolve_vq_repo_source",
-}
-
-_DYNIN_MAGVIT_REMOTE_EXPORTS = {
- "VQGANEncoder": "VQGANEncoder",
- "VQGANDecoder": "VQGANDecoder",
- "LFQuantizer": "LFQuantizer",
- "MAGVITv2": "MAGVITv2",
-}
-
-
-def _get_export_attr(
- name: str,
- export_map: dict[str, str],
- *,
- module_name: str,
- settings: RemoteCodeSettings,
- source: str | None = None,
- revision: str | None = None,
- local_files_only: bool | None = None,
- optional: bool = False,
-) -> Any | None:
- attr_name = export_map.get(name)
- if attr_name is None:
- raise AttributeError(f"Unsupported export: {name!r}")
-
- return resolve_remote_attr(
- attr_name,
- module_name=module_name,
- settings=settings,
- source=source,
- revision=revision,
- local_files_only=local_files_only,
- optional=optional,
- )
-
-
-def get_dynin_modeling_attr(name: str) -> Any:
- return _get_export_attr(
- name,
- _DYNIN_MODELING_REMOTE_EXPORTS,
- module_name="modeling_dynin_omni",
- settings=DYNIN_REMOTE_SETTINGS,
- )
-
-
-def get_dynin_sampling_attr(name: str) -> Any:
- return _get_export_attr(
- name,
- _DYNIN_SAMPLING_REMOTE_EXPORTS,
- module_name="sampling",
- settings=DYNIN_REMOTE_SETTINGS,
- )
-
-
-def get_dynin_config_resolver_attr(
- name: str,
- *,
- source: str | None = None,
- revision: str | None = None,
- local_files_only: bool | None = None,
-) -> Any:
- attr_name = _DYNIN_CONFIG_RESOLVER_REMOTE_EXPORTS.get(name)
- if attr_name is None:
- raise AttributeError(f"Unsupported Dynin config_resolver export: {name!r}")
-
- if source is not None:
- value = resolve_remote_attr(
- attr_name,
- module_name="config_resolver",
- settings=DYNIN_REMOTE_SETTINGS,
- source=source,
- revision=revision,
- local_files_only=local_files_only,
- optional=True,
- )
- if value is not None:
- return value
-
- return resolve_remote_attr(
- attr_name,
- module_name="config_resolver",
- settings=DYNIN_REMOTE_SETTINGS,
- source=DEFAULT_DYNIN_REMOTE_CODE_REPO,
- revision=revision,
- local_files_only=local_files_only,
- optional=False,
- )
-
-
-def get_dynin_magvit_attr(
- name: str,
- *,
- source: str | None = None,
- revision: str | None = None,
- local_files_only: bool | None = None,
-) -> Any:
- attr_name = _DYNIN_MAGVIT_REMOTE_EXPORTS.get(name)
- if attr_name is None:
- raise AttributeError(f"Unsupported Dynin MAGVIT export: {name!r}")
-
- value = resolve_remote_attr(
- attr_name,
- module_name="modeling_magvitv2",
- settings=MAGVIT_REMOTE_SETTINGS,
- source=source,
- revision=revision,
- local_files_only=local_files_only,
- optional=True,
- )
- if value is not None:
- return value
-
- resolved_source = _resolve_remote_source(source, MAGVIT_REMOTE_SETTINGS)
- resolved_revision = _resolve_remote_revision(revision, MAGVIT_REMOTE_SETTINGS)
- resolved_local_only = _resolve_remote_local_only(local_files_only, MAGVIT_REMOTE_SETTINGS)
-
- if resolved_source != DEFAULT_MAGVIT_REMOTE_CODE_REPO:
- return resolve_remote_attr(
- attr_name,
- module_name="modeling_magvitv2",
- settings=MAGVIT_REMOTE_SETTINGS,
- source=DEFAULT_MAGVIT_REMOTE_CODE_REPO,
- revision=resolved_revision,
- local_files_only=resolved_local_only,
- optional=False,
- )
-
- raise ImportError(
- f"Failed to resolve MAGVIT attr '{attr_name}' from source={resolved_source!r} (revision={resolved_revision!r})."
- )
-
-
-def build_dynin_chat_prompt(content: str) -> str:
- return (
- f"<|start_header_id|>user<|end_header_id|>\n{content}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
- )
-
-
-def extract_dynin_user_prompt_text(decoded_prompt: str) -> str:
- text = str(decoded_prompt or "")
- assistant_marker = "<|start_header_id|>assistant<|end_header_id|>"
- user_marker = "<|start_header_id|>user<|end_header_id|>"
- end_header_marker = "<|end_header_id|>"
- eot_marker = "<|eot_id|>"
-
- if assistant_marker in text:
- text = text.rsplit(assistant_marker, 1)[0]
- if eot_marker in text:
- text = text.rsplit(eot_marker, 1)[0]
- if user_marker in text:
- text = text.rsplit(user_marker, 1)[-1]
- if end_header_marker in text:
- text = text.split(end_header_marker, 1)[-1]
- return text.strip()
-
-
-def normalize_dynin_online_prompt_text(task: str, decoded_prompt: str) -> str:
- text = extract_dynin_user_prompt_text(decoded_prompt)
- if not text:
- text = str(decoded_prompt or "")
-
- for placeholder in _DYNIN_MODALITY_PLACEHOLDERS:
- text = text.replace(placeholder, " ")
-
- task_token = _DYNIN_ONLINE_PROMPT_TOKEN_BY_TASK.get(task)
- if task_token:
- text = text.replace(task_token, " ", 1)
-
- text = " ".join(text.split()).strip()
-
- if task == "t2s":
- if not text:
- text = "Hello. This is a default text-to-speech sample."
- text = build_dynin_chat_prompt(f"{DEFAULT_DYNIN_T2S_INSTRUCTION}\n{text}")
- elif task in {"t2i", "i2i"} and not text:
- text = "A high quality detailed image."
-
- return text
-
-
-def infer_dynin_online_task(
- *,
- decoded_prompt: str,
- has_image: bool = False,
- has_audio: bool = False,
- has_video: bool = False,
-) -> str:
- prompt = str(decoded_prompt or "")
- if "<|i2i|>" in prompt:
- return "i2i"
- if "<|t2i|>" in prompt and not has_audio and not has_video:
- return "t2i"
- if "<|t2s|>" in prompt and not has_audio and not has_video:
- return "t2s"
- return "t2t"
-
-
-def build_dynin_prompt_payload(
- *,
- task: str,
- text: str,
- image_tokens: torch.Tensor | None,
- image_placeholder_tokens: int,
- audio_placeholder_tokens: int,
- image_token_offset: int,
- mask_token_id: int,
- use_train_i2i_prompt: bool,
-) -> tuple[Any, str]:
- _, prompting_task, _, _ = DYNIN_TASK_DEFAULT_RUNTIME[task]
-
- if task == "t2t":
- payload = ([[]], [build_dynin_chat_prompt(text)])
- return payload, prompting_task
-
- if task == "t2i":
- image_placeholder = torch.full(
- (1, int(image_placeholder_tokens)),
- fill_value=int(mask_token_id),
- dtype=torch.long,
- )
- payload = ([text], image_placeholder)
- return payload, prompting_task
-
- if task == "i2i":
- if image_tokens is None:
- raise ValueError("i2i requires image tokens")
- src = image_tokens.view(1, -1).long() + int(image_token_offset)
- target_len = int(image_placeholder_tokens) if image_placeholder_tokens > 0 else int(src.shape[1])
- image_placeholder = torch.full(
- (1, target_len),
- fill_value=int(mask_token_id),
- dtype=torch.long,
- )
- if use_train_i2i_prompt:
- labels_placeholder = torch.full(
- (1, target_len),
- fill_value=-100,
- dtype=torch.long,
- )
- payload = ([text], src, image_placeholder, labels_placeholder)
- return payload, "i2i"
- payload = ([text], src, image_placeholder)
- return payload, "i2i_gen"
-
- if task == "t2s":
- audio_placeholder = torch.full(
- (1, int(audio_placeholder_tokens)),
- fill_value=int(mask_token_id),
- dtype=torch.long,
- )
- payload = ([text], audio_placeholder)
- return payload, prompting_task
-
- raise ValueError(f"Unsupported Dynin online bootstrap task: {task}")
-
-
-def _wrap_runtime_field(value: Any) -> list[Any]:
- return [value]
-
-
-def build_dynin_online_runtime_info(
- *,
- task: str,
- text_vocab_size: int,
- infer_sources: DyninInferSources,
- dynin_config_path: str | None = None,
- prompting_input: Any | None = None,
- attention_mask: list[int] | None = None,
- prompt_length: int | None = None,
- uncond_prompting_input: Any | None = None,
- image_token_count: int = 0,
- t2s_token_length: int | None = None,
- use_train_i2i_prompt: bool | None = None,
-) -> dict[str, Any]:
- runtime_task, prompting_task, detok_id, _ = DYNIN_TASK_DEFAULT_RUNTIME[task]
-
- prompt_max_text_len = int(dynin_runtime_fallback(task, "prompt_max_text_len", None) or 1024)
- max_new_tokens = int(dynin_runtime_fallback(task, "max_new_tokens", None) or 256)
- steps = int(dynin_runtime_fallback(task, "steps", None) or 256)
- block_length = int(dynin_runtime_fallback(task, "block_length", None) or 2)
- temperature = float(dynin_runtime_fallback(task, "temperature", None) or 0.0)
- cfg_scale = float(dynin_runtime_fallback(task, "cfg_scale", None) or 0.0)
- remasking = str(dynin_runtime_fallback(task, "remasking", None) or "low_confidence")
- timesteps = int(dynin_runtime_fallback(task, "timesteps", None) or 20)
- guidance_scale = float(dynin_runtime_fallback(task, "guidance_scale", None) or 0.0)
- mask_token_id = int(dynin_runtime_fallback(task, "mask_token_id", None) or 126336)
- codebook_size = int(dynin_runtime_fallback(task, "codebook_size", None) or 8192)
- audio_codebook_size = int(dynin_runtime_fallback(task, "audio_codebook_size", None) or 4096)
- image_resolution = int(dynin_runtime_fallback(task, "image_resolution", None) or 336)
- if image_token_count <= 0 and task in {"t2i", "i2i"}:
- fallback_count = dynin_runtime_fallback(task, "image_token_count", None)
- if fallback_count is not None:
- image_token_count = int(fallback_count)
- else:
- image_token_count = max(1, (image_resolution // 16) ** 2)
-
- if t2s_token_length is None:
- t2s_token_length = int(dynin_runtime_fallback(task, "t2s_token_length", None) or 383)
- t2s_condition = str(
- dynin_runtime_fallback(
- task,
- "t2s_condition",
- None,
- )
- or "gender-female_emotion-neutral_speed-normal_pitch-normal"
- )
- if use_train_i2i_prompt is None:
- use_train_i2i_prompt = bool(dynin_runtime_fallback(task, "use_train_i2i_prompt", task == "i2i"))
-
- runtime_info: dict[str, Any] = {
- "task": _wrap_runtime_field(runtime_task),
- "prompting_task": _wrap_runtime_field(prompting_task),
- "detok_id": _wrap_runtime_field(int(detok_id)),
- "prompt_max_text_len": _wrap_runtime_field(prompt_max_text_len),
- "prompting_max_text_len": _wrap_runtime_field(prompt_max_text_len),
- "cond_dropout_prob": _wrap_runtime_field(0.0),
- "prompting_cond_dropout_prob": _wrap_runtime_field(0.0),
- "tokenizer_path": _wrap_runtime_field(str(infer_sources.tokenizer_source)),
- "text_vocab_size": _wrap_runtime_field(int(text_vocab_size)),
- "model_local_files_only": _wrap_runtime_field(bool(infer_sources.model_local_files_only)),
- "max_new_tokens": _wrap_runtime_field(int(t2s_token_length if task == "t2s" else max_new_tokens)),
- "steps": _wrap_runtime_field(steps),
- "block_length": _wrap_runtime_field(block_length),
- "temperature": _wrap_runtime_field(temperature),
- "cfg_scale": _wrap_runtime_field(cfg_scale),
- "remasking": _wrap_runtime_field(remasking),
- "mask_id": _wrap_runtime_field(mask_token_id),
- "mask_token_id": _wrap_runtime_field(mask_token_id),
- "codebook_size": _wrap_runtime_field(codebook_size),
- "audio_codebook_size": _wrap_runtime_field(audio_codebook_size),
- "timesteps": _wrap_runtime_field(timesteps),
- "guidance_scale": _wrap_runtime_field(guidance_scale),
- "noise_type": _wrap_runtime_field("mask"),
- "noise_schedule_name": _wrap_runtime_field("cosine"),
- "noise_schedule_params": _wrap_runtime_field({}),
- "seq_len": _wrap_runtime_field(int(image_token_count)),
- "condition": _wrap_runtime_field(t2s_condition),
- "t2s_condition": _wrap_runtime_field(t2s_condition),
- "vq_model_image_path": _wrap_runtime_field(str(infer_sources.vq_image_source)),
- "vq_model_image_local_files_only": _wrap_runtime_field(bool(infer_sources.vq_image_local_files_only)),
- "vq_model_audio_path": _wrap_runtime_field(str(infer_sources.vq_audio_source)),
- "vq_model_audio_local_files_only": _wrap_runtime_field(bool(infer_sources.vq_audio_local_files_only)),
- "image_resolution": _wrap_runtime_field(image_resolution),
- "t2s_token_length": _wrap_runtime_field(int(t2s_token_length)),
- "use_train_i2i_prompt": _wrap_runtime_field(bool(use_train_i2i_prompt)),
- }
-
- if dynin_config_path:
- runtime_info["dynin_config_path"] = _wrap_runtime_field(str(dynin_config_path))
- if prompting_input is not None:
- runtime_info["prompting_input"] = _wrap_runtime_field(prompting_input)
- if uncond_prompting_input is not None:
- runtime_info["uncond_prompting_input"] = _wrap_runtime_field(uncond_prompting_input)
- if attention_mask:
- runtime_info["attention_mask"] = _wrap_runtime_field(list(attention_mask))
- if prompt_length is None and attention_mask:
- prompt_length = len(attention_mask)
- if prompt_length is not None:
- runtime_info["prompt_length"] = _wrap_runtime_field(int(prompt_length))
-
- return runtime_info
diff --git a/vllm_omni/model_executor/models/dynin_omni/dynin_omni_token2audio.py b/vllm_omni/model_executor/models/dynin_omni/dynin_omni_token2audio.py
deleted file mode 100644
index 8b4063d0796..00000000000
--- a/vllm_omni/model_executor/models/dynin_omni/dynin_omni_token2audio.py
+++ /dev/null
@@ -1,274 +0,0 @@
-from __future__ import annotations
-
-import os
-import tempfile
-from pathlib import Path
-from typing import Any
-
-import torch
-from vllm.config import VllmConfig
-from vllm.logger import init_logger
-from vllm.sequence import IntermediateTensors
-
-from vllm_omni.model_executor.models.output_templates import OmniOutput
-
-from .dynin_omni import DyninOmniStageBase
-from .dynin_omni_common import (
- DetokTarget,
- _looks_like_hf_repo_id,
- coerce_token_ids_1d,
- normalize_runtime_info,
- resolve_dynin_infer_sources,
- resolve_hidden_size,
- unwrap_first_value,
-)
-
-logger = init_logger(__name__)
-
-
-def _get_hf_token() -> str | None:
- return os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
-
-
-def _ensure_remote_s2u_vendor_root(
- *,
- repo_id: str,
- local_files_only: bool,
-) -> str | None:
- if local_files_only or not _looks_like_hf_repo_id(repo_id):
- return None
-
- existing = os.environ.get("DYNIN_S2U_VENDOR_ROOT")
- if existing:
- existing_path = Path(existing).expanduser().resolve()
- if existing_path.is_dir():
- return str(existing_path)
-
- try:
- from huggingface_hub import snapshot_download
- except Exception as e:
- logger.warning("huggingface_hub unavailable; cannot fetch s2u_vendor from %s: %s", repo_id, e)
- return None
-
- token = _get_hf_token()
- last_error: Exception | None = None
- revisions: list[str | None] = [None]
-
- for revision in revisions:
- try:
- snapshot_dir = snapshot_download(
- repo_id=repo_id,
- revision=revision,
- allow_patterns=["s2u_vendor/**"],
- token=token,
- )
- except TypeError:
- try:
- snapshot_dir = snapshot_download(
- repo_id=repo_id,
- revision=revision,
- allow_patterns=["s2u_vendor/**"],
- )
- except Exception as e:
- last_error = e
- continue
- except Exception as e:
- last_error = e
- continue
-
- vendor_root = (Path(snapshot_dir) / "s2u_vendor").resolve()
- if vendor_root.is_dir():
- os.environ["DYNIN_S2U_VENDOR_ROOT"] = str(vendor_root)
- logger.info("Using remote S2U vendor root: %s", vendor_root)
- return str(vendor_root)
-
- if last_error is not None:
- logger.warning("Failed to download remote s2u_vendor from %s: %s", repo_id, last_error)
- return None
-
-
-class DyninOmniToken2Audio(DyninOmniStageBase):
- """Stage-3: token detokenization to speech (or pass-through)."""
-
- stage_name = "Dynin token2audio"
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- del prefix
- super().__init__()
- self.vllm_config = vllm_config
- self.have_multimodal_outputs = True
- self.requires_raw_input_tokens = True
- self.hidden_size = resolve_hidden_size(vllm_config=vllm_config)
- self._vq_audio = None
- self._vq_audio_path: str | None = None
- self._vq_audio_local_files_only: bool | None = None
-
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- positions: torch.Tensor | None = None,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **kwargs: Any,
- ) -> OmniOutput:
- del positions, intermediate_tensors, inputs_embeds
- if input_ids is None:
- raise ValueError("token2audio stage requires input_ids")
-
- runtime_info = normalize_runtime_info(kwargs.get("runtime_additional_information"))
- detok_id = int(unwrap_first_value(runtime_info.get("detok_id"), 0))
- tokens = coerce_token_ids_1d(input_ids)
-
- if detok_id != DetokTarget.AUDIO:
- return OmniOutput(
- text_hidden_states=None,
- multimodal_outputs={
- "token_ids": tokens,
- "detok_id": torch.tensor([detok_id], dtype=torch.long, device=tokens.device),
- },
- )
-
- audio, sample_rate = self._decode_audio_tokens(tokens, runtime_info=runtime_info)
- return OmniOutput(
- text_hidden_states=None,
- multimodal_outputs={
- "speech": audio,
- "audio": audio,
- "sr": torch.tensor([sample_rate], dtype=torch.int, device=audio.device),
- "detok_id": torch.tensor([detok_id], dtype=torch.long, device=audio.device),
- },
- )
-
- def _decode_audio_tokens(self, tokens: torch.Tensor, runtime_info: dict[str, Any]) -> tuple[torch.Tensor, int]:
- # Follow DYNIN validation path:
- # token list -> "<|speech_x|>" string -> vq_model_audio.decode(...).
- vq_audio = self._ensure_vq_audio(runtime_info=runtime_info, ref_device=tokens.device)
-
- audio_codebook_size = int(unwrap_first_value(runtime_info.get("audio_codebook_size"), 4096))
- audio_vocab_offset = unwrap_first_value(
- runtime_info.get("audio_vocab_offset"),
- unwrap_first_value(runtime_info.get("t2s_vocab_start"), None),
- )
-
- token_ids = tokens.to(torch.long)
- if audio_vocab_offset is not None:
- off = int(audio_vocab_offset)
- token_ids = torch.where(token_ids >= off, token_ids - off, token_ids)
- token_ids = token_ids[(token_ids >= 0) & (token_ids < audio_codebook_size)]
- if token_ids.numel() == 0:
- raise RuntimeError("Audio detokenizer got no valid audio token ids.")
-
- speech_unit_str = " ".join(map(str, token_ids.detach().cpu().tolist()))
- speech_unit_for_decode = "".join(f"<|speech_{unit}|>" for unit in speech_unit_str.split(" ") if unit != "")
-
- condition = unwrap_first_value(
- runtime_info.get("condition"),
- unwrap_first_value(runtime_info.get("t2s_condition"), None),
- )
- output_wav_file = unwrap_first_value(runtime_info.get("output_wav_file"), None)
- created_tmp = False
- if output_wav_file is None:
- fd, tmp_wav = tempfile.mkstemp(prefix="dynin_t2s_", suffix=".wav")
- os.close(fd)
- output_wav_file = tmp_wav
- created_tmp = True
-
- audio_array = vq_audio.decode(speech_unit_for_decode, condition=condition, output_wav_file=output_wav_file)
- if created_tmp:
- try:
- os.remove(output_wav_file)
- except Exception:
- pass
- if not isinstance(audio_array, torch.Tensor):
- audio_array = torch.as_tensor(audio_array, dtype=torch.float32, device=tokens.device)
- else:
- audio_array = audio_array.to(device=tokens.device, dtype=torch.float32)
-
- if audio_array.ndim > 1:
- audio_array = audio_array.reshape(-1)
- audio_array = audio_array.contiguous()
-
- sample_rate = int(
- unwrap_first_value(
- runtime_info.get("sr"),
- unwrap_first_value(runtime_info.get("sample_rate"), 24000),
- )
- )
- try:
- cfg = getattr(vq_audio, "u2s_config", None)
- cfg_sr = getattr(cfg, "sampling_rate", None)
- if cfg_sr is None:
- cfg_sr = getattr(getattr(cfg, "data", None), "sampling_rate", None)
- if cfg_sr is not None:
- sample_rate = int(cfg_sr)
- except Exception:
- pass
- return audio_array, sample_rate
-
- def _ensure_vq_audio(self, runtime_info: dict[str, Any], ref_device: torch.device) -> Any:
- sources = resolve_dynin_infer_sources(vllm_config=self.vllm_config, runtime_info=runtime_info)
- model_path = str(sources.vq_audio_source)
- local_files_only = bool(sources.vq_audio_local_files_only)
-
- _ensure_remote_s2u_vendor_root(
- repo_id=model_path,
- local_files_only=local_files_only,
- )
-
- if (
- self._vq_audio is None
- or self._vq_audio_path != model_path
- or self._vq_audio_local_files_only != local_files_only
- ):
- logger.info(
- "Loading DYNIN audio detokenizer from %s (local_files_only=%s)",
- model_path,
- local_files_only,
- )
- try:
- from transformers import AutoModel
- except Exception as e:
- raise RuntimeError(
- "transformers is required to load EMOVASpeechTokenizer remote code from Hugging Face."
- ) from e
-
- try:
- self._vq_audio = AutoModel.from_pretrained(
- model_path,
- trust_remote_code=True,
- local_files_only=local_files_only,
- low_cpu_mem_usage=False,
- )
- except TypeError:
- try:
- self._vq_audio = AutoModel.from_pretrained(
- model_path,
- trust_remote_code=True,
- local_files_only=local_files_only,
- )
- except TypeError:
- self._vq_audio = AutoModel.from_pretrained(
- model_path,
- trust_remote_code=True,
- )
- except Exception as e:
- raise RuntimeError(
- f"Failed to load EMOVASpeechTokenizer from Hugging Face remote code for model path '{model_path}'."
- ) from e
-
- if not hasattr(self._vq_audio, "decode"):
- raise RuntimeError(
- "Loaded audio tokenizer does not expose decode(). "
- "Check HF config.json auto_map/model_type and ensure trust_remote_code=True."
- )
- self._vq_audio.eval()
- self._vq_audio.requires_grad_(False)
- self._vq_audio_path = model_path
- self._vq_audio_local_files_only = local_files_only
- if hasattr(self._vq_audio, "to"):
- self._vq_audio = self._vq_audio.to(ref_device)
- return self._vq_audio
-
- def embed_multimodal(self, **kwargs: Any) -> Any:
- del kwargs
- return None
diff --git a/vllm_omni/model_executor/models/dynin_omni/dynin_omni_token2image.py b/vllm_omni/model_executor/models/dynin_omni/dynin_omni_token2image.py
deleted file mode 100644
index 6b5110a77e2..00000000000
--- a/vllm_omni/model_executor/models/dynin_omni/dynin_omni_token2image.py
+++ /dev/null
@@ -1,150 +0,0 @@
-from __future__ import annotations
-
-import os
-from typing import Any
-
-import torch
-from vllm.config import VllmConfig
-from vllm.logger import init_logger
-from vllm.sequence import IntermediateTensors
-
-from vllm_omni.model_executor.models.output_templates import OmniOutput
-
-from .dynin_omni import DyninOmniStageBase
-from .dynin_omni_common import (
- DetokTarget,
- _to_bool,
- coerce_token_ids_1d,
- get_dynin_magvit_attr,
- normalize_runtime_info,
- resolve_dynin_infer_sources,
- resolve_hidden_size,
- unwrap_first_value,
-)
-
-logger = init_logger(__name__)
-
-
-class DyninOmniToken2Image(DyninOmniStageBase):
- """Stage-2: token detokenization to image (or pass-through)."""
-
- stage_name = "Dynin token2image"
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- del prefix
- super().__init__()
-
- self.vllm_config = vllm_config
- self.have_multimodal_outputs = True
- self.requires_raw_input_tokens = True
- self.hidden_size = resolve_hidden_size(vllm_config=vllm_config)
- self._vq_model = None
- self._vq_model_path: str | None = None
- self._vq_local_files_only: bool | None = None
-
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- positions: torch.Tensor | None = None,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **kwargs: Any,
- ) -> OmniOutput:
- del positions, intermediate_tensors, inputs_embeds
- if input_ids is None:
- raise ValueError("token2image stage requires input_ids")
- runtime_info = normalize_runtime_info(kwargs.get("runtime_additional_information"))
- detok_id = int(unwrap_first_value(runtime_info.get("detok_id"), 0))
- tokens = coerce_token_ids_1d(input_ids)
-
- if detok_id != DetokTarget.IMAGE:
- return OmniOutput(
- text_hidden_states=None,
- multimodal_outputs={
- "token_ids": tokens,
- "detok_id": torch.tensor([detok_id], dtype=torch.long, device=tokens.device),
- },
- )
-
- image = self._decode_image_tokens(tokens, runtime_info=runtime_info)
- return OmniOutput(
- text_hidden_states=None,
- multimodal_outputs={
- "image": image,
- "detok_id": torch.tensor([detok_id], dtype=torch.long, device=image.device),
- },
- )
-
- def _decode_image_tokens(self, tokens: torch.Tensor, runtime_info: dict[str, Any]) -> torch.Tensor:
- # Follow DYNIN validation path:
- # tokens -> clamp -> vq_model.decode_code -> (x+1)/2 -> [0,1].
- vq_model = self._ensure_vq_model(runtime_info=runtime_info, ref_device=tokens.device)
- codebook_size = int(unwrap_first_value(runtime_info.get("codebook_size"), 8192))
- image_vocab_offset = unwrap_first_value(runtime_info.get("image_vocab_offset"), None)
- if image_vocab_offset is None:
- text_vocab_size = unwrap_first_value(runtime_info.get("text_vocab_size"), None)
- num_new_special_tokens = int(unwrap_first_value(runtime_info.get("num_new_special_tokens"), 0))
- if text_vocab_size is not None:
- image_vocab_offset = int(text_vocab_size) + num_new_special_tokens
-
- token_ids = tokens.to(torch.long)
- if image_vocab_offset is not None:
- off = int(image_vocab_offset)
- token_ids = torch.where(token_ids >= off, token_ids - off, token_ids)
- token_ids = torch.clamp(token_ids, min=0, max=max(0, codebook_size - 1))
- token_ids = token_ids.unsqueeze(0)
-
- decoded = vq_model.decode_code(token_ids)
- decoded = torch.clamp((decoded + 1.0) / 2.0, min=0.0, max=1.0)
- if decoded.ndim != 4 or decoded.shape[0] == 0:
- raise RuntimeError(f"Unexpected MAGVIT decode output shape: {tuple(decoded.shape)}")
- return decoded[0].contiguous()
-
- def _ensure_vq_model(self, runtime_info: dict[str, Any], ref_device: torch.device) -> Any:
- sources = resolve_dynin_infer_sources(vllm_config=self.vllm_config, runtime_info=runtime_info)
- model_path = str(sources.vq_image_source)
- local_files_only = bool(sources.vq_image_local_files_only)
- if self._vq_model is None or self._vq_model_path != model_path or self._vq_local_files_only != local_files_only:
- disable_xet = unwrap_first_value(
- runtime_info.get("hf_hub_disable_xet"),
- unwrap_first_value(runtime_info.get("disable_hf_xet"), True),
- )
- if _to_bool(disable_xet, default=True):
- os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
- logger.info(
- "Loading DYNIN image detokenizer from %s (local_files_only=%s)",
- model_path,
- local_files_only,
- )
- try:
- MAGVITv2 = get_dynin_magvit_attr(
- "MAGVITv2",
- source=model_path,
- local_files_only=local_files_only,
- )
- try:
- self._vq_model = MAGVITv2.from_pretrained(
- model_path,
- local_files_only=local_files_only,
- )
- except TypeError:
- self._vq_model = MAGVITv2.from_pretrained(model_path)
- except Exception as e:
- raise RuntimeError(
- "Failed to load MAGVITv2 from local DYNIN submodel implementation "
- f"for model path '{model_path}'. "
- "If your environment cannot access huggingface.co, set "
- "additional_information.vq_model_image_path to a local MAGVITv2 directory "
- "and set additional_information.vq_model_image_local_files_only=true."
- ) from e
- self._vq_model.eval()
- self._vq_model.requires_grad_(False)
- self._vq_model_path = model_path
- self._vq_local_files_only = local_files_only
- if hasattr(self._vq_model, "to"):
- self._vq_model = self._vq_model.to(ref_device)
- return self._vq_model
-
- def embed_multimodal(self, **kwargs: Any) -> Any:
- del kwargs
- return None
diff --git a/vllm_omni/model_executor/models/dynin_omni/dynin_omni_token2text.py b/vllm_omni/model_executor/models/dynin_omni/dynin_omni_token2text.py
deleted file mode 100644
index fb5ac170295..00000000000
--- a/vllm_omni/model_executor/models/dynin_omni/dynin_omni_token2text.py
+++ /dev/null
@@ -1,1580 +0,0 @@
-from __future__ import annotations
-
-import inspect
-import json
-from contextlib import contextmanager
-from typing import Any
-
-import torch
-import torch.nn.functional as F
-from transformers import AutoTokenizer
-from vllm.config import VllmConfig
-from vllm.logger import init_logger
-from vllm.sequence import IntermediateTensors
-
-from vllm_omni.model_executor.models.output_templates import OmniOutput
-
-from .dynin_omni import DyninOmniStageBase
-from .dynin_omni_common import (
- DYNIN_PROMPT_SOURCE_KEY,
- DYNIN_PROMPT_SOURCE_OFFLINE_PREBUILT,
- DYNIN_REMOTE_SETTINGS,
- DYNIN_SPECIAL_TOKENS,
- TASK_TO_DETOK,
- DetokTarget,
- _to_bool,
- build_dynin_online_runtime_info,
- build_dynin_prompt_payload,
- coerce_token_ids_1d,
- dynin_runtime_fallback,
- get_dynin_magvit_attr,
- get_dynin_modeling_attr,
- get_dynin_sampling_attr,
- infer_dynin_online_task,
- logical_dynin_task,
- normalize_dynin_online_prompt_text,
- normalize_runtime_info,
- resolve_dynin_infer_sources,
- resolve_hidden_size,
- resolve_remote_attr,
- unwrap_first_value,
-)
-
-logger = init_logger(__name__)
-
-TASK_TO_PROMPTING_TASK = {
- "t2i": "t2i_gen",
- "i2i": "i2i_gen",
- "ti2ti": "ti2ti_gen",
- "t2s": "t2s_gen",
- "t2s_mmu_like": "t2s_gen",
- "t2s_fixed": "t2s_fixed_gen",
- "s2s": "s2s_gen",
- "v2s": "v2s_gen",
- "mmu": "mmu",
- "mmu_fast": "mmu",
- "mmu_fastdllm_v1": "mmu",
- "s2t": "s2t",
- "v2t": "v2t",
-}
-
-TASK_TO_GENERATE_FN = {
- "t2i": "t2i_generate",
- "i2i": "i2i_generate",
- "ti2ti": "ti2ti_generate",
- "t2s": "t2s_generate",
- "t2s_mmu_like": "t2s_generate_mmu_like",
- "t2s_fixed": "t2s_fixed_generate",
- "s2s": "t2s_generate_mmu_like",
- "v2s": "t2s_generate_mmu_like",
- "s2t": "s2t_generate",
- "mmu": "mmu_generate",
- "t2t": "generate",
- "mmu_fast": "mmu_generate_fast",
- "mmu_fastdllm_v1": "mmu_generate_fastdllm_v1",
- "v2t": "mmu_generate",
-}
-
-TASKS_USING_UNI_PROMPTING = set(TASK_TO_PROMPTING_TASK.keys())
-PROMPT_PAYLOAD_REQUIRED_TASKS = {
- "t2i",
- "i2i",
- "ti2ti",
- "t2s",
- "t2s_mmu_like",
- "t2s_fixed",
- "s2s",
- "v2s",
-}
-
-GENERATE_RUNTIME_KWARG_KEYS = (
- "uncond_input_ids",
- "uncond_attention_mask",
- "noise_schedule",
- "generator",
- "config",
- "uni_prompting",
- "resolution",
- "max_new_tokens",
- "steps",
- "block_length",
- "temperature",
- "top_k",
- "eot_token",
- "cfg_scale",
- "remasking",
- "mask_id",
- "attention_mask",
- "timesteps",
- "guidance_scale",
- "noise_type",
- "seq_len",
- "mask_token_id",
- "codebook_size",
- "audio_codebook_size",
- "use_cache",
- "threshold",
- "factor",
-)
-
-PASSTHROUGH_GENERATE_KWARG_KEYS = (
- "attention_mask",
- "uncond_input_ids",
- "uncond_attention_mask",
- "noise_schedule",
- "uni_prompting",
- "generator",
- "noise_type",
-)
-
-PROMPTING_PAYLOAD_KEYS = (
- "prompting_input",
- "prompting_inputs",
- "dynin_inputs",
- "model_inputs",
- "raw_inputs",
-)
-
-UNCOND_PROMPTING_PAYLOAD_KEYS = (
- "uncond_prompting_input",
- "uncond_prompting_inputs",
-)
-
-PROMPTING_META_KEYS = (
- "uncond_prompting_input",
- "uncond_prompting_inputs",
- "uni_prompting",
- "prompting_task",
- "prompting_config",
-)
-
-MM_INPUT_ALIASES = {
- "image": ("pixel_values", "image_embeds", "img2img"),
- "video": ("pixel_values_videos", "video_embeds"),
- "audio": ("input_audio_features", "audio_embeds"),
-}
-
-
-class DyninOmniToken2Text(DyninOmniStageBase):
- """Stage-1: DYNIN generation + text detokenization or pass-through."""
-
- stage_name = "Dynin token2text"
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- del prefix
- super().__init__()
-
- self.vllm_config = vllm_config
- self.have_multimodal_outputs = True
- self.requires_raw_input_tokens = True
-
- self._infer_sources = resolve_dynin_infer_sources(vllm_config=vllm_config)
- if self._infer_sources.config_path:
- logger.info(
- "DYNIN token2text using inference config: %s",
- self._infer_sources.config_path,
- )
-
- self.model = self._load_text_model(
- self._infer_sources.model_source,
- local_files_only=self._infer_sources.model_local_files_only,
- )
- self.model.eval()
- self.model.requires_grad_(False)
-
- self.hidden_size = resolve_hidden_size(
- vllm_config=vllm_config,
- model=self.model,
- )
-
- self.tokenizer: Any | None = None
- self._tokenizer_path: str | None = None
- self._uni_prompting: Any | None = None
- self._uni_prompting_init_spec: tuple[Any, ...] | None = None
- self._prompt_vq_model: Any | None = None
- self._prompt_vq_model_path: str | None = None
- self._prompt_vq_local_files_only: bool | None = None
- self._cached_mm_inputs: dict[str, Any] = {}
-
- try:
- self._set_tokenizer(
- self._infer_sources.tokenizer_source,
- local_files_only=self._infer_sources.model_local_files_only,
- )
- except Exception:
- self.tokenizer = None
- self._tokenizer_path = None
-
- @staticmethod
- def _load_text_model(model_path: str, *, local_files_only: bool = False) -> Any:
- try:
- dynin_model_cls = get_dynin_modeling_attr("DyninOmniModelLM")
- try:
- return dynin_model_cls.from_pretrained(
- model_path,
- torch_dtype=torch.bfloat16,
- local_files_only=local_files_only,
- )
- except TypeError:
- return dynin_model_cls.from_pretrained(
- model_path,
- torch_dtype=torch.bfloat16,
- )
- except Exception as e:
- raise RuntimeError(
- f"Failed to load DyninOmniModelLM via remote Dynin code for model path '{model_path}'."
- ) from e
-
- @staticmethod
- def _load_tokenizer_from_source(
- source: str,
- *,
- local_files_only: bool = False,
- trust_remote_code: bool = False,
- ) -> Any:
- load_kwargs = {
- "trust_remote_code": trust_remote_code,
- "local_files_only": _to_bool(local_files_only, default=False),
- }
- try:
- return AutoTokenizer.from_pretrained(source, **load_kwargs)
- except TypeError:
- load_kwargs.pop("local_files_only", None)
- return AutoTokenizer.from_pretrained(source, **load_kwargs)
-
- def _set_tokenizer(self, source: str, *, local_files_only: bool) -> None:
- try:
- tokenizer = self._load_tokenizer_from_source(
- source,
- local_files_only=local_files_only,
- trust_remote_code=False,
- )
- except Exception as e:
- logger.info(
- "Falling back to trust_remote_code=True tokenizer loading for %s: %s",
- source,
- e,
- )
- tokenizer = self._load_tokenizer_from_source(
- source,
- local_files_only=local_files_only,
- trust_remote_code=True,
- )
-
- self.tokenizer = tokenizer
- self._tokenizer_path = source
- self._reset_uni_prompting_cache()
-
- def _reset_uni_prompting_cache(self) -> None:
- self._uni_prompting = None
- self._uni_prompting_init_spec = None
-
- def get_language_model(self) -> Any:
- return self.model
-
- @staticmethod
- def _merge_runtime_info_missing_values(
- runtime_info: dict[str, Any],
- fallback_info: dict[str, Any],
- ) -> dict[str, Any]:
- merged = dict(runtime_info)
- for key, value in fallback_info.items():
- if unwrap_first_value(merged.get(key), None) is None:
- merged[key] = value
- return merged
-
- def _runtime_info_needs_bootstrap(
- self,
- runtime_info: dict[str, Any],
- logical_task_name: str,
- ) -> bool:
- task = str(unwrap_first_value(runtime_info.get("task"), "") or "").lower()
- detok_id = unwrap_first_value(runtime_info.get("detok_id"), None)
- prompt_length = unwrap_first_value(runtime_info.get("prompt_length"), None)
-
- if not task or detok_id is None:
- return True
- if prompt_length is None:
- return True
- if (
- task in PROMPT_PAYLOAD_REQUIRED_TASKS
- and self._find_first_payload(
- runtime_info=runtime_info,
- kwargs={},
- keys=PROMPTING_PAYLOAD_KEYS,
- )
- is None
- ):
- return True
- if logical_task_name in {"t2i", "i2i"}:
- for key in ("codebook_size", "text_vocab_size", "vq_model_image_path"):
- if unwrap_first_value(runtime_info.get(key), None) is None:
- return True
- if logical_task_name == "t2s":
- for key in ("audio_codebook_size", "condition", "vq_model_audio_path"):
- if unwrap_first_value(runtime_info.get(key), None) is None:
- return True
- return False
-
- def _decode_prompt_for_bootstrap(
- self,
- input_ids: torch.Tensor,
- runtime_info: dict[str, Any],
- ) -> str:
- self._maybe_load_runtime_tokenizer(runtime_info)
- if self.tokenizer is None:
- return ""
- token_ids = coerce_token_ids_1d(input_ids).detach().cpu().tolist()
- try:
- return str(self.tokenizer.decode(token_ids, skip_special_tokens=False))
- except Exception:
- return ""
-
- def _bootstrap_runtime_info_if_needed(
- self,
- *,
- input_ids: torch.Tensor,
- runtime_info: dict[str, Any],
- kwargs: dict[str, Any],
- ) -> dict[str, Any]:
- if unwrap_first_value(runtime_info.get(DYNIN_PROMPT_SOURCE_KEY), None) == DYNIN_PROMPT_SOURCE_OFFLINE_PREBUILT:
- return runtime_info
-
- mm_inputs = self._collect_mm_inputs(**kwargs)
- decoded_prompt = ""
-
- task_value = unwrap_first_value(runtime_info.get("task"), None)
- if task_value is None:
- decoded_prompt = self._decode_prompt_for_bootstrap(input_ids, runtime_info)
- logical_task_name = infer_dynin_online_task(
- decoded_prompt=decoded_prompt,
- has_image="image" in mm_inputs,
- has_audio="audio" in mm_inputs,
- has_video="video" in mm_inputs,
- )
- else:
- logical_task_name = logical_dynin_task(task_value)
-
- if not self._runtime_info_needs_bootstrap(runtime_info, logical_task_name):
- return runtime_info
-
- self._maybe_load_runtime_tokenizer(runtime_info)
- if self.tokenizer is None:
- logger.warning("Unable to bootstrap Dynin runtime info because tokenizer is unavailable.")
- return runtime_info
-
- if not decoded_prompt:
- decoded_prompt = self._decode_prompt_for_bootstrap(input_ids, runtime_info)
-
- text_vocab_size = int(len(self.tokenizer))
- prompt_len = int(coerce_token_ids_1d(input_ids).numel())
- dynin_config_path = self._infer_sources.config_path
-
- base_runtime_info = build_dynin_online_runtime_info(
- task=logical_task_name,
- text_vocab_size=text_vocab_size,
- infer_sources=self._infer_sources,
- dynin_config_path=dynin_config_path,
- attention_mask=([1] * prompt_len) if logical_task_name == "t2t" else None,
- prompt_length=prompt_len if logical_task_name == "t2t" else None,
- )
- merged_runtime_info = self._merge_runtime_info_missing_values(runtime_info, base_runtime_info)
-
- payload_required = logical_task_name in {"t2i", "i2i", "t2s"}
- existing_prompt_payload = self._find_first_payload(
- runtime_info=merged_runtime_info,
- kwargs=kwargs,
- keys=PROMPTING_PAYLOAD_KEYS,
- )
- has_prompt_payload = existing_prompt_payload is not None
- needs_prompt_length = unwrap_first_value(merged_runtime_info.get("prompt_length"), None) is None
- if not payload_required:
- return merged_runtime_info
-
- use_train_i2i_prompt = _to_bool(
- unwrap_first_value(
- merged_runtime_info.get("use_train_i2i_prompt"),
- dynin_runtime_fallback(logical_task_name, "use_train_i2i_prompt", logical_task_name == "i2i"),
- ),
- default=logical_task_name == "i2i",
- )
- t2s_token_length = int(
- dynin_runtime_fallback(
- logical_task_name,
- "t2s_token_length",
- unwrap_first_value(merged_runtime_info.get("t2s_token_length"), None),
- )
- or 383
- )
- image_resolution = int(
- dynin_runtime_fallback(
- logical_task_name,
- "image_resolution",
- unwrap_first_value(merged_runtime_info.get("image_resolution"), None),
- )
- or 336
- )
-
- image_token_count = int(
- dynin_runtime_fallback(
- logical_task_name,
- "image_token_count",
- unwrap_first_value(merged_runtime_info.get("seq_len"), None),
- )
- or 0
- )
- image_tokens: torch.Tensor | None = None
- if logical_task_name == "i2i" and (not has_prompt_payload or image_token_count <= 0):
- image_tokens = self._encode_prompt_image_tokens(
- runtime_info=merged_runtime_info,
- mm_inputs=mm_inputs,
- resolution=image_resolution,
- )
- image_token_count = int(image_tokens.numel())
-
- mask_token_id = int(unwrap_first_value(merged_runtime_info.get("mask_token_id"), 126336))
- prompting_input = self._unwrap_singleton(existing_prompt_payload)
- prompting_task = str(
- unwrap_first_value(
- merged_runtime_info.get("prompting_task"),
- TASK_TO_PROMPTING_TASK.get(
- str(unwrap_first_value(merged_runtime_info.get("task"), "mmu")).lower(),
- "mmu",
- ),
- )
- )
- if not has_prompt_payload:
- prompt_text = normalize_dynin_online_prompt_text(logical_task_name, decoded_prompt)
- prompting_input, prompting_task = build_dynin_prompt_payload(
- task=logical_task_name,
- text=prompt_text,
- image_tokens=image_tokens,
- image_placeholder_tokens=image_token_count,
- audio_placeholder_tokens=t2s_token_length,
- image_token_offset=text_vocab_size,
- mask_token_id=mask_token_id,
- use_train_i2i_prompt=use_train_i2i_prompt,
- )
-
- prompt_runtime_info = build_dynin_online_runtime_info(
- task=logical_task_name,
- text_vocab_size=text_vocab_size,
- infer_sources=self._infer_sources,
- dynin_config_path=dynin_config_path,
- image_token_count=image_token_count,
- t2s_token_length=t2s_token_length,
- use_train_i2i_prompt=use_train_i2i_prompt,
- )
- prompt_runtime_info["prompting_task"] = [str(prompting_task)]
- prompt_runtime_info["prompting_input"] = [prompting_input]
- merged_runtime_info = self._merge_runtime_info_missing_values(merged_runtime_info, prompt_runtime_info)
-
- if not needs_prompt_length and has_prompt_payload:
- return merged_runtime_info
-
- uni_prompting = self._get_or_create_uni_prompting(
- runtime_info=merged_runtime_info,
- kwargs=kwargs,
- )
- if uni_prompting is not None:
- prepared_input_ids, prepared_attention_mask = self._prepare_prompting_input(
- payload=prompting_input,
- task=str(unwrap_first_value(merged_runtime_info.get("task"), "mmu")),
- runtime_info=merged_runtime_info,
- kwargs=kwargs,
- uni_prompting=uni_prompting,
- ref_device=input_ids.device,
- )
- if prepared_input_ids is not None:
- prepared_prompt_len = int(prepared_input_ids.shape[-1])
- prepared_attention_list: list[int] | None = None
- if prepared_attention_mask is not None:
- prepared_attention_list = prepared_attention_mask.view(-1).detach().cpu().tolist()
- final_runtime_info = build_dynin_online_runtime_info(
- task=logical_task_name,
- text_vocab_size=text_vocab_size,
- infer_sources=self._infer_sources,
- dynin_config_path=dynin_config_path,
- prompting_input=prompting_input,
- attention_mask=prepared_attention_list,
- prompt_length=prepared_prompt_len,
- image_token_count=image_token_count,
- t2s_token_length=t2s_token_length,
- use_train_i2i_prompt=use_train_i2i_prompt,
- )
- final_runtime_info["prompting_task"] = [str(prompting_task)]
-
- guidance_scale = float(unwrap_first_value(merged_runtime_info.get("guidance_scale"), 0.0))
- if logical_task_name in {"t2i", "i2i"} and guidance_scale > 0:
- uncond_prompting_input, _ = build_dynin_prompt_payload(
- task=logical_task_name,
- text="",
- image_tokens=image_tokens,
- image_placeholder_tokens=image_token_count,
- audio_placeholder_tokens=t2s_token_length,
- image_token_offset=text_vocab_size,
- mask_token_id=mask_token_id,
- use_train_i2i_prompt=use_train_i2i_prompt,
- )
- final_runtime_info["uncond_prompting_input"] = [uncond_prompting_input]
-
- merged_runtime_info = self._merge_runtime_info_missing_values(
- merged_runtime_info,
- final_runtime_info,
- )
-
- return merged_runtime_info
-
- @staticmethod
- def _build_downstream_runtime_info(runtime_info: dict[str, Any]) -> dict[str, Any]:
- bridge_keys = (
- "task",
- "detok_id",
- "dynin_config_path",
- "codebook_size",
- "audio_codebook_size",
- "text_vocab_size",
- "num_new_special_tokens",
- "image_vocab_offset",
- "audio_vocab_offset",
- "t2s_vocab_start",
- "condition",
- "t2s_condition",
- "vq_model_image_path",
- "vq_model_image_local_files_only",
- "vq_model_audio_path",
- "vq_model_audio_local_files_only",
- "model_local_files_only",
- "local_files_only",
- "hf_hub_disable_xet",
- "disable_hf_xet",
- )
- return {key: runtime_info[key] for key in bridge_keys if key in runtime_info}
-
- @staticmethod
- def _jsonify_runtime_value(value: Any) -> Any:
- if isinstance(value, torch.Tensor):
- return value.detach().cpu().tolist()
- if isinstance(value, (list, tuple)):
- return [DyninOmniToken2Text._jsonify_runtime_value(item) for item in value]
- if isinstance(value, dict):
- return {str(key): DyninOmniToken2Text._jsonify_runtime_value(val) for key, val in value.items()}
- if isinstance(value, (str, int, float, bool)) or value is None:
- return value
- return str(value)
-
- def _encode_runtime_info_tensor(
- self,
- runtime_info: dict[str, Any],
- *,
- device: torch.device,
- ) -> torch.Tensor | None:
- if not runtime_info:
- return None
- payload = {key: self._jsonify_runtime_value(value) for key, value in runtime_info.items()}
- encoded = json.dumps(
- payload,
- ensure_ascii=False,
- separators=(",", ":"),
- sort_keys=True,
- ).encode("utf-8")
- if not encoded:
- return None
- return torch.tensor(list(encoded), dtype=torch.uint8, device=device)
-
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- positions: torch.Tensor | None = None,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **kwargs: Any,
- ) -> OmniOutput:
- del positions, intermediate_tensors, inputs_embeds
-
- if input_ids is None:
- raise ValueError("token2text stage requires input_ids")
- try:
- runtime_info = normalize_runtime_info(kwargs.get("runtime_additional_information"))
- runtime_info = self._bootstrap_runtime_info_if_needed(
- input_ids=input_ids,
- runtime_info=runtime_info,
- kwargs=kwargs,
- )
- task = str(unwrap_first_value(runtime_info.get("task"), "mmu")).lower()
-
- detok_id = int(
- unwrap_first_value(
- runtime_info.get("detok_id"),
- TASK_TO_DETOK.get(task, DetokTarget.TEXT),
- )
- )
-
- token_ids = self._generate_token_ids(
- task=task,
- input_ids=input_ids,
- runtime_info=runtime_info,
- kwargs=kwargs,
- )
- bridge_runtime_info = self._build_downstream_runtime_info(runtime_info)
- runtime_info_tensor = self._encode_runtime_info_tensor(
- bridge_runtime_info,
- device=token_ids.device,
- )
-
- if detok_id != int(DetokTarget.TEXT):
- multimodal_outputs = {
- "token_ids": token_ids,
- "detok_id": torch.tensor(
- [detok_id],
- dtype=torch.long,
- device=token_ids.device,
- ),
- }
- if runtime_info_tensor is not None:
- multimodal_outputs["runtime_info_json"] = runtime_info_tensor
- return OmniOutput(
- text_hidden_states=None,
- multimodal_outputs=multimodal_outputs,
- )
-
- decode_tokens = self._extract_decode_tokens(token_ids, runtime_info=runtime_info)
- multimodal_outputs = {
- "token_ids": token_ids,
- "text_tokens": decode_tokens,
- "detok_id": torch.tensor(
- [detok_id],
- dtype=torch.long,
- device=token_ids.device,
- ),
- }
- if runtime_info_tensor is not None:
- multimodal_outputs["runtime_info_json"] = runtime_info_tensor
-
- return OmniOutput(
- text_hidden_states=None,
- multimodal_outputs=multimodal_outputs,
- )
- finally:
- self._cached_mm_inputs = {}
-
- def _generate_token_ids(
- self,
- task: str,
- input_ids: torch.Tensor,
- runtime_info: dict[str, Any],
- kwargs: dict[str, Any],
- ) -> torch.Tensor:
- precomputed = self._get_precomputed_token_ids(runtime_info)
- if precomputed is not None:
- return coerce_token_ids_1d(precomputed, ref_device=input_ids.device)
-
- gen_fn_name = TASK_TO_GENERATE_FN.get(task, "mmu_generate")
- gen_fn = self._resolve_generate_fn(gen_fn_name)
-
- gen_kwargs = self._collect_generate_kwargs(runtime_info=runtime_info, kwargs=kwargs)
-
- if "noise_schedule" not in gen_kwargs:
- noise_schedule = self._resolve_noise_schedule(
- runtime_info=runtime_info,
- kwargs=kwargs,
- )
- if noise_schedule is not None:
- gen_kwargs["noise_schedule"] = noise_schedule
-
- if task in TASKS_USING_UNI_PROMPTING and "uni_prompting" not in gen_kwargs:
- uni_prompting = self._get_or_create_uni_prompting(
- runtime_info=runtime_info,
- kwargs=kwargs,
- )
- if uni_prompting is not None:
- gen_kwargs["uni_prompting"] = uni_prompting
-
- should_prepare_prompting_inputs = task in TASKS_USING_UNI_PROMPTING or self._contains_prompting_payload(
- runtime_info=runtime_info, kwargs=kwargs
- )
- if should_prepare_prompting_inputs:
- input_ids, gen_kwargs = self._prepare_prompting_inputs_if_needed(
- task=task,
- input_ids=input_ids,
- runtime_info=runtime_info,
- kwargs=kwargs,
- gen_kwargs=gen_kwargs,
- )
-
- input_ids, gen_kwargs = self._normalize_generate_inputs(
- input_ids=input_ids,
- gen_kwargs=gen_kwargs,
- ref_device=input_ids.device,
- )
- gen_kwargs = self._filter_supported_generate_kwargs(
- gen_fn=gen_fn,
- gen_kwargs=gen_kwargs,
- fn_name=gen_fn_name,
- )
-
- generated = self._call_generate_fn(
- gen_fn=gen_fn,
- input_ids=input_ids,
- gen_kwargs=gen_kwargs,
- )
- return coerce_token_ids_1d(generated, ref_device=input_ids.device)
-
- @staticmethod
- def _get_precomputed_token_ids(runtime_info: dict[str, Any]) -> Any | None:
- precomputed = runtime_info.get("generated_token_ids")
- if precomputed is None:
- precomputed = runtime_info.get("token_ids")
- return precomputed
-
- def _resolve_generate_fn(self, fn_name: str) -> Any:
- if not hasattr(self.model, fn_name):
- raise RuntimeError(
- f"DYNIN model does not expose '{fn_name}'. "
- "Pass additional_information.generated_token_ids or adjust task mapping."
- )
- return getattr(self.model, fn_name)
-
- @staticmethod
- def _collect_generate_kwargs(
- *,
- runtime_info: dict[str, Any],
- kwargs: dict[str, Any],
- ) -> dict[str, Any]:
- gen_kwargs: dict[str, Any] = {}
-
- for key in GENERATE_RUNTIME_KWARG_KEYS:
- if key in runtime_info:
- gen_kwargs[key] = unwrap_first_value(runtime_info[key])
-
- for key in PASSTHROUGH_GENERATE_KWARG_KEYS:
- if key not in gen_kwargs and key in kwargs:
- gen_kwargs[key] = kwargs[key]
-
- return gen_kwargs
-
- @staticmethod
- def _contains_prompting_payload(
- runtime_info: dict[str, Any],
- kwargs: dict[str, Any],
- ) -> bool:
- keys = PROMPTING_PAYLOAD_KEYS + PROMPTING_META_KEYS
- return any(key in runtime_info for key in keys) or any(key in kwargs for key in keys)
-
- @staticmethod
- def _filter_supported_generate_kwargs(
- *,
- gen_fn: Any,
- gen_kwargs: dict[str, Any],
- fn_name: str,
- ) -> dict[str, Any]:
- if not gen_kwargs:
- return gen_kwargs
-
- try:
- signature = inspect.signature(gen_fn)
- except (TypeError, ValueError):
- return gen_kwargs
-
- params = signature.parameters
- accepts_var_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())
- if accepts_var_kwargs:
- return gen_kwargs
-
- allowed_keys = {
- name
- for name, param in params.items()
- if param.kind
- in (
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
- inspect.Parameter.KEYWORD_ONLY,
- )
- }
- filtered = {k: v for k, v in gen_kwargs.items() if k in allowed_keys}
-
- removed_keys = sorted(set(gen_kwargs.keys()) - set(filtered.keys()))
- if removed_keys:
- logger.debug("Filtered unsupported kwargs for %s: %s", fn_name, removed_keys)
-
- return filtered
-
- @staticmethod
- def _call_generate_fn(
- *,
- gen_fn: Any,
- input_ids: torch.Tensor,
- gen_kwargs: dict[str, Any],
- ) -> Any:
- try:
- signature = inspect.signature(gen_fn)
- params = signature.parameters
- except (TypeError, ValueError):
- params = {}
-
- if "idx" in params:
- return gen_fn(idx=input_ids, **gen_kwargs)
- if "input_ids" in params:
- return gen_fn(input_ids=input_ids, **gen_kwargs)
-
- try:
- return gen_fn(input_ids, **gen_kwargs)
- except TypeError:
- try:
- return gen_fn(idx=input_ids, **gen_kwargs)
- except TypeError:
- return gen_fn(input_ids=input_ids, **gen_kwargs)
-
- def _normalize_generate_inputs(
- self,
- *,
- input_ids: torch.Tensor,
- gen_kwargs: dict[str, Any],
- ref_device: torch.device,
- ) -> tuple[torch.Tensor, dict[str, Any]]:
- normalized_input_ids = self._coerce_long_tensor_2d(input_ids, ref_device)
- if normalized_input_ids is None:
- normalized_input_ids = input_ids
-
- normalized_kwargs = dict(gen_kwargs)
- for key in ("attention_mask", "uncond_input_ids", "uncond_attention_mask"):
- if key not in normalized_kwargs:
- continue
- normalized_value = self._coerce_long_tensor_2d(
- normalized_kwargs[key],
- ref_device,
- )
- if normalized_value is not None:
- normalized_kwargs[key] = normalized_value
-
- return normalized_input_ids, normalized_kwargs
-
- def _get_or_create_uni_prompting(
- self,
- runtime_info: dict[str, Any],
- kwargs: dict[str, Any],
- ) -> Any | None:
- runtime_uni_prompting = runtime_info.get("uni_prompting")
- if runtime_uni_prompting is not None:
- runtime_uni_prompting = self._unwrap_singleton(runtime_uni_prompting)
- if runtime_uni_prompting is not None:
- return runtime_uni_prompting
-
- kwargs_uni_prompting = self._unwrap_singleton(kwargs.get("uni_prompting"))
- if kwargs_uni_prompting is not None:
- return kwargs_uni_prompting
-
- self._maybe_load_runtime_tokenizer(runtime_info)
- if self.tokenizer is None:
- return None
-
- use_reserved_token = _to_bool(
- unwrap_first_value(
- runtime_info.get("use_reserved_token"),
- unwrap_first_value(runtime_info.get("prompting_use_reserved_token"), True),
- ),
- default=True,
- )
-
- max_text_len_value = unwrap_first_value(
- runtime_info.get("prompt_max_text_len"),
- unwrap_first_value(
- runtime_info.get("prompting_max_text_len"),
- unwrap_first_value(runtime_info.get("max_text_len"), None),
- ),
- )
- cond_dropout_value = unwrap_first_value(
- runtime_info.get("cond_dropout_prob"),
- unwrap_first_value(runtime_info.get("prompting_cond_dropout_prob"), None),
- )
- max_audio_len_value = unwrap_first_value(
- runtime_info.get("max_audio_len"),
- unwrap_first_value(runtime_info.get("t2s_token_length"), None),
- )
- max_audio_len_short_value = unwrap_first_value(
- runtime_info.get("max_audio_len_short"),
- None,
- )
-
- max_text_len: int | None = None
- if max_text_len_value is not None:
- try:
- parsed = int(max_text_len_value)
- if parsed > 0:
- max_text_len = parsed
- except Exception:
- pass
-
- cond_dropout_prob: float | None = None
- if cond_dropout_value is not None:
- try:
- cond_dropout_prob = float(cond_dropout_value)
- except Exception:
- pass
-
- max_audio_len: int | None = None
- if max_audio_len_value is not None:
- try:
- parsed = int(max_audio_len_value)
- if parsed > 0:
- max_audio_len = max(parsed, 512)
- except Exception:
- pass
-
- max_audio_len_short: int | None = None
- if max_audio_len_short_value is not None:
- try:
- parsed = int(max_audio_len_short_value)
- if parsed > 0:
- max_audio_len_short = parsed
- except Exception:
- pass
- elif max_audio_len is not None:
- max_audio_len_short = max(256, max_audio_len // 2)
-
- if self._uni_prompting is not None:
- if max_text_len is None and hasattr(self._uni_prompting, "max_text_len"):
- try:
- existing_max_text_len = int(getattr(self._uni_prompting, "max_text_len"))
- if existing_max_text_len > 0:
- max_text_len = existing_max_text_len - 1
- except Exception:
- pass
- if cond_dropout_prob is None and hasattr(self._uni_prompting, "cond_dropout_prob"):
- try:
- cond_dropout_prob = float(getattr(self._uni_prompting, "cond_dropout_prob"))
- except Exception:
- pass
-
- desired_spec = (
- id(self.tokenizer),
- use_reserved_token,
- max_text_len,
- cond_dropout_prob,
- max_audio_len,
- max_audio_len_short,
- )
-
- if self._uni_prompting is not None and self._uni_prompting_init_spec != desired_spec:
- self._reset_uni_prompting_cache()
-
- if self._uni_prompting is None:
- try:
- universal_prompting_cls = resolve_remote_attr(
- "UniversalPrompting",
- module_name="prompting_utils",
- settings=DYNIN_REMOTE_SETTINGS,
- source=self._infer_sources.model_source,
- local_files_only=self._infer_sources.model_local_files_only,
- fallback_module_names=("modeling_dynin_omni",),
- optional=True,
- )
- except Exception:
- universal_prompting_cls = None
-
- try:
- if universal_prompting_cls is None:
- raise ImportError("UniversalPrompting is not available in the configured remote Dynin code.")
-
- init_kwargs: dict[str, Any] = {
- "use_reserved_token": use_reserved_token,
- "special_tokens": DYNIN_SPECIAL_TOKENS,
- "ignore_id": -100,
- }
- if max_text_len is not None:
- init_kwargs["max_text_len"] = max_text_len
- if cond_dropout_prob is not None:
- init_kwargs["cond_dropout_prob"] = cond_dropout_prob
- if max_audio_len is not None:
- init_kwargs["max_audio_len"] = max_audio_len
- if max_audio_len_short is not None:
- init_kwargs["max_audio_len_short"] = max_audio_len_short
-
- try:
- self._uni_prompting = universal_prompting_cls(self.tokenizer, **init_kwargs)
- except TypeError:
- trimmed_audio_kwargs = dict(init_kwargs)
- trimmed_audio_kwargs.pop("max_audio_len", None)
- trimmed_audio_kwargs.pop("max_audio_len_short", None)
- try:
- self._uni_prompting = universal_prompting_cls(self.tokenizer, **trimmed_audio_kwargs)
- except TypeError:
- minimal_kwargs = dict(trimmed_audio_kwargs)
- minimal_kwargs.pop("special_tokens", None)
- minimal_kwargs.pop("ignore_id", None)
- self._uni_prompting = universal_prompting_cls(self.tokenizer, **minimal_kwargs)
- self._uni_prompting_init_spec = desired_spec
- except Exception as e:
- logger.warning("Failed to initialize UniversalPrompting: %s", e)
- self._reset_uni_prompting_cache()
-
- return self._uni_prompting
-
- @staticmethod
- def _unwrap_singleton(value: Any) -> Any:
- if isinstance(value, list) and len(value) == 1:
- return value[0]
- return value
-
- @classmethod
- def _coerce_schedule_params(cls, value: Any) -> dict[str, Any]:
- value = cls._unwrap_singleton(value)
- if value is None:
- return {}
- if isinstance(value, dict):
- return {str(k): v for k, v in value.items()}
- if hasattr(value, "items"):
- try:
- return {str(k): v for k, v in dict(value).items()}
- except Exception:
- return {}
- if isinstance(value, str):
- text = value.strip()
- if not text:
- return {}
- try:
- parsed = json.loads(text)
- except Exception:
- return {}
- if isinstance(parsed, dict):
- return {str(k): v for k, v in parsed.items()}
- return {}
-
- def _resolve_noise_schedule(
- self,
- runtime_info: dict[str, Any],
- kwargs: dict[str, Any],
- ) -> Any | None:
- runtime_noise_schedule = unwrap_first_value(
- runtime_info.get("noise_schedule"),
- kwargs.get("noise_schedule"),
- )
- runtime_noise_schedule = self._unwrap_singleton(runtime_noise_schedule)
- if callable(runtime_noise_schedule):
- return runtime_noise_schedule
-
- schedule_name: str | None = None
- if isinstance(runtime_noise_schedule, str) and runtime_noise_schedule.strip():
- schedule_name = runtime_noise_schedule.strip()
-
- if schedule_name is None:
- for key in ("noise_schedule_name", "mask_schedule", "schedule"):
- value = unwrap_first_value(runtime_info.get(key), None)
- if value is None and key in kwargs:
- value = self._unwrap_singleton(kwargs.get(key))
- if isinstance(value, str) and value.strip():
- schedule_name = value.strip()
- break
-
- if schedule_name is None:
- return None
-
- schedule_params = self._coerce_schedule_params(
- unwrap_first_value(
- runtime_info.get("noise_schedule_params"),
- kwargs.get("noise_schedule_params"),
- )
- )
-
- try:
- get_mask_schedule = get_dynin_sampling_attr("get_mask_schedule")
- return get_mask_schedule(schedule_name, **schedule_params)
- except Exception as e:
- logger.warning(
- "Failed to resolve mask schedule '%s' with params=%s: %s",
- schedule_name,
- schedule_params,
- e,
- )
- return None
-
- @staticmethod
- def _coerce_long_tensor_2d(
- value: Any,
- device: torch.device,
- ) -> torch.Tensor | None:
- if value is None:
- return None
- out = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
- if out.ndim == 1:
- out = out.unsqueeze(0)
- if out.ndim > 2:
- out = out.view(out.shape[0], -1)
- return out.to(device=device, dtype=torch.long).contiguous()
-
- @staticmethod
- def _config_get(config_obj: Any, key: str) -> Any:
- if config_obj is None:
- return None
- if isinstance(config_obj, dict):
- return config_obj.get(key)
- if hasattr(config_obj, "get"):
- try:
- return config_obj.get(key)
- except Exception:
- return None
- return None
-
- @classmethod
- def _is_numeric_token_structure(cls, value: Any) -> bool:
- if isinstance(value, torch.Tensor):
- return True
- if isinstance(value, bool):
- return True
- if isinstance(value, int):
- return True
- if isinstance(value, float):
- return float(value).is_integer()
- if isinstance(value, (list, tuple)):
- if not value:
- return False
- return all(cls._is_numeric_token_structure(v) for v in value)
- return False
-
- @classmethod
- def _materialize_prompting_payload(cls, value: Any, ref_device: torch.device) -> Any:
- if isinstance(value, torch.Tensor):
- return value.to(device=ref_device, dtype=torch.long).contiguous()
- if isinstance(value, dict):
- return {k: cls._materialize_prompting_payload(v, ref_device) for k, v in value.items()}
- if isinstance(value, (list, tuple)):
- if cls._is_numeric_token_structure(value):
- try:
- return torch.as_tensor(value, dtype=torch.long, device=ref_device)
- except Exception:
- pass
- converted = [cls._materialize_prompting_payload(v, ref_device) for v in value]
- return tuple(converted) if isinstance(value, tuple) else converted
- return value
-
- @contextmanager
- def _temporary_prompting_overrides(self, uni_prompting: Any, prompting_cfg: Any):
- restore_values: dict[str, Any] = {}
- try:
- max_text_len_override = self._config_get(prompting_cfg, "max_text_len_override")
- if max_text_len_override is not None and hasattr(uni_prompting, "max_text_len"):
- try:
- override_int = int(max_text_len_override)
- if override_int > 0:
- restore_values["max_text_len"] = getattr(uni_prompting, "max_text_len")
- setattr(uni_prompting, "max_text_len", override_int + 1)
- except Exception:
- pass
- yield
- finally:
- for attr_name, original_value in restore_values.items():
- try:
- setattr(uni_prompting, attr_name, original_value)
- except Exception:
- pass
-
- def _prepare_prompting_input(
- self,
- *,
- payload: Any,
- task: str,
- runtime_info: dict[str, Any],
- kwargs: dict[str, Any],
- uni_prompting: Any,
- ref_device: torch.device,
- ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
- if payload is None:
- return None, None
-
- payload = self._unwrap_singleton(payload)
- prompting_task = str(
- self._unwrap_singleton(
- unwrap_first_value(
- runtime_info.get("prompting_task"),
- TASK_TO_PROMPTING_TASK.get(task, task),
- )
- )
- )
- prompting_cfg = self._unwrap_singleton(
- unwrap_first_value(
- runtime_info.get("prompting_config"),
- kwargs.get("prompting_config"),
- )
- )
-
- if isinstance(payload, dict):
- if payload.get("task") is not None:
- prompting_task = str(payload["task"])
- if payload.get("config") is not None:
- prompting_cfg = payload["config"]
- payload = payload.get("input", payload.get("inputs", payload.get("data", payload)))
-
- payload = self._materialize_prompting_payload(payload, ref_device)
-
- try:
- with self._temporary_prompting_overrides(uni_prompting, prompting_cfg):
- prepared = uni_prompting(payload, prompting_task, config=prompting_cfg)
- except Exception as e:
- logger.warning(
- "UniversalPrompting failed for task=%s prompting_task=%s: %s",
- task,
- prompting_task,
- e,
- )
- return None, None
-
- if isinstance(prepared, tuple):
- prepared_input_ids = prepared[0] if len(prepared) > 0 else None
- prepared_attention_mask = prepared[1] if len(prepared) > 1 else None
- else:
- prepared_input_ids = prepared
- prepared_attention_mask = None
-
- return (
- self._coerce_long_tensor_2d(prepared_input_ids, ref_device),
- self._coerce_long_tensor_2d(prepared_attention_mask, ref_device),
- )
-
- def _prepare_prompting_inputs_if_needed(
- self,
- *,
- task: str,
- input_ids: torch.Tensor,
- runtime_info: dict[str, Any],
- kwargs: dict[str, Any],
- gen_kwargs: dict[str, Any],
- ) -> tuple[torch.Tensor, dict[str, Any]]:
- uni_prompting = gen_kwargs.get("uni_prompting")
- if uni_prompting is None:
- uni_prompting = self._get_or_create_uni_prompting(
- runtime_info=runtime_info,
- kwargs=kwargs,
- )
- if uni_prompting is not None:
- gen_kwargs["uni_prompting"] = uni_prompting
-
- if uni_prompting is None:
- return input_ids, gen_kwargs
-
- payload = self._find_first_payload(
- runtime_info=runtime_info,
- kwargs=kwargs,
- keys=PROMPTING_PAYLOAD_KEYS,
- )
-
- if payload is not None:
- prepared_input_ids, prepared_attention_mask = self._prepare_prompting_input(
- payload=payload,
- task=task,
- runtime_info=runtime_info,
- kwargs=kwargs,
- uni_prompting=uni_prompting,
- ref_device=input_ids.device,
- )
- if prepared_input_ids is not None:
- input_ids = prepared_input_ids
- if prepared_attention_mask is not None and "attention_mask" not in gen_kwargs:
- gen_kwargs["attention_mask"] = prepared_attention_mask
-
- uncond_payload = self._find_first_payload(
- runtime_info=runtime_info,
- kwargs=kwargs,
- keys=UNCOND_PROMPTING_PAYLOAD_KEYS,
- )
- if uncond_payload is not None and "uncond_input_ids" not in gen_kwargs:
- uncond_input_ids, uncond_attention_mask = self._prepare_prompting_input(
- payload=uncond_payload,
- task=task,
- runtime_info=runtime_info,
- kwargs=kwargs,
- uni_prompting=uni_prompting,
- ref_device=input_ids.device,
- )
- if uncond_input_ids is not None:
- gen_kwargs["uncond_input_ids"] = uncond_input_ids
- if uncond_attention_mask is not None and "uncond_attention_mask" not in gen_kwargs:
- gen_kwargs["uncond_attention_mask"] = uncond_attention_mask
-
- return input_ids, gen_kwargs
-
- @staticmethod
- def _find_first_payload(
- *,
- runtime_info: dict[str, Any],
- kwargs: dict[str, Any],
- keys: tuple[str, ...],
- ) -> Any | None:
- for key in keys:
- if key in runtime_info:
- return runtime_info[key]
- if key in kwargs:
- return kwargs[key]
- return None
-
- def _extract_decode_tokens(
- self,
- tokens: torch.Tensor,
- runtime_info: dict[str, Any],
- ) -> torch.Tensor:
- prompt_len = int(
- unwrap_first_value(
- runtime_info.get("prompt_length"),
- unwrap_first_value(
- runtime_info.get("prompt_len"),
- unwrap_first_value(runtime_info.get("prompt_token_len"), 0),
- ),
- )
- )
-
- decode_tokens = tokens
- if 0 < prompt_len < tokens.numel():
- decode_tokens = tokens[prompt_len:]
-
- text_vocab_size = unwrap_first_value(runtime_info.get("text_vocab_size"), None)
- if text_vocab_size is None and self.tokenizer is not None:
- text_vocab_size = len(self.tokenizer)
-
- if text_vocab_size is not None:
- vocab_size = int(text_vocab_size)
- valid = decode_tokens[(decode_tokens >= 0) & (decode_tokens < vocab_size)]
- if valid.numel() > 0:
- decode_tokens = valid
-
- return decode_tokens.contiguous()
-
- def _decode_text(self, tokens: torch.Tensor, runtime_info: dict[str, Any]) -> str:
- self._maybe_load_runtime_tokenizer(runtime_info)
- if self.tokenizer is None:
- return ""
- try:
- return self.tokenizer.decode(
- tokens.detach().cpu().tolist(),
- skip_special_tokens=True,
- )
- except Exception:
- return ""
-
- def _maybe_load_runtime_tokenizer(self, runtime_info: dict[str, Any]) -> None:
- tokenizer_path = unwrap_first_value(runtime_info.get("tokenizer_path"), None)
- if tokenizer_path is not None:
- tokenizer_path = str(tokenizer_path)
-
- runtime_local_files_only = unwrap_first_value(
- runtime_info.get("local_files_only_model"),
- unwrap_first_value(
- runtime_info.get("model_local_files_only"),
- unwrap_first_value(
- runtime_info.get("local_files_only"),
- self._infer_sources.model_local_files_only,
- ),
- ),
- )
- local_only = _to_bool(
- runtime_local_files_only,
- default=self._infer_sources.model_local_files_only,
- )
-
- if tokenizer_path and tokenizer_path != self._tokenizer_path:
- try:
- logger.info("Loading DYNIN text tokenizer from %s", tokenizer_path)
- self._set_tokenizer(tokenizer_path, local_files_only=local_only)
- except Exception as e:
- logger.warning("Failed to load tokenizer from %s: %s", tokenizer_path, e)
-
- def _ensure_prompt_vq_model(self, runtime_info: dict[str, Any], ref_device: torch.device) -> Any:
- sources = resolve_dynin_infer_sources(vllm_config=self.vllm_config, runtime_info=runtime_info)
- model_path = str(sources.vq_image_source)
- local_files_only = bool(sources.vq_image_local_files_only)
- if (
- self._prompt_vq_model is None
- or self._prompt_vq_model_path != model_path
- or self._prompt_vq_local_files_only != local_files_only
- ):
- logger.info(
- "Loading DYNIN prompt VQ encoder from %s (local_files_only=%s)",
- model_path,
- local_files_only,
- )
- magvit_cls = get_dynin_magvit_attr(
- "MAGVITv2",
- source=model_path,
- local_files_only=local_files_only,
- )
- try:
- self._prompt_vq_model = magvit_cls.from_pretrained(
- model_path,
- local_files_only=local_files_only,
- )
- except TypeError:
- self._prompt_vq_model = magvit_cls.from_pretrained(model_path)
- self._prompt_vq_model.eval()
- self._prompt_vq_model.requires_grad_(False)
- self._prompt_vq_model_path = model_path
- self._prompt_vq_local_files_only = local_files_only
- if hasattr(self._prompt_vq_model, "to"):
- self._prompt_vq_model = self._prompt_vq_model.to(ref_device)
- return self._prompt_vq_model
-
- @staticmethod
- def _prepare_prompt_image_tensor(
- image: Any,
- *,
- resolution: int,
- device: torch.device,
- ) -> torch.Tensor:
- tensor = image if isinstance(image, torch.Tensor) else torch.as_tensor(image)
- if tensor.ndim == 4:
- tensor = tensor[0]
- if tensor.ndim != 3:
- raise ValueError(f"Unsupported image tensor shape for Dynin bootstrap: {tuple(tensor.shape)}")
-
- if tensor.shape[0] not in (1, 3, 4) and tensor.shape[-1] in (1, 3, 4):
- tensor = tensor.permute(2, 0, 1)
- if tensor.shape[0] == 1:
- tensor = tensor.repeat(3, 1, 1)
- if tensor.shape[0] == 4:
- tensor = tensor[:3]
-
- tensor = tensor.to(device=device, dtype=torch.float32)
- if tensor.numel() > 0 and tensor.max() > 1.0:
- tensor = tensor / 255.0
-
- tensor = tensor.unsqueeze(0)
- _, _, height, width = tensor.shape
- short_side = max(1, min(int(height), int(width)))
- scale = float(resolution) / float(short_side)
- new_height = max(1, int(round(height * scale)))
- new_width = max(1, int(round(width * scale)))
- tensor = F.interpolate(
- tensor,
- size=(new_height, new_width),
- mode="bicubic",
- align_corners=False,
- )
- top = max(0, (new_height - resolution) // 2)
- left = max(0, (new_width - resolution) // 2)
- tensor = tensor[:, :, top : top + resolution, left : left + resolution]
- if tensor.shape[-2:] != (resolution, resolution):
- tensor = F.interpolate(
- tensor,
- size=(resolution, resolution),
- mode="bicubic",
- align_corners=False,
- )
- tensor = torch.clamp(tensor, min=0.0, max=1.0)
- return ((tensor - 0.5) / 0.5).contiguous()
-
- def _encode_prompt_image_tokens(
- self,
- *,
- runtime_info: dict[str, Any],
- mm_inputs: dict[str, Any],
- resolution: int,
- ) -> torch.Tensor:
- image_value = mm_inputs.get("image")
- image_items = self._split_mm_items(image_value)
- if not image_items:
- raise ValueError("Dynin online i2i bootstrap requires an image input.")
-
- device = self._default_mm_device()
- image_tensor = self._prepare_prompt_image_tensor(
- image_items[0],
- resolution=resolution,
- device=device,
- )
- vq_model = self._ensure_prompt_vq_model(runtime_info=runtime_info, ref_device=device)
- with torch.no_grad():
- token_ids = vq_model.get_code(image_tensor)
- token_ids = torch.as_tensor(token_ids, dtype=torch.long).detach().cpu()
- if token_ids.ndim == 2 and token_ids.shape[0] == 1:
- token_ids = token_ids[0]
- return token_ids.contiguous()
-
- @staticmethod
- def _split_mm_items(value: Any) -> list[Any]:
- if value is None:
- return []
- if isinstance(value, torch.Tensor):
- if value.ndim == 0:
- return [value]
- return [value[i] for i in range(value.shape[0])]
- if isinstance(value, list):
- return value
- if isinstance(value, tuple):
- if len(value) == 2 and isinstance(value[1], (int, float)):
- return [value]
- return list(value)
- return [value]
-
- def _default_mm_device(self) -> torch.device:
- try:
- return next(self.model.parameters()).device
- except StopIteration:
- return torch.device("cpu")
-
- @staticmethod
- def _coerce_mm_item_to_float_tensor(
- item: Any,
- *,
- device: torch.device,
- ) -> torch.Tensor:
- if isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], (int, float)):
- item = item[0]
-
- if isinstance(item, torch.Tensor):
- tensor = item.detach().to(device=device, dtype=torch.float32)
- else:
- tensor = torch.as_tensor(item, dtype=torch.float32, device=device)
-
- return tensor.contiguous()
-
- def _build_deterministic_mm_embedding(
- self,
- item: Any,
- *,
- device: torch.device,
- ) -> torch.Tensor:
- tensor = self._coerce_mm_item_to_float_tensor(item, device=device)
- if tensor.numel() == 0:
- return torch.zeros((1, self.hidden_size), dtype=torch.bfloat16, device=device)
-
- flattened = tensor.view(-1)
- first = flattened[0]
- last = flattened[-1]
- mean = flattened.mean()
- std = flattened.std(unbiased=False)
- abs_mean = flattened.abs().mean()
- max_abs = flattened.abs().max()
- l2 = torch.linalg.vector_norm(flattened) / max(float(flattened.numel()), 1.0)
-
- base = torch.stack([first, last, mean, std, abs_mean, max_abs, l2], dim=0)
- denom = torch.clamp(base.abs().max(), min=1.0)
- base = base / denom
-
- repeats = (self.hidden_size + base.numel() - 1) // base.numel()
- embedding = base.repeat(repeats)[: self.hidden_size].to(dtype=torch.bfloat16)
- return embedding.unsqueeze(0).contiguous()
-
- def _collect_mm_inputs(self, **kwargs: Any) -> dict[str, Any]:
- mm_inputs: dict[str, Any] = {}
- for modality, aliases in MM_INPUT_ALIASES.items():
- for alias in aliases:
- if alias in kwargs and kwargs[alias] is not None:
- mm_inputs[modality] = kwargs[alias]
- break
- for modality, value in self._cached_mm_inputs.items():
- if modality not in mm_inputs and value is not None:
- mm_inputs[modality] = value
- return mm_inputs
-
- def embed_multimodal(self, **kwargs: Any) -> Any:
- mm_inputs = self._collect_mm_inputs(**kwargs)
- self._cached_mm_inputs = dict(mm_inputs)
- if not mm_inputs:
- return None
-
- device = self._default_mm_device()
- mm_embeddings: list[torch.Tensor] = []
-
- for modality in ("image", "video", "audio"):
- value = mm_inputs.get(modality)
- if value is None:
- continue
- for item in self._split_mm_items(value):
- mm_embeddings.append(self._build_deterministic_mm_embedding(item, device=device))
-
- return tuple(mm_embeddings) if mm_embeddings else None
diff --git a/vllm_omni/model_executor/models/fish_speech/dac_encoder.py b/vllm_omni/model_executor/models/fish_speech/dac_encoder.py
index cdf0da992fc..397530ca340 100644
--- a/vllm_omni/model_executor/models/fish_speech/dac_encoder.py
+++ b/vllm_omni/model_executor/models/fish_speech/dac_encoder.py
@@ -54,9 +54,6 @@ def _load_dac_codec(
if "generator" in state_dict:
state_dict = state_dict["generator"]
codec.load_state_dict(state_dict, strict=False)
- # Encoder path only uses encoder + quantizer.forward(); prune the
- # decoder before moving to device to avoid unnecessary GPU allocation.
- codec.decoder = None
codec = codec.to(device=device, dtype=dtype)
codec.eval()
diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_dac_decoder.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_dac_decoder.py
index daef2be1856..e121b03371c 100644
--- a/vllm_omni/model_executor/models/fish_speech/fish_speech_dac_decoder.py
+++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_dac_decoder.py
@@ -141,13 +141,6 @@ def _ensure_codec_loaded(self) -> None:
self._bake_weight_norm(codec)
self._cache_attention_masks(codec)
- # Decode path only uses quantizer.decode() + decoder; prune
- # encode-only components before moving to device to avoid
- # unnecessary GPU allocation.
- codec.encoder = None
- codec.quantizer.pre_module = None
- codec.quantizer.downsample = None
-
device = self.vllm_config.device_config.device
codec = codec.to(device=device, dtype=torch.float32)
codec.eval()
@@ -230,9 +223,8 @@ def forward(
for i, info in enumerate(runtime_additional_information):
if i >= len(left_context_size):
break
- meta = info.get("meta", {}) if isinstance(info, dict) else {}
- if "left_context_size" in meta:
- left_context_size[i] = meta["left_context_size"]
+ if "left_context_size" in info:
+ left_context_size[i] = info["left_context_size"]
for i, req_ids in enumerate(request_ids_list):
if req_ids.numel() < 1:
diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py
index 22a2744ff5d..8bbb643ebec 100644
--- a/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py
+++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_fast_ar.py
@@ -310,7 +310,6 @@ def __init__(
self._compiled_model_fwd: object | None = None
self._compile_attempted = False
self._compile_failed = False
- self._disable_compile_for_graph = False
def _ensure_buffers(self, bsz: int, device: torch.device, dtype: torch.dtype) -> None:
max_seq = self._num_codebooks + 1 # hidden_state + num_codebooks codes
@@ -328,20 +327,11 @@ def _setup_compile(self) -> None:
if self._compile_attempted:
return
self._compile_attempted = True
- if self._disable_compile_for_graph:
- try:
- self._compiled_model_fwd = torch.compile(
- self.model.forward,
- dynamic=True,
- options={"epilogue_fusion": False},
- )
- except Exception as exc:
- logger.warning("Fast AR torch.compile (graph mode) failed: %s", exc)
- self._compiled_model_fwd = self.model.forward
- return
try:
self._compiled_model_fwd = torch.compile(
self.model.forward,
+ # Keep the helper compiler separate from vLLM's outer
+ # cudagraph-managed Stage-0 execution.
mode="default",
dynamic=True,
fullgraph=False,
@@ -376,10 +366,10 @@ def warmup_compile(
@torch.inference_mode()
def _run_model(self, step_input: torch.Tensor, step_pos_ids: torch.Tensor, bsz: int) -> torch.Tensor:
- if self._disable_compile_for_graph:
- model_fwd = self._compiled_model_fwd or self.model.forward
- else:
- model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward
+ # Default-on compile only pays off for single-request decode. For
+ # batched decode, eager preserves loaded throughput and avoids the
+ # regression seen with batch>1 compiled execution.
+ model_fwd = self._compiled_model_fwd if bsz == 1 else self.model.forward
try:
return model_fwd(step_input, step_pos_ids)
except Exception as exc:
diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py
index e46451180d3..4ad2a1fa63b 100644
--- a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py
+++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py
@@ -14,6 +14,7 @@
import dataclasses
import math
+import os
from collections.abc import Iterable
from typing import Any
@@ -32,7 +33,6 @@
from vllm.sequence import IntermediateTensors
from vllm_omni.model_executor.models.output_templates import OmniOutput
-from vllm_omni.utils.voice_cache import VoiceEmbeddingCache
from .configuration_fish_speech import FishSpeechConfig, FishSpeechFastARConfig, FishSpeechSlowARConfig
from .dac_encoder import _load_dac_codec, encode_reference_audio_codes
@@ -193,9 +193,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.has_preprocess = True
self.has_postprocess = True
self.mtp_hidden_size = int(self.text_config.hidden_size)
- self.talker_mtp_output_key = ("codes", "audio")
- self.gpu_resident_buffer_keys: set[tuple[str, str]] = {("hidden_states", "last")}
- self.talker_mtp_graph_safe = True
+ self.talker_mtp_output_key = "audio_codes"
+ self.gpu_resident_buffer_keys: set[str] = {"last_slow_ar_hidden"}
# Qwen3 transformer backbone.
self.model = Qwen3Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
@@ -237,8 +236,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
slow_ar_config=self.text_config,
prefix="fast_ar",
)
- if self.talker_mtp_graph_safe:
- self.fast_ar._disable_compile_for_graph = True
# Constant logit mask: allow only semantic tokens + im_end.
vocab = int(self.text_config.vocab_size)
@@ -253,9 +250,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
semantic_mask[im_end_id] = True
self.register_buffer("_semantic_allowed_mask", semantic_mask, persistent=False)
- # In-memory LRU cache for DAC-encoded reference audio codes.
- self._voice_cache = VoiceEmbeddingCache()
-
# Tokeniser (lazy).
self._tokenizer = None
@@ -330,7 +324,7 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: A
for info in info_dicts:
if not isinstance(info, dict):
continue
- ac = info.get("codes", {}).get("audio")
+ ac = info.get("audio_codes")
if isinstance(ac, torch.Tensor):
audio_codes_list.append(ac)
@@ -365,8 +359,7 @@ def preprocess(
if span_len > 1:
# --- Prefill ---
- embed = info_dict.get("embed", {})
- prompt_embeds_buf = embed.get("prefill")
+ prompt_embeds_buf = info_dict.get("slow_ar_prompt_embeds")
is_first_prefill = not isinstance(prompt_embeds_buf, torch.Tensor) or prompt_embeds_buf.ndim != 2
dev = input_ids.device
@@ -382,8 +375,8 @@ def preprocess(
next_offset = min(span_len, total_prompt_len)
info_update: dict[str, Any] = {
- "embed": {"prefill": prompt_embeds_buf if next_offset < total_prompt_len else None},
- "meta": {"prefill_offset": next_offset},
+ "slow_ar_prompt_embeds": prompt_embeds_buf if next_offset < total_prompt_len else None,
+ "prefill_offset": next_offset,
}
take = prompt_embeds_buf[:span_len]
@@ -400,7 +393,7 @@ def preprocess(
device=dev,
dtype=torch.long,
)
- info_update["codes"] = {"audio": zeros}
+ info_update["audio_codes"] = zeros
input_ids_out = input_ids.clone()
input_ids_out[:] = self._audio_pad_token_id
@@ -408,8 +401,7 @@ def preprocess(
else:
# Subsequent prefill chunk.
- meta = info_dict.get("meta", {})
- offset = int(meta.get("prefill_offset", 0) or 0)
+ offset = int(info_dict.get("prefill_offset", 0) or 0)
total_prompt_len = int(prompt_embeds_buf.shape[0])
s = max(0, min(offset, total_prompt_len))
e = max(0, min(offset + span_len, total_prompt_len))
@@ -428,21 +420,20 @@ def preprocess(
input_ids.clone().fill_(self._audio_pad_token_id),
prompt_embeds,
{
- "embed": {"prefill": prompt_embeds_buf if next_offset < total_prompt_len else None},
- "meta": {"prefill_offset": next_offset},
- "codes": {"audio": zeros},
+ "slow_ar_prompt_embeds": prompt_embeds_buf if next_offset < total_prompt_len else None,
+ "prefill_offset": next_offset,
+ "audio_codes": zeros,
},
)
# --- Decode: span_len == 1 ---
dev = input_ids.device
- hs = info_dict.get("hidden_states", {})
- last_hidden = hs.get("last")
+ last_hidden = info_dict.get("last_slow_ar_hidden")
if not isinstance(last_hidden, torch.Tensor):
# First decode step after prefill -- just embed the token directly.
logger.warning(
- "preprocess decode: hidden_states.last not found (keys=%s), "
+ "preprocess decode: last_slow_ar_hidden not found (keys=%s), "
"returning plain embed (mtp_inputs will NOT be set)",
list(info_dict.keys()),
)
@@ -471,8 +462,8 @@ def postprocess(self, hidden_states: torch.Tensor, **_: Any) -> dict[str, Any]:
logger.debug("postprocess: empty hidden_states")
return {}
last = hidden_states[-1, :].detach().contiguous()
- logger.debug("postprocess: saved hidden_states.last shape=%s", tuple(last.shape))
- return {"hidden_states": {"last": last}}
+ logger.debug("postprocess: saved last_slow_ar_hidden shape=%s", tuple(last.shape))
+ return {"last_slow_ar_hidden": last}
# -------------------- prompt construction --------------------
@@ -527,52 +518,17 @@ def _build_structured_voice_clone_prefill_embeds(self, info_dict: dict[str, Any]
tokenizer = self._get_tokenizer()
ref_text = info_dict.get("ref_text")
text = info_dict.get("text")
+ ref_audio_path = info_dict.get("ref_audio_path")
ref_audio_sr = info_dict.get("ref_audio_sr")
if not isinstance(ref_text, str) or not isinstance(text, str):
raise ValueError("Fish Speech structured voice clone requires string text and ref_text")
-
- # --- Voice cache: reuse DAC codes for uploaded (named) voices ---
- _voice_cache_key: str | None = None
- voice_name = info_dict.get("voice_name")
- voice_created_at = info_dict.get("voice_created_at")
- if isinstance(voice_name, str) and voice_name:
- _created_at = float(voice_created_at) if voice_created_at is not None else 0.0
- if _created_at <= 0:
- logger.warning(
- "Voice '%s' has no created_at timestamp; DAC code caching disabled for this request",
- voice_name,
- )
- else:
- _voice_cache_key = self._voice_cache.make_cache_key(
- voice_name,
- xvec_only=False,
- created_at=_created_at,
- )
- _cached = self._voice_cache.get(_voice_cache_key)
- if _cached is not None:
- ref_codes_fq = _cached["ref_codes_fq"].to(
- device=self.codebook_embeddings.weight.device,
- dtype=torch.long,
- )
- _voice_cache_key = None # hit → don't store again
- logger.debug("Voice cache HIT for Fish Speech voice '%s'", voice_name)
- return self._apply_codebook_embeddings(
- tokenizer,
- text,
- ref_text,
- ref_codes_fq,
- )
-
+ if not isinstance(ref_audio_path, str) or not ref_audio_path:
+ raise ValueError("Fish Speech structured voice clone requires ref_audio_path")
if not isinstance(ref_audio_sr, int):
raise ValueError("Fish Speech structured voice clone requires integer ref_audio_sr")
- ref_audio_wav_raw = info_dict.get("ref_audio_wav")
- if ref_audio_wav_raw is None:
- raise ValueError("Fish Speech structured voice clone requires ref_audio_wav")
- if isinstance(ref_audio_wav_raw, torch.Tensor):
- ref_audio_wav = ref_audio_wav_raw.cpu().numpy()
- else:
- ref_audio_wav = np.asarray(ref_audio_wav_raw, dtype=np.float32)
+ ref_audio_wav = np.load(ref_audio_path)
+ os.remove(ref_audio_path)
ref_codes_fq = encode_reference_audio_codes(
self.model_path,
@@ -580,25 +536,6 @@ def _build_structured_voice_clone_prefill_embeds(self, info_dict: dict[str, Any]
ref_audio_sr,
device=self.codebook_embeddings.weight.device,
)
-
- # Cache miss: store DAC codes for future reuse.
- if _voice_cache_key is not None:
- self._voice_cache.put(
- _voice_cache_key,
- {"ref_codes_fq": ref_codes_fq.detach().cpu()},
- )
- logger.debug("Voice cache STORE for Fish Speech voice '%s'", voice_name)
-
- return self._apply_codebook_embeddings(tokenizer, text, ref_text, ref_codes_fq)
-
- def _apply_codebook_embeddings(
- self,
- tokenizer: Any,
- text: str,
- ref_text: str,
- ref_codes_fq: torch.Tensor,
- ) -> torch.Tensor:
- """Build prefill embeddings from DAC codes and inject codebook conditioning."""
semantic_token_ids = (ref_codes_fq[:, 0] + self._semantic_begin_id).tolist()
prompt_ids, _, _ = build_fish_voice_clone_prompt_ids(
tokenizer,
@@ -652,7 +589,6 @@ def talker_mtp(
input_embeds: torch.Tensor,
last_talker_hidden: torch.Tensor,
text_step: torch.Tensor,
- **kwargs: Any,
) -> tuple[torch.Tensor, torch.Tensor]:
"""GPU fast-path: run Fast AR to predict residual codebook codes.
@@ -687,13 +623,18 @@ def talker_mtp(
inputs_embeds_out = input_embeds.reshape(bsz, -1).clone()
semantic_mask = (input_ids[:, 0] >= self._semantic_begin_id) & (input_ids[:, 0] <= self._semantic_end_id)
- semantic_codes = audio_codes.clamp(min=0, max=self._codebook_size - 1)
- offsets = (
- torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size
- ).unsqueeze(0)
- codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16)
- norm_embeds = (inputs_embeds_out + codebook_sum) / math.sqrt(self._num_codebooks + 1)
- inputs_embeds_out = torch.where(semantic_mask.unsqueeze(-1), norm_embeds, inputs_embeds_out)
+ if semantic_mask.any():
+ semantic_codes = audio_codes[semantic_mask].clamp(min=0)
+ offsets = (
+ torch.arange(self._num_codebooks, device=dev, dtype=semantic_codes.dtype) * self._codebook_size
+ ).unsqueeze(0)
+ codebook_sum = self.codebook_embeddings(semantic_codes + offsets).sum(dim=1).to(dtype=torch.bfloat16)
+
+ # Normalize by sqrt(num_codebooks + 1) as in the reference model
+ # (scale_codebook_embeddings=True for fish_qwen3_omni).
+ inputs_embeds_out[semantic_mask] = (inputs_embeds_out[semantic_mask] + codebook_sum) / math.sqrt(
+ self._num_codebooks + 1
+ )
return inputs_embeds_out, audio_codes.to(dtype=torch.long)
@@ -804,15 +745,14 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if truncated:
logger.info("Truncated %d RoPE cos_sin_cache buffers to bf16 precision", truncated)
- if not getattr(self, "talker_mtp_graph_safe", False):
- try:
- self.fast_ar.warmup_compile(
- device=self.codebook_embeddings.weight.device,
- dtype=torch.bfloat16,
- batch_sizes=(1,),
- )
- except Exception as exc:
- logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc)
+ try:
+ self.fast_ar.warmup_compile(
+ device=self.codebook_embeddings.weight.device,
+ dtype=torch.bfloat16,
+ batch_sizes=(1,),
+ )
+ except Exception as exc:
+ logger.warning("Fish Speech Fast AR compile warmup failed: %s", exc)
codec_device = self.codebook_embeddings.weight.device
_load_dac_codec(
diff --git a/vllm_omni/model_executor/models/fish_speech/pipeline.py b/vllm_omni/model_executor/models/fish_speech/pipeline.py
deleted file mode 100644
index baee1b5e72f..00000000000
--- a/vllm_omni/model_executor/models/fish_speech/pipeline.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Fish Speech S2 Pro pipeline topology (frozen).
-
-Stage 0: slow_ar — text → RVQ codec tokens (LLM autoregressive).
-Stage 1: dac_decoder — RVQ tokens → audio waveform (LLM_GENERATION).
-
-The HF config top-level reports ``model_type = "fish_qwen3_omni"`` (the
-OmniConfig that bundles slow-AR and fast-AR sub-configs), which is why the
-registry key follows the HF name rather than the human-readable "fish_speech".
-"""
-
-from vllm_omni.config.stage_config import (
- PipelineConfig,
- StageExecutionType,
- StagePipelineConfig,
-)
-
-_PROC = "vllm_omni.model_executor.stage_input_processors.fish_speech"
-
-FISH_SPEECH_PIPELINE = PipelineConfig(
- model_type="fish_qwen3_omni",
- model_arch="FishSpeechSlowARForConditionalGeneration",
- stages=(
- StagePipelineConfig(
- stage_id=0,
- model_stage="fish_speech_slow_ar",
- execution_type=StageExecutionType.LLM_AR,
- input_sources=(),
- owns_tokenizer=True,
- engine_output_type="latent",
- async_chunk_process_next_stage_input_func=(f"{_PROC}.slow_ar_to_dac_decoder_async_chunk"),
- sampling_constraints={
- "detokenize": False,
- # <|im_end|> — stop when the model emits end-of-turn.
- "stop_token_ids": [151645],
- },
- ),
- StagePipelineConfig(
- stage_id=1,
- model_stage="dac_decoder",
- model_arch="FishSpeechDACDecoder",
- execution_type=StageExecutionType.LLM_GENERATION,
- input_sources=(0,),
- final_output=True,
- final_output_type="audio",
- engine_output_type="audio",
- sampling_constraints={"detokenize": True},
- ),
- ),
-)
diff --git a/vllm_omni/model_executor/models/fish_speech/prompt_utils.py b/vllm_omni/model_executor/models/fish_speech/prompt_utils.py
index 8b8d8559ead..923e97b63af 100644
--- a/vllm_omni/model_executor/models/fish_speech/prompt_utils.py
+++ b/vllm_omni/model_executor/models/fish_speech/prompt_utils.py
@@ -38,7 +38,10 @@ def _encode_plain_text(tokenizer: Any, text: str) -> list[int]:
def _encode_control_token(tokenizer: Any, token: str) -> list[int]:
- token_id = tokenizer.convert_tokens_to_ids(token)
+ vocab = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else {}
+ token_id = vocab.get(token)
+ if token_id is None:
+ token_id = tokenizer.convert_tokens_to_ids(token)
if token_id is None or token_id == getattr(tokenizer, "unk_token_id", None):
raise ValueError(f"Fish Speech tokenizer is missing required control token: {token}")
return [int(token_id)]
diff --git a/vllm_omni/model_executor/models/glm_image/__init__.py b/vllm_omni/model_executor/models/glm_image/__init__.py
index 6153f07c0a0..d37044c09f1 100644
--- a/vllm_omni/model_executor/models/glm_image/__init__.py
+++ b/vllm_omni/model_executor/models/glm_image/__init__.py
@@ -1,16 +1,3 @@
-def __getattr__(name: str):
- """Lazy import to avoid importing transformers.models.glm_image at module init.
-
- The AR model depends on ``transformers.models.glm_image`` which is only
- available when the model weights are present. Importing it eagerly in
- ``__init__.py`` breaks the pipeline registry lookup for environments that
- don't have this custom transformers extension installed.
- """
- if name == "GlmImageForConditionalGeneration":
- from .glm_image_ar import GlmImageForConditionalGeneration
-
- return GlmImageForConditionalGeneration
- raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
-
+from .glm_image_ar import GlmImageForConditionalGeneration
__all__ = ["GlmImageForConditionalGeneration"]
diff --git a/vllm_omni/model_executor/models/glm_image/glm_image_ar.py b/vllm_omni/model_executor/models/glm_image/glm_image_ar.py
index 411af5fa096..31eed9b2cb9 100644
--- a/vllm_omni/model_executor/models/glm_image/glm_image_ar.py
+++ b/vllm_omni/model_executor/models/glm_image/glm_image_ar.py
@@ -21,7 +21,6 @@
# limitations under the License.
"""Inference-only GLM-Image model compatible with HuggingFace weights."""
-import math
import os
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal
@@ -128,14 +127,6 @@ def _get_subparsers(self):
parsers["img2img"] = self._parse_image_data
return parsers
- def parse_mm_data(self, mm_data, **kwargs):
- # Normalize "img2img" to "image" so the rest of the pipeline
- # (mm_hashes, _merge_mm_kwargs) uses a single modality key.
- normalized = {}
- for k, v in mm_data.items():
- normalized["image" if k == "img2img" else k] = v
- return super().parse_mm_data(normalized, **kwargs)
-
class GlmImageProcessingInfo(BaseProcessingInfo):
"""
@@ -330,13 +321,6 @@ class GlmImageMultiModalProcessor(BaseMultiModalProcessor[GlmImageProcessingInfo
- Grid dimension calculation for M-RoPE position encoding
"""
- def _cached_apply_hf_processor(self, inputs, timing_ctx):
- # i2i: prompt text must be modified based on mm data presence,
- # and grid computation requires all images together — bypass cache.
- if inputs.mm_data_items.get_all_counts().get("image", 0) > 0:
- return self._apply_hf_processor(inputs, timing_ctx)
- return super()._cached_apply_hf_processor(inputs, timing_ctx)
-
def _call_hf_processor(
self,
prompt: str,
@@ -362,10 +346,6 @@ def _call_hf_processor(
target_h = mm_kwargs.get("target_h", 1024) if mm_kwargs else 1024
target_w = mm_kwargs.get("target_w", 1024) if mm_kwargs else 1024
- logger.debug(
- f"_call_hf_processor: target dimensions for generation: {target_h}x{target_w}, mm_kwargs={mm_kwargs}"
- )
-
if not mm_data or not mm_data.get("images"):
# Text-to-image mode
if processor is not None:
@@ -586,58 +566,6 @@ def _apply_hf_processor_mm_only(
tensor_type="pt",
)
- def _apply_hf_processor_text_only(
- self, prompt_text: str, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object]
- ) -> list[int]:
- prompt_ids, _, _ = super()._apply_hf_processor_text_mm(
- prompt_text=prompt_text,
- mm_items=MultiModalDataItems({}),
- hf_processor_mm_kwargs=hf_processor_mm_kwargs,
- tokenization_kwargs=tokenization_kwargs,
- )
- return prompt_ids
-
- def _build_generation_grids(self, hf_processor_mm_kwargs: Mapping[str, object]) -> torch.Tensor:
- """Build generation grids for M-RoPE decode positions.
-
- For GLM-Image generation, decode order is:
- 1) small preview grid
- 2) large target grid
- 3) EOS
-
- We store grids as [large, small] to match HF processor behavior, and
- decode logic consumes them in reverse order.
- """
-
- target_h = (
- hf_processor_mm_kwargs.get("target_h") if isinstance(hf_processor_mm_kwargs.get("target_h"), int) else None
- )
- target_w = (
- hf_processor_mm_kwargs.get("target_w") if isinstance(hf_processor_mm_kwargs.get("target_w"), int) else None
- )
- if target_h is None or target_w is None:
- target_h = (
- hf_processor_mm_kwargs.get("height") if isinstance(hf_processor_mm_kwargs.get("height"), int) else 1024
- )
- target_w = (
- hf_processor_mm_kwargs.get("width") if isinstance(hf_processor_mm_kwargs.get("width"), int) else 1024
- )
-
- factor = 32
- target_h = (target_h // factor) * factor
- target_w = (target_w // factor) * factor
- token_h = target_h // factor
- token_w = target_w // factor
-
- ratio = token_h / token_w if token_w > 0 else 1.0
- small_token_h = max(1, int(math.sqrt(ratio) * (factor // 2)))
- small_token_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2)))
-
- return torch.tensor(
- [[1, token_h, token_w], [1, small_token_h, small_token_w]],
- dtype=torch.long,
- )
-
def _apply_hf_processor_main(
self,
prompt: str | list[int],
@@ -666,145 +594,126 @@ def _apply_hf_processor_main(
logger.debug(f"_apply_hf_processor_main: mm_counts={mm_counts}, num_images={num_images}")
- if num_images == 0 and isinstance(prompt, str):
+ if num_images == 0 or enable_hf_prompt_update:
# t2i mode or normal flow - use parent implementation
- prompt_ids = self._apply_hf_processor_text_only(
- prompt_text=prompt,
- hf_processor_mm_kwargs=hf_processor_mm_kwargs,
- tokenization_kwargs=tokenization_kwargs,
- )
- mm_processed_data = self._apply_hf_processor_mm_only(
+ return super()._apply_hf_processor_main(
+ prompt=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
+ enable_hf_prompt_update=enable_hf_prompt_update,
)
- # t2i has no source images, so mm features cannot provide image_grid_thw.
- # Provide explicit generation grids for M-RoPE to avoid fallback token parsing
- # (which can degrade high-resolution spatial positions, e.g. 1920x1920).
- try:
- mrope_grid_thw = self._build_generation_grids(hf_processor_mm_kwargs)
- mm_processed_data["mrope_image_grid_thw"] = mrope_grid_thw
- logger.debug(
- "_apply_hf_processor_main t2i: mrope_image_grid_thw=%s",
- mrope_grid_thw.tolist(),
- )
- except Exception as e:
- logger.warning("_apply_hf_processor_main t2i: failed to set mrope_image_grid_thw: %s", e)
-
- return prompt_ids, mm_processed_data, False
+ # i2i mode with enable_hf_prompt_update=False (cache miss scenario)
+ # We need to build prompt_ids with image placeholders
+ logger.debug(f"_apply_hf_processor_main: i2i mode with enable_hf_prompt_update=False, num_images={num_images}")
- # i2i mode: use unified HF processor path only.
- # This avoids drift between duplicated manual/HF i2i implementations.
- logger.debug(
- "_apply_hf_processor_main: i2i mode (enable_hf_prompt_update=%s), num_images=%s",
- enable_hf_prompt_update,
- num_images,
+ # Get mm data from our overridden _apply_hf_processor_mm_only
+ mm_processed_data = self._apply_hf_processor_mm_only(
+ mm_items=mm_items,
+ hf_processor_mm_kwargs=hf_processor_mm_kwargs,
+ tokenization_kwargs=tokenization_kwargs,
)
- if not isinstance(prompt, str):
- # Online OpenAI chat preprocessing can arrive here with tokenized
- # prompts (list[int]) before serving_chat replaces engine prompt
- # with the clean text prompt. Do not fail the whole request.
- logger.warning(
- "_apply_hf_processor_main i2i: got tokenized prompt type=%s; "
- "using compatibility path for preprocessing",
- type(prompt).__name__,
- )
-
- prompt_ids = list(prompt)
- mm_processed_data = self._apply_hf_processor_mm_only(
- mm_items=mm_items,
- hf_processor_mm_kwargs=hf_processor_mm_kwargs,
- tokenization_kwargs=tokenization_kwargs,
- )
-
- # Preserve full grids for M-RoPE decode (source + target), while
- # keeping image_grid_thw source-only for MM batching.
- try:
- source_grid_thw = mm_processed_data.get("image_grid_thw")
- if source_grid_thw is not None and isinstance(source_grid_thw, torch.Tensor):
+ # In this path we do NOT call HF apply_chat_template, so we must still
+ # provide full grids (source + target) for M-RoPE to compute decode positions.
+ # Keep `image_grid_thw` source-only for MM batching/validation.
+ try:
+ source_grid_thw = mm_processed_data.get("image_grid_thw")
+ if source_grid_thw is not None and isinstance(source_grid_thw, torch.Tensor):
+ # Compute target grid following HF GlmImageProcessor: factor=32.
+ # Prefer explicit target_h/target_w if present, otherwise fall back.
+ target_h = (
+ hf_processor_mm_kwargs.get("target_h")
+ if isinstance(hf_processor_mm_kwargs.get("target_h"), int)
+ else None
+ )
+ target_w = (
+ hf_processor_mm_kwargs.get("target_w")
+ if isinstance(hf_processor_mm_kwargs.get("target_w"), int)
+ else None
+ )
+ if target_h is None or target_w is None:
+ # Some callers pass generation size as height/width.
target_h = (
- hf_processor_mm_kwargs.get("target_h")
- if isinstance(hf_processor_mm_kwargs.get("target_h"), int)
- else None
+ hf_processor_mm_kwargs.get("height")
+ if isinstance(hf_processor_mm_kwargs.get("height"), int)
+ else 1024
)
target_w = (
- hf_processor_mm_kwargs.get("target_w")
- if isinstance(hf_processor_mm_kwargs.get("target_w"), int)
- else None
+ hf_processor_mm_kwargs.get("width")
+ if isinstance(hf_processor_mm_kwargs.get("width"), int)
+ else 1024
)
- if target_h is None or target_w is None:
- target_h = (
- hf_processor_mm_kwargs.get("height")
- if isinstance(hf_processor_mm_kwargs.get("height"), int)
- else 1024
- )
- target_w = (
- hf_processor_mm_kwargs.get("width")
- if isinstance(hf_processor_mm_kwargs.get("width"), int)
- else 1024
- )
- factor = 32
- token_h = max(1, target_h // factor)
- token_w = max(1, target_w // factor)
- target_grid = torch.tensor([[1, token_h, token_w]], dtype=source_grid_thw.dtype)
- mm_processed_data["mrope_image_grid_thw"] = torch.cat([source_grid_thw, target_grid], dim=0)
- except Exception:
- pass
+ factor = 32
+ target_h = (target_h // factor) * factor
+ target_w = (target_w // factor) * factor
+ token_h = target_h // factor
+ token_w = target_w // factor
+ target_grid = torch.tensor([[1, token_h, token_w]], dtype=source_grid_thw.dtype)
- # Prompt updates will expand image placeholders in this compatibility path.
- return prompt_ids, mm_processed_data, False
+ mm_processed_data["mrope_image_grid_thw"] = torch.cat([source_grid_thw, target_grid], dim=0)
+ except Exception:
+ # Best-effort only; M-RoPE has additional fallbacks.
+ pass
- images = mm_items.get_items("image", ImageProcessorItems)
- image_list = [images.get(i) for i in range(images.get_count())]
- if not image_list:
- raise ValueError("GLM-Image i2i requires at least one source image in mm_items")
-
- hf_inputs = self._call_hf_processor(
- prompt=prompt,
- mm_data={"images": image_list},
- mm_kwargs=hf_processor_mm_kwargs,
- tok_kwargs=tokenization_kwargs,
- )
+ # Build prompt_ids with image placeholders
+ # _apply_prompt_updates will replace each [image_token_id] with expanded tokens
+ tokenizer = self.info.get_tokenizer()
+ image_token_id = tokenizer.convert_tokens_to_ids("<|image|>")
+
+ if isinstance(prompt, str):
+ # Match HF GlmImageProcessor behavior: append target grid tokens + BOS.
+ # This helps M-RoPE/grid parsing and keeps i2i vs t2i behavior aligned.
+ try:
+ grid_bos = getattr(tokenizer, "grid_bos_token", "")
+ grid_eos = getattr(tokenizer, "grid_eos_token", "")
+ bos = getattr(tokenizer, "bos_token", "")
+
+ # Use the same target sizes we used for mrope grids when available.
+ target_h = (
+ hf_processor_mm_kwargs.get("target_h")
+ if isinstance(hf_processor_mm_kwargs.get("target_h"), int)
+ else None
+ )
+ target_w = (
+ hf_processor_mm_kwargs.get("target_w")
+ if isinstance(hf_processor_mm_kwargs.get("target_w"), int)
+ else None
+ )
+ if target_h is None or target_w is None:
+ target_h = (
+ hf_processor_mm_kwargs.get("height")
+ if isinstance(hf_processor_mm_kwargs.get("height"), int)
+ else 1024
+ )
+ target_w = (
+ hf_processor_mm_kwargs.get("width")
+ if isinstance(hf_processor_mm_kwargs.get("width"), int)
+ else 1024
+ )
- input_ids = hf_inputs.get("input_ids")
- if input_ids is None:
- raise ValueError("HF i2i processor returned no input_ids")
+ factor = 32
+ target_h = (target_h // factor) * factor
+ target_w = (target_w // factor) * factor
+ token_h = target_h // factor
+ token_w = target_w // factor
- if isinstance(input_ids, torch.Tensor):
- prompt_ids = input_ids[0].tolist() if input_ids.dim() > 1 else input_ids.tolist()
+ expanded_prompt = f"{prompt}{grid_bos}{token_h} {token_w}{grid_eos}{bos}"
+ text_ids = tokenizer.encode(expanded_prompt, add_special_tokens=False)
+ except Exception:
+ text_ids = tokenizer.encode(prompt, add_special_tokens=False)
else:
- prompt_ids = (
- input_ids[0]
- if isinstance(input_ids, list) and input_ids and isinstance(input_ids[0], list)
- else list(input_ids)
- )
+ text_ids = list(prompt)
- mm_processed_data = BatchFeature(dict(), tensor_type="pt")
- for key in ("pixel_values", "image_grid_thw", "mrope_image_grid_thw"):
- value = hf_inputs.get(key)
- if value is not None:
- mm_processed_data[key] = value
+ # Prepend image placeholders - one per image
+ prompt_ids = [image_token_id] * num_images + text_ids
- image_grid_thw = mm_processed_data.get("image_grid_thw")
- mrope_grid_thw = mm_processed_data.get("mrope_image_grid_thw")
- hf_config = self.info.get_hf_config()
- image_token_id = getattr(hf_config, "image_token_id", 167855)
- image_token_count = prompt_ids.count(image_token_id)
- logger.debug(
- "_apply_hf_processor_main i2i(HF): num_images=%s, prompt_len=%s, image_token_count=%s, "
- "source_grid_shape=%s, mrope_grid_shape=%s",
- num_images,
- len(prompt_ids),
- image_token_count,
- tuple(image_grid_thw.shape) if image_grid_thw is not None else None,
- tuple(mrope_grid_thw.shape) if mrope_grid_thw is not None else None,
- )
+ logger.debug(f"_apply_hf_processor_main: built prompt_ids with {num_images} image placeholders")
- # HF processor already expanded image placeholders in input_ids.
- return prompt_ids, mm_processed_data, True
+ # Return is_update_applied=False so _apply_prompt_updates will expand the placeholders
+ return prompt_ids, mm_processed_data, False
def _get_mm_fields_config(
self,
@@ -2328,13 +2237,13 @@ def forward(
upsampled_token_ids.append(tokens_upsampled.view(-1))
prior_token_image_ids_info = {
- "ids": {"prior_image": upsampled_token_ids},
+ "prior_token_image_ids": upsampled_token_ids,
"image_grid_thw": image_grid_thw.tolist(),
}
# Debug: log prior_token_image_ids_info
shapes = [t.shape for t in upsampled_token_ids]
- logger.debug(
+ logger.info(
f"[GlmImageModel.forward] Built prior_token_image_ids_info: "
f"num_images={len(upsampled_token_ids)}, shapes={shapes}, "
f"image_grid_thw={image_grid_thw.tolist()}"
@@ -2540,11 +2449,13 @@ def _process_image_input(
# image_grid_thw is NOT included because:
# 1. vLLM's pooling_output expects dict[str, torch.Tensor], not mixed types
# 2. ar2diffusion doesn't need it - the grid info is already encoded in tensor shape
- prior_token_info = {"ids": {"prior_image": upsampled_token_ids}}
+ prior_token_info = {
+ "prior_token_image_ids": upsampled_token_ids,
+ }
# Debug: log prior_token_info
shapes = [t.shape for t in upsampled_token_ids]
- logger.debug(
+ logger.info(
f"[_process_image_input] Built prior_token_info: "
f"num_images={len(upsampled_token_ids)}, shapes={shapes}, "
f"image_grid_thw={image_grid_thw.tolist()}"
@@ -2756,23 +2667,9 @@ def get_mrope_input_positions(
# Input format: "textH Wh w" where =image_start_token_id=16384
# For 1024x1024: H=32, W=32 (large), h=16, w=16 (small preview)
if not image_grid_thw:
- # Preferred path for t2i: use explicit target size propagated from
- # serving/request sampling params. This avoids fragile grid parsing
- # from token IDs and matches HF processor grid construction.
- target_h = kwargs.get("target_h")
- target_w = kwargs.get("target_w")
- if isinstance(target_h, int) and isinstance(target_w, int) and target_h > 0 and target_w > 0:
- factor = 32
- token_h = target_h // factor
- token_w = target_w // factor
- ratio = token_h / token_w if token_w > 0 else 1.0
- small_h = max(1, int(math.sqrt(ratio) * (factor // 2)))
- small_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2)))
- image_grid_thw = [[1, token_h, token_w], [1, small_h, small_w]]
-
# Try to parse from kwargs (passed from processor)
hf_config_arg = kwargs.get("hf_config")
- if (not image_grid_thw) and hf_config_arg is not None and hasattr(hf_config_arg, "image_grid_thw"):
+ if hf_config_arg is not None and hasattr(hf_config_arg, "image_grid_thw"):
image_grid_thw = hf_config_arg.image_grid_thw
# If still empty, try to infer from input tokens
@@ -2826,29 +2723,19 @@ def get_mrope_input_positions(
prompt_ends_with_start = len(input_tokens) > 0 and input_tokens[-1] == image_start_token_id
if prompt_ends_with_start and len(image_grid_thw) == num_source_images and num_source_images > 0:
# i2i mode: source grids exist but no target grids
- # Prefer explicit target size propagated from request sampling params.
- # This avoids fragile grid parsing from token IDs for non-1024 i2i.
- target_h = kwargs.get("target_h")
- target_w = kwargs.get("target_w")
- if isinstance(target_h, int) and isinstance(target_w, int) and target_h > 0 and target_w > 0:
- factor = 32
- token_h = target_h // factor
- token_w = target_w // factor
- image_grid_thw = list(image_grid_thw) + [[1, token_h, token_w]]
- else:
- # Parse target grids from prompt tokens or use defaults
- parsed_grids = self._parse_grid_from_tokens(input_tokens, hf_config)
- if parsed_grids:
- # parsed_grids contains all grids mentioned in prompt
- # For i2i, add only the generation target grids
- if len(parsed_grids) > num_source_images:
- image_grid_thw = list(image_grid_thw) + parsed_grids[num_source_images:]
- else:
- # Fallback: add default 1024x1024 generation grid (1 target for i2i)
- image_grid_thw = list(image_grid_thw) + [[1, 32, 32]]
+ # Parse target grids from prompt tokens or use defaults
+ parsed_grids = self._parse_grid_from_tokens(input_tokens, hf_config)
+ if parsed_grids:
+ # parsed_grids contains all grids mentioned in prompt
+ # For i2i, add only the generation target grids
+ if len(parsed_grids) > num_source_images:
+ image_grid_thw = list(image_grid_thw) + parsed_grids[num_source_images:]
else:
- # Fallback to default 1024x1024 grid for generation
+ # Fallback: add default 1024x1024 generation grids (1 target for i2i)
image_grid_thw = list(image_grid_thw) + [[1, 32, 32]]
+ else:
+ # Fallback to default 1024x1024 grids for generation
+ image_grid_thw = list(image_grid_thw) + [[1, 32, 32]]
llm_pos_ids_list: list[torch.Tensor] = []
diff --git a/vllm_omni/model_executor/models/glm_image/pipeline.py b/vllm_omni/model_executor/models/glm_image/pipeline.py
deleted file mode 100644
index bd041292c6d..00000000000
--- a/vllm_omni/model_executor/models/glm_image/pipeline.py
+++ /dev/null
@@ -1,47 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""GLM-Image pipeline topologies (frozen).
-Two-stage (default):
- Stage 0: AR — multimodal understanding + token_ids generation
- Stage 1: DiT — diffusion image generation
-"""
-
-from vllm_omni.config.stage_config import (
- PipelineConfig,
- StageExecutionType,
- StagePipelineConfig,
-)
-
-GLM_IMAGE_PIPELINE = PipelineConfig(
- model_type="glm_image",
- model_arch="GlmImageForConditionalGeneration",
- hf_architectures=("GlmImageForConditionalGeneration",),
- diffusers_class_name="GlmImagePipeline",
- stages=(
- StagePipelineConfig(
- stage_id=0,
- model_stage="ar",
- execution_type=StageExecutionType.LLM_AR,
- requires_multimodal_data=True,
- input_sources=(),
- final_output=False,
- owns_tokenizer=True,
- model_arch="GlmImageForConditionalGeneration",
- engine_output_type="token_ids",
- model_subdir="vision_language_encoder",
- tokenizer_subdir="processor",
- ),
- StagePipelineConfig(
- stage_id=1,
- model_stage="dit",
- execution_type=StageExecutionType.DIFFUSION,
- input_sources=(0,),
- requires_multimodal_data=True,
- final_output=True,
- final_output_type="image",
- model_arch="GlmImagePipeline",
- custom_process_input_func="vllm_omni.model_executor.stage_input_processors.glm_image.ar2diffusion",
- omni_kv_config={"need_recv_cache": False},
- ),
- ),
-)
diff --git a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py
index 6304eeab29b..6d25274f901 100644
--- a/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py
+++ b/vllm_omni/model_executor/models/hunyuan_image3/hunyuan_image3.py
@@ -77,9 +77,7 @@
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils.tensor_schema import TensorSchema
-from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
-from vllm.v1.sample.sampler import Sampler
from vllm_omni.model_executor.models.hunyuan_image3.autoencoder_kl_3d import AutoencoderKLConv3D
from vllm_omni.model_executor.models.hunyuan_image3.siglip2 import LightProjector, Siglip2VisionTransformer
@@ -177,11 +175,8 @@ def contains_unexpected_keyword(name, keywords):
return True
return False
- skipped_unexpected: set[str] = set()
-
for name, loaded_weight in weights:
if contains_unexpected_keyword(name, unexpected_keywords):
- skipped_unexpected.add(name)
continue
if "rotary_emb.inv_freq" in name:
@@ -367,17 +362,6 @@ def contains_unexpected_keyword(name, keywords):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
-
- if skipped_unexpected:
- logger.warning_once(
- "Skipped %d weights matching unexpected_keywords "
- "(e.g. vae, vision_model, patch_embed, timestep_emb). "
- "If upstream renamed components, these may be silently "
- "lost. Skipped names: %s",
- len(skipped_unexpected),
- sorted(skipped_unexpected)[:10],
- )
-
return loaded_params
@@ -1165,8 +1149,6 @@ class HunyuanImage3ForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo
HunyuanImage3Inputs: TypeAlias = HunyuanImage3PixelInputs
- prefer_model_sampler = True
-
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -1217,10 +1199,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
else:
self.lm_head = PPMissingLayer()
- # --- AR-stage components ---
- # These are needed for image encoding in the AR stage.
- # If a future text-only stage is added, gate on vllm_config.model_config.model_stage.
-
# vae
self.vae = AutoencoderKLConv3D.from_config(config.vae)
self.patch_embed = UNetDown(
@@ -1248,63 +1226,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._mrope_joint_img_sep_token_id = tokenizer.convert_tokens_to_ids("")
self._mrope_max_num_patches = config.vit_processor.get("max_num_patches", 729)
- # Special token IDs for logits processors (stage transitions).
- # These mirror the official tokenization_hunyuan_image_3.py setup.
- self._end_of_think_id = tokenizer.convert_tokens_to_ids(" ")
- self._recaption_id = tokenizer.convert_tokens_to_ids("")
- self._end_of_recaption_id = tokenizer.convert_tokens_to_ids(" ")
- self._answer_id = tokenizer.convert_tokens_to_ids("")
- self._end_of_answer_id = tokenizer.convert_tokens_to_ids(" ")
- image_base_size = getattr(config, "image_base_size", 1024)
- self._size_token_id = tokenizer.convert_tokens_to_ids(f"")
- self._start_ratio_id = tokenizer.convert_tokens_to_ids("")
- self._end_ratio_id = tokenizer.convert_tokens_to_ids("")
- ratio_33 = tokenizer.convert_tokens_to_ids("")
- ratio_36 = tokenizer.convert_tokens_to_ids("")
- self._ratio_other_slices = [(ratio_33, ratio_36 + 1)]
- # Build the full set of ratio token IDs for use as stop tokens.
- self._all_ratio_ids = set(range(self._start_ratio_id, self._end_ratio_id + 1))
- for s, e in self._ratio_other_slices:
- self._all_ratio_ids.update(range(s, e))
-
- # Determine mode: comprehension (I2T/T2T) vs generation (IT2I/T2I).
- engine_output_type = getattr(vllm_config.model_config, "engine_output_type", None)
- self._is_comprehension = engine_output_type in (None, "text")
-
- # For comprehension mode, block image generation tokens but allow
- # text structure tokens (, , etc.) so the model can
- # follow its natural generation pattern. Stop tokens in YAML will
- # terminate at or EOS.
- self._blocked_token_ids: set[int] = set()
- if self._is_comprehension:
- self._blocked_token_ids.update(
- [
- self._mrope_boi_token_id, #
- self._mrope_eoi_token_id, #
- self._size_token_id, #
- ]
- )
- self._blocked_token_ids.update(self._all_ratio_ids)
-
- # For generation mode, build stage transition map.
- # Official logic: → [],
- # → [, , ]
- # After , restrict vocab to ratio tokens only.
- # Stage-transition forced sequences, keyed by trigger token.
- self._stage_transitions: dict[int, list[int]] = {}
- if not self._is_comprehension:
- self._stage_transitions[self._end_of_think_id] = [
- self._recaption_id,
- ]
- self._stage_transitions[self._end_of_recaption_id] = [
- self._answer_id,
- self._mrope_boi_token_id,
- self._size_token_id,
- ]
-
- self._sampler: Sampler | None = None
- self._eos_token_id: int = tokenizer.eos_token_id
-
self._replace_rotary_embeddings()
def _replace_rotary_embeddings(self):
@@ -1336,12 +1257,6 @@ def _replace_rotary_embeddings(self):
head_dim,
rope_theta,
)
- if replaced == 0:
- raise RuntimeError(
- "HunyuanImage3: _replace_rotary_embeddings replaced 0 layers. "
- "The custom interleaved 2D mRoPE is not active — model outputs "
- "will be incorrect. Check that model.layers[*].self_attn.rotary_emb exists."
- )
def _parse_and_validate_image_input(
self,
@@ -1359,10 +1274,6 @@ def _parse_and_validate_image_input(
if vit_pixel_values is None or vae_pixel_values is None:
return None
- # Handle empty batch (e.g., during profiling with 0 images / T2T mode)
- if vit_pixel_values.numel() == 0 or vae_pixel_values.numel() == 0:
- return None
-
return HunyuanImage3PixelInputs(
type="pixel_values",
pixel_values={
@@ -1561,112 +1472,6 @@ def compute_logits(
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
- # ------------------------------------------------------------------
- # Custom sampler — applies HunyuanImage3-specific logits processors
- # before the standard sampling step.
- #
- # Comprehension (I2T / T2T):
- # Block generation-specific special tokens so sampling can't
- # accidentally produce , , ratio tokens, etc.
- #
- # Generation (IT2I / T2I think):
- # 1. _StageTransitionLogitsProcessor — force token sequences at
- # transition boundaries ( → , etc.)
- # 2. _ConditionalSliceVocabLogitsProcessor — after ,
- # restrict vocab to ratio tokens only (greedy).
- # ------------------------------------------------------------------
-
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> SamplerOutput | None:
- if logits is None or logits.numel() == 0:
- return None
-
- if self._sampler is None:
- self._sampler = Sampler()
-
- min_score = torch.finfo(logits.dtype).min
-
- assert logits.shape[0] == 1, f"HunyuanImage3 sampler requires max_num_seqs=1, got batch size {logits.shape[0]}"
-
- for req_idx in range(logits.shape[0]):
- decoded_tokens: list[int] = (
- sampling_metadata.output_token_ids[req_idx] if req_idx < len(sampling_metadata.output_token_ids) else []
- )
- last_token = decoded_tokens[-1] if decoded_tokens else -1
-
- if self._is_comprehension:
- for tid in self._blocked_token_ids:
- logits[req_idx, tid] = min_score
- else:
- forced = self._get_forced_token(decoded_tokens)
- if forced is not None:
- logits[req_idx].fill_(min_score)
- logits[req_idx, forced] = 0
- elif last_token == self._size_token_id:
- self._apply_ratio_restriction(logits, req_idx, min_score)
- elif last_token in self._all_ratio_ids:
- logits[req_idx].fill_(min_score)
- logits[req_idx, self._eos_token_id] = 0
-
- return self._sampler(logits=logits, sampling_metadata=sampling_metadata)
-
- def _get_forced_token(self, decoded_tokens: list[int]) -> int | None:
- """Derive the next forced token from output history (stateless).
-
- Scans decoded_tokens backwards for the most recent trigger token,
- then prefix-matches the forced sequence against what followed.
- Returns the next token to force, or None if the sequence is complete
- or history has diverged from the expected forced sequence.
- """
- for i in range(len(decoded_tokens) - 1, -1, -1):
- trigger = decoded_tokens[i]
- if trigger not in self._stage_transitions:
- continue
-
- forced_seq = self._stage_transitions[trigger]
- emitted = decoded_tokens[i + 1 :]
-
- matched = 0
- for expected, actual in zip(forced_seq, emitted):
- if actual != expected:
- # History diverged from the expected forced sequence.
- # Stop applying transition forcing for safety.
- return None
- matched += 1
-
- if matched < len(forced_seq):
- return forced_seq[matched]
- return None
-
- return None
-
- def _apply_ratio_restriction(
- self,
- logits: torch.Tensor,
- req_idx: int,
- min_score: float,
- ) -> None:
- """Port of official _ConditionalSliceVocabLogitsProcessor.__call__.
-
- After the size token, only allow ratio tokens and pick greedily.
- """
- original = logits[req_idx].clone()
- logits[req_idx].fill_(min_score)
- # Allow primary ratio range.
- logits[req_idx, self._start_ratio_id : self._end_ratio_id + 1] = original[
- self._start_ratio_id : self._end_ratio_id + 1
- ]
- # Allow extra ratio slices.
- for s, e in self._ratio_other_slices:
- logits[req_idx, s:e] = original[s:e]
- # Force greedy: keep only the argmax.
- max_id = logits[req_idx].argmax().item()
- logits[req_idx].fill_(min_score)
- logits[req_idx, max_id] = 0
-
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype, device: torch.device
) -> IntermediateTensors:
@@ -1702,9 +1507,9 @@ def get_mrope_input_positions(
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec] | None = None,
*,
- hf_config: PretrainedConfig | None = None,
- image_grid_thw: list[list[int]] | torch.Tensor | None = None,
- video_grid_thw: list[list[int]] | torch.Tensor | None = None,
+ hf_config: PretrainedConfig,
+ image_grid_thw: list[list[int]] | torch.Tensor,
+ video_grid_thw: list[list[int]] | torch.Tensor,
second_per_grid_ts: list[float] | None = None,
context_len: int = 0,
seq_len: int | None = None,
diff --git a/vllm_omni/model_executor/models/mammoth_moda2/mammoth_moda2.py b/vllm_omni/model_executor/models/mammoth_moda2/mammoth_moda2.py
index 6311216dc0c..be56a14a606 100644
--- a/vllm_omni/model_executor/models/mammoth_moda2/mammoth_moda2.py
+++ b/vllm_omni/model_executor/models/mammoth_moda2/mammoth_moda2.py
@@ -608,17 +608,16 @@ def _apply_t2i_token_constraints(self, logits: torch.Tensor) -> torch.Tensor:
num_reqs = int(logits.shape[0])
for i in range(num_reqs):
runtime_info = runtime_infos[i] if isinstance(runtime_infos[i], dict) else {}
- meta = runtime_info.get("meta", {})
- omni_task = meta.get("omni_task")
+ omni_task = runtime_info.get("omni_task")
if not isinstance(omni_task, list) or not omni_task or omni_task[0] != "t2i":
# Text/understanding/chat: forbid sampling from the extra gen vocab.
logits[i, self.language_model.base_vocab_size :] = neg_inf
continue
- ar_width = meta["ar_width"][0]
- eol_token_id = meta["eol_token_id"][0]
- visual_start = meta["visual_token_start_id"][0]
- visual_end = meta["visual_token_end_id"][0]
+ ar_width = runtime_info["ar_width"][0]
+ eol_token_id = runtime_info["eol_token_id"][0]
+ visual_start = runtime_info["visual_token_start_id"][0]
+ visual_end = runtime_info["visual_token_end_id"][0]
generated_len = runtime_info["generated_len"]
row = logits[i]
diff --git a/vllm_omni/model_executor/models/mimo_audio/mimo_audio.py b/vllm_omni/model_executor/models/mimo_audio/mimo_audio.py
index ee924dc4669..22a9a911130 100644
--- a/vllm_omni/model_executor/models/mimo_audio/mimo_audio.py
+++ b/vllm_omni/model_executor/models/mimo_audio/mimo_audio.py
@@ -799,7 +799,7 @@ def forward(
return OmniOutput(
text_hidden_states=text_hidden_states.reshape(-1, text_hidden_states.shape[-1]),
- multimodal_outputs={"codes": {"audio": next_speech_tokens}},
+ multimodal_outputs={"code_predictor_codes": next_speech_tokens},
)
if self.model_stage == "code2wav":
diff --git a/vllm_omni/model_executor/models/mimo_audio/mimo_audio_code2wav.py b/vllm_omni/model_executor/models/mimo_audio/mimo_audio_code2wav.py
index e6d70915ce2..7c2e87d5f5c 100644
--- a/vllm_omni/model_executor/models/mimo_audio/mimo_audio_code2wav.py
+++ b/vllm_omni/model_executor/models/mimo_audio/mimo_audio_code2wav.py
@@ -486,10 +486,10 @@ def _split_flat_codes_for_requests(
return [ids]
if runtime_additional_information and all(
- isinstance(info.get("meta", {}).get("code_flat_numel"), int) and int(info["meta"]["code_flat_numel"]) > 0
+ isinstance(info.get("code_flat_numel"), int) and int(info["code_flat_numel"]) > 0
for info in runtime_additional_information
):
- sizes = [int(info["meta"]["code_flat_numel"]) for info in runtime_additional_information]
+ sizes = [int(info["code_flat_numel"]) for info in runtime_additional_information]
if sum(sizes) == n:
parts: list[torch.Tensor] = []
offset = 0
@@ -517,11 +517,11 @@ def _mimo_codec_runtime_lists(
if not runtime_additional_information:
return left_frames, chunk_frames
for i in range(min(num_req, len(runtime_additional_information))):
- meta = runtime_additional_information[i].get("meta", {})
- if "left_context_size" in meta:
- left_frames[i] = int(meta["left_context_size"])
- if "codec_chunk_frames" in meta:
- chunk_frames[i] = int(meta["codec_chunk_frames"])
+ info = runtime_additional_information[i]
+ if "left_context_size" in info:
+ left_frames[i] = int(info["left_context_size"])
+ if "codec_chunk_frames" in info:
+ chunk_frames[i] = int(info["codec_chunk_frames"])
return left_frames, chunk_frames
def chunked_decode_streaming(
diff --git a/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py b/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py
index 85fe4b0051c..56cb8788ee2 100644
--- a/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py
+++ b/vllm_omni/model_executor/models/mimo_audio/mimo_audio_llm.py
@@ -50,7 +50,6 @@
PromptUpdate,
PromptUpdateDetails,
)
-from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema
@@ -151,6 +150,7 @@ def __init__(self, model: "MiMoAudioLLMForConditionalGeneration", max_batch_size
dtype = next(model.hidden_states_downcast.parameters()).dtype
hidden_size = model.local_config.hidden_size
+ self.pool = torch.cuda.graph_pool_handle()
self.input_tensor = torch.zeros((max_batch_size, 1, hidden_size), dtype=dtype, device=device)
self.sampler = MiMoLocalSamplerTensor(
temperature=torch.ones(max_batch_size, dtype=torch.float32, device=device),
@@ -231,7 +231,7 @@ def capture(
cuda_graph = torch.cuda.CUDAGraph()
if eager_run_first:
model.base_local_forward(input_tensor, local_sampler=sampler)
- with torch.cuda.graph(cuda_graph, pool=current_platform.get_global_graph_pool()):
+ with torch.cuda.graph(cuda_graph, buffer.pool):
output_tensor = model.base_local_forward(input_tensor, local_sampler=sampler)
return cls(
@@ -263,6 +263,7 @@ def __init__(self, model: "MiMoAudioLLMForConditionalGeneration", max_batch_size
hidden_size = model.input_local_config.hidden_size
group_size = model.group_size
+ self.pool = torch.cuda.graph_pool_handle()
self.input_tensor = torch.zeros((max_batch_size, group_size, hidden_size), dtype=dtype, device=device)
self.lock = threading.Lock()
@@ -310,7 +311,7 @@ def capture(
out = model.input_local_transformer(inputs_embeds=input_tensor, return_dict=True, is_causal=False)
_ = out.last_hidden_state
- with torch.cuda.graph(cuda_graph, pool=current_platform.get_global_graph_pool()):
+ with torch.cuda.graph(cuda_graph, buffer.pool):
out = model.input_local_transformer(inputs_embeds=input_tensor, return_dict=True, is_causal=False)
output_tensor = out.last_hidden_state
diff --git a/vllm_omni/model_executor/models/mimo_audio/pipeline.py b/vllm_omni/model_executor/models/mimo_audio/pipeline.py
deleted file mode 100644
index 1cc15af5064..00000000000
--- a/vllm_omni/model_executor/models/mimo_audio/pipeline.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""MiMo Audio pipeline topology (frozen).
-
-Stage 0: fused thinker+talker — multimodal understanding + text + RVQ codes.
-Stage 1: Code2Wav — RVQ codes → waveform.
-
-MiMoAudioConfig inherits from Qwen2Config, so the HF ``model_type`` field
-reports ``qwen2`` — the registry's model_type-based auto-detect can't
-disambiguate. ``hf_architectures`` lets ``StageConfigFactory`` fall back to
-matching ``hf_config.architectures`` instead.
-"""
-
-from vllm_omni.config.stage_config import (
- PipelineConfig,
- StageExecutionType,
- StagePipelineConfig,
-)
-
-_PROC = "vllm_omni.model_executor.stage_input_processors.mimo_audio"
-
-MIMO_AUDIO_PIPELINE = PipelineConfig(
- model_type="mimo_audio",
- # HF ``architectures: ["MiMoAudioModel"]`` is also the registry key in
- # ``model_executor/models/registry.py``; it resolves to the internal
- # class ``MiMoAudioForConditionalGeneration`` in ``mimo_audio.py``.
- model_arch="MiMoAudioModel",
- hf_architectures=("MiMoAudioModel", "MiMoV2ASRForCausalLM"),
- stages=(
- StagePipelineConfig(
- stage_id=0,
- model_stage="fused_thinker_talker",
- execution_type=StageExecutionType.LLM_AR,
- input_sources=(),
- final_output=True,
- final_output_type="text",
- owns_tokenizer=True,
- engine_output_type="latent",
- async_chunk_process_next_stage_input_func=(f"{_PROC}.llm2code2wav_async_chunk"),
- sampling_constraints={"detokenize": True},
- ),
- StagePipelineConfig(
- stage_id=1,
- model_stage="code2wav",
- execution_type=StageExecutionType.LLM_GENERATION,
- input_sources=(0,),
- final_output=True,
- final_output_type="audio",
- engine_output_type="audio",
- sync_process_input_func=f"{_PROC}.llm2code2wav",
- sampling_constraints={"detokenize": False},
- ),
- ),
-)
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/__init__.py b/vllm_omni/model_executor/models/ming_flash_omni/__init__.py
deleted file mode 100644
index 4cd086e6426..00000000000
--- a/vllm_omni/model_executor/models/ming_flash_omni/__init__.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-
-from .ming_flash_omni import MingFlashOmniForConditionalGeneration
-from .ming_flash_omni_talker import MingFlashOmniTalkerForConditionalGeneration
-from .ming_flash_omni_thinker import (
- MingFlashOmniThinkerDummyInputsBuilder,
- MingFlashOmniThinkerForConditionalGeneration,
- MingFlashOmniThinkerMultiModalProcessor,
- MingFlashOmniThinkerProcessingInfo,
-)
-
-__all__ = [
- "MingFlashOmniForConditionalGeneration",
- "MingFlashOmniTalkerForConditionalGeneration",
- "MingFlashOmniThinkerForConditionalGeneration",
- "MingFlashOmniThinkerProcessingInfo",
- "MingFlashOmniThinkerMultiModalProcessor",
- "MingFlashOmniThinkerDummyInputsBuilder",
-]
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py b/vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py
deleted file mode 100644
index 6ca19901141..00000000000
--- a/vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py
+++ /dev/null
@@ -1,246 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-# Copyright 2024 ANT Group and the HuggingFace Inc. team.
-# Copyright (c) 2022 OpenAI
-# Adapted from Ming repository modeling_whisper_encoder.py
-# https://github.com/inclusionAI/Ming
-
-import operator
-from collections.abc import Iterable
-from itertools import accumulate
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from vllm.logger import init_logger
-from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-
-from vllm_omni.diffusion.attention.backends.utils.fa import HAS_FLASH_ATTN, flash_attn_varlen_func
-from vllm_omni.model_executor.models.whisper_utils import Conv1d, Linear, sinusoids
-
-logger = init_logger(__name__)
-
-
-class MultiHeadAttention(nn.Module):
- """Multi-head attention with packed sequence support.
- Adapted from Qwen3-TTS WhisperEncoder.
- """
-
- def __init__(self, n_state: int, n_head: int, use_flash_attn: bool = True):
- super().__init__()
- self.n_head = n_head
- self.query = Linear(n_state, n_state)
- self.key = Linear(n_state, n_state, bias=False)
- self.value = Linear(n_state, n_state)
- self.out = Linear(n_state, n_state)
-
- if use_flash_attn and not HAS_FLASH_ATTN:
- logger.warning("flash-attn is not available. Fallback to manual PyTorch version")
- self.use_flash_attn = use_flash_attn and HAS_FLASH_ATTN
-
- def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
- """Forward pass with packed sequence support.
-
- Args:
- x: [total_tokens, n_state] packed sequence
- cu_seqlens: [num_seqs + 1] cumulative sequence lengths, e.g. [0, len1, len1+len2, ...]
-
- Returns:
- [total_tokens, n_state] attention output
- """
- q = self.query(x)
- k = self.key(x)
- v = self.value(x)
-
- n_ctx, n_state = q.shape
- head_dim = n_state // self.n_head
-
- q = q.view(n_ctx, self.n_head, head_dim)
- k = k.view(n_ctx, self.n_head, head_dim)
- v = v.view(n_ctx, self.n_head, head_dim)
-
- # Try flash attention varlen
- if self.use_flash_attn and cu_seqlens is not None and q.dtype in [torch.float16, torch.bfloat16]:
- max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen)
- else:
- attn_output = self._manual_attention(q, k, v, cu_seqlens)
-
- # Reshape back: [T, H, D] -> [T, H*D]
- attn_output = attn_output.contiguous().view(n_ctx, n_state)
- return self.out(attn_output)
-
- def _manual_attention(
- self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor
- ) -> torch.Tensor:
- """Manual attention for variable-length sequences (fallback)."""
- _, n_head, head_dim = q.shape
- scale = head_dim**-0.5
-
- # Unpack sequences and pad to max length
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
- batch_size = len(seqlens)
- max_seqlen = max(seqlens)
-
- # Create padded tensors
- q_padded = torch.zeros(batch_size, max_seqlen, n_head, head_dim, dtype=q.dtype, device=q.device)
- k_padded = torch.zeros_like(q_padded)
- v_padded = torch.zeros_like(q_padded)
-
- # Fill with actual sequences
- for i in range(batch_size):
- start_idx = cu_seqlens[i]
- end_idx = cu_seqlens[i + 1]
- seq_len = seqlens[i]
- q_padded[i, :seq_len] = q[start_idx:end_idx]
- k_padded[i, :seq_len] = k[start_idx:end_idx]
- v_padded[i, :seq_len] = v[start_idx:end_idx]
-
- # Transpose for attention: [B, H, T, D]
- q_padded = q_padded.transpose(1, 2)
- k_padded = k_padded.transpose(1, 2)
- v_padded = v_padded.transpose(1, 2)
-
- # Create attention mask for variable lengths: 0 for valid positions, -inf for padding
- padding_mask = (
- torch.arange(max_seqlen, device=q.device)[None, :] >= torch.tensor(seqlens, device=q.device)[:, None]
- )
- attn_mask = torch.zeros(batch_size, 1, 1, max_seqlen, dtype=q.dtype, device=q.device)
- attn_mask = attn_mask.masked_fill(padding_mask.unsqueeze(1).unsqueeze(2), -torch.finfo(q.dtype).max)
-
- # Compute attention
- attn_scores = torch.matmul(q_padded, k_padded.transpose(-2, -1)) * scale
- attn_scores = attn_scores + attn_mask
- attn_weights = F.softmax(attn_scores, dim=-1)
- context = torch.matmul(attn_weights, v_padded)
-
- # Transpose back: [B, H, T, D] -> [B, T, H, D]
- context = context.transpose(1, 2).contiguous()
- output_packed = torch.cat([context[i, : seqlens[i]] for i in range(batch_size)], dim=0)
-
- return output_packed
-
-
-class ResidualAttentionBlock(nn.Module):
- """Whisper-style residual attention block with packed sequence support.
-
- Adapted from
- https://github.com/openai/whisper/blob/v20250625/whisper/model.py
- vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py
- """
-
- def __init__(self, n_state: int, n_head: int, use_flash_attn: bool = True):
- super().__init__()
- self.attn = MultiHeadAttention(n_state, n_head, use_flash_attn=use_flash_attn)
- self.attn_ln = nn.LayerNorm(n_state)
-
- n_mlp = n_state * 4
- self.mlp = nn.Sequential(
- Linear(n_state, n_mlp),
- nn.GELU(),
- Linear(n_mlp, n_state),
- )
- self.mlp_ln = nn.LayerNorm(n_state)
-
- def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
- x = x + self.attn(self.attn_ln(x), cu_seqlens=cu_seqlens)
- x = x + self.mlp(self.mlp_ln(x))
- return x
-
-
-class WhisperAudioEncoder(nn.Module):
- """Whisper audio encoder for Ming with packed sequence support.
-
- Adapted from
- https://github.com/openai/whisper/blob/v20250625/whisper/model.py
- vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py
- """
-
- def __init__(
- self,
- n_mels: int = 128,
- n_ctx: int = 15000,
- n_state: int = 1280,
- n_head: int = 20,
- n_layer: int = 32,
- use_flash_attn: bool = True,
- ):
- super().__init__()
- self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
- self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
- # self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
- self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
- self.blocks = nn.ModuleList(
- [ResidualAttentionBlock(n_state, n_head, use_flash_attn=use_flash_attn) for _ in range(n_layer)]
- )
- self.ln_post = nn.LayerNorm(n_state)
- self.audio_emb_dim = n_state
-
- self.n_layer = n_layer
- self.n_mels = n_mels
- self.use_flash_attn = use_flash_attn
-
- def forward(
- self,
- x_list: list[torch.Tensor],
- audio_lens: list[int],
- ) -> torch.Tensor:
- """Forward pass with packed sequence format for variable-length inputs.
-
- Args:
- x_list: List of [n_mels, T_i] mel spectrogram features for each audio
- audio_lens: List of original audio lengths in frames
-
- Returns:
- [total_T', n_state] packed encoded audio features, where
- total_T' is the sum of all encoded sequence lengths
- """
- # Cast inputs to model dtype
- target_dtype = self.conv1.weight.dtype
- x_list = [x.to(target_dtype) for x in x_list]
-
- encoded_list = []
- encoded_lens = []
- for mel_spec in x_list:
- # mel_spec: [n_mels, T] - process through conv layers
- x = mel_spec.unsqueeze(0) # [1, n_mels, T]
- x = F.gelu(self.conv1(x))
- x = F.gelu(self.conv2(x))
- x = x.squeeze(0).transpose(0, 1) # [T', n_state]
-
- # Add positional embedding
- seq_len = x.shape[0]
- positional_embedding = self.positional_embedding[:seq_len, :]
- x = (x + positional_embedding).to(x.dtype)
-
- encoded_list.append(x)
- encoded_lens.append(seq_len)
-
- x_packed = torch.cat(encoded_list, dim=0) # [total_T', n_state]
-
- cu_seqlens = list(accumulate(encoded_lens, func=operator.add, initial=0))
- cu_seqlens = torch.tensor(cu_seqlens, device=x_packed.device, dtype=torch.int32)
-
- for block in self.blocks:
- x_packed = block(x_packed, cu_seqlens=cu_seqlens)
-
- x_packed = self.ln_post(x_packed)
- return x_packed
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- params_dict: dict[str, torch.Tensor] = {
- **dict(self.named_parameters(remove_duplicate=False)),
- **dict(self.named_buffers()),
- }
- loaded_params: set[str] = set()
-
- for name, loaded_weight in weights:
- if name not in params_dict:
- logger.warning("Skipping unknown audio encoder weight: %s", name)
- continue
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
-
- return loaded_params
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/audio_vae.py b/vllm_omni/model_executor/models/ming_flash_omni/audio_vae.py
deleted file mode 100644
index 9d5c266b4fe..00000000000
--- a/vllm_omni/model_executor/models/ming_flash_omni/audio_vae.py
+++ /dev/null
@@ -1,390 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-# Copyright (c) Ant Group. All rights reserved.
-# Adapted from:
-# https://github.com/inclusionAI/Ming/tree/e58533db227031990c5a6864dcf5f08fb53ed0d2/AudioVAE
-
-from __future__ import annotations
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from transformers import PretrainedConfig, PreTrainedModel, Qwen2Config, Qwen2Model
-from vllm.logger import init_logger
-
-logger = init_logger(__name__)
-try:
- import flash_attn # noqa: F401
-except (ImportError, ModuleNotFoundError):
- flash_attn = None
- logger.warning(
- "flash_attn is not available, the model may not yield the "
- "exactly same result as the transformers implementation "
- "in the audio tower part."
- )
-
-
-class AudioVAEConfig(PretrainedConfig):
- model_type = "audio_vae"
-
- def __init__(
- self,
- sample_rate: int = 44100,
- enc_kwargs: dict | None = None,
- dec_kwargs: dict | None = None,
- init_method: str = "kaiming",
- patch_size: int = 4,
- **kwargs,
- ):
- self.sample_rate = sample_rate
- self.enc_kwargs = enc_kwargs or {}
- self.dec_kwargs = dec_kwargs or {}
- self.init_method = init_method
- self.patch_size = patch_size
- super().__init__(**kwargs)
-
-
-class ISTFT(nn.Module):
- def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
- super().__init__()
- if padding not in ["center", "same"]:
- raise ValueError("Padding must be 'center' or 'same'.")
- self.padding = padding
- self.n_fft = n_fft
- self.hop_length = hop_length
- self.win_length = win_length
- window = torch.hann_window(win_length)
- self.register_buffer("window", window)
- self.buffer_len = self.win_length - self.hop_length
-
- def _buffer_process(self, x, buffer, pad, last_chunk=False, streaming=False):
- if streaming:
- if buffer is None:
- x = x[:, pad:]
- if buffer is not None:
- x[:, : self.buffer_len] += buffer
- buffer = x[:, -self.buffer_len :]
- if not last_chunk:
- x = x[:, : -self.buffer_len]
- else:
- x = x[:, :-pad]
- else:
- x = x[:, pad:-pad]
- return x, buffer
-
- def forward(self, spec, audio_buffer=None, window_buffer=None, streaming=False, last_chunk=False):
- if self.padding == "center":
- return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
- elif self.padding == "same":
- pad = (self.win_length - self.hop_length) // 2
- else:
- raise ValueError("Padding must be 'center' or 'same'.")
-
- B, N, T = spec.shape
- ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
- ifft = ifft * self.window[None, :, None]
-
- output_size = (T - 1) * self.hop_length + self.win_length
- y = torch.nn.functional.fold(
- ifft,
- output_size=(1, output_size),
- kernel_size=(1, self.win_length),
- stride=(1, self.hop_length),
- )[:, 0, 0, :]
-
- y, audio_buffer = self._buffer_process(y, audio_buffer, pad, last_chunk=last_chunk, streaming=streaming)
-
- window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
- window_envelope = (
- torch.nn.functional.fold(
- window_sq,
- output_size=(1, output_size),
- kernel_size=(1, self.win_length),
- stride=(1, self.hop_length),
- )
- .squeeze(0)
- .squeeze(0)
- )
-
- window_envelope, window_buffer = self._buffer_process(
- window_envelope, window_buffer, pad, last_chunk=last_chunk, streaming=streaming
- )
- window_envelope = window_envelope.squeeze()
-
- assert (window_envelope > 1e-11).all()
- y = y / window_envelope
-
- return y, audio_buffer, window_buffer
-
-
-class ISTFTHead(nn.Module):
- def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
- super().__init__()
- out_dim = n_fft + 2
- self.out = nn.Linear(dim, out_dim)
- self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
-
- def forward(self, x, audio_buffer=None, window_buffer=None, streaming=False, last_chunk=False):
- x_pred = self.out(x)
- x_pred = x_pred.transpose(1, 2)
- mag, p = x_pred.chunk(2, dim=1)
- mag = torch.exp(mag)
- mag = torch.clip(mag, max=1e2)
- x = torch.cos(p)
- y = torch.sin(p)
- S = mag * (x + 1j * y)
- audio, audio_buffer, window_buffer = self.istft(
- S, audio_buffer=audio_buffer, window_buffer=window_buffer, streaming=streaming, last_chunk=last_chunk
- )
- return audio.unsqueeze(1), x_pred, audio_buffer, window_buffer
-
-
-class StreamingLinearUpsample(nn.Module):
- def __init__(self, scale_factor=4):
- super().__init__()
- self.scale_factor = scale_factor
- self.upsampler = nn.Upsample(scale_factor=scale_factor, mode="linear", align_corners=False)
-
- def forward(self, x, state=None, is_last=False):
- if state is None:
- state = {"prev_chunk": None, "history_last": None, "is_first": True}
-
- if x is None and not is_last:
- return None, state
-
- if state["is_first"] and is_last:
- out = self.upsampler(x.transpose(1, 2)).transpose(1, 2)
- return out, None
-
- output_chunks = []
-
- if state["is_first"]:
- state["prev_chunk"] = x
- state["is_first"] = False
- if not is_last:
- return None, state
-
- if state["prev_chunk"] is not None:
- p = state["prev_chunk"].transpose(1, 2)
-
- if state["history_last"] is None:
- lookahead = x[:, :1, :].transpose(1, 2)
- inp = torch.cat([p, lookahead], dim=2)
- up = self.upsampler(inp)
- out_prev = up[:, :, : p.size(2) * self.scale_factor]
- else:
- lookahead = x[:, :1, :].transpose(1, 2)
- inp = torch.cat([state["history_last"], p, lookahead], dim=2)
- up = self.upsampler(inp)
- start = self.scale_factor
- end = start + p.size(2) * self.scale_factor
- out_prev = up[:, :, start:end]
-
- output_chunks.append(out_prev.transpose(1, 2))
- state["history_last"] = p[:, :, -1:]
- state["prev_chunk"] = x
-
- if is_last:
- p = state["prev_chunk"].transpose(1, 2)
- inp = torch.cat([state["history_last"], p], dim=2)
- up = self.upsampler(inp)
- out_last = up[:, :, self.scale_factor :]
- output_chunks.append(out_last.transpose(1, 2))
- state = None
-
- final_out = torch.cat(output_chunks, dim=1) if output_chunks else None
- return final_out, state
-
-
-class Decoder(nn.Module):
- def __init__(self, decoder_args, output_dim=320, latent_dim=64, patch_size=-1):
- super().__init__()
- config = Qwen2Config.from_dict(config_dict=decoder_args)
- if flash_attn is None:
- config._attn_implementation = "sdpa"
- self.decoder = Qwen2Model(config)
- self.output_dim = output_dim
- self.latent_dim = latent_dim
- self.fc1 = nn.Linear(latent_dim, config.hidden_size)
- self.hop_length = output_dim
- self.head = ISTFTHead(
- dim=config.hidden_size, n_fft=self.hop_length * 4, hop_length=self.hop_length, padding="same"
- )
- self.patch_size = patch_size
- if self.patch_size != -1:
- self.upsampling = StreamingLinearUpsample(scale_factor=patch_size)
-
- def low_level_reconstruct(self, x, past_key_values=None, use_cache=False, stream_state=None, last_chunk=False):
- upsample_state, audio_buffer, window_buffer = stream_state
- bsz, device, dtype = x.size(0), x.device, x.dtype
- x = self.fc1(x)
- if self.patch_size != -1:
- if use_cache:
- x, upsample_state = self.upsampling(x, state=upsample_state, is_last=last_chunk)
- if x is None:
- stream_state = (upsample_state, audio_buffer, window_buffer)
- return torch.empty(bsz, 1, 0, device=device, dtype=dtype), stream_state, past_key_values
- else:
- x = self.upsampling.upsampler(x.transpose(1, 2)).transpose(1, 2)
-
- hidden_states_list = []
-
- if use_cache and getattr(self.decoder.config, "sliding_window", None) is not None:
- sw_size = self.decoder.config.sliding_window
- target_len = sw_size - 1
- if past_key_values is None:
- past_len = 0
- elif hasattr(past_key_values, "get_seq_length"):
- past_len = past_key_values.get_seq_length()
- elif isinstance(past_key_values, tuple) and len(past_key_values) > 0:
- past_len = past_key_values[0][0].shape[-2]
- else:
- past_len = 0
-
- curr_len = x.shape[1]
-
- if past_len < target_len and (past_len + curr_len) >= sw_size:
- fill_len = target_len - past_len
- x_fill = x[:, :fill_len, :]
- outputs = self.decoder(inputs_embeds=x_fill, past_key_values=past_key_values, use_cache=use_cache)
- hidden_states_list.append(outputs.last_hidden_state)
- past_key_values = outputs.past_key_values
- x = x[:, fill_len:, :]
-
- outputs = self.decoder(inputs_embeds=x, past_key_values=past_key_values, use_cache=use_cache)
- hidden_states_list.append(outputs.last_hidden_state)
- past_key_values = outputs.past_key_values
-
- if len(hidden_states_list) > 1:
- full_hidden_state = torch.cat(hidden_states_list, dim=1)
- else:
- full_hidden_state = hidden_states_list[0]
-
- x_out, _, audio_buffer, window_buffer = self.head(
- full_hidden_state,
- streaming=use_cache,
- audio_buffer=audio_buffer,
- window_buffer=window_buffer,
- last_chunk=last_chunk,
- )
-
- stream_state = (upsample_state, audio_buffer, window_buffer)
- return x_out, stream_state, past_key_values
-
-
-class Encoder(nn.Module):
- def __init__(self, encoder_args, input_dim=320, hop_size=320, latent_dim=64, patch_size=-1):
- super().__init__()
- config = Qwen2Config.from_dict(config_dict=encoder_args)
- if flash_attn is None:
- config._attn_implementation = "sdpa"
- self.encoder = Qwen2Model(config)
- self.input_dim = input_dim
- self.hop_size = hop_size
- self.latent_dim = latent_dim
- self.fc1 = nn.Linear(input_dim, config.hidden_size, bias=False)
- self.fc2 = nn.Linear(config.hidden_size, config.hidden_size)
- self.fc3 = nn.Linear(config.hidden_size, latent_dim * 2)
- self.norm = nn.LayerNorm(config.hidden_size)
- self.patch_size = patch_size
- if patch_size != -1:
- config.num_hidden_layers = 4
- self.aggregator = Qwen2Model(config)
- self.cls_embed = nn.Parameter(torch.rand(1, 1, config.hidden_size))
- self.cls_embed.data.normal_(0, 0.02)
-
- def get_frames(self, x):
- num_frames_total = (x.size(-1) + self.hop_size - 1) // self.hop_size
- expected_len = (num_frames_total - 1) * self.hop_size + self.input_dim
- padding_needed = expected_len - x.size(-1)
- waveform = F.pad(x, (0, padding_needed), value=0.0)
- frames = waveform.unfold(dimension=-1, size=self.input_dim, step=self.hop_size)
- return frames
-
- def pad_patch_insert_cls(self, x):
- bsz, _, dim = x.size()
- num_frame = x.size(1)
- r = num_frame % self.patch_size
- pad_num = self.patch_size - r if r else 0
- x = F.pad(x, (0, 0, 0, pad_num), value=0.0)
- x = x.reshape(-1, self.patch_size, dim)
- x = torch.cat((x, self.cls_embed.expand(x.size(0), -1, -1)), dim=1)
- x = x.reshape(bsz, -1, dim)
- return x
-
- def forward(self, waveform):
- x = self.get_frames(waveform)
- x = self.fc1(x)
- x = self.fc2(x)
- x = self.encoder(inputs_embeds=x)
- x = x.last_hidden_state
-
- if self.patch_size != -1:
- x = self.pad_patch_insert_cls(x)
- x = self.aggregator(inputs_embeds=x)
- x = x.last_hidden_state
- bsz, _, dim = x.size()
- x = x.reshape(-1, self.patch_size + 1, dim)
- x = x[:, -1:, :].reshape(bsz, -1, dim)
-
- x = self.fc3(x)
- return x, waveform.unsqueeze(1)
-
-
-class AudioVAE(PreTrainedModel):
- config_class = AudioVAEConfig
-
- def __init__(self, config: AudioVAEConfig):
- super().__init__(config)
- self.encoder = Encoder(
- encoder_args=config.enc_kwargs["backbone"],
- input_dim=config.enc_kwargs["input_dim"],
- hop_size=config.enc_kwargs.get("hop_size", 320),
- latent_dim=config.enc_kwargs["latent_dim"],
- patch_size=config.patch_size,
- )
- self.decoder = Decoder(
- decoder_args=config.dec_kwargs["backbone"],
- output_dim=config.dec_kwargs["output_dim"],
- latent_dim=config.dec_kwargs["latent_dim"],
- patch_size=config.patch_size,
- )
- self.post_init()
-
- def _init_weights(self, module):
- std = 0.02
- if isinstance(module, nn.Linear):
- if self.config.init_method == "kaiming":
- nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
- else:
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
- def encode_latent(self, waveform, waveform_length):
- from diffusers.models.autoencoders.autoencoder_oobleck import OobleckDiagonalGaussianDistribution
-
- frame_num = torch.ceil(waveform_length / self.config.enc_kwargs["input_dim"]).to(torch.int32)
- if self.config.patch_size != -1:
- frame_num = torch.ceil(frame_num / self.config.patch_size)
- h, y = self.encoder(waveform)
- h = h.transpose(1, 2)
-
- posterior = OobleckDiagonalGaussianDistribution(h)
- latent = posterior.sample()
- latent = latent.transpose(1, 2)
- return latent, frame_num
-
- def decode(self, latent, past_key_values=None, use_cache=False, stream_state=(None, None, None), last_chunk=False):
- waveform, stream_state, past_key_values = self.decoder.low_level_reconstruct(
- latent,
- past_key_values=past_key_values,
- use_cache=use_cache,
- stream_state=stream_state,
- last_chunk=last_chunk,
- )
- return waveform, stream_state, past_key_values
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py
deleted file mode 100644
index 462c2043864..00000000000
--- a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py
+++ /dev/null
@@ -1,209 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-# Copyright 2024 ANT Group and the HuggingFace Inc. team. All rights reserved.
-# Adapted from Ming repository modeling_bailingmm2.py
-# https://github.com/inclusionAI/Ming
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Ming-flash-omni-2.0 thinker / image-gen wrapper.
-
-This class is the multimodal-registered entry point for Ming stages that
-share the thinker's backbone: comprehension / text generation (`thinker`)
-and diffusion conditioning for image generation (`imagegen`, not yet
-implemented).
-
-The talker deliberately lives elsewhere. Upstream Ming hands text (not
-hidden states) from the thinker to the talker, and the talker then
-tokenises that string with its own Qwen2 tokenizer and runs an entirely
-self-contained LLM + CFM + AudioVAE pipeline. Because it has no
-multimodal inputs, it belongs in the non-MM-registered
-`MingFlashOmniTalkerForConditionalGeneration` — routing it through
-this wrapper would force it through vLLM's multimodal preprocess path
-and trigger a hidden-size mismatch between the outer Ming config
-(4096, thinker's LLM) and the talker's Qwen2 backbone (896).
-"""
-
-from collections.abc import Iterable
-
-import torch
-import torch.nn as nn
-from vllm.config import VllmConfig
-from vllm.logger import init_logger
-from vllm.model_executor.models.interfaces import (
- SupportsMRoPE,
- SupportsMultiModal,
- SupportsPP,
-)
-from vllm.model_executor.models.module_mapping import MultiModelKeys
-from vllm.model_executor.models.utils import (
- init_vllm_registered_model,
- maybe_prefix,
-)
-from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.sequence import IntermediateTensors
-
-from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin
-from vllm_omni.model_executor.models.output_templates import OmniOutput
-from vllm_omni.model_executor.models.utils import add_prefix_to_loaded_weights
-from vllm_omni.transformers_utils.configs.ming_flash_omni import (
- BailingMM2Config,
- MingFlashOmniConfig,
-)
-
-from .ming_flash_omni_thinker import (
- MingFlashOmniThinkerDummyInputsBuilder,
- MingFlashOmniThinkerMultiModalProcessor,
- MingFlashOmniThinkerProcessingInfo,
-)
-
-logger = init_logger(__name__)
-
-
-@MULTIMODAL_REGISTRY.register_processor(
- MingFlashOmniThinkerMultiModalProcessor,
- info=MingFlashOmniThinkerProcessingInfo,
- dummy_inputs=MingFlashOmniThinkerDummyInputsBuilder,
-)
-class MingFlashOmniForConditionalGeneration(
- nn.Module,
- SupportsMultiModal,
- SupportsPP,
- SupportsMRoPE,
- CustomProcessMixin,
-):
- """Ming-flash-omni-2.0 thinker + image-gen wrapper."""
-
- supports_multimodal = True
- requires_raw_input_tokens: bool = True
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
- self.have_multimodal_outputs = True
- self.has_preprocess = False
- self.has_postprocess = False
-
- config = vllm_config.model_config.hf_config
- self.config = config
- self.model_stage = vllm_config.model_config.model_stage
-
- if self.model_stage == "talker":
- raise ValueError(
- "MingFlashOmniForConditionalGeneration does not support "
- "model_stage='talker'. Use "
- "model_arch='MingFlashOmniTalkerForConditionalGeneration' "
- "directly — the talker has a self-contained LLM that "
- "tokenises text itself and does not need the multimodal "
- "preprocess path. See stage_configs/ming_flash_omni.yaml "
- "stage 1 and stage_configs/ming_flash_omni_tts.yaml."
- )
-
- if self.model_stage == "thinker":
- if isinstance(config, MingFlashOmniConfig):
- thinker_config: BailingMM2Config = config.thinker_config
- else:
- thinker_config: BailingMM2Config = config
-
- thinker_vllm_config = vllm_config.with_hf_config(
- thinker_config, architectures=["MingFlashOmniThinkerForConditionalGeneration"]
- )
- self.thinker = init_vllm_registered_model(
- vllm_config=thinker_vllm_config,
- prefix=maybe_prefix(prefix, "thinker"),
- architectures=["MingFlashOmniThinkerForConditionalGeneration"],
- )
- self.model = self.thinker
- self.make_empty_intermediate_tensors = self.thinker.make_empty_intermediate_tensors
-
- elif self.model_stage == "imagegen":
- # TODO: Implement image generator stage
- raise NotImplementedError(
- "Image generation stage is not yet implemented. Please use model_stage='thinker' for now."
- )
-
- else:
- raise ValueError(
- f"Invalid model_stage: {self.model_stage!r}. Must be one of: 'thinker', 'imagegen'. "
- f"For the talker stage, use MingFlashOmniTalkerForConditionalGeneration directly."
- )
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **kwargs,
- ) -> OmniOutput:
- return self.model.forward(
- input_ids=input_ids,
- positions=positions,
- intermediate_tensors=intermediate_tensors,
- inputs_embeds=inputs_embeds,
- **kwargs,
- )
-
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- sampling_metadata=None,
- ) -> torch.Tensor | None:
- if hasattr(self.model, "compute_logits"):
- return self.model.compute_logits(hidden_states, sampling_metadata)
- return None
-
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata,
- ):
- if hasattr(self.model, "sample"):
- return self.model.sample(logits, sampling_metadata)
- return None
-
- def get_mrope_input_positions(self, *args, **kwargs):
- if hasattr(self.model, "get_mrope_input_positions"):
- return self.model.get_mrope_input_positions(*args, **kwargs)
- raise NotImplementedError("get_mrope_input_positions not available on current stage")
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- stripped = ((name.removeprefix("thinker."), value) for name, value in weights)
- thinker_loaded = self.thinker.load_weights(stripped)
- return add_prefix_to_loaded_weights(thinker_loaded, "thinker")
-
- def get_mm_mapping(self) -> MultiModelKeys:
- return MultiModelKeys.from_string_field(
- language_model="thinker.language_model",
- connector=["thinker.linear_proj.", "thinker.linear_proj_audio."],
- tower_model=["thinker.vision.", "thinker.audio."],
- )
-
- @property
- def sampler(self):
- return getattr(self.model, "sampler", None)
-
- def embed_input_ids(
- self,
- input_ids: torch.Tensor,
- multimodal_embeddings=None,
- *,
- is_multimodal=None,
- ) -> torch.Tensor:
- return self.model.embed_input_ids(
- input_ids,
- multimodal_embeddings,
- is_multimodal=is_multimodal,
- )
-
- def embed_multimodal(self, **kwargs):
- return self.model.embed_multimodal(**kwargs)
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_talker.py b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_talker.py
deleted file mode 100644
index 08ed9e85476..00000000000
--- a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_talker.py
+++ /dev/null
@@ -1,586 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-# Copyright (c) Ant Group. All rights reserved.
-# Adapted from:
-# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/modeling_bailing_talker.py
-"""Ming-flash-omni-2.0 talker (TTS) stage model."""
-
-from __future__ import annotations
-
-import glob as glob_module
-import os
-from collections.abc import Iterable
-from dataclasses import dataclass
-from functools import cached_property
-from typing import Any
-
-import torch
-import torch.nn as nn
-from safetensors.torch import load_file
-from transformers import AutoTokenizer, Qwen2Config, Qwen2Model
-from transformers.utils.hub import cached_file
-from vllm.config import VllmConfig
-from vllm.logger import init_logger
-from vllm.model_executor.models.utils import AutoWeightsLoader
-from vllm.sequence import IntermediateTensors
-
-from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin
-from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific
-from vllm_omni.model_executor.models.output_templates import OmniOutput
-from vllm_omni.transformers_utils.configs.ming_flash_omni import MingFlashOmniTalkerConfig
-
-from .audio_vae import AudioVAE, AudioVAEConfig
-from .prompt_utils import DEFAULT_PROMPT as MING_DEFAULT_PROMPT
-from .talker_module import CFM, Aggregator, DiT, MingAudioGenerator, build_tts_input
-from .text_processing import segment_and_normalize
-from .voice_presets import VoicePresetRegistry
-
-logger = init_logger(__name__)
-
-
-@dataclass(slots=True)
-class _GenerationParams:
- """Resolved sampling / decoding parameters for one forward call."""
-
- prompt: str
- instruction: str | None
- cfg: float
- sigma: float
- temperature: float
- max_steps: int
- use_zero_spk_emb: bool
- max_text_length: int
- use_static_cache: bool
- stream_decode: bool
-
-
-@dataclass(slots=True)
-class _VoiceContext:
- """Voice cloning inputs resolved from request info + presets."""
-
- spk_emb: Any # list[Tensor] | Tensor | list[float] | None
- prompt_text: str | None
- prompt_wav_lat: torch.Tensor | None
- prompt_wav_emb: torch.Tensor | None
- already_projected: bool
-
-
-class MingFlashOmniTalkerForConditionalGeneration(nn.Module, CustomProcessMixin):
- """Ming-flash-omni-2.0 talker stage: text -> audio waveform.
-
- Uses Qwen2 LLM + CFM (Conditional Flow Matching with DiT) + Aggregator
- in an autoregressive loop to produce continuous audio latents, then
- AudioVAE decodes latents to waveforms.
- """
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
- self.have_multimodal_outputs = True
- self.has_preprocess = False
- self.has_postprocess = False
-
- self.vllm_config = vllm_config
- root_config = vllm_config.model_config.hf_config
-
- model_path = vllm_config.model_config.model
- self._model_path = model_path
- self.talker_dir = (
- os.path.join(model_path, "talker") if os.path.isdir(os.path.join(model_path, "talker")) else model_path
- )
-
- # When used standalone (model_arch=MingFlashOmniTalkerForConditionalGeneration),
- # the root hf_config may be BailingMM2Config (thinker-only) due to model file structure
- # Resolve talker config from talker/config.json in that case.
- config = (
- root_config
- if isinstance(root_config, MingFlashOmniTalkerConfig)
- else self._resolve_talker_config(root_config, self.talker_dir, model_path)
- )
- self.config = config
-
- self._standalone = prefix in ("", "talker")
- if self._standalone:
- self.allow_patterns_overrides = ["talker/model*.safetensors"]
- self.fall_back_to_pt_during_load = False
-
- # LLM
- llm_config = self._resolve_llm_config(config, self.talker_dir, model_path)
- llm_config._attn_implementation = "sdpa"
- self.llm_config = llm_config
- self.hidden_size = llm_config.hidden_size
- self.latent_dim = config.latent_dim
- self.patch_size = config.patch_size
- self.his_patch_size = config.history_patch_size
- self.cfg_strength = config.cfg_strength
-
- self.model = Qwen2Model(llm_config)
- self.cfm = CFM(
- DiT(llm_input_dim=self.hidden_size, **config.flowmodel),
- steps=config.steps,
- )
- self.aggregator = Aggregator(llm_input_dim=self.hidden_size, **config.aggregator)
- self.stop_head = nn.Linear(self.hidden_size, 2, bias=True)
- # CAMPPlus 192-dim -> hidden
- self.spk_head = nn.Linear(192, self.hidden_size, bias=True)
-
- # AudioVAE
- self.audio_vae, self._vae_weight_source = self._init_audio_vae(config, self.talker_dir, model_path)
-
- self._use_cuda_graphs = not vllm_config.model_config.enforce_eager
-
- self.audio_generator = MingAudioGenerator(
- config=self.config,
- llm_config=self.llm_config,
- model=self.model,
- cfm=self.cfm,
- aggregator=self.aggregator,
- stop_head=self.stop_head,
- audio_vae=self.audio_vae,
- patch_size=self.patch_size,
- his_patch_size=self.his_patch_size,
- latent_dim=self.latent_dim,
- cfg_strength=self.cfg_strength,
- use_cuda_graphs=self._use_cuda_graphs,
- )
- self.voice_presets = VoicePresetRegistry(
- talker_dir=self.talker_dir,
- model_path=self._model_path,
- download_dir=vllm_config.load_config.download_dir,
- audio_vae=self.audio_vae,
- aggregator=self.aggregator,
- spk_head=self.spk_head,
- patch_size=self.patch_size,
- )
-
- @property
- def device(self) -> torch.device:
- return next(self.model.parameters()).device
-
- @property
- def dtype(self) -> torch.dtype:
- return next(self.model.parameters()).dtype
-
- @cached_property
- def tokenizer(self):
- # Lazy Qwen2 tokenizer resolution:
- # 1. Try local dirs first (talker/llm, talker, and then model root).
- # 2. HF repo-id fallback: talker/llm is the canonical tokenizer location.
- candidates = (os.path.join(self.talker_dir, "llm"), self.talker_dir, self._model_path)
- for path in candidates:
- if os.path.isdir(path):
- try:
- logger.debug("Resolving talker tokenizer from local dir %s", path)
- return AutoTokenizer.from_pretrained(path, trust_remote_code=True)
- except Exception:
- continue
- for subfolder in ("talker/llm", "llm"):
- try:
- logger.debug("Resolving talker tokenizer from HF subfolder %s", subfolder)
- return AutoTokenizer.from_pretrained(self._model_path, subfolder=subfolder, trust_remote_code=True)
- except Exception:
- continue
- logger.debug("Falling back to raw model_path tokenizer resolution")
- return AutoTokenizer.from_pretrained(self._model_path, trust_remote_code=True)
-
- @staticmethod
- def _resolve_talker_config(config, talker_dir: str, model_path: str) -> MingFlashOmniTalkerConfig:
- """Resolve MingFlashOmniTalkerConfig when the root config is not one.
-
- This happens in standalone TTS mode where hf_config is BailingMM2Config.
- """
- # If the root config wraps a talker_config, use it
- talker_config = getattr(config, "talker_config", None)
- if isinstance(talker_config, MingFlashOmniTalkerConfig):
- return talker_config
-
- # Try loading from talker/config.json
- if os.path.isdir(talker_dir):
- try:
- resolved = MingFlashOmniTalkerConfig.from_pretrained(talker_dir)
- logger.info("Resolved talker config from %s", talker_dir)
- return resolved
- except Exception:
- pass
-
- try:
- resolved = MingFlashOmniTalkerConfig.from_pretrained(model_path, subfolder="talker", trust_remote_code=True)
- logger.info("Resolved talker config from %s/talker (HF hub)", model_path)
- return resolved
- except Exception as e:
- raise ValueError(
- f"Cannot resolve MingFlashOmniTalkerConfig. The root config "
- f"is {type(config).__name__}, and talker/config.json was not "
- f"found at {talker_dir} or via HF hub: {e}"
- ) from e
-
- @staticmethod
- def _resolve_llm_config(config: MingFlashOmniTalkerConfig, talker_dir: str, model_path: str) -> Qwen2Config:
- """Resolve the Qwen2 LLM config for the talker backbone."""
-
- if config.llm_config is not None:
- return Qwen2Config(**config.llm_config) if isinstance(config.llm_config, dict) else config.llm_config
-
- # Try local talker/llm directory
- llm_dir = os.path.join(talker_dir, "llm")
- if os.path.isdir(llm_dir):
- return Qwen2Config.from_pretrained(llm_dir)
-
- # HF hub fallback
- for subfolder in ("talker/llm", "llm"):
- try:
- return Qwen2Config.from_pretrained(model_path, subfolder=subfolder, trust_remote_code=True)
- except Exception:
- continue
-
- raise ValueError(
- f"Cannot find talker LLM config at {llm_dir}. "
- "Either provide llm_config in MingFlashOmniTalkerConfig or "
- "ensure the model path contains talker/llm/config.json."
- )
-
- @staticmethod
- def _init_audio_vae(
- config: MingFlashOmniTalkerConfig, talker_dir: str, model_path: str
- ) -> tuple[AudioVAE | None, str | tuple[str, str] | None]:
- """Initialize AudioVAE and return (vae, weight_source).
-
- weight_source is either a local directory path (str) or an
- (repo_id, subfolder) tuple for HF hub downloads, or None.
- """
- vae_path = config.audio_vae_path or os.path.join(talker_dir, "vae")
-
- # Try local directory first
- if os.path.isdir(vae_path):
- try:
- vae_config = AudioVAEConfig.from_pretrained(vae_path)
- vae = AudioVAE(vae_config)
- logger.info("Initialized AudioVAE from %s (sr=%d)", vae_path, vae_config.sample_rate)
- return vae, vae_path
- except Exception as e:
- logger.warning("Failed to initialize AudioVAE from %s: %s", vae_path, e)
- return None, None
-
- # HF hub fallback
- for subfolder in ("talker/vae", "vae"):
- try:
- vae_config = AudioVAEConfig.from_pretrained(model_path, subfolder=subfolder, trust_remote_code=True)
- vae = AudioVAE(vae_config)
- logger.info(f"Initialized AudioVAE from {model_path}/{subfolder}")
- return vae, (model_path, subfolder)
- except Exception:
- continue
-
- logger.info("AudioVAE not found at %s; waveform decoding unavailable", vae_path)
- return None, None
-
- def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata=None) -> torch.Tensor | None:
- return None
-
- def sample(self, logits: torch.Tensor, sampling_metadata):
- return None
-
- def embed_input_ids(
- self,
- input_ids: torch.Tensor,
- multimodal_embeddings=None,
- is_multimodal=None,
- ) -> torch.Tensor:
- return self.model.get_input_embeddings()(input_ids)
-
- def make_empty_intermediate_tensors(
- self, batch_size: int, dtype: torch.dtype, device: torch.device
- ) -> IntermediateTensors | None:
- return None
-
- def get_dummy_runtime_additional_information(self, num_reqs: int) -> list[dict[str, object]]:
- info: dict[str, object] = {"text": "dummy", "use_zero_spk_emb": True, "max_steps": 1}
- return [info for _ in range(num_reqs)]
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- runtime_additional_information: list[dict] | None = None,
- **kwargs,
- ) -> OmniOutput:
- """Run TTS generation and return audio output.
-
- The full autoregressive generation loop is executed inside this method.
- """
- additional_info = self._extract_additional_info(runtime_additional_information)
- params = self._resolve_generation_params(additional_info)
- voice = self._resolve_voice(additional_info)
-
- latents = self._generate_latents(
- input_ids=input_ids,
- inputs_embeds=inputs_embeds,
- text=additional_info.get("text", ""),
- params=params,
- voice=voice,
- )
- return self._decode_to_output(latents, stream_decode=params.stream_decode)
-
- @staticmethod
- def _extract_additional_info(
- runtime_additional_information: list[dict] | None,
- ) -> dict[str, Any]:
- if runtime_additional_information and len(runtime_additional_information) > 0:
- return runtime_additional_information[0] or {}
- return {}
-
- def _resolve_generation_params(self, additional_info: dict[str, Any]) -> _GenerationParams:
- # "omni" : thinker -> talker hand-off with hardcoded defaults
- # "instruct": standalone TTS with caller-supplied sampling knobs
- ming_task = additional_info.get("ming_task", "instruct")
-
- if ming_task == "omni":
- prompt = MING_DEFAULT_PROMPT
- instruction = None
- use_zero_spk_emb = additional_info.get("spk_emb") is None
- cfg = 2.0
- sigma = 0.25
- temperature = 0.0
- max_steps = 200
- else:
- prompt = additional_info.get("prompt", MING_DEFAULT_PROMPT)
- instruction = additional_info.get("instruction", None)
- use_zero_spk_emb = additional_info.get("use_zero_spk_emb", False)
- cfg = additional_info.get("cfg", self.cfg_strength)
- sigma = additional_info.get("sigma", 0.25)
- temperature = additional_info.get("temperature", 0.0)
- max_steps = int(additional_info.get("max_steps", additional_info.get("max_decode_steps", 200)))
-
- return _GenerationParams(
- prompt=prompt,
- instruction=instruction,
- cfg=cfg,
- sigma=sigma,
- temperature=temperature,
- max_steps=max_steps,
- use_zero_spk_emb=use_zero_spk_emb,
- max_text_length=int(additional_info.get("max_text_length", 50)),
- use_static_cache=bool(additional_info.get("use_static_cache", True)),
- stream_decode=bool(additional_info.get("stream_decode", True)),
- )
-
- def _resolve_voice(self, additional_info: dict[str, Any]) -> _VoiceContext:
- spk_emb = additional_info.get("spk_emb", None)
- prompt_text = additional_info.get("prompt_text", None)
- prompt_wav_lat = additional_info.get("prompt_wav_lat", None)
- prompt_wav_emb = additional_info.get("prompt_wav_emb", None)
- already_projected = False
-
- voice_name = additional_info.get("voice_name", None)
- if voice_name and spk_emb is None and voice_name in self.voice_presets:
- preset = self.voice_presets.get(voice_name) or {}
- prompt_wav_lat = preset.get("prompt_wav_lat")
- prompt_wav_emb = preset.get("prompt_wav_emb")
- spk_emb = preset.get("spk_emb")
- already_projected = True
- if prompt_text is None:
- prompt_text = preset.get("prompt_text")
-
- return _VoiceContext(
- spk_emb=spk_emb,
- prompt_text=prompt_text,
- prompt_wav_lat=prompt_wav_lat,
- prompt_wav_emb=prompt_wav_emb,
- already_projected=already_projected,
- )
-
- def _project_spk_emb(
- self, spk_emb: Any, already_projected: bool, use_zero_spk_emb: bool
- ) -> list[torch.Tensor] | None:
- if spk_emb is None:
- if use_zero_spk_emb:
- return [torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)]
- return None
-
- if already_projected:
- return spk_emb if isinstance(spk_emb, list) else [spk_emb]
-
- if isinstance(spk_emb, torch.Tensor):
- tensors = [spk_emb]
- elif isinstance(spk_emb, list) and spk_emb and isinstance(spk_emb[0], (int, float)):
- tensors = [torch.tensor(spk_emb, dtype=self.dtype).unsqueeze(0)]
- elif isinstance(spk_emb, list):
- tensors = spk_emb
- else:
- tensors = [spk_emb]
- return [self.spk_head(t.to(device=self.device, dtype=self.dtype)) for t in tensors]
-
- def _generate_latents(
- self,
- *,
- input_ids: torch.Tensor,
- inputs_embeds: torch.Tensor | None,
- text: str,
- params: _GenerationParams,
- voice: _VoiceContext,
- ) -> list[torch.Tensor]:
- generator = self.audio_generator
-
- if inputs_embeds is not None:
- # Caller pre-built embeddings — run a single AR pass.
- return generator.generate_latents(
- inputs_embeds=inputs_embeds,
- prompt_wav_lat=voice.prompt_wav_lat,
- max_steps=params.max_steps,
- cfg=params.cfg,
- sigma=params.sigma,
- temperature=params.temperature,
- use_static_cache=params.use_static_cache,
- )
-
- spk_emb = self._project_spk_emb(voice.spk_emb, voice.already_projected, params.use_zero_spk_emb)
- text_segments = segment_and_normalize(text, max_length=params.max_text_length) if text else []
-
- if not text_segments:
- # vLLM passes 1D input_ids; Qwen2Model expects (batch, seq).
- inputs_embeds = self.model.get_input_embeddings()(input_ids.to(self.device)).unsqueeze(0)
- return generator.generate_latents(
- inputs_embeds=inputs_embeds,
- prompt_wav_lat=voice.prompt_wav_lat,
- max_steps=params.max_steps,
- cfg=params.cfg,
- sigma=params.sigma,
- temperature=params.temperature,
- use_static_cache=params.use_static_cache,
- )
-
- all_latents: list[torch.Tensor] = []
- for segment in text_segments:
- seg_embeds, _ = build_tts_input(
- tokenizer=self.tokenizer,
- embed_tokens=self.model.get_input_embeddings(),
- device=self.device,
- dtype=torch.bfloat16,
- text=segment,
- prompt=params.prompt,
- spk_emb=spk_emb,
- instruction=params.instruction,
- prompt_text=voice.prompt_text,
- prompt_wav_emb=voice.prompt_wav_emb,
- )
- effective_max_steps = generator.duration_capped_steps(len(segment), params.max_steps)
- all_latents.extend(
- generator.generate_latents(
- inputs_embeds=seg_embeds,
- prompt_wav_lat=voice.prompt_wav_lat,
- max_steps=effective_max_steps,
- cfg=params.cfg,
- sigma=params.sigma,
- temperature=params.temperature,
- use_static_cache=params.use_static_cache,
- )
- )
- return all_latents
-
- def _decode_to_output(self, latents: list[torch.Tensor], *, stream_decode: bool) -> OmniOutput:
- multimodal_outputs: dict[str, Any] = {}
- if latents and self.audio_vae is not None:
- waveform = self.audio_generator.decode_to_waveform(latents, stream_decode=stream_decode)
- if not stream_decode:
- waveform = self.audio_generator.trim_trailing_silence(waveform)
- multimodal_outputs["audio"] = waveform.detach().float().cpu()
- multimodal_outputs["sr"] = torch.tensor(self.audio_vae.config.sample_rate)
- elif latents:
- all_lat = torch.cat(latents, dim=1)
- multimodal_outputs["audio_latents"] = all_lat.detach().float().cpu()
-
- return OmniOutput(text_hidden_states=None, multimodal_outputs=multimodal_outputs)
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- """Load weights for all talker components.
-
- The talker's HF checkpoint (talker/model.safetensors) stores
- weights with prefixes matching this module's submodule names directly.
- And AudioVAE weights live in a separate file under talker/vae/
- """
- # Standalone: bypass the default loader's iterator (torch.load on
- # .safetensors crashes) and read talker/model*.safetensors directly.
- if self._standalone:
- weights = self._iter_talker_safetensors()
-
- loader = AutoWeightsLoader(
- self,
- skip_prefixes=["audio_vae."], # loaded separately
- skip_substrs=["rotary_embed.inv_freq"], # non-persistent buffer
- )
- loaded = loader.load_weights(weights)
- logger.info("Loaded %d talker weights from checkpoint", len(loaded))
-
- if self.audio_vae is not None and self._vae_weight_source is not None:
- loaded.update(self._load_vae_weights())
-
- # Register voice presets after all weights (incl. VAE) are loaded.
- try:
- self.voice_presets.load_presets_from_manifest(device=self.device, dtype=self.dtype)
- except Exception as e: # pragma: no cover — best-effort
- logger.warning("Voice preset loading failed (non-fatal): %s", e)
-
- return loaded
-
- def _iter_talker_safetensors(self) -> Iterable[tuple[str, torch.Tensor]]:
- """Yield (name, tensor) pairs from talker/model*.safetensors."""
- model_path = self._model_path
- # Try local path first
- for candidate in (os.path.join(model_path, "talker"), model_path):
- sf_files = sorted(glob_module.glob(os.path.join(candidate, "model*.safetensors")))
- if sf_files:
- for sf_path in sf_files:
- yield from load_file(sf_path, device="cpu").items()
- return
-
- # HF hub fallback: download only the talker checkpoint files
- model_root = download_weights_from_hf_specific(
- model_path,
- self.vllm_config.load_config.download_dir,
- allow_patterns=["talker/model*.safetensors"],
- )
- talker_dir = os.path.join(model_root, "talker")
- sf_files = sorted(glob_module.glob(os.path.join(talker_dir, "model*.safetensors")))
- if not sf_files:
- raise RuntimeError(f"No talker safetensors found under {model_root}. Expected talker/model*.safetensors.")
- for sf_path in sf_files:
- yield from load_file(sf_path, device="cpu").items()
-
- def _load_vae_weights(self) -> set[str]:
- """Load AudioVAE weights from talker/vae/model.safetensors."""
- if self.audio_vae is None or self._vae_weight_source is None:
- return set()
-
- # Resolve safetensors file paths from the weight source
- safetensors_files: list[str] = []
- source = self._vae_weight_source
- if isinstance(source, str):
- # Local directory path
- safetensors_files = sorted(glob_module.glob(os.path.join(source, "*.safetensors")))
- elif isinstance(source, tuple):
- # (repo_id, subfolder) for HF hub
- repo_id, subfolder = source
- for filename in ("model.safetensors", "diffusion_pytorch_model.safetensors"):
- try:
- cached = cached_file(repo_id, filename, subfolder=subfolder)
- except Exception:
- cached = None
- if cached is not None:
- safetensors_files.append(cached)
- break
-
- if not safetensors_files:
- logger.warning("No AudioVAE safetensors files found for source=%s", source)
- return set()
-
- vae_state_keys = set(self.audio_vae.state_dict().keys())
- vae_loader = AutoWeightsLoader(self.audio_vae)
- loaded: set[str] = set()
- for sf_path in safetensors_files:
- file_weights = load_file(sf_path, device="cpu")
- matched = ((name, tensor) for name, tensor in file_weights.items() if name in vae_state_keys)
- loaded.update(f"audio_vae.{name}" for name in vae_loader.load_weights(matched))
-
- logger.info("Loaded %d AudioVAE weights from %s", len(loaded), source)
- return loaded
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py
deleted file mode 100644
index bde7477b945..00000000000
--- a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py
+++ /dev/null
@@ -1,893 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-# Copyright 2024 ANT Group and the HuggingFace Inc. team.
-# Adapted from Ming repository modeling_bailingmm2.py and processing_bailingmm2.py
-# https://github.com/inclusionAI/Ming
-
-"""Ming-flash-omni-2.0 Thinker stage implementation (multimodal understanding)."""
-
-from collections.abc import Iterable, Iterator, Mapping, Sequence
-from typing import Annotated, Any
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from transformers.feature_extraction_utils import BatchFeature
-from vllm.config import VllmConfig
-from vllm.config.multimodal import BaseDummyOptions
-from vllm.inputs import MultiModalDataDict
-from vllm.logger import init_logger
-from vllm.model_executor.models.interfaces import (
- MultiModalEmbeddings,
- SupportsMRoPE,
- SupportsMultiModal,
- SupportsPP,
-)
-from vllm.model_executor.models.qwen2_5_vl import (
- Qwen2_5_VLImageInputs,
- Qwen2_5_VLImagePixelInputs,
- Qwen2_5_VLVideoInputs,
- Qwen2_5_VLVideoPixelInputs,
-)
-from vllm.model_executor.models.qwen2_vl import (
- Qwen2VLProcessingInfo,
-)
-from vllm.model_executor.models.utils import (
- AutoWeightsLoader,
- WeightsMapper,
- _merge_multimodal_embeddings,
- maybe_prefix,
-)
-from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.multimodal.inputs import (
- MultiModalFeatureSpec,
- MultiModalFieldConfig,
- MultiModalKwargsItems,
-)
-from vllm.multimodal.parse import (
- AudioProcessorItems,
- ImageProcessorItems,
- MultiModalDataItems,
- MultiModalDataParser,
- VideoProcessorItems,
-)
-from vllm.multimodal.processing import (
- BaseDummyInputsBuilder,
- BaseMultiModalProcessor,
- PromptReplacement,
- PromptUpdate,
- PromptUpdateDetails,
-)
-from vllm.sequence import IntermediateTensors
-from vllm.utils.tensor_schema import TensorSchema, TensorShape
-
-from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin
-from vllm_omni.model_executor.models.output_templates import OmniOutput
-from vllm_omni.transformers_utils.configs.ming_flash_omni import BailingMM2Config
-from vllm_omni.transformers_utils.processors.ming import (
- PLACEHOLDER_AUDIO_TOKEN_IN_TEXT,
- PLACEHOLDER_IMAGE_TOKEN_IN_TEXT,
- PLACEHOLDER_VIDEO_TOKEN_IN_TEXT,
- MingFlashOmniProcessor,
- MingWhisperFeatureExtractor,
-)
-
-from .audio_encoder import WhisperAudioEncoder
-from .modeling_bailing_moe_v2 import BailingMoeV2ForCausalLM
-from .projectors import AudioProjector, VisionProjector
-from .vision_encoder import MingVisionEncoder
-
-logger = init_logger(__name__)
-
-
-class MingAudioInput(TensorSchema):
- """
- Dimensions:
- - b: Batch size
- - l: Total audio frames (clips concatenated along the time axis)
- - nm: Number of mel bins
- - N: Max number of audio clips per batch item
- """
-
- audio_feats: Annotated[
- torch.Tensor,
- TensorShape("b", "l", "nm"),
- ]
-
- audio_feats_lengths: Annotated[
- torch.Tensor,
- TensorShape("b", "N"),
- ]
-
-
-class MingFlashOmniThinkerProcessingInfo(Qwen2VLProcessingInfo):
- def get_hf_config(self) -> BailingMM2Config:
- return self.ctx.get_hf_config(BailingMM2Config)
-
- def get_hf_processor(self, **kwargs: object):
- return self.ctx.get_hf_processor(MingFlashOmniProcessor, **kwargs)
-
- def get_target_channels(self) -> int:
- # See `_normalize_audio_tensor` in vllm_omni/transformers_utils/processors/ming.py
- return 1
-
- def get_supported_mm_limits(self) -> Mapping[str, int | None]:
- return {"image": None, "video": None, "audio": None}
-
- def get_mm_max_tokens_per_item(
- self,
- seq_len: int,
- mm_counts: Mapping[str, int],
- ) -> Mapping[str, int]:
- mm_counts = mm_counts or {}
- requested_modalities = {m for m, c in mm_counts.items() if c > 0}
- mm_max_tokens: dict[str, int] = {}
-
- if requested_modalities & {"image", "video"}:
- vl_tokens = super().get_mm_max_tokens_per_item(
- seq_len=seq_len,
- mm_counts=mm_counts,
- )
- mm_max_tokens.update({m: vl_tokens[m] for m in ["image", "video"] if m in requested_modalities})
-
- if "audio" in requested_modalities:
- # TODO: consider computing from audio config
- mm_max_tokens["audio"] = 3000
-
- return mm_max_tokens
-
- def get_feature_extractor(self, **kwargs: object) -> MingWhisperFeatureExtractor:
- hf_processor = self.get_hf_processor(**kwargs)
- feature_extractor = hf_processor.audio_processor
- assert isinstance(feature_extractor, MingWhisperFeatureExtractor)
- return feature_extractor
-
- def get_data_parser(self):
- feature_extractor = self.get_feature_extractor()
- return MultiModalDataParser(
- target_sr=feature_extractor.sampling_rate,
- target_channels=self.get_target_channels(),
- expected_hidden_size=self._get_expected_hidden_size(),
- )
-
-
-class MingFlashOmniThinkerDummyInputsBuilder(BaseDummyInputsBuilder[MingFlashOmniThinkerProcessingInfo]):
- def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
- num_images = mm_counts.get("image", 0)
- num_videos = mm_counts.get("video", 0)
- num_audios = mm_counts.get("audio", 0)
-
- hf_processor = self.info.get_hf_processor()
-
- audio_token: str = hf_processor.audio_token
- image_token: str = hf_processor.image_token
- video_token: str = hf_processor.video_token
-
- return image_token * num_images + video_token * num_videos + audio_token * num_audios
-
- def get_dummy_mm_data(
- self,
- seq_len: int,
- mm_counts: Mapping[str, int],
- mm_options: Mapping[str, BaseDummyOptions] | None = None,
- ) -> MultiModalDataDict:
- num_images = mm_counts.get("image", 0)
- num_videos = mm_counts.get("video", 0)
- num_audios = mm_counts.get("audio", 0)
-
- # Default dimensions for dummy data
- image_width, image_height = 448, 448
- video_width, video_height = 448, 448
- num_frames = 8
- audio_duration = 3.0 # seconds
- sample_rate = 16000
-
- audio_length = int(audio_duration * sample_rate)
-
- mm_data: MultiModalDataDict = {
- "image": self._get_dummy_images(
- width=image_width,
- height=image_height,
- num_images=num_images,
- ),
- "video": self._get_dummy_videos(
- width=video_width,
- height=video_height,
- num_frames=num_frames,
- num_videos=num_videos,
- ),
- "audio": [(np.random.randn(audio_length).astype(np.float32), sample_rate) for _ in range(num_audios)],
- }
-
- return mm_data
-
-
-class MingFlashOmniThinkerMultiModalProcessor(BaseMultiModalProcessor[MingFlashOmniThinkerProcessingInfo]):
- """Multimodal processor for Ming-flash-omni Thinker stage.
-
- Handles preprocessing of 1) image, 2) video, and 3) audio inputs,
- and expands placeholder tokens to the correct number of patch tokens.
- """
-
- def _get_prompt_updates(
- self,
- mm_items: MultiModalDataItems,
- hf_processor_mm_kwargs: Mapping[str, Any],
- out_mm_kwargs: MultiModalKwargsItems,
- ) -> Sequence[PromptUpdate]:
- tokenizer = self.info.get_tokenizer()
- # might want to add a fallback to resolve token ids
- # vocab = tokenizer.get_vocab()
- thinker_config = self.info.get_hf_config()
-
- # patch/delimiter token IDs (used in replacement sequences)
- image_start_token_id = thinker_config.llm_config.image_start_token
- image_patch_token_id = thinker_config.llm_config.image_patch_token
- image_end_token_id = thinker_config.llm_config.image_end_token
-
- video_start_token_id = thinker_config.llm_config.video_start_token
- frame_patch_token_id = thinker_config.llm_config.video_patch_token
- video_end_token_id = thinker_config.llm_config.video_end_token
-
- audio_start_token_id = thinker_config.llm_config.audio_start_token
- audio_patch_token_id = thinker_config.llm_config.audio_patch_token
- audio_end_token_id = thinker_config.llm_config.audio_end_token
-
- vision_config = thinker_config.vision_config
- spatial_merge_size = vision_config.spatial_merge_size if vision_config else 2
-
- newline_token_ids: list[int] = tokenizer.encode("\n", add_special_tokens=False)
-
- out_mm_data = out_mm_kwargs.get_data()
-
- def get_replacement_image(item_idx: int) -> PromptUpdateDetails:
- """Generate token sequence for an image."""
- grid_thw = out_mm_data.get("image_grid_thw")
- if grid_thw is None:
- raise ValueError(
- "image_grid_thw missing from processor output; "
- "cannot determine image patch count for prompt replacement."
- )
- if isinstance(grid_thw, torch.Tensor):
- thw = grid_thw[item_idx]
- num_patches = int(thw.prod().item()) // (spatial_merge_size**2)
- else:
- thw = grid_thw[item_idx]
- num_patches = (thw[0] * thw[1] * thw[2]) // (spatial_merge_size**2)
-
- # Build token sequence: *N \n
- # the newline token is added in purpose from original model processing
- tokens: list[int] = []
- tokens.append(image_start_token_id)
- tokens.extend([image_patch_token_id] * num_patches)
- tokens.append(image_end_token_id)
- # Refer to Ming's BailingMM2Processor._expand_image_tokens
- # https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/processing_bailingmm2.py
- tokens.extend(newline_token_ids)
-
- # Only tokens receive multimodal embeddings
- return PromptUpdateDetails.select_token_id(tokens, image_patch_token_id)
-
- def get_replacement_video(item_idx: int) -> PromptUpdateDetails:
- """Generate token sequence for a video."""
- grid_thw = out_mm_data.get("video_grid_thw", None)
- if grid_thw is None:
- raise ValueError(
- "video_grid_thw missing from processor output; "
- "cannot determine video patch count for prompt replacement."
- )
- if isinstance(grid_thw, torch.Tensor):
- thw = grid_thw[item_idx]
- num_patches = int(thw.prod().item()) // (spatial_merge_size**2)
- else:
- thw = grid_thw[item_idx]
- num_patches = (thw[0] * thw[1] * thw[2]) // (spatial_merge_size**2)
-
- # Build token sequence: *N \n
- # the newline token is added in purpose from original model processing
- tokens: list[int] = []
- tokens.append(video_start_token_id)
- tokens.extend([frame_patch_token_id] * num_patches)
- tokens.append(video_end_token_id)
- tokens.extend(newline_token_ids)
-
- # Only tokens receive multimodal embeddings
- return PromptUpdateDetails.select_token_id(tokens, frame_patch_token_id)
-
- def get_replacement_audio(item_idx: int) -> PromptUpdateDetails:
- """Generate token sequence for an audio."""
- encoder_feats_lengths = out_mm_data.get("encoder_feats_lengths", None)
- if encoder_feats_lengths is None:
- raise ValueError(
- "encoder_feats_lengths missing from processor output; "
- "cannot determine audio patch count for prompt replacement."
- )
- if isinstance(encoder_feats_lengths, torch.Tensor):
- num_patches = int(encoder_feats_lengths[item_idx].item())
- else:
- num_patches = encoder_feats_lengths[item_idx]
-
- # Build token sequence: *N
- tokens: list[int] = []
- tokens.append(audio_start_token_id)
- tokens.extend([audio_patch_token_id] * num_patches)
- tokens.append(audio_end_token_id)
-
- # Only tokens receive multimodal embeddings
- return PromptUpdateDetails.select_token_id(tokens, audio_patch_token_id)
-
- # Build prompt updates and process replacement
- updates: list[PromptUpdate] = []
-
- if "image" in mm_items and mm_items.get_items("image", ImageProcessorItems):
- updates.append(
- PromptReplacement(
- modality="image",
- target=PLACEHOLDER_IMAGE_TOKEN_IN_TEXT,
- replacement=get_replacement_image,
- )
- )
- if "video" in mm_items and mm_items.get_items("video", VideoProcessorItems):
- updates.append(
- PromptReplacement(
- modality="video",
- target=PLACEHOLDER_VIDEO_TOKEN_IN_TEXT,
- replacement=get_replacement_video,
- )
- )
- if "audio" in mm_items and mm_items.get_items("audio", AudioProcessorItems):
- updates.append(
- PromptReplacement(
- modality="audio",
- target=PLACEHOLDER_AUDIO_TOKEN_IN_TEXT,
- replacement=get_replacement_audio,
- )
- )
- return updates
-
- def _get_mm_fields_config(
- self,
- hf_inputs: BatchFeature,
- hf_processor_mm_kwargs: Mapping[str, object],
- ) -> Mapping[str, MultiModalFieldConfig]:
- config: dict[str, MultiModalFieldConfig] = {}
-
- # Image fields, pixel_values is flat (concatenated patches from all images)
- image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
- if "pixel_values" in hf_inputs:
- image_sizes = image_grid_thw.prod(-1)
- config["pixel_values"] = MultiModalFieldConfig.flat_from_sizes(
- "image",
- image_sizes,
- )
- if "image_grid_thw" in hf_inputs:
- config["image_grid_thw"] = MultiModalFieldConfig.batched("image")
-
- # Video fields, same flat layout as images
- video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
- if "pixel_values_videos" in hf_inputs:
- video_sizes = video_grid_thw.prod(-1)
- config["pixel_values_videos"] = MultiModalFieldConfig.flat_from_sizes(
- "video",
- video_sizes,
- )
- if "video_grid_thw" in hf_inputs:
- config["video_grid_thw"] = MultiModalFieldConfig.batched("video")
-
- # Audio fields
- if "audio_feats" in hf_inputs:
- config["audio_feats"] = MultiModalFieldConfig.batched("audio")
- if "audio_feats_lengths" in hf_inputs:
- config["audio_feats_lengths"] = MultiModalFieldConfig.batched("audio")
- if "encoder_feats_lengths" in hf_inputs:
- config["encoder_feats_lengths"] = MultiModalFieldConfig.batched("audio")
- if "placeholder_audio_loc_lens" in hf_inputs:
- config["placeholder_audio_loc_lens"] = MultiModalFieldConfig.batched("audio")
-
- return config
-
- def _hf_processor_applies_updates(
- self,
- prompt_text: str,
- mm_items: MultiModalDataItems,
- hf_processor_mm_kwargs: Mapping[str, object],
- tokenization_kwargs: Mapping[str, object],
- ) -> bool:
- return False
-
- def _call_hf_processor(
- self,
- prompt: str,
- mm_data: Mapping[str, object],
- mm_kwargs: Mapping[str, object],
- tok_kwargs: Mapping[str, object],
- ) -> BatchFeature:
- """Call sub-processors for multimodal inputs and tokenize.
-
- We call the image/audio sub-processors directly (instead of going
- through `MingFlashOmniProcessor.__call__`) so that the high-level
- placeholder tokens remain **unexpanded** in the tokenized output.
- """
- hf_processor = self.info.get_hf_processor()
- tokenizer = self.info.get_tokenizer()
-
- data: dict[str, object] = {}
-
- images = mm_data.get("images", None)
- if images is not None:
- image_outputs = hf_processor.image_processor(
- images=images,
- videos=None,
- return_tensors="pt",
- )
- data.update(image_outputs)
-
- videos = mm_data.get("videos", None)
- if videos is not None:
- video_outputs = hf_processor.image_processor(
- images=None,
- videos=videos,
- return_tensors="pt",
- )
- # Rename keys to distinguish from images
- if "pixel_values" in video_outputs:
- video_outputs["pixel_values_videos"] = video_outputs.pop("pixel_values")
- if "image_grid_thw" in video_outputs:
- video_outputs["video_grid_thw"] = video_outputs.pop("image_grid_thw")
- data.update(video_outputs)
-
- audios = mm_data.get("audios", None)
- if audios is not None:
- # vLLM's AudioProcessorItems provides raw numpy arrays (already resampled).
- # MingWhisperAudioProcessor expects (waveform, sr) tuples,
- # so wrap them with the target sample rate.
- target_sr = hf_processor.audio_processor.sampling_rate
- audio_tuples = [(a, target_sr) if not isinstance(a, tuple) else a for a in audios]
-
- audio_outputs = hf_processor.audio_processor(
- audio_tuples,
- return_tensors="pt",
- )
- data.update(audio_outputs)
-
- # Tokenize text with placeholders still intact
- text_outputs = tokenizer(prompt, return_tensors="pt", **tok_kwargs)
- data.update(text_outputs)
-
- return BatchFeature(data=data)
-
-
-@MULTIMODAL_REGISTRY.register_processor(
- MingFlashOmniThinkerMultiModalProcessor,
- info=MingFlashOmniThinkerProcessingInfo,
- dummy_inputs=MingFlashOmniThinkerDummyInputsBuilder,
-)
-class MingFlashOmniThinkerForConditionalGeneration(
- nn.Module,
- SupportsMultiModal,
- SupportsPP,
- SupportsMRoPE,
- CustomProcessMixin,
-):
- """Ming Thinker stage: multimodal understanding
- (text + image + video + audio) -> text generation.
- """
-
- hf_to_vllm_mapper = WeightsMapper(
- orig_to_new_prefix={"model.": "language_model."},
- )
-
- @classmethod
- def get_placeholder_str(cls, modality: str, i: int) -> str | None:
- # vllm_omni/transformers_utils/processors/ming.py
- if modality.startswith("image"):
- return ""
- elif modality.startswith("video"):
- return ""
- elif modality.startswith("audio"):
- return ""
-
- raise ValueError("Only image, video, or audio modality is supported")
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
-
- config = vllm_config.model_config.hf_config
-
- thinker_config: BailingMM2Config = config
- if (
- thinker_config.llm_config is None
- or thinker_config.vision_config is None
- or thinker_config.audio_config is None
- ):
- raise ValueError(
- "MingFlashOmniThinker requires `llm_config`, `vision_config`, and `audio_config` in `thinker_config`."
- )
-
- llm_config = thinker_config.llm_config
-
- self.config = llm_config
- self.thinker_config = thinker_config
- self.have_multimodal_outputs = True
-
- # Initialize LLM as a component
- with self._mark_language_model(vllm_config):
- llm_vllm_config = vllm_config.with_hf_config(llm_config)
- self.language_model = BailingMoeV2ForCausalLM(
- vllm_config=llm_vllm_config, prefix=maybe_prefix(prefix, "llm")
- )
-
- # Ming thinker is inherently multimodal; initialize both towers eagerly.
- with self._mark_tower_model(vllm_config, {"image", "video"}):
- self.vision = MingVisionEncoder(
- vision_config=thinker_config.vision_config,
- quant_config=vllm_config.quant_config,
- prefix=maybe_prefix(prefix, "vision"),
- )
- self.linear_proj = VisionProjector(
- vision_dim=self.vision.image_emb_dim,
- llm_dim=llm_config.hidden_size,
- mlp_depth=getattr(thinker_config, "mlp_depth", 2),
- )
- logger.info("Initialized MingVisionEncoder and VisionProjector")
-
- audio_cfg = thinker_config.audio_config
- whisper_cfg = getattr(audio_cfg, "whisper_encoder_config", {}) or {}
- with self._mark_tower_model(vllm_config, "audio"):
- self.audio = WhisperAudioEncoder(
- **whisper_cfg,
- use_flash_attn=True,
- )
- self.linear_proj_audio = AudioProjector(
- audio_dim=self.audio.audio_emb_dim,
- llm_dim=llm_config.hidden_size,
- ds_kernel_size=getattr(audio_cfg, "ds_kernel_size", 3),
- ds_stride=getattr(audio_cfg, "ds_stride", 2),
- mlp_depth=getattr(thinker_config, "mlp_depth", 1),
- )
- logger.info("Initialized WhisperAudioEncoder and AudioProjector")
-
- # Expose interfaces
- self.make_empty_intermediate_tensors = self.language_model.make_empty_intermediate_tensors
-
- logger.info("MingFlashOmniThinker initialized with vision and audio towers")
-
- def extract_image_feature(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
- """Extract and project image features.
-
- Args:
- pixel_values: Flattened pixel values from vision processor.
- grid_thw: [num_images, 3] tensor of (t, h, w) grid dimensions.
-
- Returns:
- [seq_len, hidden_size] L2-normalized image embeddings.
- """
- if self.vision is None:
- raise ValueError("Vision encoder not initialized")
-
- with torch.amp.autocast(pixel_values.device.type, dtype=torch.bfloat16):
- image_embeds = self.vision(pixel_values, grid_thw=grid_thw)
-
- if self.vision.use_deepstack:
- image_embeds = image_embeds[:, : self.vision.image_emb_dim]
-
- image_embeds = self.linear_proj(image_embeds)
- image_embeds = F.normalize(image_embeds, dim=-1)
- return image_embeds
-
- def extract_audio_feature(
- self, audio_feats: torch.Tensor, audio_feats_lengths: torch.Tensor
- ) -> tuple[torch.Tensor, ...]:
- """Extract and project audio features.
-
- Args:
- audio_feats: [B, L_total, n_mels] wrapped mel features — multiple audio
- clips per batch item are concatenated along the time dimension
- (time-first, as produced by MingWhisperFeatureExtractor).
- audio_feats_lengths: [B, N] lengths of each audio clip per batch item.
- N is the max number of clips per item; zero-padded entries are skipped.
-
- Returns:
- Tuple of per-clip [T'_i, hidden_size] projected audio embeddings.
- """
- if self.audio is None:
- raise ValueError("Audio encoder not initialized")
-
- # Unwrap packed [B, L_total, n_mels] into a list of [n_mels, T'_i] tensors,
- # one per audio clip, as expected by WhisperAudioEncoder.
- x_list: list[torch.Tensor] = []
- audio_lens: list[int] = []
- for i in range(audio_feats_lengths.shape[0]):
- feat_index = 0
- for j in range(audio_feats_lengths.shape[1]):
- feat_len = int(audio_feats_lengths[i, j].item())
- if feat_len == 0:
- break
- mel_seg = audio_feats[i, feat_index : feat_index + feat_len].transpose(0, 1)
- x_list.append(mel_seg)
- audio_lens.append(feat_len)
- feat_index += feat_len
-
- audio_packed = self.audio(x_list, audio_lens)
-
- # Compute per-clip lengths after Whisper Conv1d (kernel=3, stride=2, pad=1)
- encoded_lens = [(audio_len - 3 + 2) // 2 + 1 for audio_len in audio_lens]
-
- # Project packed
- proj_packed, proj_lens = self.linear_proj_audio.forward_packed(audio_packed, encoded_lens)
-
- normalize = getattr(self.thinker_config.audio_config, "norm_query_embeds", False)
- if normalize:
- proj_packed = F.normalize(proj_packed, dim=-1)
-
- proj_packed = proj_packed.to(audio_feats.dtype)
-
- # Split into per-clip tensors
- return proj_packed.split(proj_lens)
-
- def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
- """Parse and validate multimodal kwargs into per-modality dicts."""
- mm_input_by_modality: dict[str, Qwen2_5_VLImageInputs | Qwen2_5_VLVideoInputs | MingAudioInput] = {}
-
- for key in kwargs:
- if key == "pixel_values" and "image" not in mm_input_by_modality:
- pixel_values = kwargs.get("pixel_values")
- image_grid_thw = kwargs.get("image_grid_thw")
- if pixel_values is not None and image_grid_thw is not None:
- mm_input_by_modality["image"] = Qwen2_5_VLImagePixelInputs(
- type="pixel_values",
- pixel_values=pixel_values, # type: ignore[arg-type]
- image_grid_thw=image_grid_thw, # type: ignore[arg-type]
- )
- elif key == "pixel_values_videos" and "video" not in mm_input_by_modality:
- pixel_values_videos = kwargs.get("pixel_values_videos")
- video_grid_thw = kwargs.get("video_grid_thw")
- second_per_grid_ts = kwargs.get("second_per_grid_ts")
- if pixel_values_videos is not None and video_grid_thw is not None:
- mm_input_by_modality["video"] = Qwen2_5_VLVideoPixelInputs(
- type="pixel_values_videos",
- pixel_values_videos=pixel_values_videos, # type: ignore[arg-type]
- video_grid_thw=video_grid_thw, # type: ignore[arg-type]
- second_per_grid_ts=second_per_grid_ts, # type: ignore[arg-type]
- )
- elif key == "audio_feats" and "audio" not in mm_input_by_modality:
- audio_feats = kwargs.get("audio_feats")
- audio_feats_lengths = kwargs.get("audio_feats_lengths")
- if audio_feats is not None and audio_feats_lengths is not None:
- mm_input_by_modality["audio"] = MingAudioInput(
- audio_feats=audio_feats, # type: ignore[arg-type]
- audio_feats_lengths=audio_feats_lengths, # type: ignore[arg-type]
- )
-
- return mm_input_by_modality
-
- def _process_image_input(self, image_input: Qwen2_5_VLImageInputs) -> list[torch.Tensor]:
- # Splits the flat [total_tokens, D] output of extract_image_feature
- # into one tensor per image.
- pixel_values = image_input["pixel_values"]
- image_grid_thw = image_input["image_grid_thw"]
- image_embeds = self.extract_image_feature(pixel_values, image_grid_thw)
- merge_unit = self.thinker_config.vision_config.spatial_merge_size**2
- sizes = (image_grid_thw.prod(dim=-1) // merge_unit).tolist()
- return list(image_embeds.split([int(s) for s in sizes], dim=0))
-
- def _process_video_input(self, video_input: Qwen2_5_VLVideoInputs) -> list[torch.Tensor]:
- pixel_values_videos = video_input["pixel_values_videos"]
- video_grid_thw = video_input["video_grid_thw"]
- video_embeds = self.extract_image_feature(pixel_values_videos, video_grid_thw)
- merge_unit = self.thinker_config.vision_config.spatial_merge_size**2
- sizes = (video_grid_thw.prod(dim=-1) // merge_unit).tolist()
- return list(video_embeds.split([int(s) for s in sizes], dim=0))
-
- def _process_audio_input(self, audio_input: MingAudioInput) -> list[torch.Tensor]:
- return list(self.extract_audio_feature(audio_input["audio_feats"], audio_input["audio_feats_lengths"]))
-
- def _compute_modality_masks(self, input_ids: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]:
- """Compute vision and audio MoE-routing masks from input_ids.
-
- Returns:
- Tuple of (vision_mask, audio_mask), each shape [seq_len] bool.
- """
- llm_config = self.config
-
- # vision mask
- vision_mask = torch.zeros_like(input_ids, dtype=torch.bool)
- image_token = llm_config.image_patch_token
- video_token = llm_config.video_patch_token
- vision_mask = vision_mask | (input_ids == image_token)
- vision_mask = vision_mask | (input_ids == video_token)
-
- # audio mask
- audio_token = llm_config.audio_patch_token
- audio_mask = input_ids == audio_token
-
- return vision_mask, audio_mask
-
- def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
- mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
- if not mm_input_by_modality:
- return []
-
- # preserve the order of modalities
- multimodal_embeddings: tuple[torch.Tensor, ...] = ()
-
- for modality, mm_input in mm_input_by_modality.items():
- if modality == "image":
- multimodal_embeddings += tuple(self._process_image_input(mm_input)) # type: ignore[arg-type]
- elif modality == "video":
- multimodal_embeddings += tuple(self._process_video_input(mm_input)) # type: ignore[arg-type]
- elif modality == "audio":
- multimodal_embeddings += tuple(self._process_audio_input(mm_input)) # type: ignore[arg-type]
-
- return multimodal_embeddings
-
- def embed_input_ids(
- self,
- input_ids: torch.Tensor,
- multimodal_embeddings: MultiModalEmbeddings | None = None,
- *,
- is_multimodal: torch.Tensor | None = None,
- handle_oov_mm_token: bool = False,
- ) -> torch.Tensor:
- inputs_embeds = self.language_model.model.word_embeddings(input_ids)
-
- if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
- return inputs_embeds
-
- assert is_multimodal is not None, "`is_multimodal` mask required when `multimodal_embeddings` provided"
- return _merge_multimodal_embeddings(
- inputs_embeds=inputs_embeds,
- multimodal_embeddings=multimodal_embeddings,
- is_multimodal=is_multimodal,
- )
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **kwargs,
- ) -> OmniOutput:
- # Compute MoE modality masks on every device
- image_mask, audio_mask = self._compute_modality_masks(input_ids)
- hidden_states = self.language_model.forward(
- input_ids=input_ids,
- positions=positions,
- intermediate_tensors=intermediate_tensors,
- inputs_embeds=inputs_embeds,
- image_mask=image_mask,
- audio_mask=audio_mask,
- )
-
- # Capture embeddings for downstream stages
- multimodal_outputs = {
- "final_hidden_states": hidden_states,
- }
-
- return OmniOutput(
- text_hidden_states=hidden_states,
- multimodal_outputs=multimodal_outputs,
- )
-
- def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) -> torch.Tensor | None:
- return self.language_model.compute_logits(hidden_states, sampling_metadata)
-
- def sample(self, logits: torch.Tensor, sampling_metadata):
- return self.language_model.sample(logits, sampling_metadata)
-
- @property
- def sampler(self):
- return self.language_model.sampler
-
- def iter_mm_features(
- self,
- mm_features: list[MultiModalFeatureSpec],
- ) -> Iterator[tuple[int, str, dict[str, Any]]]:
- """Iterate over image/video features sorted by token position.
-
- Yields: (offset, modality, feature_data) where feature_data contains:
- - image: {"grid_t", "grid_h", "grid_w", "second_per_grid_t"}
- - video: {"grid_t", "grid_h", "grid_w", "second_per_grid_t"}
-
- Audio features are not yielded: Ming assigns them sequential
- text positions (same T/H/W value) rather than 3D grid positions.
- """
- spatial_merge_size = self.config.spatial_merge_size
-
- for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
- if mm_feature.data is None:
- continue
-
- offset = mm_feature.mm_position.offset
- modality = mm_feature.modality
-
- if modality == "image":
- t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
- yield (
- offset,
- "image",
- {
- "grid_t": int(t),
- "grid_h": int(h) // spatial_merge_size,
- "grid_w": int(w) // spatial_merge_size,
- "second_per_grid_t": 0.0,
- },
- )
- elif modality == "video":
- t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
- second_per_grid_t = 1.0
- spgt_field = mm_feature.data.get("second_per_grid_ts")
- if spgt_field is not None:
- second_per_grid_t = float(spgt_field.data.item())
- yield (
- offset,
- "video",
- {
- "grid_t": int(t),
- "grid_h": int(h) // spatial_merge_size,
- "grid_w": int(w) // spatial_merge_size,
- "second_per_grid_t": second_per_grid_t,
- },
- )
-
- def get_mrope_input_positions(
- self,
- input_tokens: list[int],
- mm_features: list[MultiModalFeatureSpec] | None = None,
- **kwargs: object,
- ) -> tuple[torch.Tensor, int]:
- """Compute M-RoPE input positions using mm_features directly."""
- llm_config = self.config
- tokens_per_second: int = getattr(llm_config, "tokens_per_second", 2)
- seq_len = len(input_tokens)
-
- llm_pos_ids_list: list[np.ndarray] = []
- st = 0 # index of next unprocessed token
-
- for patch_offset, _modality, data in self.iter_mm_features(mm_features or []):
- text_len = patch_offset - st
- st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
- if text_len > 0:
- llm_pos_ids_list.append(np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx)
- st_idx += text_len
-
- # 3-D grid positions for patch tokens
- grid_t: int = data["grid_t"]
- grid_h: int = data["grid_h"]
- grid_w: int = data["grid_w"]
- second_per_grid_t: float = data["second_per_grid_t"]
-
- t_raw = np.arange(grid_t)
- if second_per_grid_t > 0:
- t_index = (t_raw * second_per_grid_t * tokens_per_second).astype(np.int64)
- else:
- t_index = t_raw.astype(np.int64)
- t_index = np.repeat(t_index, grid_h * grid_w)
-
- h_index = np.tile(np.arange(grid_h).repeat(grid_w), grid_t)
- w_index = np.tile(np.arange(grid_w), grid_t * grid_h)
-
- llm_pos_ids_list.append(np.stack([t_index, h_index, w_index]) + st_idx)
-
- num_patches = grid_t * grid_h * grid_w
- st = patch_offset + num_patches
-
- if st < seq_len:
- st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
- tail_len = seq_len - st
- llm_pos_ids_list.append(np.broadcast_to(np.arange(tail_len), (3, tail_len)) + st_idx)
-
- if llm_pos_ids_list:
- position_ids = torch.from_numpy(np.concatenate(llm_pos_ids_list, axis=1).astype(np.int64)) # (3, seq_len)
- else:
- # text-only, simple sequential positions
- position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).expand(3, -1)
-
- mrope_position_delta = int(position_ids.max().item()) + 1 - seq_len
- return position_ids, mrope_position_delta
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- loader = AutoWeightsLoader(self)
- return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/modeling_bailing_moe_v2.py b/vllm_omni/model_executor/models/ming_flash_omni/modeling_bailing_moe_v2.py
deleted file mode 100644
index 1ff362c5b9d..00000000000
--- a/vllm_omni/model_executor/models/ming_flash_omni/modeling_bailing_moe_v2.py
+++ /dev/null
@@ -1,896 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved.
-# Adapted from Ming
-# https://github.com/inclusionAI/Ming/blob/2a0c02ae3130190160c215f89fce7de3005db483/modeling_bailing_moe_v2.py
-#
-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
-# and OPT implementations in this library. It has been modified from its
-# original forms to accommodate minor architectural differences compared
-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from collections.abc import Iterable
-
-import torch
-from torch import nn
-from vllm.compilation.decorators import support_torch_compile
-from vllm.config import VllmConfig
-from vllm.config.cache import CacheConfig
-from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
-from vllm.logger import init_logger
-from vllm.model_executor.layers.activation import SiluAndMul
-from vllm.model_executor.layers.attention import Attention
-from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE
-from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import (
- MergedColumnParallelLinear,
- QKVParallelLinear,
- ReplicatedLinear,
- RowParallelLinear,
-)
-from vllm.model_executor.layers.logits_processor import LogitsProcessor
-from vllm.model_executor.layers.quantization import QuantizationConfig
-from vllm.model_executor.layers.rotary_embedding import get_rope
-from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
-from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding
-from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-from vllm.model_executor.models.utils import (
- PPMissingLayer,
- WeightsMapper,
- make_empty_intermediate_tensors_factory,
- make_layers,
- maybe_prefix,
-)
-from vllm.sequence import IntermediateTensors
-from vllm.v1.outputs import SamplerOutput
-from vllm.v1.sample.sampler import Sampler
-
-from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin
-from vllm_omni.transformers_utils.configs.ming_flash_omni import BailingMoeV2Config
-
-logger = init_logger(__name__)
-
-
-class MingVideoRopeMRotaryEmbedding(MRotaryEmbedding):
- """MRotaryEmbedding with Ming's video_rope cos/sin interleaving.
-
- Unlike standard mrope which maps contiguous frequency sections to T/H/W,
- video_rope alternates H/W frequencies element-wise in the spatial section
- and places temporal frequencies at the end:
- Standard mrope: [T T T ... H H H ... W W W ...]
- Video rope: [H W H W ... H W ... T T T ...]
-
- Refer to Ming's BailingMoeV2RotaryEmbedding3D
- https://github.com/inclusionAI/Ming/blob/2a0c02ae3130190160c215f89fce7de3005db483/modeling_bailing_moe_v2.py#L174
- """
-
- def _remap_video_rope(
- self,
- cos: torch.Tensor,
- sin: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """Remap 3D cos/sin to video_rope interleaved layout.
-
- Args:
- cos, sin: [3, num_tokens, rotary_dim // 2]
- Returns:
- cos, sin: [num_tokens, rotary_dim // 2]
-
- Refer to Ming's apply_3d_rotary_pos_emb
- https://github.com/inclusionAI/Ming/blob/2a0c02ae3130190160c215f89fce7de3005db483/modeling_bailing_moe_v2.py#L226
- """
- assert self.mrope_section is not None
- hw_size = self.mrope_section[1] + self.mrope_section[2]
-
- result_cos = torch.empty_like(cos[0])
- result_sin = torch.empty_like(sin[0])
-
- # Spatial frequencies: even indices from H (dim 1), odd from W (dim 2)
- result_cos[:, 0:hw_size:2] = cos[1, :, 0:hw_size:2]
- result_cos[:, 1:hw_size:2] = cos[2, :, 1:hw_size:2]
- result_sin[:, 0:hw_size:2] = sin[1, :, 0:hw_size:2]
- result_sin[:, 1:hw_size:2] = sin[2, :, 1:hw_size:2]
-
- # Temporal frequencies at the end
- result_cos[:, hw_size:] = cos[0, :, hw_size:]
- result_sin[:, hw_size:] = sin[0, :, hw_size:]
-
- return result_cos, result_sin
-
- def forward_native(
- self,
- positions: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor | None = None,
- offsets: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- assert positions.ndim == 1 or positions.ndim == 2
- assert key is not None
-
- cos_sin_cache = self._match_cos_sin_cache_dtype(query)
- num_tokens = positions.shape[-1]
- cos_sin = cos_sin_cache[positions]
- cos, sin = cos_sin.chunk(2, dim=-1)
-
- if positions.ndim == 2:
- cos, sin = self._remap_video_rope(cos, sin)
-
- query_shape = query.shape
- query = query.view(num_tokens, -1, self.head_size)
- query_rot = query[..., : self.rotary_dim]
- query_pass = query[..., self.rotary_dim :]
- query_rot = self.apply_rotary_emb.forward_native(query_rot, cos, sin)
- query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
-
- key_shape = key.shape
- key = key.view(num_tokens, -1, self.head_size)
- key_rot = key[..., : self.rotary_dim]
- key_pass = key[..., self.rotary_dim :]
- key_rot = self.apply_rotary_emb.forward_native(key_rot, cos, sin)
- key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
-
- return query, key
-
- def forward_cuda(
- self,
- positions: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor | None = None,
- offsets: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- # No custom Triton kernel for video_rope; fall back to native for 3D
- # TODO: Consider custom optimization
- if positions.ndim == 2:
- return self.forward_native(positions, query, key, offsets)
- return super().forward_cuda(positions, query, key, offsets)
-
- def forward_cpu(
- self,
- positions: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor | None = None,
- offsets: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- return self.forward_native(positions, query, key, offsets)
-
-
-class BailingMoeV2MLP(nn.Module):
- def __init__(
- self,
- config: BailingMoeV2Config,
- intermediate_size: int,
- hidden_act: str = "silu",
- quant_config: QuantizationConfig | None = None,
- reduce_results: bool = True,
- prefix: str = "",
- ):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = intermediate_size
-
- self.gate_up_proj = MergedColumnParallelLinear(
- self.hidden_size,
- [self.intermediate_size] * 2,
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.gate_up_proj",
- )
- self.down_proj = RowParallelLinear(
- self.intermediate_size,
- self.hidden_size,
- bias=False,
- quant_config=quant_config,
- reduce_results=reduce_results,
- prefix=f"{prefix}.down_proj",
- )
-
- if hidden_act != "silu":
- raise ValueError(f"Unsupported activation: {hidden_act}")
- self.act_fn = SiluAndMul()
-
- def forward(self, x):
- gate_up, _ = self.gate_up_proj(x)
- x = self.act_fn(gate_up)
- x, _ = self.down_proj(x)
- return x
-
-
-class BailingMoeV2Gate(nn.Module):
- """MoE routing gate with grouped expert selection."""
-
- def __init__(
- self,
- config: BailingMoeV2Config,
- quant_config: QuantizationConfig | None = None,
- prefix: str = "",
- ):
- super().__init__()
- self.config = config
- self.top_k = config.num_experts_per_tok
- self.num_experts = config.num_experts
-
- self.n_group = config.n_group
- self.topk_group = config.topk_group
-
- self.gating_dim = config.hidden_size
-
- self.gate = ReplicatedLinear(
- self.gating_dim,
- self.num_experts,
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.gate",
- )
-
- self.routed_scaling_factor = config.routed_scaling_factor
-
- self.expert_bias = nn.Parameter(torch.zeros(self.num_experts), requires_grad=False)
-
- def group_limited_topk(self, scores: torch.Tensor):
- """Group-limited top-k selection for expert routing."""
- num_tokens, _ = scores.size()
- # Organize experts into groups
- group_scores = scores.view(num_tokens, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
- group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
- group_mask = torch.zeros_like(group_scores)
- group_mask.scatter_(1, group_idx, 1)
-
- # Mask experts based on selected groups
- score_mask = (
- group_mask.unsqueeze(-1)
- .expand(num_tokens, self.n_group, self.num_experts // self.n_group)
- .reshape(num_tokens, -1)
- )
-
- masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf"))
- probs, top_indices = torch.topk(masked_scores, k=self.top_k, dim=-1, sorted=False)
-
- return probs, top_indices
-
- def forward(self, hidden_states):
- # compute gating score
- hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
- logits, _ = self.gate(hidden_states)
-
- logits = logits.float()
- scores = torch.sigmoid(logits)
-
- scores_for_routing = scores + self.expert_bias
- _, topk_idx = self.group_limited_topk(scores_for_routing)
-
- scores = torch.gather(scores, dim=1, index=topk_idx).type_as(logits)
-
- topk_weight = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.top_k > 1 else scores
- topk_weight = topk_weight * self.routed_scaling_factor
-
- return topk_idx, topk_weight, logits
-
-
-def _unpack_multi_routing(
- hidden_states: torch.Tensor,
- gating_output: torch.Tensor,
- topk: int,
- renormalize: bool,
-) -> tuple[torch.Tensor, torch.Tensor]:
- """Stateless routing function that unpacks pre-computed routing results.
-
- Used as `custom_routing_function` for `FusedMoE`. The caller is expected
- to pack (topk_weight, topk_idx) into `gating_output` before
- calling FusedMoE.forward(), and this function unpacks them.
-
- Args:
- gating_output: [num_tokens, top_k * 2]
- - [:, :top_k], topk_weight (float)
- - [:, top_k:], topk_idx (float, cast back to int)
- """
- topk_weight = gating_output[:, :topk].contiguous()
- topk_idx = gating_output[:, topk:]
- return topk_weight.to(torch.float32), topk_idx.to(torch.int32)
-
-
-class BailingMoeV2SparseMoeBlock(nn.Module):
- """Sparse MoE block with MultiRouter support for multimodal routing.
-
- Keep the custom multi-router gating logic external.
- """
-
- def __init__(
- self,
- config: BailingMoeV2Config,
- quant_config: QuantizationConfig | None = None,
- prefix: str = "",
- ):
- super().__init__()
- self.config = config
- self.tp_size = get_tensor_model_parallel_world_size()
- self.num_experts_per_tok = config.num_experts_per_tok
-
- if isinstance(self.config.num_shared_experts, int) and self.config.num_shared_experts > 0:
- self.shared_experts = BailingMoeV2MLP(
- config=self.config,
- intermediate_size=self.config.moe_intermediate_size * self.config.num_shared_experts,
- quant_config=quant_config,
- reduce_results=False,
- prefix=f"{prefix}.shared_experts",
- )
- else:
- self.shared_experts = None
-
- self.experts = SharedFusedMoE(
- shared_experts=self.shared_experts,
- num_experts=config.num_experts,
- top_k=config.num_experts_per_tok,
- hidden_size=config.hidden_size,
- intermediate_size=config.moe_intermediate_size,
- custom_routing_function=_unpack_multi_routing,
- renormalize=False, # we handle normalization in the gate
- reduce_results=True,
- quant_config=quant_config,
- prefix=f"{prefix}.experts",
- )
-
- self.experts.expert_mapping = FusedMoE.make_expert_params_mapping(
- self.experts,
- ckpt_gate_proj_name="gate_proj",
- ckpt_down_proj_name="down_proj",
- ckpt_up_proj_name="up_proj",
- num_experts=config.num_experts,
- )
-
- self.router_type = self.config.router_type
- if self.router_type == "topN":
- self.gate = BailingMoeV2Gate(self.config, quant_config, prefix=f"{prefix}.gate")
- elif self.router_type == "MultiRouter":
- self.gate = BailingMoeV2Gate(self.config, quant_config, prefix=f"{prefix}.gate")
- self.image_gate = BailingMoeV2Gate(self.config, quant_config, prefix=f"{prefix}.image_gate")
- self.audio_gate = BailingMoeV2Gate(self.config, quant_config, prefix=f"{prefix}.audio_gate")
- else:
- raise ValueError(f"Unsupported router_type: {self.router_type}")
-
- @staticmethod
- def _normalize_mask(
- mask: torch.Tensor,
- bsz: int,
- seq_len: int,
- name: str,
- ) -> torch.Tensor:
- """Validate and reshape a modality mask to [bsz*seq_len, 1] bool."""
- N = bsz * seq_len
- if mask.ndim == 1:
- # vLLM path: flat tokens [N]
- assert mask.shape[0] == N, f"{name} length {mask.shape[0]} != N={N}"
- elif mask.ndim == 2:
- assert mask.shape == (bsz, seq_len), f"{name} shape {mask.shape} != ({bsz}, {seq_len})"
- elif mask.ndim == 3:
- assert mask.shape == (bsz, seq_len, 1), f"{name} shape {mask.shape} != ({bsz}, {seq_len}, 1)"
- else:
- raise ValueError(f"Unsupported {name} shape: {mask.shape}")
-
- return mask.reshape(N, 1).bool()
-
- def forward(self, hidden_states, image_mask: torch.Tensor, audio_mask: torch.Tensor):
- # TODO(yuanheng-zhao): revise the shapes in the flow
- assert 2 <= hidden_states.dim() <= 3, f"{self.__class__.__name__} only supports 2D or 3D inputs"
- input_is_2d = hidden_states.ndim == 2
- if input_is_2d:
- hidden_states = hidden_states.unsqueeze(0)
-
- bsz, seq_len, h = hidden_states.shape
-
- if self.router_type == "MultiRouter":
- image_mask = self._normalize_mask(image_mask, bsz, seq_len, "image_mask").to(hidden_states.device)
- audio_mask = self._normalize_mask(audio_mask, bsz, seq_len, "audio_mask").to(hidden_states.device)
-
- # if image_mask is not None and audio_mask is not None:
- # assert torch.logical_and(image_mask, audio_mask).sum() == 0
-
- image_topk_idx, image_topk_weight, _ = self.image_gate(hidden_states)
- audio_topk_idx, audio_topk_weight, _ = self.audio_gate(hidden_states)
- topk_idx, topk_weight, _ = self.gate(hidden_states)
-
- topk_idx = torch.where(image_mask, image_topk_idx, topk_idx)
- topk_weight = torch.where(image_mask, image_topk_weight, topk_weight)
- topk_idx = torch.where(audio_mask, audio_topk_idx, topk_idx)
- topk_weight = torch.where(audio_mask, audio_topk_weight, topk_weight)
- else:
- topk_idx, topk_weight, _ = self.gate(hidden_states)
-
- # Pack pre-computed routing into a single tensor
- packed_routing = torch.cat(
- [
- topk_weight.to(hidden_states.dtype),
- topk_idx.to(hidden_states.dtype),
- ],
- dim=-1,
- )
-
- # SharedFusedMoE expects 2D hidden_states
- hidden_states_2d = hidden_states.view(-1, h)
- result = self.experts(hidden_states_2d, packed_routing)
-
- if self.shared_experts is not None:
- shared_output, fused_out = result
- else:
- shared_output, fused_out = None, result
-
- final_hidden_states = fused_out + shared_output if shared_output is not None else fused_out
-
- final_hidden_states = final_hidden_states.view(bsz, seq_len, h)
-
- return final_hidden_states.squeeze(0) if input_is_2d else final_hidden_states
-
-
-class BailingMoeV2Attention(nn.Module):
- """Multi-headed attention using vLLM's Attention layer with 3D RoPE support."""
-
- def __init__(
- self,
- config: BailingMoeV2Config,
- layer_idx: int,
- cache_config: CacheConfig | None = None,
- quant_config: QuantizationConfig | None = None,
- prefix: str = "",
- ):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
-
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.num_kv_heads = config.num_key_value_heads
- self.head_dim = config.head_dim
-
- tp_size = get_tensor_model_parallel_world_size()
- assert self.num_heads % tp_size == 0
- self.num_heads = self.num_heads // tp_size
- self.num_kv_heads = max(1, self.num_kv_heads // tp_size)
-
- self.q_size = self.num_heads * self.head_dim
- self.kv_size = self.num_kv_heads * self.head_dim
- self.scaling = self.head_dim**-0.5
-
- partial_rotary_factor = config.partial_rotary_factor
- self.rope_dim = int(self.head_dim * partial_rotary_factor)
-
- total_num_heads = config.num_attention_heads
- total_num_kv_heads = config.num_key_value_heads
- self.qkv_proj = QKVParallelLinear(
- self.hidden_size,
- self.head_dim,
- total_num_heads,
- total_num_kv_heads,
- bias=config.use_qkv_bias,
- quant_config=quant_config,
- prefix=f"{prefix}.qkv_proj",
- )
-
- self.dense = RowParallelLinear(
- total_num_heads * self.head_dim,
- self.hidden_size,
- bias=config.use_bias,
- quant_config=quant_config,
- prefix=f"{prefix}.dense",
- )
-
- # apply vLLM RMSNorm here rather than BailingMoeV2RMSNorm, diff might exist
- self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
- self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
-
- # 3D Rotary embeddings for multimodal
- if config.rope_scaling is None:
- raise ValueError("rope_scaling must not be None")
-
- rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- mrope_section = config.rope_scaling.get("mrope_section", [8, 12, 12])
-
- if rope_type == "video_rope":
- # Ming-specific video_rope with custom H/W interleaving
- self.rotary_emb = MingVideoRopeMRotaryEmbedding(
- head_size=self.head_dim,
- rotary_dim=self.rope_dim,
- max_position_embeddings=config.max_position_embeddings,
- base=config.rope_theta,
- is_neox_style=True,
- dtype=torch.get_default_dtype(),
- mrope_section=mrope_section,
- )
- else:
- # Standard m_rope (rope_type "3D", "default", or None)
- rope_scaling = dict(config.rope_scaling)
- rope_scaling["rope_type"] = "default" # normalize for get_rope dispatch
- rope_scaling["mrope_section"] = mrope_section
- self.rotary_emb = get_rope(
- head_size=self.head_dim,
- max_position=config.max_position_embeddings,
- is_neox_style=True,
- rope_parameters={
- "rope_theta": config.rope_theta,
- "partial_rotary_factor": config.partial_rotary_factor,
- **rope_scaling,
- },
- )
-
- self.attn = Attention(
- self.num_heads,
- self.head_dim,
- self.scaling,
- num_kv_heads=self.num_kv_heads,
- cache_config=cache_config,
- quant_config=quant_config,
- prefix=f"{prefix}.attn",
- )
-
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- ) -> torch.Tensor:
- """Forward pass for attention with 3D RoPE.
-
- Args:
- positions: Position IDs, shape (3, num_tokens) for 3D rope
- or (num_tokens,) for text-only
- hidden_states: Input hidden states, shape (num_tokens, hidden_size)
-
- Returns:
- Attention output tensor, shape (num_tokens, hidden_size)
- """
- qkv, _ = self.qkv_proj(hidden_states)
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
-
- num_tokens = q.shape[0]
- q = self.q_norm(q.view(num_tokens, self.num_heads, self.head_dim)).view(num_tokens, self.q_size)
- k = self.k_norm(k.view(num_tokens, self.num_kv_heads, self.head_dim)).view(num_tokens, self.kv_size)
-
- q, k = self.rotary_emb(positions, q, k)
-
- attn_output = self.attn(q, k, v)
-
- output, _ = self.dense(attn_output)
- return output
-
-
-class BailingMoeV2DecoderLayer(nn.Module):
- """Decoder layer with attention and MoE MLP."""
-
- def __init__(
- self,
- config: BailingMoeV2Config,
- layer_idx: int,
- cache_config: CacheConfig | None = None,
- quant_config: QuantizationConfig | None = None,
- prefix: str = "",
- ):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.layer_idx = layer_idx
-
- self.attention = BailingMoeV2Attention(
- config=config,
- layer_idx=layer_idx,
- cache_config=cache_config,
- quant_config=quant_config,
- prefix=f"{prefix}.attention",
- )
-
- # MLP or MoE based on layer index
- if config.num_experts is not None and layer_idx >= config.first_k_dense_replace:
- self.mlp = BailingMoeV2SparseMoeBlock(
- config=config,
- quant_config=quant_config,
- prefix=f"{prefix}.mlp",
- )
- self.is_moe = True
- else:
- self.mlp = BailingMoeV2MLP(
- config=config,
- intermediate_size=config.intermediate_size,
- quant_config=quant_config,
- prefix=f"{prefix}.mlp",
- )
- self.is_moe = False
-
- # apply vLLM RMSNorm to replace BailingMoeV2RMSNorm, diff might exist
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
-
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- residual: torch.Tensor | None,
- image_mask: torch.Tensor | None = None,
- audio_mask: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """Forward pass for decoder layer.
-
- Args:
- positions: Position IDs
- hidden_states: Input hidden states
- residual: Residual connection from previous layer
- image_mask: Mask for image tokens (for MultiRouter MoE)
- audio_mask: Mask for audio tokens (for MultiRouter MoE)
-
- Returns:
- Tuple of (hidden_states, residual)
- """
- if residual is None:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- else:
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
-
- hidden_states = self.attention(
- positions=positions,
- hidden_states=hidden_states,
- )
-
- hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
-
- if self.is_moe:
- hidden_states = self.mlp(hidden_states, image_mask, audio_mask)
- else:
- # Dense MLP only takes hidden_states (no routing masks)
- hidden_states = self.mlp(hidden_states)
-
- return hidden_states, residual
-
-
-@support_torch_compile(
- dynamic_arg_dims={
- "input_ids": 0,
- "positions": -1,
- "intermediate_tensors": 0,
- "inputs_embeds": 0,
- "image_mask": 0,
- "audio_mask": 0,
- }
-)
-class BailingMoeV2Model(nn.Module):
- """BailingMoeV2 Model adapted from:
-
- Ming repo BailingMoeV2Model
- https://github.com/inclusionAI/Ming/blob/2a0c02ae3130190160c215f89fce7de3005db483/modeling_bailing_moe_v2.py
- vLLM repo BailingMoeModel
- https://github.com/vllm-project/vllm/blob/7291d1b288558d48508e1a17c37b0aa170332264/vllm/model_executor/models/bailing_moe.py
- """
-
- def __init__(
- self,
- *,
- vllm_config: VllmConfig,
- prefix: str = "",
- ):
- super().__init__()
-
- # BailingMoeV2Config
- config = vllm_config.model_config.hf_text_config
-
- cache_config = vllm_config.cache_config
- quant_config = vllm_config.quant_config
-
- self.config = config
- self.quant_config = quant_config
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
-
- if get_pp_group().is_first_rank or (self.tie_word_embeddings and get_pp_group().is_last_rank):
- self.word_embeddings = VocabParallelEmbedding(
- config.vocab_size,
- config.hidden_size,
- quant_config=quant_config,
- prefix=f"{prefix}.word_embeddings",
- )
- else:
- self.word_embeddings = PPMissingLayer()
-
- # Decoder layers with later pipeline parallelism support
- self.start_layer, self.end_layer, self.layers = make_layers(
- config.num_hidden_layers,
- lambda prefix: BailingMoeV2DecoderLayer(
- config=config,
- layer_idx=int(prefix.split(".")[-1]),
- cache_config=cache_config,
- quant_config=quant_config,
- prefix=prefix,
- ),
- prefix=f"{prefix}.layers",
- )
-
- if get_pp_group().is_last_rank:
- # apply vLLM RMSNorm to replace BailingMoeV2RMSNorm, diff might exist
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- else:
- self.norm = PPMissingLayer()
-
- self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
- ["hidden_states", "residual"], config.hidden_size
- )
-
- def get_input_embeddings(self):
- return self.word_embeddings
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- image_mask: torch.Tensor | None = None,
- audio_mask: torch.Tensor | None = None,
- ) -> torch.Tensor | IntermediateTensors:
- if get_pp_group().is_first_rank:
- if inputs_embeds is not None:
- hidden_states = inputs_embeds
- else:
- hidden_states = self.word_embeddings(input_ids)
- residual = None
- else:
- assert intermediate_tensors is not None
- hidden_states = intermediate_tensors["hidden_states"]
- residual = intermediate_tensors["residual"]
-
- for layer in self.layers[self.start_layer : self.end_layer]:
- hidden_states, residual = layer(
- positions,
- hidden_states,
- residual,
- image_mask=image_mask,
- audio_mask=audio_mask,
- )
-
- if not get_pp_group().is_last_rank:
- return IntermediateTensors({"hidden_states": hidden_states, "residual": residual})
-
- hidden_states, _ = self.norm(hidden_states, residual)
- return hidden_states
-
-
-class BailingMoeV2ForCausalLM(nn.Module, CustomProcessMixin):
- """BailingMoeV2 model for causal language modeling, adapted for vLLM.
-
- Inherits from CustomProcessMixin to support custom preprocessing and postprocessing
- for integration with omni model pipelines.
- """
-
- packed_modules_mapping = {
- "qkv_proj": ["q_proj", "k_proj", "v_proj"],
- "gate_up_proj": ["gate_proj", "up_proj"],
- }
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
-
- # BailingMoeV2Config
- config = vllm_config.model_config.hf_text_config
- quant_config = vllm_config.quant_config
-
- self.config = config
- self.quant_config = quant_config
-
- self.model = BailingMoeV2Model(
- vllm_config=vllm_config,
- prefix=maybe_prefix(prefix, "model"),
- )
-
- self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
- if get_pp_group().is_last_rank:
- self.lm_head = ParallelLMHead(
- config.vocab_size,
- config.hidden_size,
- quant_config=quant_config,
- prefix=maybe_prefix(prefix, "lm_head"),
- )
- if self.tie_word_embeddings:
- self.lm_head.weight = self.model.word_embeddings.weight
- else:
- self.lm_head = PPMissingLayer()
-
- self.logits_processor = LogitsProcessor(config.vocab_size)
- self.sampler = Sampler()
- self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- image_mask: torch.Tensor | None = None,
- audio_mask: torch.Tensor | None = None,
- ):
- hidden_states = self.model(
- input_ids=input_ids,
- positions=positions,
- intermediate_tensors=intermediate_tensors,
- inputs_embeds=inputs_embeds,
- image_mask=image_mask,
- audio_mask=audio_mask,
- )
- return hidden_states
-
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- sampling_metadata,
- ) -> torch.Tensor | None:
- logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
- return logits
-
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata,
- ) -> SamplerOutput | None:
- next_tokens = self.sampler(logits, sampling_metadata)
- return next_tokens
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- stacked_params_mapping = [
- # (param_name, weight_name, shard_id)
- # BailingMoE stores fused QKV in checkpoint as query_key_value
- ("qkv_proj", "query_key_value", None),
- # Dense MLP and shared_experts gate/up are stored separately
- ("gate_up_proj", "gate_proj", 0),
- ("gate_up_proj", "up_proj", 1),
- ]
-
- # Gate router linear layers: checkpoint `{r}.weight` -> model `{r}.gate.weight`
- gate_name_mapper = WeightsMapper(
- orig_to_new_substr={f".{r}.weight": f".{r}.gate.weight" for r in ("gate", "image_gate", "audio_gate")}
- )
-
- # FusedMoE expert params mapping is identical across all MoE layers
- expert_params_mapping: list[tuple[str, str, int, str]] = []
- for layer in self.model.layers:
- if hasattr(layer, "mlp") and hasattr(layer.mlp, "experts"):
- expert_params_mapping = layer.mlp.experts.expert_mapping
- break
-
- params_dict = dict(self.named_parameters(remove_duplicate=False))
- loaded_params: set[str] = set()
-
- for name, loaded_weight in gate_name_mapper.apply(weights):
- for param_name, weight_name, shard_id in stacked_params_mapping:
- if weight_name not in name or "mlp.experts" in name:
- continue
- name = name.replace(weight_name, param_name)
- param = params_dict.get(name)
- if param is not None:
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- loaded_params.add(name)
- break
- else:
- for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
- param = params_dict.get(name)
- if param is not None:
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id)
- loaded_params.add(name)
- break
- else:
- param = params_dict.get(name)
- if param is not None:
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
-
- return loaded_params
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/projectors.py b/vllm_omni/model_executor/models/ming_flash_omni/projectors.py
deleted file mode 100644
index 42e53d1c635..00000000000
--- a/vllm_omni/model_executor/models/ming_flash_omni/projectors.py
+++ /dev/null
@@ -1,184 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-# Copyright (c) Ant Group. All rights reserved.
-# Adapted from Ming repository modeling_bailingmm2.py
-# https://github.com/inclusionAI/Ming
-
-from collections.abc import Iterable
-
-import torch
-import torch.nn as nn
-from vllm.logger import init_logger
-from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-
-logger = init_logger(__name__)
-
-
-class Transpose(nn.Module):
- """Used in nn.Sequential pipelines."""
-
- def __init__(self, dim0: int, dim1: int):
- super().__init__()
- self.dim0 = dim0
- self.dim1 = dim1
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return x.transpose(self.dim0, self.dim1)
-
-
-class VisionProjector(nn.Module):
- """MLP projector from vision encoder output to LLM hidden space.
-
- Args:
- vision_dim: Vision encoder output dimension (out_hidden_size).
- llm_dim: LLM hidden dimension.
- mlp_depth: Number of linear layers (>= 1).
- """
-
- def __init__(self, vision_dim: int, llm_dim: int, mlp_depth: int = 1):
- super().__init__()
- layers: list[nn.Module] = [nn.Linear(vision_dim, llm_dim)]
- for _ in range(1, mlp_depth):
- layers.append(nn.GELU())
- layers.append(nn.Linear(llm_dim, llm_dim))
- self.proj = nn.Sequential(*layers)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Project vision features.
-
- Args:
- x: [seq_len, vision_dim] or [B, seq_len, vision_dim]
-
- Returns:
- Projected features with last dim = llm_dim.
- """
- return self.proj(x)
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- params_dict = dict(self.named_parameters())
- loaded_params: set[str] = set()
- for name, loaded_weight in weights:
- if not name.startswith("proj."):
- name = f"proj.{name}"
- if name not in params_dict:
- logger.warning("Skipping unknown vision projector weight: %s", name)
- continue
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
- return loaded_params
-
-
-class AudioProjector(nn.Module):
- """Projector for audio features.
-
- Args:
- audio_dim: Audio encoder output dimension (n_state).
- llm_dim: LLM hidden dimension.
- ds_kernel_size: Conv1d kernel size for downsampling.
- ds_stride: Conv1d stride for downsampling.
- mlp_depth: Total number of projection layers (>= 1).
- """
-
- def __init__(
- self,
- audio_dim: int,
- llm_dim: int,
- ds_kernel_size: int = 3,
- ds_stride: int = 2,
- mlp_depth: int = 1,
- ):
- super().__init__()
- self.ds_kernel_size = ds_kernel_size
- self.ds_stride = ds_stride
-
- layers: list[nn.Module] = [
- nn.Conv1d(
- audio_dim,
- llm_dim,
- kernel_size=ds_kernel_size,
- stride=ds_stride,
- padding=ds_kernel_size // 2,
- ),
- Transpose(-1, -2), # [B, llm_dim, T'] -> [B, T', llm_dim]
- ]
- for _ in range(1, mlp_depth):
- layers.append(nn.GELU())
- layers.append(nn.Linear(llm_dim, llm_dim))
- self.proj = nn.Sequential(*layers)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Project audio features with temporal downsampling.
-
- Args:
- x: [B, T, audio_dim] audio encoder output (channel-last).
-
- Returns:
- [B, T', llm_dim] projected features (channel-last),
- where T' = (T - ds_kernel_size + 2*(ds_kernel_size//2)) // ds_stride + 1.
- """
- # Conv1d expects [B, C, T], so transpose input
- x = x.transpose(-1, -2) # [B, audio_dim, T]
- return self.proj(x)
-
- def forward_packed(
- self,
- packed: torch.Tensor,
- encoded_lens: list[int],
- ) -> tuple[torch.Tensor, list[int]]:
- """Project packed audio features from the Whisper encoder.
-
- Args:
- packed: [total_T', audio_dim] packed encoder output.
- encoded_lens: Length of each clip after Whisper encoding.
-
- Returns:
- Tuple of:
- - [total_T'', llm_dim] packed projected features.
- - List of projected lengths per clip.
- """
- conv1d = self.proj[0]
- mlp = self.proj[2:]
-
- # Split packed tensor per clip for Conv1d
- segments = packed.split(encoded_lens)
- conv_segments = []
- proj_lens: list[int] = []
- for seg in segments:
- out = conv1d(seg.transpose(0, 1).unsqueeze(0)) # [1, llm_dim, T'_i]
- out = out.squeeze(0).transpose(0, 1) # [T'_i, llm_dim]
- conv_segments.append(out)
- proj_lens.append(out.shape[0])
-
- packed_proj = torch.cat(conv_segments, dim=0) # [total_T'', llm_dim]
- packed_proj = mlp(packed_proj)
- return packed_proj, proj_lens
-
- def compute_output_length(self, input_length: torch.Tensor) -> torch.Tensor:
- """Compute output sequence length after Conv1d downsampling.
-
- Args:
- input_length: Original mel spectrogram lengths.
-
- Returns:
- Output lengths after both convolutions.
- """
- length = (input_length - 3 + 2 * 1) // 2 + 1
- length = (length - self.ds_kernel_size + 2 * (self.ds_kernel_size // 2)) // self.ds_stride + 1
- return length
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- params_dict = dict(self.named_parameters())
- loaded_params: set[str] = set()
- for name, loaded_weight in weights:
- if not name.startswith("proj."):
- name = f"proj.{name}"
- if name not in params_dict:
- logger.warning("Skipping unknown audio projector weight: %s", name)
- continue
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, loaded_weight)
- loaded_params.add(name)
- return loaded_params
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/prompt_utils.py b/vllm_omni/model_executor/models/ming_flash_omni/prompt_utils.py
deleted file mode 100644
index 4271114bc2d..00000000000
--- a/vllm_omni/model_executor/models/ming_flash_omni/prompt_utils.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-# Copyright (c) Ant Group. All rights reserved.
-# Adapted from Ming repo's usage cookbook:
-# https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/cookbook.ipynb
-"""Shared prompt-building helpers for Ming-flash-omni standalone talker."""
-
-import copy
-import json
-from typing import Any
-
-DEFAULT_PROMPT = "Please generate speech based on the following description.\n"
-
-BASE_CAPTION_TEMPLATE: dict[str, Any] = {
- "audio_sequence": [
- {
- "序号": 1,
- "说话人": "speaker_1",
- "方言": None,
- "风格": None,
- "语速": None,
- "基频": None,
- "音量": None,
- "情感": None,
- "BGM": {
- "Genre": None,
- "Mood": None,
- "Instrument": None,
- "Theme": None,
- "ENV": None,
- "SNR": None,
- },
- "IP": None,
- }
- ]
-}
-
-
-def create_instruction(user_input: dict[str, Any]) -> str:
- """Return a JSON caption string for ``audio_sequence[0]``.
-
- Only keys already present on the base template are merged in; unknown
- keys are silently ignored to keep the output schema stable.
- """
- caption = copy.deepcopy(BASE_CAPTION_TEMPLATE)
- item = caption["audio_sequence"][0]
- for key, value in user_input.items():
- if key in item:
- item[key] = value
- return json.dumps(caption, ensure_ascii=False)
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/spk_embedding.py b/vllm_omni/model_executor/models/ming_flash_omni/spk_embedding.py
deleted file mode 100644
index 68dbfe65021..00000000000
--- a/vllm_omni/model_executor/models/ming_flash_omni/spk_embedding.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-# Copyright (c) Ant Group. All rights reserved.
-# Ported from:
-# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/modeling_bailing_talker.py
-
-from __future__ import annotations
-
-import torch
-
-
-class SpkembExtractor:
- """CAMPPlus ONNX-based speaker embedding extractor (runs on CPU)."""
-
- def __init__(self, campplus_model: str, target_sr: int = 16000):
- import onnxruntime
- import torchaudio.compliance.kaldi as kaldi
-
- self.kaldi = kaldi
- option = onnxruntime.SessionOptions()
- option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
- option.intra_op_num_threads = 2
- self.campplus_session = onnxruntime.InferenceSession(
- campplus_model, sess_options=option, providers=["CPUExecutionProvider"]
- )
- self.target_sr = target_sr
-
- def _extract_spk_embedding(self, speech):
- feat = self.kaldi.fbank(speech, num_mel_bins=80, dither=0, sample_frequency=16000)
- feat = feat - feat.mean(dim=0, keepdim=True)
- embedding = (
- self.campplus_session.run(
- None,
- {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()},
- )[0]
- .flatten()
- .tolist()
- )
- embedding = torch.tensor([embedding])
- return embedding
-
- def __call__(self, waveform, **kwargs) -> torch.Tensor | None:
- spk_emb = self._extract_spk_embedding(waveform)
- return spk_emb
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/talker_module.py b/vllm_omni/model_executor/models/ming_flash_omni/talker_module.py
deleted file mode 100644
index 80acbaad06a..00000000000
--- a/vllm_omni/model_executor/models/ming_flash_omni/talker_module.py
+++ /dev/null
@@ -1,1151 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-# Copyright (c) Ant Group. All rights reserved.
-# Ported from:
-# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/talker_module/dit.py
-#
-# Ported from:
-# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/talker_module/modules.py
-# Ported from:
-# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/talker_module/cfm.py
-
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# Partial of the following source code
-# is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# References:
-# GLIDE: https://github.com/openai/glide-text2im
-# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
-# --------------------------------------------------------
-import logging
-import math
-from functools import cached_property
-from queue import Queue
-from threading import Lock
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from transformers import PreTrainedTokenizerBase, Qwen2Config, Qwen2Model, StaticCache
-from vllm.logger import init_logger
-from x_transformers.x_transformers import RotaryEmbedding, apply_rotary_pos_emb
-
-from .audio_vae import AudioVAE
-
-logger = init_logger(__name__)
-
-
-########################################################################
-# DiT Modules
-# Ported from:
-# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/talker_module/modules.py
-# Ported from:
-# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/talker_module/dit.py
-########################################################################
-
-
-class RMSNorm(nn.Module):
- def __init__(self, dim: int, eps: float = 1e-6):
- super().__init__()
- self.eps = eps
- self.weight = nn.Parameter(torch.ones(dim))
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
- x = x.to(self.weight.dtype)
- x = F.rms_norm(x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps)
- return x
-
-
-class FeedForward(nn.Module):
- def __init__(
- self, dim: int, dim_out: int | None = None, mult: float = 4, dropout: float = 0.0, approximate: str = "none"
- ):
- super().__init__()
- inner_dim = int(dim * mult)
- dim_out = dim_out if dim_out is not None else dim
-
- activation = nn.GELU(approximate=approximate)
- project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
- self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.ff(x)
-
-
-class Attention(nn.Module):
- def __init__(
- self,
- dim: int,
- heads: int = 8,
- dim_head: int = 64,
- dropout: float = 0.0,
- qk_norm: str | None = None,
- pe_attn_head: int | None = None,
- attn_mask_enabled: bool = True,
- ):
- super().__init__()
- self.dim = dim
- self.heads = heads
- self.inner_dim = dim_head * heads
- self.dropout = dropout
-
- self.to_q = nn.Linear(dim, self.inner_dim)
- self.to_k = nn.Linear(dim, self.inner_dim)
- self.to_v = nn.Linear(dim, self.inner_dim)
- if qk_norm is None:
- self.q_norm = None
- self.k_norm = None
- elif qk_norm == "rms_norm":
- self.q_norm = RMSNorm(dim_head)
- self.k_norm = RMSNorm(dim_head)
- else:
- raise ValueError(f"Unimplemented qk_norm: {qk_norm}")
-
- self.to_out = nn.ModuleList([])
- self.to_out.append(nn.Linear(self.inner_dim, dim))
- self.to_out.append(nn.Dropout(dropout))
-
- self.pe_attn_head = pe_attn_head
- self.attn_mask_enabled = attn_mask_enabled
-
- def forward(
- self,
- x: torch.Tensor,
- mask: torch.Tensor | None = None,
- rope: tuple[torch.Tensor, torch.Tensor | None] | None = None,
- ) -> torch.Tensor:
- batch_size = x.shape[0]
-
- query = self.to_q(x)
- key = self.to_k(x)
- value = self.to_v(x)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // self.heads
- query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
-
- if self.q_norm is not None:
- query = self.q_norm(query)
- if self.k_norm is not None:
- key = self.k_norm(key)
-
- if rope is not None:
- freqs, xpos_scale = rope
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
-
- if self.pe_attn_head is not None:
- on = self.pe_attn_head
- query[:, :on, :, :] = apply_rotary_pos_emb(query[:, :on, :, :], freqs, q_xpos_scale)
- key[:, :on, :, :] = apply_rotary_pos_emb(key[:, :on, :, :], freqs, k_xpos_scale)
- else:
- query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
- key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
-
- if self.attn_mask_enabled and mask is not None:
- valid_sample_indices = mask.any(dim=1)
- final_output = torch.zeros_like(query).to(query.device)
-
- attn_mask = mask[valid_sample_indices]
- query = query[valid_sample_indices]
- key = key[valid_sample_indices]
- value = value[valid_sample_indices]
- attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
- attn_mask = attn_mask.expand(valid_sample_indices.sum().item(), self.heads, query.shape[-2], key.shape[-2])
- else:
- attn_mask = None
-
- x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
- if self.attn_mask_enabled and mask is not None:
- final_output[valid_sample_indices] = x
- x = final_output
-
- x = x.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
- x = x.to(query.dtype)
-
- x = self.to_out[0](x)
- x = self.to_out[1](x)
-
- if mask is not None:
- mask = mask.unsqueeze(-1)
- x = x.masked_fill(~mask, 0.0)
-
- return x
-
-
-class DiTBlock(nn.Module):
- """A DiT block with pre-norm and residual connections."""
-
- def __init__(
- self,
- hidden_size: int,
- num_heads: int,
- mlp_ratio: float = 4.0,
- dropout: float = 0.1,
- qk_norm: str | None = None,
- pe_attn_head: int | None = None,
- attn_mask_enabled: bool = True,
- **kwargs,
- ):
- super().__init__()
- self.norm1 = RMSNorm(hidden_size)
- self.attn = Attention(
- dim=hidden_size,
- heads=num_heads,
- dim_head=hidden_size // num_heads,
- dropout=dropout,
- qk_norm=qk_norm,
- pe_attn_head=pe_attn_head,
- attn_mask_enabled=attn_mask_enabled,
- )
- self.norm2 = RMSNorm(hidden_size)
- self.mlp = FeedForward(dim=hidden_size, mult=mlp_ratio, dropout=dropout, approximate="tanh")
-
- def forward(
- self,
- x: torch.Tensor,
- mask: torch.Tensor | None,
- rope: tuple[torch.Tensor, torch.Tensor | None] | None,
- ) -> torch.Tensor:
- x = x + self.attn(self.norm1(x), mask=mask, rope=rope)
- x = x + self.mlp(self.norm2(x))
- return x
-
-
-class FinalLayer(nn.Module):
- """The final layer of DiT."""
-
- def __init__(self, hidden_size: int, out_channels: int):
- super().__init__()
- self.norm_final = RMSNorm(hidden_size)
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.norm_final(x)
- x = self.linear(x)
- return x
-
-
-class SinusPositionEmbedding(nn.Module):
- def __init__(self, dim: int):
- super().__init__()
- self.dim = dim
-
- def forward(self, x: torch.Tensor, scale: float = 1000) -> torch.Tensor:
- device = x.device
- half_dim = self.dim // 2
- emb = math.log(10000) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
- emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
- emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
- return emb
-
-
-class TimestepEmbedder(nn.Module):
- def __init__(self, dim: int, freq_embed_dim: int = 256):
- super().__init__()
- self.time_embed = SinusPositionEmbedding(freq_embed_dim)
- self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
-
- def forward(self, timestep: torch.Tensor) -> torch.Tensor:
- time_hidden = self.time_embed(timestep)
- time_hidden = time_hidden.to(timestep.dtype)
- time = self.time_mlp(time_hidden)
- return time
-
-
-class CondEmbedder(nn.Module):
- """Embeds LLM hidden states with optional CFG dropout."""
-
- def __init__(self, input_feature_size: int, hidden_size: int):
- super().__init__()
- self.cond_embedder = nn.Linear(input_feature_size, hidden_size)
-
- def forward(self, llm_cond: torch.Tensor) -> torch.Tensor:
- return self.cond_embedder(llm_cond)
-
-
-class DiT(nn.Module):
- """Diffusion model with a Transformer backbone for audio latent generation."""
-
- def __init__(
- self,
- in_channels: int = 64,
- hidden_size: int = 1024,
- depth: int = 28,
- num_heads: int = 16,
- mlp_ratio: float = 4.0,
- llm_cond_dim: int = 896,
- **kwargs,
- ):
- super().__init__()
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.num_heads = num_heads
-
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.x_embedder = nn.Linear(in_channels, hidden_size)
- self.c_embedder = CondEmbedder(llm_cond_dim, hidden_size)
- if "spk_dim" in kwargs:
- self.spk_embedder = nn.Linear(kwargs["spk_dim"], hidden_size)
- else:
- self.spk_embedder = None
- self.hidden_size = hidden_size
-
- self.rotary_embed = RotaryEmbedding(hidden_size // num_heads)
-
- self.blocks = nn.ModuleList(
- [DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **kwargs) for _ in range(depth)]
- )
- self.final_layer = FinalLayer(hidden_size, self.out_channels)
-
- def forward(
- self,
- x: torch.Tensor,
- t: torch.Tensor,
- c: torch.Tensor,
- latent_history: torch.Tensor,
- spk_emb: torch.Tensor | None = None,
- ) -> torch.Tensor:
- x = torch.cat([latent_history, x], dim=1)
- x = self.x_embedder(x)
- t = self.t_embedder(t).unsqueeze(1)
- c = self.c_embedder(c)
- y = t + c
- if spk_emb is None:
- assert self.spk_embedder is None
- x = torch.cat([y, x], dim=1)
- else:
- x = torch.cat([self.spk_embedder(spk_emb), y, x], dim=1)
- rope = self.rotary_embed.forward_from_seq_len(x.shape[1])
-
- for block in self.blocks:
- x = block(x, None, rope)
- x = self.final_layer(x)
- return x
-
- def forward_with_cfg(
- self,
- x: torch.Tensor,
- t: torch.Tensor,
- c: torch.Tensor,
- latent_history: torch.Tensor,
- spk_emb: torch.Tensor | None = None,
- ) -> torch.Tensor:
- """Forward with classifier-free guidance (doubles batch for CFG)."""
- x = torch.cat([x, x], dim=0)
- latent_history = torch.cat([latent_history, latent_history], dim=0)
- fake_latent = torch.zeros_like(c)
- c = torch.cat([c, fake_latent], dim=0)
- if t.ndim == 0:
- t = t.repeat(x.shape[0])
- if spk_emb is not None:
- spk_emb = torch.cat([spk_emb, spk_emb], dim=0)
- model_out = self.forward(x, t, c, latent_history, spk_emb)
- return model_out[:, -x.shape[1] :, :]
-
-
-#########################################################################################
-# CFM
-# Ported from:
-# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/talker_module/cfm.py
-#########################################################################################
-
-
-def get_epss_timesteps(n, device, dtype):
- dt = 1 / 32
- predefined_timesteps = {
- 5: [0, 2, 4, 8, 16, 32],
- 6: [0, 2, 4, 6, 8, 16, 32],
- 7: [0, 2, 4, 6, 8, 16, 24, 32],
- 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
- 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
- 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
- }
- t = predefined_timesteps.get(n, [])
- if not t:
- return torch.linspace(0, 1, n + 1, device=device, dtype=dtype)
- return dt * torch.tensor(t, device=device, dtype=dtype)
-
-
-class CFM(nn.Module):
- """Conditional Flow Matching module for audio latent generation."""
-
- def __init__(self, model: nn.Module, steps: int = 10, sway_sampling_coef: float | None = -1.0):
- """
- Args:
- model: DiT used for the velocity prediction.
- steps: number of integration steps per sample call.
- sway_sampling_coef: coefficient used to skew the integration
- grid towards low-noise timesteps. Defaults to -1.0 which
- packs more steps near t=0, where prediction error is highest.
- Set to `None` to use the linear grid as-is.
- """
- super().__init__()
- self.model = model
- self.steps = steps
- self.sway_sampling_coef = sway_sampling_coef
-
- @torch.no_grad()
- def sample(
- self,
- llm_cond: torch.Tensor,
- lat_cond: torch.Tensor,
- y0: torch.Tensor,
- t: torch.Tensor,
- sde_args: torch.Tensor,
- sde_rnd: torch.Tensor,
- ):
- """Sample audio latent via ODE/SDE integration with CFG.
-
- Args:
- llm_cond: LLM hidden state (B, 1, hidden_size)
- lat_cond: latent history (B, his_patch_size, latent_dim)
- y0: initial noise (B, patch_size, latent_dim)
- t: timesteps from get_epss_timesteps
- sde_args: [cfg_strength, sigma, temperature]
- sde_rnd: random noise for SDE steps (steps, B, patch_size, latent_dim)
- """
-
- def fn(fn_t, x):
- pred_cfg = self.model.forward_with_cfg(x, fn_t, llm_cond, lat_cond, None)
- pred, null_pred = torch.chunk(pred_cfg, 2, dim=0)
- return pred + (pred - null_pred) * sde_args[0]
-
- if self.sway_sampling_coef is not None:
- t = t + self.sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
-
- for step in range(self.steps):
- dt = t[step + 1] - t[step]
- y0 = y0 + fn(t[step], y0) * dt
- y0 = y0 + sde_args[1] * (sde_args[2] ** 0.5) * (dt.abs() ** 0.5) * sde_rnd[step]
-
- return y0
-
-
-class CFMGraphExecutor:
- """CUDA graph-accelerated executor for CFM + Aggregator + StopHead pipeline."""
-
- def __init__(self, config, cfm, aggregator, stop_head: nn.Linear):
- self.config = config
- self.cfm = cfm
- self.aggregator = aggregator
- self.stop_head = stop_head
- self.initialized = False
-
- self.last_hidden_state_placeholder = None
- self.his_lat_placeholder = None
- self.randn_like_placeholder = None
- self.t_placeholder = None
- self.sde_args_placeholder = None
- self.sde_rnd_placeholder = None
- self.gen_lat_placeholder = None
- self.inputs_embeds_placeholder = None
- self.stop_out_placeholder = None
- self.graph = None
-
- def execute(
- self,
- input_tensor: torch.Tensor,
- his_lat: torch.Tensor,
- cfg_strength: float = 2.0,
- sigma: float = 0.25,
- temperature: float = 0.0,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- bat_size, his_patch_size, z_dim = his_lat.shape
- randn_tensor = torch.randn(
- (bat_size, self.config.patch_size, z_dim), device=input_tensor.device, dtype=input_tensor.dtype
- )
- t = get_epss_timesteps(self.config.steps, device=input_tensor.device, dtype=input_tensor.dtype)
- sde_rnd = torch.randn(
- (self.config.steps, *randn_tensor.shape), device=input_tensor.device, dtype=input_tensor.dtype
- )
-
- if not self.initialized:
- self._initialize_graph(input_tensor, his_lat, randn_tensor, sde_rnd)
-
- self.last_hidden_state_placeholder.copy_(input_tensor)
- self.his_lat_placeholder.copy_(his_lat)
- self.randn_like_placeholder.copy_(randn_tensor)
- self.t_placeholder.copy_(t)
- self.sde_args_placeholder[0] = cfg_strength
- self.sde_args_placeholder[1] = sigma
- self.sde_args_placeholder[2] = temperature
- self.sde_rnd_placeholder.copy_(sde_rnd)
-
- self.graph.replay()
-
- gen_lat = torch.empty_like(self.gen_lat_placeholder)
- gen_lat.copy_(self.gen_lat_placeholder)
-
- inputs_embeds = torch.empty_like(self.inputs_embeds_placeholder)
- inputs_embeds.copy_(self.inputs_embeds_placeholder)
-
- stop_out = torch.empty_like(self.stop_out_placeholder)
- stop_out.copy_(self.stop_out_placeholder)
-
- return gen_lat, inputs_embeds, stop_out
-
- def _initialize_graph(
- self,
- input_tensor: torch.Tensor,
- his_lat: torch.Tensor,
- randn_tensor: torch.Tensor,
- sde_rnd: torch.Tensor,
- ) -> None:
- self.last_hidden_state_placeholder = torch.empty_like(input_tensor)
- self.his_lat_placeholder = torch.empty_like(his_lat)
- self.randn_like_placeholder = torch.empty_like(randn_tensor)
- self.t_placeholder = get_epss_timesteps(self.config.steps, device=input_tensor.device, dtype=input_tensor.dtype)
- self.sde_args_placeholder = torch.empty(3, device=input_tensor.device, dtype=input_tensor.dtype)
- self.sde_rnd_placeholder = torch.empty_like(sde_rnd)
-
- self.graph = torch.cuda.CUDAGraph()
- with torch.cuda.graph(self.graph):
- self.gen_lat_placeholder = self.cfm.sample(
- self.last_hidden_state_placeholder,
- self.his_lat_placeholder,
- self.randn_like_placeholder,
- self.t_placeholder,
- self.sde_args_placeholder,
- self.sde_rnd_placeholder,
- )
- self.inputs_embeds_placeholder = self.aggregator(self.gen_lat_placeholder)
- self.stop_out_placeholder = self.stop_head(self.last_hidden_state_placeholder[:, -1, :]).softmax(dim=-1)
-
- self.initialized = True
-
-
-class CFMGraphExecutorPool:
- """Thread-safe pool of CFMGraphExecutors for concurrent inference."""
-
- def __init__(self, config, cfm, aggregator, stop_head: nn.Linear, pool_size: int = 1):
- self.config = config
- self.cfm = cfm
- self.aggregator = aggregator
- self.stop_head = stop_head
- self.pool_size = pool_size
- self.pool = Queue(maxsize=pool_size)
- self.lock = Lock()
-
- for _ in range(pool_size):
- executor = CFMGraphExecutor(config, cfm, aggregator, stop_head)
- self.pool.put(executor)
-
- def acquire(self) -> CFMGraphExecutor:
- return self.pool.get()
-
- def release(self, executor: CFMGraphExecutor) -> None:
- self.pool.put(executor)
-
- def execute(
- self,
- input_tensor: torch.Tensor,
- his_lat: torch.Tensor,
- cfg_strength: float = 2.0,
- sigma: float = 0.25,
- temperature: float = 0.0,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- executor = self.acquire()
- try:
- return executor.execute(
- input_tensor, his_lat, cfg_strength=cfg_strength, sigma=sigma, temperature=temperature
- )
- finally:
- self.release(executor)
-
-
-########################################################################
-# Audio Postprocess
-# Adapted from:
-# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/modeling_bailing_talker.py
-########################################################################
-
-
-@torch.no_grad()
-def resample(waveform: torch.Tensor, orig_sr: int, target_sr: int) -> torch.Tensor:
- """Resample a waveform via linear interpolation (no torchaudio dep).
-
- Args:
- waveform: Tensor shaped ``(..., num_samples)``.
- orig_sr: Source sample rate (Hz); must be > 0.
- target_sr: Target sample rate (Hz); must be > 0.
-
- Raises:
- ValueError: If sample rates are non-positive, the waveform is empty,
- or the resampled length would round to zero.
- """
- if orig_sr <= 0:
- raise ValueError(f"orig_sr must be positive, got {orig_sr}")
- if target_sr <= 0:
- raise ValueError(f"target_sr must be positive, got {target_sr}")
- if waveform.numel() == 0 or waveform.shape[-1] == 0:
- raise ValueError("waveform must contain at least one sample")
- if orig_sr == target_sr:
- return waveform
-
- ratio = target_sr / orig_sr
- new_len = int(waveform.shape[-1] * ratio)
- if new_len <= 0:
- raise ValueError(
- f"resampled waveform would be empty for input length {waveform.shape[-1]}, "
- f"orig_sr={orig_sr}, target_sr={target_sr}"
- )
- return torch.nn.functional.interpolate(
- waveform.unsqueeze(0),
- size=new_len,
- mode="linear",
- align_corners=False,
- ).squeeze(0)
-
-
-def trim_trailing_silence(
- waveform: torch.Tensor,
- sample_rate: int,
- sil_th: float = 1e-3,
- tail_silence_s: float = 0.3,
-) -> torch.Tensor:
- """Trim low-energy tail while keeping a short trailing silence.
-
- Works on 2-D ``(channels, samples)`` or 3-D ``(batch, channels, samples)``
- tensors. Any other shape is returned unchanged.
- """
- if waveform.numel() == 0:
- return waveform
-
- original_dim = waveform.dim()
- if original_dim == 3:
- speech = waveform[:, 0, :]
- elif original_dim == 2:
- speech = waveform
- else:
- return waveform
-
- frame_step = int(sample_rate * 0.1)
- frame_size = int(sample_rate * 0.1)
- if speech.shape[-1] < frame_size:
- keep = min(speech.shape[-1], int(tail_silence_s * sample_rate))
- trimmed = speech[..., :keep]
- else:
- num_frame = (speech.shape[-1] - frame_size) // frame_step + 1
- cur_len = (num_frame - 1) * frame_step + frame_size
- speech = speech[..., :cur_len]
- spe_frames = speech.unfold(-1, frame_size, frame_step)
- scores = spe_frames.abs().mean(dim=-1)
- scores = scores.mean(dim=list(range(scores.dim() - 1)))
- idx = scores.shape[0] - 1
- while idx >= 0 and scores[idx] <= sil_th:
- idx -= 1
- if idx < 0:
- keep = min(speech.shape[-1], int(tail_silence_s * sample_rate))
- trimmed = speech[..., :keep]
- else:
- non_sil_len = idx * frame_step + frame_size + int(tail_silence_s * sample_rate)
- non_sil_len = min(non_sil_len, speech.shape[-1])
- trimmed = speech[..., :non_sil_len]
-
- if original_dim == 3:
- return trimmed.unsqueeze(1)
- return trimmed
-
-
-def silence_holder(
- speech: torch.Tensor,
- sample_rate: int,
- sil_cache: dict | None = None,
- last_chunk: bool = True,
- sil_th: float = 1e-3,
- last_sil: float = 0.3,
-) -> tuple[torch.Tensor, dict]:
- """Ming-style streaming silence holder.
-
- Used during streaming VAE decode to defer emission of silent regions
- until a non-silent frame arrives (or the stream ends). ``sil_cache``
- is carried across chunks and updated in place.
- """
- if speech.numel() == 0:
- return speech, sil_cache or {"holder": [], "buffer": []}
-
- frame_step = int(sample_rate * 0.1)
- frame_size = int(sample_rate * 0.1)
- if sil_cache is None:
- sil_cache = {"holder": [], "buffer": []}
-
- if sil_cache["buffer"]:
- speech = torch.cat([*sil_cache["buffer"], speech], dim=-1)
- sil_cache["buffer"] = []
-
- if speech.shape[-1] < frame_size:
- sil_cache["buffer"].append(speech)
- if last_chunk:
- speech = torch.cat(sil_cache["holder"] + sil_cache["buffer"], dim=-1)
- return speech[..., : int(last_sil * sample_rate)], sil_cache
- return torch.zeros((*speech.shape[:-1], 0), device=speech.device, dtype=speech.dtype), sil_cache
-
- num_frame = (speech.shape[-1] - frame_size) // frame_step + 1
- cur_len = (num_frame - 1) * frame_step + frame_size
- if speech.shape[-1] > cur_len:
- sil_cache["buffer"].append(speech[..., cur_len:])
- speech = speech[..., :cur_len]
-
- spe_frames = speech.unfold(-1, frame_size, frame_step)
- scores = spe_frames.abs().mean(dim=-1)
- scores = scores.mean(dim=list(range(scores.dim() - 1)))
- idx = scores.shape[0] - 1
- while idx >= 0 and scores[idx] <= sil_th:
- idx -= 1
-
- if idx < 0:
- sil_cache["holder"].append(speech)
- if last_chunk:
- speech = torch.cat(sil_cache["holder"] + sil_cache["buffer"], dim=-1)
- return speech[..., : int(last_sil * sample_rate)], sil_cache
- return torch.zeros((*speech.shape[:-1], 0), device=speech.device, dtype=speech.dtype), sil_cache
-
- non_sil_len = idx * frame_step + frame_size
- if last_chunk:
- non_sil_len += int(last_sil * sample_rate)
- non_sil_len = min(non_sil_len, speech.shape[-1])
- speech_out = torch.cat([*sil_cache["holder"], speech[..., :non_sil_len]], dim=-1)
- sil_cache["holder"] = []
- if non_sil_len < speech.shape[-1]:
- sil_cache["holder"].append(speech[..., non_sil_len:])
- return speech_out, sil_cache
-
-
-########################################################################
-# Audio Postprocess
-# Ported from:
-# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/talker_module/aggregator.py
-########################################################################
-
-
-class Aggregator(nn.Module):
- """Maps generated audio latent patches back to LLM embedding space."""
-
- def __init__(
- self,
- in_channels: int = 64,
- hidden_size: int = 1152,
- depth: int = 28,
- num_heads: int = 16,
- mlp_ratio: float = 4.0,
- llm_input_dim: int = 896,
- **kwargs,
- ):
- super().__init__()
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.num_heads = num_heads
-
- self.word_embedder = nn.Embedding(1, hidden_size)
- self.x_embedder = nn.Linear(in_channels, hidden_size)
- self.hidden_size = hidden_size
-
- self.rotary_embed = RotaryEmbedding(hidden_size // num_heads)
-
- self.blocks = nn.ModuleList(
- [DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **kwargs) for _ in range(depth)]
- )
- self.final_layer = FinalLayer(hidden_size, llm_input_dim)
-
- def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
- x = self.x_embedder(x)
- cls_embed = self.word_embedder(torch.zeros((x.shape[0], 1), dtype=torch.long, device=x.device))
- x = torch.cat([cls_embed, x], dim=1)
-
- rope = self.rotary_embed.forward_from_seq_len(x.shape[1])
- if mask is not None:
- mask_pad = mask.clone().detach()[:, :1]
- mask = torch.cat([mask_pad, mask], dim=-1)
- for block in self.blocks:
- x = block(x, mask, rope)
- x = self.final_layer(x)
- x = x[:, :1, :]
- return x
-
-
-########################################################################
-# Prompt Builder
-# Adapted from:
-# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/modeling_bailing_talker.py
-########################################################################
-
-_MUSIC_TAGS = ("Genre: ", "Mood: ", "Instrument: ", "Theme: ", "Duration: ")
-
-
-def _looks_like_music_prompt(text: str) -> bool:
- return all(tag in text for tag in _MUSIC_TAGS)
-
-
-def build_tts_input(
- *,
- tokenizer: PreTrainedTokenizerBase,
- embed_tokens: torch.nn.Module,
- device: torch.device,
- dtype: torch.dtype,
- text: str,
- prompt: str,
- spk_emb: list[torch.Tensor] | None = None,
- instruction: str | None = None,
- prompt_text: str | None = None,
- prompt_wav_emb: torch.Tensor | None = None,
-) -> tuple[torch.Tensor, torch.Tensor]:
- """Build (inputs_embeds, input_ids) for one TTS segment.
-
- Args:
- tokenizer: HF tokenizer
- embed_tokens: The LLM's input-embedding module
- device: Device to place the returned tensors on.
- dtype: dtype for the returned `inputs_embeds`.
- text: Text to synthesize.
- prompt: System-level instruction prompt prepended to the user turn.
- spk_emb: Optional list of speaker embeddings already projected into
- LLM hidden dim; each is injected at a `<|vision_start|>` slot.
- instruction: Optional free-form instruction
- prompt_text: Reference text for zero-shot voice cloning.
- prompt_wav_emb: Reference-wav embeddings to inject.
- """
- spk_emb_prompt: list[int] = []
- if spk_emb is not None:
- for i in range(len(spk_emb)):
- spk_emb_prompt.extend(
- tokenizer.encode(f" speaker_{i + 1}:")
- + tokenizer.encode("<|vision_start|>")
- + tokenizer.encode("<|vision_pad|>")
- + tokenizer.encode("<|vision_end|>\n")
- )
-
- instruction_prompt: list[int] = []
- if instruction is not None:
- instruction_prompt = tokenizer.encode(instruction) + tokenizer.encode("<|im_end|>")
-
- prompt_text_token: list[int] = []
- prompt_latent_token: list[int] = []
- if prompt_wav_emb is not None and prompt_text is not None:
- prompt_text_token = tokenizer.encode(prompt_text)
- prompt_latent_token = tokenizer.encode("") * prompt_wav_emb.size(1)
-
- prompt2 = [] if _looks_like_music_prompt(text) else tokenizer.encode(" Text input:\n")
-
- input_part = (
- tokenizer.encode("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n")
- + tokenizer.encode("<|im_start|>user\n")
- + tokenizer.encode(prompt)
- + spk_emb_prompt
- + prompt2
- + prompt_text_token
- + tokenizer.encode(text)
- + tokenizer.encode("<|im_end|>\n")
- + tokenizer.encode("<|im_start|>assistant\n")
- + instruction_prompt
- + tokenizer.encode("")
- + prompt_latent_token
- )
-
- input_ids = torch.tensor(input_part, dtype=torch.long, device=device).unsqueeze(0)
- inputs_embeds = embed_tokens(input_ids).to(device=device, dtype=dtype)
-
- # inject speaker embeddings
- if spk_emb is not None:
- spk_token_id = tokenizer.encode("<|vision_start|>")
- assert len(spk_token_id) == 1, "<|vision_start|> must tokenize to a single id"
- spk_indices = torch.where(input_ids[0] == spk_token_id[0])[0]
- assert len(spk_indices) > 0, "expected at least one <|vision_start|> slot"
- for i, se in enumerate(spk_emb):
- inputs_embeds[0, spk_indices[i] + 1] = se
-
- # inject prompt-wav embeddings after
- if prompt_wav_emb is not None and prompt_text is not None:
- audio_token_id = tokenizer.encode("")
- assert len(audio_token_id) == 1, " must tokenize to a single id"
- audio_indices = torch.where(input_ids[0] == audio_token_id[0])[0]
- assert len(audio_indices) > 0, "expected at least one slot"
- start = audio_indices[0] + 1
- inputs_embeds[0, start : start + prompt_wav_emb.size(1), :] = prompt_wav_emb[0]
-
- return inputs_embeds, input_ids
-
-
-########################################################################
-# Audio Generator
-########################################################################
-
-
-class MingAudioGenerator:
- """Generator driving prefill -> AR decode -> VAE decode
- for a single TTS request. The generator is stateless across requests.
- """
-
- def __init__(
- self,
- config,
- llm_config: Qwen2Config,
- model: Qwen2Model,
- cfm: CFM,
- aggregator: Aggregator,
- stop_head: torch.nn.Module,
- audio_vae: AudioVAE | None,
- patch_size: int,
- his_patch_size: int,
- latent_dim: int,
- cfg_strength: float,
- use_cuda_graphs: bool,
- ) -> None:
- self._config = config
- self._llm_config = llm_config
- self._model = model
- self._cfm = cfm
- self._aggregator = aggregator
- self._stop_head = stop_head
- self._audio_vae = audio_vae
-
- self.patch_size = patch_size
- self.his_patch_size = his_patch_size
- self.latent_dim = latent_dim
- self.cfg_strength = cfg_strength
-
- self._use_cuda_graphs = use_cuda_graphs
-
- @cached_property
- def _sampler_pool(self) -> CFMGraphExecutorPool | None:
- device = next(self._model.parameters()).device
- if self._use_cuda_graphs and device.type == "cuda":
- return CFMGraphExecutorPool(self._config, self._cfm, self._aggregator, self._stop_head, pool_size=1)
- return None
-
- def duration_capped_steps(self, text_len: int, requested_max_steps: int) -> int:
- """Apply the original Ming duration heuristic as a cap on decode steps."""
- if self._audio_vae is None:
- return requested_max_steps
-
- sample_rate = float(self._audio_vae.config.sample_rate)
- vae_patch_size = float(getattr(self._audio_vae.config, "patch_size", 4))
- hop_size = float(getattr(self._audio_vae.decoder, "hop_length", 320))
- seconds_per_step = (self.patch_size * vae_patch_size * hop_size) / sample_rate
- if seconds_per_step <= 0:
- return requested_max_steps
-
- max_duration_s = max(2.0, float(text_len) * (5818.0 / 16000.0))
- max_steps_by_duration = max(1, int(max_duration_s / seconds_per_step))
- return min(requested_max_steps, max_steps_by_duration)
-
- @torch.no_grad()
- def generate_latents(
- self,
- inputs_embeds: torch.Tensor,
- *,
- prompt_wav_lat: torch.Tensor | None = None,
- min_new_token: int = 10,
- max_steps: int = 1000,
- cfg: float | None = None,
- sigma: float = 0.25,
- temperature: float = 0.0,
- use_static_cache: bool = True,
- ) -> list[torch.Tensor]:
- """Autoregressive LLM + CFM sampling loop"""
- if cfg is None:
- cfg = self.cfg_strength
- device = next(self._model.parameters()).device
- dtype = next(self._model.parameters()).dtype
-
- his_lat = self._init_his_lat(prompt_wav_lat, device, dtype)
- past_key_values, max_cache_len = self._init_kv_cache(use_static_cache, device, dtype)
- prefill_len = inputs_embeds.shape[1]
- all_latents: list[torch.Tensor] = []
-
- for step in range(min(max_steps, max_cache_len - prefill_len)):
- last_hs = self.llm_step(
- inputs_embeds,
- step=step,
- past_key_values=past_key_values,
- use_static_cache=use_static_cache,
- )
- gen_lat, inputs_embeds, stop_out = self.cfm_sample_step(
- last_hs, his_lat, cfg=cfg, sigma=sigma, temperature=temperature
- )
- his_lat = self._update_his_lat(his_lat, gen_lat)
- all_latents.append(gen_lat)
-
- stop_prob = stop_out.cpu()[0, 1].item()
-
- if logger.isEnabledFor(logging.DEBUG):
- if step % 50 == 0 or step < 5:
- logger.debug(
- "step=%d stop_prob=%.4f hs_norm=%.4f lat_norm=%.4f emb_norm=%.4f",
- step,
- stop_prob,
- last_hs.float().norm().item(),
- gen_lat.float().norm().item(),
- inputs_embeds.float().norm().item(),
- )
-
- if step > min_new_token and stop_prob > 0.5:
- logger.info("Stopping at step %d with stop_prob=%.4f", step, stop_prob)
- break
-
- return all_latents
-
- def cfm_sample_step(
- self,
- last_hidden_state: torch.Tensor,
- his_lat: torch.Tensor,
- *,
- cfg: float | None = None,
- sigma: float = 0.25,
- temperature: float = 0.0,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Run one CFM sampling step.
-
- This is the CFM one-shot sampling step with CUDA-graph fast path.
- """
- if cfg is None:
- cfg = self.cfg_strength
-
- if self._sampler_pool is not None:
- return self._sampler_pool.execute(last_hidden_state, his_lat, cfg, sigma, temperature)
-
- bat_size, _, z_dim = his_lat.shape
- randn_tensor = torch.randn(
- (bat_size, self.patch_size, z_dim),
- device=last_hidden_state.device,
- dtype=last_hidden_state.dtype,
- )
- t = get_epss_timesteps(self._config.steps, device=last_hidden_state.device, dtype=last_hidden_state.dtype)
- sde_rnd = torch.randn(
- (self._config.steps, *randn_tensor.shape),
- device=last_hidden_state.device,
- dtype=last_hidden_state.dtype,
- )
- sde_args = torch.tensor(
- [cfg, sigma, temperature],
- device=last_hidden_state.device,
- dtype=last_hidden_state.dtype,
- )
-
- gen_lat = self._cfm.sample(last_hidden_state, his_lat, randn_tensor, t, sde_args, sde_rnd)
- inputs_embeds = self._aggregator(gen_lat)
- stop_out = self._stop_head(last_hidden_state[:, -1, :]).softmax(dim=-1)
-
- return gen_lat, inputs_embeds, stop_out
-
- def decode_to_waveform(self, latents: list[torch.Tensor], stream_decode: bool = True) -> torch.Tensor:
- """Decode accumulated latents to waveform via AudioVAE."""
- if self._audio_vae is None:
- raise RuntimeError("AudioVAE not loaded. Cannot decode audio latents to waveform.")
-
- if stream_decode:
- return self._stream_decode(latents)
-
- all_lat = torch.cat(latents, dim=1)
- waveform, _, _ = self._audio_vae.decode(
- all_lat, use_cache=False, stream_state=(None, None, None), last_chunk=True
- )
- return waveform
-
- def llm_step(
- self,
- inputs_embeds: torch.Tensor,
- *,
- step: int,
- past_key_values: StaticCache | None,
- use_static_cache: bool,
- ) -> torch.Tensor:
- if step == 0 or not use_static_cache:
- outputs = self._model(
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=True,
- )
- else:
- past_seen_tokens = past_key_values.get_seq_length()
- cache_position = torch.arange(
- past_seen_tokens,
- past_seen_tokens + inputs_embeds.shape[1],
- device=inputs_embeds.device,
- )
- outputs = self._model(
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=True,
- cache_position=cache_position,
- )
- return outputs.last_hidden_state[:, -1:, :]
-
- def _init_his_lat(
- self, prompt_wav_lat: torch.Tensor | None, device: torch.device, dtype: torch.dtype
- ) -> torch.Tensor:
- his_lat = torch.zeros(1, self.his_patch_size, self.latent_dim, device=device, dtype=dtype)
- if prompt_wav_lat is not None:
- start_index = self.his_patch_size - prompt_wav_lat.size(1)
- if start_index < 0:
- his_lat[:] = prompt_wav_lat[:, -start_index:, :]
- else:
- his_lat[:, start_index:, :] = prompt_wav_lat
- return his_lat
-
- def _init_kv_cache(
- self, use_static_cache: bool, device: torch.device, dtype: torch.dtype
- ) -> tuple[StaticCache | None, int]:
- max_cache_len = 2048
- if not use_static_cache:
- return None, max_cache_len
- cache = StaticCache(
- config=self._llm_config,
- max_batch_size=1,
- max_cache_len=max_cache_len,
- device=device,
- dtype=dtype,
- )
- return cache, max_cache_len
-
- def _update_his_lat(self, his_lat: torch.Tensor, gen_lat: torch.Tensor) -> torch.Tensor:
- if self.his_patch_size == self.patch_size:
- return gen_lat
- if self.his_patch_size > self.patch_size:
- return torch.cat([his_lat[:, self.patch_size - self.his_patch_size :], gen_lat], dim=1)
- raise NotImplementedError(f"his_patch_size ({self.his_patch_size}) < patch_size ({self.patch_size})")
-
- # VAE streaming decode
- def _stream_decode(self, latents: list[torch.Tensor]) -> torch.Tensor:
- sr = int(self._audio_vae.config.sample_rate)
- vae_cache = {"past_key_values": None, "stream_state": (None, None, None)}
- sil_cache: dict | None = None
- wav_chunks: list[torch.Tensor] = []
-
- for i, lat in enumerate(latents):
- last_chunk = i == (len(latents) - 1)
- speech, stream_state, past_key_values = self._audio_vae.decode(
- lat,
- past_key_values=vae_cache["past_key_values"],
- use_cache=True,
- stream_state=vae_cache["stream_state"],
- last_chunk=last_chunk,
- )
- vae_cache = {"past_key_values": past_key_values, "stream_state": stream_state}
- speech_chunk = speech[0].detach().float()
- speech_chunk, sil_cache = silence_holder(
- speech_chunk,
- sr,
- sil_cache=sil_cache,
- last_chunk=last_chunk,
- )
- if speech_chunk.numel() > 0:
- wav_chunks.append(speech_chunk)
-
- if not wav_chunks:
- device = next(self._model.parameters()).device
- dtype = next(self._model.parameters()).dtype
- return torch.zeros((1, 1, 0), device=device, dtype=dtype)
- return torch.cat(wav_chunks, dim=-1).unsqueeze(0)
-
- # Post-decode helper
- def trim_trailing_silence(self, waveform: torch.Tensor) -> torch.Tensor:
- if self._audio_vae is None:
- return waveform
- return trim_trailing_silence(waveform, int(self._audio_vae.config.sample_rate))
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/text_processing.py b/vllm_omni/model_executor/models/ming_flash_omni/text_processing.py
deleted file mode 100644
index 436b92f1428..00000000000
--- a/vllm_omni/model_executor/models/ming_flash_omni/text_processing.py
+++ /dev/null
@@ -1,535 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-# Copyright (c) Ant Group. All rights reserved.
-# Adapted from:
-# https://github.com/inclusionAI/Ming/tree/e58533db227031990c5a6864dcf5f08fb53ed0d2/front
-
-"""Text segmentation and normalization utilities for Ming TTS."""
-
-from __future__ import annotations
-
-import re
-import string
-
-from vllm.logger import init_logger
-
-logger = init_logger(__name__)
-
-# Ported from
-# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/front/toolkit.py
-_TOKENIZE_PATTERN = re.compile(r"(?:[a-zA-Z]\.)+|[a-zA-Z]+(?:['\-][a-zA-Z]+)*|\d+(?:\.\d+)?|[\u4e00-\u9fff]|\s+|\S")
-
-
-def tokenize_mixed_text(text: str) -> list[str]:
- return re.findall(_TOKENIZE_PATTERN, text)
-
-
-def tokenize_mixed_text_iterator(text: str):
- for match in _TOKENIZE_PATTERN.finditer(text):
- yield match.group(0)
-
-
-# Ported from
-# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/front/text_segment_cut.py
-def is_chinese(text: str) -> bool:
- return bool(re.search(r"[\u4e00-\u9fff]", text))
-
-
-def get_semantic_length(text: str) -> int:
- """1 CJK char = 1 unit; 1 contiguous English word = 1 unit."""
- chinese_char_count = len(re.findall(r"[\u4e00-\u9fa5]", text))
- english_word_count = len(re.findall(r"[a-zA-Z]+", text))
- return chinese_char_count + english_word_count
-
-
-def has_valid_content(text: str) -> bool:
- punctuation_and_whitespace = string.punctuation + string.whitespace
- for char in text:
- if char not in punctuation_and_whitespace:
- return True
- return False
-
-
-def append_text_fragment(
- fragments: list[str],
- new_text: str,
- max_len: int,
- min_tail_length: int,
-) -> list[str]:
- new_text = new_text.lstrip(",,:;" + string.whitespace)
- if not has_valid_content(new_text):
- return fragments
- if not fragments:
- fragments.append(new_text)
- return fragments
-
- last_fragment = fragments[-1]
- last_semantic_len = get_semantic_length(last_fragment)
- new_semantic_len = get_semantic_length(new_text)
-
- if last_semantic_len + new_semantic_len <= max_len:
- if last_fragment.endswith(("。", "!", "?")) and new_semantic_len < min_tail_length:
- fragments.append(new_text)
- else:
- separator = ""
- if not last_fragment.endswith(" ") and re.match(r"^[a-zA-Z0-9]", new_text):
- separator = " "
- fragments[-1] += separator + new_text
- else:
- fragments.append(new_text)
- return fragments
-
-
-def split_long_fragment(text_fragment: str, max_len: int) -> list[str]:
- if get_semantic_length(text_fragment) <= max_len:
- return [text_fragment]
-
- fragments: list[str] = []
- current_fragment = ""
- semantic_units = re.findall(r"([\u4e00-\u9fa5]|[a-zA-Z]+|[^a-zA-Z\u4e00-\u9fa5]+)", text_fragment)
- for unit in semantic_units:
- unit_len = get_semantic_length(unit)
- current_len = get_semantic_length(current_fragment)
- if current_len + unit_len <= max_len:
- current_fragment += unit
- else:
- if current_fragment:
- fragments.append(current_fragment)
- if unit_len > max_len:
- fragments.append(unit)
- current_fragment = ""
- else:
- current_fragment = unit
- if current_fragment:
- fragments.append(current_fragment)
- return fragments
-
-
-_DOT_PLACEHOLDER = "##DOT##"
-# default soft cap on fragment length in semantic units
-_DEFAULT_MAX_SEMANTIC_LENGTH: int = 50
-# default tail length controls when a short trailing fragment is
-# merged with the previous one to avoid leaving an awkward stub.
-_DEFAULT_MIN_TAIL_LENGTH: int = 5
-
-
-def cut_text_by_semantic_length(
- text: str,
- max_semantic_length: int = _DEFAULT_MAX_SEMANTIC_LENGTH,
- min_tail_length: int = _DEFAULT_MIN_TAIL_LENGTH,
-) -> list[str]:
- """Segment text into fragments respecting semantic length limits.
-
- Ported from upstream Ming's `front/text_segment_cut.py`.
- Position tracking is omitted (not needed for non-streaming VAE decode).
- """
- if not has_valid_content(text):
- return []
-
- processed = re.sub(r"(\d)\.(\d)", r"\1" + _DOT_PLACEHOLDER + r"\2", text)
- for _ in range(3):
- processed = re.sub(r"([A-Z])\.([A-Z])", r"\1" + _DOT_PLACEHOLDER + r"\2", processed)
- processed = processed.replace("\n", " ").replace("。,", "。")
-
- if get_semantic_length(processed) <= max_semantic_length:
- return [processed.replace(_DOT_PLACEHOLDER, ".")]
-
- normalized = processed.replace(".", "。").replace("!", "!").replace("?", "?").replace(",", ",")
-
- # Phase 1: split into sentences on 。!?
- sentences: list[str] = []
- current: list[str] = []
- for char in normalized:
- current.append(char)
- if char in "。!?":
- s = "".join(current).strip()
- if s:
- sentences.append(s)
- current = []
- if current:
- s = "".join(current).strip()
- if s:
- if not s.endswith(("。", "!", "?")):
- s += "。"
- sentences.append(s)
-
- # Phase 2: merge whole sentences; only clause-split oversized ones.
- # This ensures split points land on sentence boundaries (。!?)
- # rather than mid-sentence commas.
- result_fragments: list[str] = []
- for sentence in sentences:
- sent_len = get_semantic_length(sentence)
-
- if sent_len > max_semantic_length:
- # Oversized sentence: fall back to clause-level splitting
- clauses: list[str] = []
- clause_buf: list[str] = []
- for char in sentence:
- clause_buf.append(char)
- if char in ",;;":
- cl = "".join(clause_buf).strip()
- if cl and has_valid_content(cl):
- clauses.append(cl)
- elif cl and clauses:
- clauses[-1] += cl
- clause_buf = []
- if clause_buf:
- cl = "".join(clause_buf).strip()
- if cl and has_valid_content(cl):
- clauses.append(cl)
- elif cl and clauses:
- clauses[-1] += cl
-
- i = 0
- while i < len(clauses):
- clause = clauses[i]
- clause_len = get_semantic_length(clause)
-
- if clause_len < min_tail_length and i + 1 < len(clauses):
- combined = clause + clauses[i + 1]
- if get_semantic_length(combined) <= max_semantic_length:
- result_fragments = append_text_fragment(
- result_fragments, combined, max_semantic_length, min_tail_length
- )
- i += 2
- continue
-
- if clause_len > max_semantic_length:
- for frag in split_long_fragment(clause, max_semantic_length):
- result_fragments = append_text_fragment(
- result_fragments, frag, max_semantic_length, min_tail_length
- )
- else:
- result_fragments = append_text_fragment(
- result_fragments, clause, max_semantic_length, min_tail_length
- )
- i += 1
- else:
- # Normal sentence: merge at sentence level
- if not result_fragments:
- result_fragments.append(sentence)
- else:
- last_len = get_semantic_length(result_fragments[-1])
- if last_len + sent_len <= max_semantic_length:
- result_fragments[-1] += sentence
- else:
- result_fragments.append(sentence)
-
- return [f.replace(_DOT_PLACEHOLDER, ".") for f in result_fragments]
-
-
-# Streaming sentence boundary detection
-
-_RE_CJK = re.compile(r"[\u4e00-\u9fff]")
-_RE_DIGIT_LAST = re.compile(r"[0-9]")
-
-
-# Left for reference for now
-# Alternative to `cut_text_by_semantic_length`
-def detect_sentence_boundaries(
- text: str,
- max_length: int = 50,
-) -> list[str]:
- """Accumulate tokens and flush at sentence boundaries.
-
- Ported from the streaming sentence detection loop from the Ming repo
- TTS branch, but operates on the full text since we have it available upfront.
- """
- sentences: list[str] = []
- streaming_text: list[str] = []
- count = 0
-
- for ele in tokenize_mixed_text_iterator(text):
- if len(ele) == 0:
- continue
-
- should_process = False
- min_tokens = 12 if count == 0 else 8
-
- if ele[-1] in "!?。,!?":
- if len(streaming_text) >= min_tokens:
- should_process = True
- streaming_text.append(ele)
-
- elif ele[-1] == ".":
- if (
- len(streaming_text) >= min_tokens
- and streaming_text
- and not _RE_DIGIT_LAST.search(streaming_text[-1][-1])
- ):
- should_process = True
- streaming_text.append(ele)
-
- elif ele[-1] == "\n":
- if streaming_text:
- joined = "".join(streaming_text)
- if _RE_CJK.search(joined):
- if _RE_CJK.search(streaming_text[-1][-1]):
- ele = ","
- streaming_text.append(ele)
- else:
- if len(ele) > 1 and re.search(r"[a-zA-Z]", ele[-2]):
- ele = ele[:-1] + "."
- else:
- ele = ele[:-1]
- streaming_text.append(ele)
-
- if len(streaming_text) >= min_tokens:
- should_process = True
- else:
- streaming_text.append(ele)
- continue
-
- if should_process:
- joined = "".join(streaming_text)
- fragments = cut_text_by_semantic_length(joined, max_length)
- sentences.extend(fragments)
- streaming_text = []
- count += 1
-
- # Flush remaining
- if streaming_text and re.search(r"[a-zA-Z\u4e00-\u9fff1-9]", "".join(streaming_text)):
- joined = "".join(streaming_text)
- fragments = cut_text_by_semantic_length(joined, max_length)
- sentences.extend(fragments)
-
- return sentences
-
-
-# number normalization for English. Ported from
-# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/front/number_en.py
-
-
-_inflect_engine = None
-
-
-def _get_inflect():
- global _inflect_engine
- if _inflect_engine is not None:
- return _inflect_engine
- try:
- import inflect
-
- _inflect_engine = inflect.engine()
- except ImportError:
- logger.warning(
- "Package 'inflect' not installed - English number normalization "
- "will be skipped. Install with: pip install inflect"
- )
- _inflect_engine = None
- return _inflect_engine
-
-
-_comma_number_re = re.compile(r"([0-9][0-9,]+[0-9])")
-_percent_number_re = re.compile(r"(-?[0-9.,]*[0-9]+)%")
-_pounds_re = re.compile(r"£(-?[0-9,]*[0-9]+(?:\.[0-9]+)?)")
-_dollars_re = re.compile(r"\$(-?[0-9.,]*[0-9]+(?:\.[0-9]+)?)")
-_fraction_re = re.compile(r"([0-9]+)\/([0-9]+)")
-_ordinal_re = re.compile(r"\b[0-9]+(st|[nr]d|th)\b")
-_number_re = re.compile(r"\b-?[0-9]+(?:\.[0-9]+)?\b")
-_unit_re = re.compile(
- r"\b(-?\d+(?:\.\d+)?)\s*"
- r"(ms|s|Hz|kHz|MHz|GHz|kb|mb|gb|tb|KB|MB|GB|TB|bps|kbps|Mbps|Gbps|cm|km|kg|V|A|W|°C|°F)\b",
- re.IGNORECASE,
-)
-_version_re = re.compile(r"\b([a-zA-Z]+)([-]?)([0-9]+(?:\.[0-9]+)?)\b")
-_whitespace_re = re.compile(r"\s+")
-
-_unit_mapping = {
- "ms": "milliseconds",
- "s": "seconds",
- "hz": "hertz",
- "khz": "kilohertz",
- "mhz": "megahertz",
- "ghz": "gigahertz",
- "kb": "kilobytes",
- "mb": "megabytes",
- "gb": "gigabytes",
- "tb": "terabytes",
- "kbps": "kilobits per second",
- "mbps": "megabits per second",
- "gbps": "gigabits per second",
- "bps": "bits per second",
- "cm": "centimeters",
- "km": "kilometers",
- "kg": "kilograms",
- "v": "volts",
- "a": "amperes",
- "w": "watts",
- "°c": "degrees celsius",
- "°f": "degrees fahrenheit",
-}
-
-
-def _num_to_words(n: int) -> str:
- p = _get_inflect()
- if p is None:
- return str(n)
- return p.number_to_words(n, andword="")
-
-
-def _expand_decimal(num_str: str) -> str:
- """Expand a decimal number string like '3.14' -> 'three point one four'."""
- is_negative = num_str.startswith("-")
- clean = num_str.lstrip("-") or "0"
-
- if "." in clean:
- parts = clean.split(".", 1)
- integer_part = parts[0] or "0"
- decimal_part = parts[1]
- if not integer_part.isdigit() or not decimal_part.isdigit():
- return num_str
- int_word = _num_to_words(int(integer_part)) if integer_part != "0" else "zero"
- dec_words = " ".join(_num_to_words(int(d)) for d in decimal_part if d.isdigit())
- word = f"{int_word} point {dec_words}"
- else:
- if not clean.isdigit():
- return num_str
- word = _num_to_words(int(clean))
-
- if is_negative:
- word = f"minus {word}"
- return word
-
-
-def _remove_commas(m: re.Match) -> str:
- return m.group(1).replace(",", "")
-
-
-_NUM_PARSE_EXC: tuple[type[BaseException], ...] = (ValueError, TypeError)
-
-
-def _expand_unit(m: re.Match) -> str:
- num_str, unit = m.group(1), m.group(2).lower()
- unit_word = _unit_mapping.get(unit, unit)
- try:
- return f" {_expand_decimal(num_str)} {unit_word} "
- except _NUM_PARSE_EXC:
- return f" {num_str} {unit} "
-
-
-def _expand_percent(m: re.Match) -> str:
- try:
- return f" {_expand_decimal(m.group(1))} percent "
- except _NUM_PARSE_EXC:
- return f" {m.group(1)} percent "
-
-
-def _expand_dollars(m: re.Match) -> str:
- raw = m.group(1)
- clean = raw.lstrip("-") or "0"
- try:
- word = _expand_decimal(raw)
- value = float(clean)
- unit = "dollar" if abs(value) == 1.0 else "dollars"
- return f" {word} {unit} "
- except _NUM_PARSE_EXC:
- return f" {clean} dollars "
-
-
-def _expand_pounds(m: re.Match) -> str:
- raw = m.group(1)
- clean = raw.lstrip("-") or "0"
- try:
- word = _expand_decimal(raw)
- value = float(clean)
- unit = "pound" if abs(value) == 1.0 else "pounds"
- return f" {word} {unit} "
- except _NUM_PARSE_EXC:
- return f" {clean} pounds "
-
-
-def _expand_fraction(m: re.Match) -> str:
- p = _get_inflect()
- if p is None:
- return m.group(0)
- try:
- num, den = int(m.group(1)), int(m.group(2))
- if num == 1 and den == 2:
- return " one half "
- if num == 1 and den == 4:
- return " one quarter "
- if den == 2:
- plural = " half" if num == 1 else " halves"
- return f" {p.number_to_words(num)}{plural} "
- if den == 4:
- plural = " quarter" if num == 1 else " quarters"
- return f" {p.number_to_words(num)}{plural} "
- ordinal = p.ordinal(p.number_to_words(den))
- return f" {p.number_to_words(num)} {ordinal} "
- except _NUM_PARSE_EXC:
- return f" {m.group(1)} over {m.group(2)} "
-
-
-def _expand_ordinal(m: re.Match) -> str:
- try:
- num = int(re.sub(r"(st|and|rd|th)", "", m.group(0)))
- return f" {_num_to_words(num)} "
- except _NUM_PARSE_EXC:
- return m.group(0)
-
-
-def _expand_number(m: re.Match) -> str:
- try:
- return f" {_expand_decimal(m.group(0))} "
- except _NUM_PARSE_EXC:
- return f" {m.group(0)} "
-
-
-def _expand_version(m: re.Match) -> str:
- prefix, _, num_str = m.group(1), m.group(2), m.group(3)
- try:
- word = _expand_decimal(num_str)
- except _NUM_PARSE_EXC:
- return m.group(0)
- return f"{prefix} {word}"
-
-
-def normalize_numbers(text: str) -> str:
- """Expand English numbers, currencies, units, etc. to words.
-
- Returns text unchanged if `inflect` package is not installed.
- """
- if _get_inflect() is None:
- return text
- text = re.sub(_comma_number_re, _remove_commas, text)
- text = re.sub(_unit_re, _expand_unit, text)
- text = re.sub(_pounds_re, _expand_pounds, text)
- text = re.sub(_dollars_re, _expand_dollars, text)
- text = re.sub(_fraction_re, _expand_fraction, text)
- text = re.sub(_percent_number_re, _expand_percent, text)
- text = re.sub(_ordinal_re, _expand_ordinal, text)
- text = re.sub(_version_re, _expand_version, text)
- text = re.sub(_number_re, _expand_number, text)
- text = re.sub(_whitespace_re, " ", text)
- return text.strip()
-
-
-# Top-level API
-def segment_and_normalize(
- text: str,
- max_length: int = _DEFAULT_MAX_SEMANTIC_LENGTH,
-) -> list[str]:
- """Segment text into fragments and expand English numbers for Ming TTS.
-
- This function cuts text by semantic length directly rather than following
- the streaming algorithm to detect sentence boundaries in the upstream
- Ming repo (which is more aggressively splitting at commas). It produces
- fewer and larger segments at natural sentence boundaries.
- """
- if not text or not text.strip():
- return []
-
- segments = cut_text_by_semantic_length(text.strip(), max_length)
-
- normalized: list[str] = []
- for seg in segments:
- if not is_chinese(seg):
- seg = normalize_numbers(seg)
- if seg and seg[0] == ",":
- seg = seg[1:]
- seg = seg.strip()
- if seg:
- normalized.append(seg)
-
- return normalized if normalized else [text.strip()]
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/vision_encoder.py b/vllm_omni/model_executor/models/ming_flash_omni/vision_encoder.py
deleted file mode 100644
index 7976d76ce8d..00000000000
--- a/vllm_omni/model_executor/models/ming_flash_omni/vision_encoder.py
+++ /dev/null
@@ -1,125 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# Adapted from Ming repository qwen3_moe_vit.py
-# https://github.com/inclusionAI/Ming
-
-from collections.abc import Iterable
-
-import torch
-import torch.nn as nn
-from vllm.logger import init_logger
-from vllm.model_executor.layers.quantization import QuantizationConfig
-from vllm.model_executor.models.qwen3_omni_moe_thinker import (
- Qwen3Omni_VisionTransformer,
-)
-from vllm.model_executor.models.utils import WeightsMapper
-
-logger = init_logger(__name__)
-
-
-def _adapt_vision_config(vision_config):
- # Adapt Ming's Qwen3VLMoeVisionConfig to be compatible with vLLM's
- # Qwen3Omni_VisionTransformer expectations.
- if not hasattr(vision_config, "image_size") or vision_config.image_size is None:
- if hasattr(vision_config, "num_position_embeddings") and vision_config.num_position_embeddings:
- import math
-
- num_grid = int(math.sqrt(vision_config.num_position_embeddings))
- vision_config.image_size = num_grid * vision_config.patch_size
- else:
- vision_config.image_size = vision_config.patch_size * 14 # fallback
-
- if not hasattr(vision_config, "apply_vit_abs_pos_embed"):
- vision_config.apply_vit_abs_pos_embed = True
-
- return vision_config
-
-
-class MingVisionEncoder(nn.Module):
- """**Wrapper** around vLLM's Qwen3Omni_VisionTransformer for Ming."""
-
- hf_to_vllm_mapper = WeightsMapper(
- orig_to_new_substr={
- "deepstack_merger_list.": "merger_list.",
- "merger.norm.": "merger.ln_q.",
- "merger.linear_fc1.": "merger.mlp.0.",
- "merger.linear_fc2.": "merger.mlp.2.",
- }
- )
-
- def __init__(
- self,
- vision_config,
- quant_config: QuantizationConfig | None = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
- adapted_config = _adapt_vision_config(vision_config)
- norm_eps = 1e-6
- self.encoder = Qwen3Omni_VisionTransformer(
- vision_config=adapted_config,
- norm_eps=norm_eps,
- quant_config=quant_config,
- prefix=f"{prefix}.encoder",
- )
- self.image_emb_dim = vision_config.out_hidden_size
- self.use_deepstack = (
- hasattr(vision_config, "deepstack_visual_indexes") and vision_config.deepstack_visual_indexes is not None
- )
-
- @property
- def dtype(self) -> torch.dtype:
- return self.encoder.dtype
-
- @property
- def device(self) -> torch.device:
- return self.encoder.device
-
- def forward(
- self,
- pixel_values: torch.Tensor,
- grid_thw: torch.Tensor,
- ) -> torch.Tensor:
- """forward method of the vision encoder.
-
- Args:
- pixel_values: Flattened pixel values.
- grid_thw: [num_images, 3] tensor of (t, h, w) grid sizes.
-
- Returns:
- If deepstack is enabled, returns concatenated multi-scale features
- along the feature dim: [seq_len, hidden_size * (1 + num_deepstack)].
- Otherwise returns [seq_len, hidden_size].
- """
- return self.encoder(pixel_values, grid_thw=grid_thw)
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- import re
-
- def _remap_merger_list_inner(name: str) -> str:
- name = re.sub(r"(merger_list\.\d+)\.norm\.", r"\1.ln_q.", name)
- name = re.sub(r"(merger_list\.\d+)\.linear_fc1\.", r"\1.mlp.0.", name)
- name = re.sub(r"(merger_list\.\d+)\.linear_fc2\.", r"\1.mlp.2.", name)
-
- return name
-
- remapped_weights = self.hf_to_vllm_mapper.apply(weights)
- remapped_weights = ((_remap_merger_list_inner(name), tensor) for name, tensor in remapped_weights)
- loaded_params = self.encoder.load_weights(remapped_weights)
-
- loaded_params = {f"encoder.{loaded_param}" for loaded_param in loaded_params}
-
- return loaded_params
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/voice_presets.py b/vllm_omni/model_executor/models/ming_flash_omni/voice_presets.py
deleted file mode 100644
index 5f54687c0cb..00000000000
--- a/vllm_omni/model_executor/models/ming_flash_omni/voice_presets.py
+++ /dev/null
@@ -1,289 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-# Copyright (c) Ant Group. All rights reserved.
-# Adapted from:
-# https://github.com/inclusionAI/Ming/blob/e58533db227031990c5a6864dcf5f08fb53ed0d2/modeling_bailing_talker.py
-
-from __future__ import annotations
-
-import json
-import os
-from functools import cached_property
-from typing import TYPE_CHECKING, Any
-
-import soundfile as sf
-import torch
-from transformers.utils.hub import cached_file
-from vllm.logger import init_logger
-
-from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific
-
-from .spk_embedding import SpkembExtractor
-from .talker_module import resample
-
-if TYPE_CHECKING:
- from .audio_vae import AudioVAE
- from .talker_module import Aggregator
-
-logger = init_logger(__name__)
-
-
-class InvalidPromptWavError(ValueError):
- """Prompt wav failed local validation and can be skipped in list mode."""
-
-
-class VoicePresetRegistry:
- """Loader and registry for Ming voice presets."""
-
- def __init__(
- self,
- *,
- talker_dir: str,
- model_path: str,
- download_dir: str | None,
- audio_vae: AudioVAE | None,
- aggregator: Aggregator,
- spk_head: torch.nn.Module,
- patch_size: int,
- ) -> None:
- self._talker_dir = talker_dir
- self._model_path = model_path
- self._download_dir = download_dir
- self._audio_vae = audio_vae
- self._aggregator = aggregator
- self._spk_head = spk_head
- self._patch_size = patch_size
-
- self.registered: dict[str, dict[str, Any]] = {}
-
- def __contains__(self, voice_name: str) -> bool:
- return voice_name in self.registered
-
- def get(self, voice_name: str) -> dict[str, Any] | None:
- return self.registered.get(voice_name)
-
- def register(
- self,
- voice_name: str,
- prompt_wav_path: str | list[str],
- *,
- device: torch.device,
- dtype: torch.dtype,
- ) -> None:
- """Register a voice preset from one or more reference wav files.
-
- Args:
- voice_name: Key under which to store the preset.
- prompt_wav_path: Single wav path or a list (multi-clip mode skips
- invalid entries with a warning instead of raising).
- device: Target device for cached prompt latents / projected
- speaker embeddings.
- dtype: Target dtype for the projected speaker embedding head.
- """
- paths = self._normalize_paths(voice_name, prompt_wav_path)
- allow_partial = len(paths) > 1
-
- vae_sr = int(self._audio_vae.config.sample_rate) if self._audio_vae else 44100
- if self._audio_vae is None:
- logger.warning(
- "Voice preset '%s' being registered without AudioVAE features",
- voice_name,
- )
-
- speech_chunks: list[torch.Tensor] = []
- spk_emb_list: list[torch.Tensor] = []
- for wav_path in paths:
- try:
- speech_for_vae, raw_emb = self._load_single_wav(voice_name, wav_path, vae_sr)
- except (FileNotFoundError, InvalidPromptWavError) as e:
- if allow_partial:
- logger.warning(
- "Voice preset '%s': skipping invalid prompt wav %s: %s",
- voice_name,
- wav_path,
- e,
- )
- continue
- raise
- speech_chunks.append(speech_for_vae)
- if raw_emb is not None:
- projected = self._spk_head(raw_emb.to(device=device, dtype=dtype))
- spk_emb_list.append(projected)
-
- if not speech_chunks:
- raise RuntimeError(f"Failed to register voice preset '{voice_name}': no valid prompt wavs remained")
- if not spk_emb_list and self._audio_vae is None:
- raise RuntimeError(
- f"Failed to register voice preset '{voice_name}': neither speaker "
- "embeddings nor AudioVAE prompt features are available"
- )
-
- prompt_wav_lat, prompt_wav_emb = self._build_wav_embeddings(
- voice_name, torch.cat(speech_chunks, dim=-1), device=device
- )
-
- if voice_name in self.registered:
- logger.warning("Voice preset '%s' is being overwritten", voice_name)
- self.registered[voice_name] = {
- "prompt_wav_lat": prompt_wav_lat,
- "prompt_wav_emb": prompt_wav_emb,
- "spk_emb": spk_emb_list,
- }
- logger.info("Registered voice preset '%s' from %s", voice_name, paths)
-
- def load_presets_from_manifest(self, *, device: torch.device, dtype: torch.dtype) -> None:
- """Resolve voice_name.json on disk or HF hub and register all entries.
-
- Each entry is registered onto the supplied device and dtype.
- """
- voice_json_path, base_dir = self._locate_manifest()
- if voice_json_path is None:
- logger.info("No voice_name.json found; voice presets unavailable")
- return
-
- with open(voice_json_path) as f:
- voice_dict = json.load(f)
-
- for name, info in voice_dict.items():
- wav_path = info.get("prompt_wav_path", "")
- prompt_text = info.get("prompt_text", "")
- if not wav_path:
- logger.warning("Voice preset '%s' has no prompt_wav_path, skipping", name)
- continue
- if not os.path.isabs(wav_path):
- wav_path = os.path.join(base_dir, wav_path)
- if not os.path.isfile(wav_path):
- logger.warning("Voice preset '%s': wav not found at %s, skipping", name, wav_path)
- continue
- try:
- self.register(name, wav_path, device=device, dtype=dtype)
- self.registered[name]["prompt_text"] = prompt_text
- except Exception as e: # pragma: no cover — manifest is best-effort
- logger.warning("Failed to register voice preset '%s': %s", name, e)
-
- @cached_property
- def _spkemb_extractor(self) -> SpkembExtractor:
- """Lazily resolve the CAMPPlus ONNX extractor."""
- for candidate in (self._talker_dir, self._model_path):
- path = os.path.join(candidate, "campplus.onnx")
- if os.path.isfile(path):
- extractor = SpkembExtractor(path)
- logger.info("Initialized SpkembExtractor from %s", path)
- return extractor
- try:
- path = cached_file(self._model_path, "campplus.onnx", subfolder="talker")
- except Exception as e:
- raise RuntimeError("campplus.onnx not found. Expected at /talker/campplus.onnx") from e
- extractor = SpkembExtractor(path)
- logger.info("Initialized SpkembExtractor from %s", path)
- return extractor
-
- @staticmethod
- def _normalize_paths(voice_name: str, prompt_wav_path: str | list[str]) -> list[str]:
- if not isinstance(voice_name, str) or not voice_name.strip():
- raise ValueError("voice_name must be a non-empty string")
- if isinstance(prompt_wav_path, str):
- paths = [prompt_wav_path]
- elif isinstance(prompt_wav_path, list):
- paths = list(prompt_wav_path)
- else:
- raise TypeError("prompt_wav_path must be a string path or a list of string paths")
- paths = [p.strip() for p in paths]
- if not paths or any(not p for p in paths):
- raise ValueError("Provided audio path is invalid")
- return paths
-
- def _load_single_wav(self, voice_name: str, wav_path: str, vae_sr: int) -> tuple[torch.Tensor, torch.Tensor | None]:
- """Return ``(speech_for_vae, raw_spk_emb_or_none)``.
-
- Stays device-agnostic — both returned tensors live on CPU; the caller
- moves them to the target device when projecting / encoding.
- """
- if not os.path.isfile(wav_path):
- raise FileNotFoundError(f"prompt wav not found: {wav_path}")
-
- data, sample_rate = sf.read(wav_path, dtype="float32")
- speech_tmp = torch.from_numpy(data)
- if speech_tmp.ndim == 1:
- speech_tmp = speech_tmp.unsqueeze(0)
- elif speech_tmp.ndim == 2:
- num_channels = speech_tmp.shape[1]
- if num_channels > 1:
- logger.warning(
- "Voice preset '%s': downmixing %d-channel audio at %s to mono",
- voice_name,
- num_channels,
- wav_path,
- )
- speech_tmp = speech_tmp.mean(dim=1, keepdim=True).T
- else:
- raise InvalidPromptWavError(f"unsupported audio shape {tuple(speech_tmp.shape)} for {wav_path}")
-
- if not torch.isfinite(speech_tmp).all():
- raise InvalidPromptWavError(f"audio file contains NaN or Inf samples: {wav_path}")
-
- speech_for_vae = resample(speech_tmp, sample_rate, vae_sr)
-
- # Speaker embedding (16 kHz CAMPPlus). If the extractor fails to
- # resolve (missing ONNX model), skip embedding extraction rather than
- # blocking VAE-only registration.
- raw_emb: torch.Tensor | None = None
- try:
- extractor = self._spkemb_extractor
- speech_for_spk = resample(speech_tmp, sample_rate, 16000)
- raw_emb = extractor(speech_for_spk)
- except RuntimeError:
- raw_emb = None
- return speech_for_vae, raw_emb
-
- def _build_wav_embeddings(
- self,
- voice_name: str,
- speech: torch.Tensor,
- *,
- device: torch.device,
- ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
- if self._audio_vae is None:
- return None, None
-
- patch_pt = self._audio_vae.encoder.hop_size * max(1, self._audio_vae.encoder.patch_size) * self._patch_size
- if speech.shape[-1] % patch_pt != 0:
- pad_len = (speech.shape[-1] + patch_pt - 1) // patch_pt * patch_pt
- pad_speech = torch.zeros((speech.shape[0], pad_len), dtype=speech.dtype, device=speech.device)
- pad_speech[:, -speech.shape[-1] :] = speech
- speech = pad_speech
-
- prompt_wav_lat, _ = self._audio_vae.encode_latent(
- speech.to(dtype=torch.bfloat16, device=device),
- torch.tensor([speech.size(1)], dtype=torch.long, device=device),
- )
- assert prompt_wav_lat.shape[1] % self._patch_size == 0, (
- f"AudioVAE latent length is incompatible with patch_size for voice preset '{voice_name}'"
- )
- prompt_wav_lat = prompt_wav_lat.reshape(-1, self._patch_size, prompt_wav_lat.shape[-1])
- prompt_wav_emb = self._aggregator(prompt_wav_lat)
- prompt_wav_lat = prompt_wav_lat.reshape(1, -1, prompt_wav_lat.shape[-1])
- prompt_wav_emb = prompt_wav_emb.reshape(1, -1, prompt_wav_emb.shape[-1])
- return prompt_wav_lat, prompt_wav_emb
-
- def _locate_manifest(self) -> tuple[str | None, str | None]:
- for candidate in (self._talker_dir, self._model_path):
- path = os.path.join(candidate, "data", "voice_name.json")
- if os.path.isfile(path):
- return path, candidate
-
- if not os.path.isdir(self._model_path):
- try:
- hf_root = download_weights_from_hf_specific(
- self._model_path,
- self._download_dir,
- allow_patterns=["talker/data/**"],
- require_all=True,
- )
- candidate = os.path.join(hf_root, "talker", "data", "voice_name.json")
- if os.path.isfile(candidate):
- return candidate, os.path.join(hf_root, "talker")
- except Exception as e: # pragma: no cover
- logger.info("Could not download voice presets from HF: %s", e)
-
- return None, None
diff --git a/vllm_omni/model_executor/models/omnivoice/omnivoice.py b/vllm_omni/model_executor/models/omnivoice/omnivoice.py
index 7fde8f16faa..a3603a3c398 100644
--- a/vllm_omni/model_executor/models/omnivoice/omnivoice.py
+++ b/vllm_omni/model_executor/models/omnivoice/omnivoice.py
@@ -15,7 +15,6 @@
import numpy as np
import torch
import torch.nn as nn
-import torchaudio
from transformers.feature_extraction_utils import BatchFeature
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
@@ -78,21 +77,31 @@ def _ensure_cached_runtime_components(self, model_dir: str, config: OmniVoiceCon
self.text_tokenizer = AutoTokenizer.from_pretrained(model_dir)
- # Audio tokenizer for encoding reference audio (requires transformers>=5.3)
+ # Audio tokenizer for encoding reference audio
audio_tokenizer_path = os.path.join(model_dir, "audio_tokenizer")
- try:
- from transformers import (
- AutoFeatureExtractor,
- HiggsAudioV2TokenizerModel,
- )
+ if os.path.isdir(audio_tokenizer_path):
+ try:
+ from transformers import (
+ AutoFeatureExtractor,
+ HiggsAudioV2TokenizerModel,
+ )
+ except ImportError as e:
+ raise ImportError(
+ "OmniVoice voice cloning requires transformers with "
+ "HiggsAudioV2TokenizerModel. Upgrade transformers or "
+ "use text-only mode (no reference audio)."
+ ) from e
self.audio_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained(audio_tokenizer_path, device_map="cpu")
self.feature_extractor = AutoFeatureExtractor.from_pretrained(audio_tokenizer_path)
self.audio_tokenizer.eval()
- except ImportError:
+ else:
self.audio_tokenizer = None
self.feature_extractor = None
- logger.warning("Voice cloning disabled (requires transformers>=5.3.0).")
+ logger.warning(
+ "audio_tokenizer not found at %s, voice cloning disabled",
+ audio_tokenizer_path,
+ )
self._cached_model_dir = model_dir
@@ -157,16 +166,20 @@ def _call_hf_processor(
if self.feature_extractor is not None:
target_sr = self.feature_extractor.sampling_rate
if sr != target_sr:
+ import torchaudio
+
audio_signal = torchaudio.functional.resample(audio_signal, sr, target_sr)
# Encode reference audio to 8-codebook tokens
- if self.audio_tokenizer is None:
- raise RuntimeError("Voice cloning requires transformers>=5.3.0. Try: uv pip install 'transformers>=5.3.0'")
-
- with torch.inference_mode():
- ref_audio_tokens = self.audio_tokenizer.encode(audio_signal) # [8, T_ref]
- if ref_audio_tokens.dim() == 3:
- ref_audio_tokens = ref_audio_tokens.squeeze(0) # [8, T_ref]
+ if self.audio_tokenizer is not None:
+ with torch.inference_mode():
+ ref_audio_tokens = self.audio_tokenizer.encode(audio_signal) # [8, T_ref]
+ if ref_audio_tokens.dim() == 3:
+ ref_audio_tokens = ref_audio_tokens.squeeze(0) # [8, T_ref]
+ else:
+ raise RuntimeError(
+ "Audio tokenizer not available for voice cloning. Ensure audio_tokenizer/ exists in model directory."
+ )
ft = BatchFeature(
{
diff --git a/vllm_omni/model_executor/models/output_templates.py b/vllm_omni/model_executor/models/output_templates.py
index 9d6a84ac407..2ed20980651 100644
--- a/vllm_omni/model_executor/models/output_templates.py
+++ b/vllm_omni/model_executor/models/output_templates.py
@@ -3,13 +3,11 @@
import torch
from vllm.sequence import IntermediateTensors
-from vllm_omni.data_entry_keys import OmniPayload
-
class OmniOutput(NamedTuple):
"""Output from the merged Omni model containing both text and audio."""
text_hidden_states: torch.Tensor
- multimodal_outputs: OmniPayload | None = None
+ multimodal_outputs: dict | None = None
intermediate_tensors: IntermediateTensors | None = None
next_token_id: torch.Tensor | None = None
diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py b/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py
deleted file mode 100644
index b44d08eb32a..00000000000
--- a/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Qwen2.5-Omni pipeline topology (frozen).
-
-Stage 0: Thinker — multimodal understanding + text generation
-Stage 1: Talker — text embeddings → speech tokens
-Stage 2: Code2Wav — speech tokens → audio waveform
-"""
-
-from vllm_omni.config.stage_config import (
- PipelineConfig,
- StageExecutionType,
- StagePipelineConfig,
-)
-
-_PROC = "vllm_omni.model_executor.stage_input_processors.qwen2_5_omni"
-
-QWEN2_5_OMNI_PIPELINE = PipelineConfig(
- model_type="qwen2_5_omni",
- model_arch="Qwen2_5OmniForConditionalGeneration",
- stages=(
- StagePipelineConfig(
- stage_id=0,
- model_stage="thinker",
- execution_type=StageExecutionType.LLM_AR,
- input_sources=(),
- final_output=True,
- final_output_type="text",
- owns_tokenizer=True,
- requires_multimodal_data=True,
- engine_output_type="latent",
- sampling_constraints={"detokenize": True},
- ),
- StagePipelineConfig(
- stage_id=1,
- model_stage="talker",
- execution_type=StageExecutionType.LLM_AR,
- input_sources=(0,),
- engine_output_type="latent",
- custom_process_input_func=f"{_PROC}.thinker2talker",
- sampling_constraints={
- "detokenize": True,
- "stop_token_ids": [8294],
- },
- ),
- StagePipelineConfig(
- stage_id=2,
- model_stage="code2wav",
- execution_type=StageExecutionType.LLM_GENERATION,
- input_sources=(1,),
- final_output=True,
- final_output_type="audio",
- engine_output_type="audio",
- sampling_constraints={"detokenize": True},
- ),
- ),
-)
-
-
-# Single-stage thinker-only variant for the abort test.
-QWEN2_5_OMNI_THINKER_ONLY_PIPELINE = PipelineConfig(
- model_type="qwen2_5_omni_thinker_only",
- model_arch="Qwen2_5OmniForConditionalGeneration",
- stages=(
- StagePipelineConfig(
- stage_id=0,
- model_stage="thinker",
- execution_type=StageExecutionType.LLM_AR,
- input_sources=(),
- final_output=True,
- final_output_type="text",
- owns_tokenizer=True,
- requires_multimodal_data=True,
- engine_output_type="latent",
- sampling_constraints={"detokenize": True},
- ),
- ),
-)
diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py
index cd4c1aa1ce3..067c08e3c7d 100644
--- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py
+++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py
@@ -32,7 +32,6 @@
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
-from vllm_omni.data_entry_keys import OmniPayload
from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin
from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific
from vllm_omni.model_executor.models.output_templates import OmniOutput
@@ -667,8 +666,6 @@ def talker_preprocess(
# - For Decode segments, if per-request auxiliary decode embeddings are provided (optional),
# add them; otherwise, keep the original embedding.
- payload: OmniPayload = info_dict
-
# Ensure we have base embeddings when only ids are provided
if input_embeds is None and input_ids is not None:
input_embeds = self.talker.embed_input_ids(input_ids)
@@ -676,27 +673,23 @@ def talker_preprocess(
span_len = input_ids.shape[0]
if span_len > 1:
# prefill
- return self.thinker_to_talker_process(input_ids, input_embeds, payload)
+ return self.thinker_to_talker_process(input_ids, input_embeds, **info_dict)
else:
# decode
- return self.thinker_to_talker_decode_one_step(input_ids, input_embeds, payload)
+ return self.thinker_to_talker_decode_one_step(input_ids, input_embeds, **info_dict)
def thinker_to_talker_process(
self,
input_ids: torch.Tensor,
input_embeds: torch.Tensor,
- payload: OmniPayload,
+ **info_dict: object,
):
- embed = payload.get("embed", {})
- hs = payload.get("hidden_states", {})
- ids = payload.get("ids", {})
-
update_dict = {}
- prompt_embeds = embed.get("prefill") # Tensor [P,H]
- thinker_result = hs.get("output") # Tensor [K,H]
- prompt_token_ids = ids.get("prompt") # list[int]
- thinker_output_token_ids = ids.get("output") # list[int]
+ prompt_embeds = info_dict.get("prompt_embeds") # Tensor [P,H]
+ thinker_result = info_dict.get("thinker_result") # Tensor [K,H]
+ prompt_token_ids = info_dict.get("prompt_token_ids") # list[int]
+ thinker_output_token_ids = info_dict.get("thinker_output_token_ids") # list[int]
if not isinstance(prompt_embeds, torch.Tensor):
prompt_embeds = torch.zeros(
@@ -721,7 +714,7 @@ def thinker_to_talker_process(
)
if thinker_result.ndim == 2 and thinker_result.shape[0] > 0:
- update_dict.setdefault("embed", {})["thinker_reply"] = thinker_result[1:].detach().to("cpu").contiguous()
+ update_dict["thinker_reply_part"] = thinker_result[1:].detach().to("cpu").contiguous()
return req_input_ids, req_embeds, update_dict
@@ -770,20 +763,18 @@ def _thinker_to_talker_prefill(
)
return prompt_token_ids_processed, prompt_embeds
- def thinker_to_talker_decode_one_step(self, input_ids, input_embeds, payload: OmniPayload):
- embed = payload.get("embed", {})
-
+ def thinker_to_talker_decode_one_step(self, input_ids, input_embeds, **info_dict):
update_dict = {}
# choose step vector in priority order
step_vec = None
- q = embed.get("thinker_reply", None)
+ q = info_dict.get("thinker_reply_part", None)
if isinstance(q, torch.Tensor) and q.numel() > 0:
step_vec = q[0:1]
new_q = q[1:].detach().to("cpu").contiguous()
- update_dict.setdefault("embed", {})["thinker_reply"] = new_q
+ update_dict["thinker_reply_part"] = new_q
else:
# B) per-request provided decode vector (optional)
- dv = embed.get("decode")
+ dv = info_dict.get("decode_output_prompt_embeds") if isinstance(info_dict, dict) else None
if isinstance(dv, torch.Tensor) and dv.numel() > 0:
step_vec = dv[0:1] if dv.ndim == 2 else dv.view(1, -1)
elif (
diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py
index 617f0f9e325..0307034089c 100644
--- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py
+++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py
@@ -64,10 +64,6 @@
)
from vllm.sequence import IntermediateTensors
-from vllm_omni.quantization.component_config import (
- resolve_encoder_quant_config,
-)
-
try:
import flash_attn
except (ImportError, ModuleNotFoundError):
@@ -363,12 +359,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.quant_config = quant_config
- # Pre-quantized checkpoints (modelopt NVFP4/FP8/MXFP8) only quantize
- # the Thinker LM. Vision encoder weights remain in BF16 with no FP8
- # scale tensors; passing quant_config causes FP8 kernels to run on
- # BF16 weights, producing garbage embeddings. Keep None for encoders.
- visual_quant_config = resolve_encoder_quant_config(quant_config)
-
with self._mark_tower_model(vllm_config, "audio"):
if multimodal_config.get_limit_per_prompt("audio"):
self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config)
@@ -380,7 +370,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.visual = Qwen2_5_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
- quant_config=visual_quant_config,
+ quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
else:
diff --git a/vllm_omni/model_executor/models/qwen3_omni/pipeline.py b/vllm_omni/model_executor/models/qwen3_omni/pipeline.py
deleted file mode 100644
index 1c69ec79570..00000000000
--- a/vllm_omni/model_executor/models/qwen3_omni/pipeline.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Qwen3-Omni-MoE pipeline topology (frozen).
-
-Stage 0: Thinker — multimodal understanding + text generation
-Stage 1: Talker — text embeddings → RVQ codec codes
-Stage 2: Code2Wav — RVQ codes → audio waveform
-"""
-
-from vllm_omni.config.stage_config import (
- PipelineConfig,
- StageExecutionType,
- StagePipelineConfig,
-)
-
-_PROC = "vllm_omni.model_executor.stage_input_processors.qwen3_omni"
-
-QWEN3_OMNI_PIPELINE = PipelineConfig(
- model_type="qwen3_omni_moe",
- model_arch="Qwen3OmniMoeForConditionalGeneration",
- stages=(
- StagePipelineConfig(
- stage_id=0,
- model_stage="thinker",
- execution_type=StageExecutionType.LLM_AR,
- input_sources=(),
- final_output=True,
- final_output_type="text",
- owns_tokenizer=True,
- requires_multimodal_data=True,
- hf_config_name="thinker_config",
- engine_output_type="latent",
- custom_process_next_stage_input_func=(f"{_PROC}.thinker2talker_async_chunk"),
- sampling_constraints={"detokenize": True},
- ),
- StagePipelineConfig(
- stage_id=1,
- model_stage="talker",
- execution_type=StageExecutionType.LLM_AR,
- input_sources=(0,),
- hf_config_name="talker_config",
- engine_output_type="latent",
- custom_process_input_func=f"{_PROC}.thinker2talker",
- custom_process_next_stage_input_func=(f"{_PROC}.talker2code2wav_async_chunk"),
- sampling_constraints={
- "detokenize": False,
- "stop_token_ids": [2150],
- },
- ),
- StagePipelineConfig(
- stage_id=2,
- model_stage="code2wav",
- execution_type=StageExecutionType.LLM_GENERATION,
- input_sources=(1,),
- final_output=True,
- final_output_type="audio",
- hf_config_name="thinker_config",
- engine_output_type="audio",
- custom_process_input_func=f"{_PROC}.talker2code2wav",
- sampling_constraints={"detokenize": True},
- ),
- ),
-)
diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
index 28b969ff7cd..ed6df6af36a 100644
--- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
+++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
@@ -36,7 +36,6 @@
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
-from vllm_omni.data_entry_keys import Embeddings, HiddenStates, Ids, OmniPayload, OmniPayloadMeta
from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin
from vllm_omni.model_executor.models.output_templates import OmniOutput
from vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_thinker import (
@@ -176,15 +175,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.suppressed_tokens = self._get_talker_suppressed_tokens()
self.requires_raw_input_tokens = True
# Keys that should stay on GPU in model_intermediate_buffer to avoid CPU↔GPU round-trips
- self.gpu_resident_buffer_keys: set[tuple[str, str]] = {
- ("hidden_states", "last"),
- ("hidden_states", "trailing_text"),
- ("embed", "tts_pad_projected"),
- }
- # Keys that need to be accumulated across streaming inputs
- self.streaming_accumulated_keys: set[tuple[str, str]] = {
- ("embed", "prefill"),
- ("hidden_states", "output"),
+ self.gpu_resident_buffer_keys: set[str] = {
+ "last_talker_hidden",
+ "trailing_text_hidden",
+ "tts_pad_embed_projected",
}
elif self.model_stage == "code2wav":
@@ -430,9 +424,8 @@ def forward(
left_context_size = []
if runtime_additional_information is not None:
for info in runtime_additional_information:
- meta = info.get("meta", {})
- if "left_context_size" in meta:
- left_context_size.append(meta["left_context_size"])
+ if "left_context_size" in info:
+ left_context_size.append(info["left_context_size"])
else:
logger.debug("No additional_information provided to code2wav stage.")
audio_tensors = self.generate_audio(codes, left_context_size, seq_token_counts)
@@ -461,7 +454,7 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs) -
text_hidden_states, captured_layer_dict = model_outputs
# Compute thinker-side TTS token embeddings for BOS/EOS/PAD and expose via multimodal outputs.
# These will later be projected into talker text space by the talker stage.
- multimodal_outputs: OmniPayload = captured_layer_dict if captured_layer_dict is not None else {}
+ multimodal_outputs = captured_layer_dict if captured_layer_dict is not None else {}
try:
thinker_tts_embeds = self.thinker.embed_input_ids(self.tts_tokens) # [1,3,thinker_hidden]
if (
@@ -470,10 +463,9 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs) -
and thinker_tts_embeds.shape[1] == 3
):
bos_eos_pad = thinker_tts_embeds.to(text_hidden_states.device).chunk(3, dim=1) # 3 * [1,1,H]
- embed = multimodal_outputs.setdefault("embed", {})
- embed["tts_bos"] = [bos_eos_pad[0]]
- embed["tts_eos"] = [bos_eos_pad[1]]
- embed["tts_pad"] = [bos_eos_pad[2]]
+ multimodal_outputs["tts_bos_embed"] = [bos_eos_pad[0]]
+ multimodal_outputs["tts_eos_embed"] = [bos_eos_pad[1]]
+ multimodal_outputs["tts_pad_embed"] = [bos_eos_pad[2]]
except Exception:
# Best-effort; absence will be handled by talker with fallbacks
pass
@@ -497,10 +489,9 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs) -
if "runtime_additional_information" in kwargs and "model_intermediate_buffer" not in kwargs:
logger.warning_once("runtime_additional_information is deprecated, use model_intermediate_buffer")
- code_predictor_codes = [info.get("codes", {}).get("audio") for info in info_dicts]
- audio_codes = torch.cat(code_predictor_codes, dim=0)
- multimodal_outputs: OmniPayload = {"codes": {"audio": audio_codes}}
- span_len = audio_codes.shape[0]
+ code_predictor_codes = [info.get("code_predictor_codes") for info in info_dicts]
+ multimodal_outputs = {"code_predictor_codes": torch.cat(code_predictor_codes, dim=0)}
+ span_len = multimodal_outputs["code_predictor_codes"].shape[0]
talker_hidden = talker_hidden[:span_len]
return OmniOutput(text_hidden_states=talker_hidden, multimodal_outputs=multimodal_outputs)
elif self.model_stage == "code2wav":
@@ -619,14 +610,13 @@ def _init_special_tokens_embeddings(self) -> set[str]:
# Speaker token IDs (for voice selection)
# In Qwen3, speaker_id mapping is in talker_config.speaker_id
- # Keys are lowercased for case-insensitive matching with serving layer.
if hasattr(talker_hf_config, "speaker_id") and talker_hf_config.speaker_id:
- self.tts_text_spk_token_ids = {k.lower(): v for k, v in talker_hf_config.speaker_id.items()}
+ self.tts_text_spk_token_ids = talker_hf_config.speaker_id
else:
# Default to audio_start_token_id if no speaker mapping
self.tts_text_spk_token_ids = {
"default": talker_hf_config.audio_start_token_id,
- "ethan": talker_hf_config.audio_start_token_id,
+ "Ethan": talker_hf_config.audio_start_token_id,
"prefix_caching": talker_hf_config.audio_start_token_id,
}
@@ -644,48 +634,47 @@ def talker_postprocess(self, hidden_states: torch.Tensor, **info_dict: object):
"""
Postprocess the talker hidden states.
"""
- return {"hidden_states": {"last": hidden_states[-1, :].detach()}}
+ update_dict = {}
+ update_dict["last_talker_hidden"] = hidden_states[-1, :].detach()
+ return update_dict
def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **info_dict: dict):
"""
Preprocess talker embeds. Noted that we set the MTP here.
"""
- payload: OmniPayload = info_dict
- meta = payload.setdefault("meta", {})
-
# Ensure we have base embeddings when only ids are provided
if input_embeds is None and input_ids is not None:
input_embeds = self.talker.embed_input_ids(input_ids)
span_len = input_ids.shape[0]
- update_dict: OmniPayload = {}
+ update_dict = {}
if span_len > 1:
# prefill
- input_ids, input_embeds, update_dict = self.talker_preprocess_prefill(input_ids, input_embeds, payload)
+ input_ids, input_embeds, update_dict = self.talker_preprocess_prefill(input_ids, input_embeds, **info_dict)
code_predictor_codes = torch.zeros(
(input_embeds.shape[0], self.talker.num_code_groups),
device=self._module_device(self.talker),
dtype=torch.long,
)
- update_dict.setdefault("codes", {})["audio"] = code_predictor_codes
+ update_dict["code_predictor_codes"] = code_predictor_codes
else:
# decode
- if not meta.get("decode_flag", False):
+ if not info_dict.get("decode_flag", False):
# Prefill already consumed the first text token via the
# assistant bootstrap path, so decode starts from the
# remaining-text boundary rather than cumulative index 0.
- prefill_consumed_text_tokens = meta.get("prefill_consumed_text_tokens")
+ prefill_consumed_text_tokens = info_dict.get("prefill_consumed_text_tokens")
if prefill_consumed_text_tokens is None:
raise RuntimeError("Missing prefill_consumed_text_tokens for talker decode handoff.")
- meta["num_processed_tokens"] = prefill_consumed_text_tokens
- update_dict.setdefault("meta", {})["decode_flag"] = True
+ info_dict["num_processed_tokens"] = prefill_consumed_text_tokens
+ update_dict["decode_flag"] = True
last_talker_hidden, text_step, update_dict = self.talker_preprocess_decode(
- input_ids, input_embeds, update_dict, payload
+ input_ids, input_embeds, update_dict, **info_dict
)
update_dict["mtp_inputs"] = last_talker_hidden, text_step
- update_dict.setdefault("meta", {})["num_processed_tokens"] = meta.get("num_processed_tokens", 0) + span_len
+ update_dict["num_processed_tokens"] = info_dict.get("num_processed_tokens", 0) + span_len
return input_ids, input_embeds, update_dict
def talker_mtp(
@@ -694,7 +683,6 @@ def talker_mtp(
input_embeds: torch.Tensor,
last_talker_hidden: torch.Tensor,
text_step: torch.Tensor,
- **kwargs: Any,
):
# TODO(Peiqi): not support intermediate_tensors now
input_ids = safe_tensor_reshape(input_ids, (input_ids.shape[0], -1))
@@ -742,16 +730,11 @@ def _proj_from_thinker(x_opt: torch.Tensor | None) -> torch.Tensor:
self.tts_pad_embed = _proj_from_thinker(tts_pad_thinker)
return self.tts_bos_embed, self.tts_eos_embed, self.tts_pad_embed
- def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, payload: OmniPayload):
- hs: HiddenStates = payload.get("hidden_states", {})
- embed: Embeddings = payload.get("embed", {})
- ids: Ids = payload.get("ids", {})
- meta: OmniPayloadMeta = payload.get("meta", {})
-
+ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, **info_dict: dict):
# Containers to return per-request updates (e.g., code_predictor_hidden_per_request)
- update_dict: OmniPayload = {}
+ update_dict: dict[str, dict] = {}
- voice_type = payload.get("speaker")
+ voice_type = info_dict.get("speaker")
if voice_type is not None and isinstance(voice_type, (list, tuple)) and len(voice_type) > 0:
voice_type = voice_type[0]
if not isinstance(voice_type, str) or not voice_type.strip():
@@ -759,34 +742,40 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch
voice_type = self.default_tts_text_spk_type
else:
voice_type = str(voice_type).lower().strip()
- start_index = meta.get("num_processed_tokens", 0)
+ start_index = info_dict.get("num_processed_tokens", 0)
end_index = start_index + input_embeds.shape[0]
# Read thinker outputs for prefill
- thinker_sequence_embeds = embed["prefill"].to(
+ thinker_sequence_embeds = info_dict.get("thinker_prefill_embeddings").to(
device=self._module_device(self.talker), dtype=torch.bfloat16
) # Tensor [P,H]
- thinker_hidden_states = hs["output"].to(
+ thinker_hidden_states = info_dict.get("thinker_hidden_states").to(
device=self._module_device(self.talker), dtype=torch.bfloat16
) # Tensor [K,H]
thinker_sequences = (
- ids.get("all")
- if ids.get("all") is None
- else torch.as_tensor(ids["all"], device=self._module_device(self.talker))
+ info_dict.get("thinker_sequences")
+ if info_dict.get("thinker_sequences") is None
+ else torch.as_tensor(info_dict.get("thinker_sequences"), device=self._module_device(self.talker))
)
thinker_chatml_ids = (
- ids.get("prompt")
- if ids.get("prompt") is None
- else torch.as_tensor(ids["prompt"], device=self._module_device(self.talker))
+ info_dict.get("thinker_input_ids")
+ if info_dict.get("thinker_input_ids") is None
+ else torch.as_tensor(info_dict.get("thinker_input_ids"), device=self._module_device(self.talker))
)
- tts_bos_thinker = embed["tts_bos"].to(device=self._module_device(self.talker), dtype=torch.bfloat16)
- tts_eos_thinker = embed["tts_eos"].to(device=self._module_device(self.talker), dtype=torch.bfloat16)
- tts_pad_thinker = embed["tts_pad"].to(device=self._module_device(self.talker), dtype=torch.bfloat16)
+ tts_bos_thinker = info_dict.get("tts_bos_embed").to(
+ device=self._module_device(self.talker), dtype=torch.bfloat16
+ )
+ tts_eos_thinker = info_dict.get("tts_eos_embed").to(
+ device=self._module_device(self.talker), dtype=torch.bfloat16
+ )
+ tts_pad_thinker = info_dict.get("tts_pad_embed").to(
+ device=self._module_device(self.talker), dtype=torch.bfloat16
+ )
if thinker_sequence_embeds is None or thinker_hidden_states is None:
raise ValueError(
"additional_information_by_req_id must include "
- "'embed.prefill' and 'hidden_states.output' for talker prefill."
+ "'thinker_prefill_embeddings' and 'thinker_hidden_states' for talker prefill."
)
# Normalize to tensors
@@ -841,7 +830,7 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch
# compatible with old shape [1,S,D]
rem_tail = trailing_text_hidden.squeeze(0)
if rem_tail.shape[0] > 0:
- update_dict.setdefault("hidden_states", {})["trailing_text"] = rem_tail.detach()
+ update_dict["trailing_text_hidden"] = rem_tail.detach()
# Also persist projected tts_pad for decode fallback if needed
if isinstance(tts_pad_thinker, torch.Tensor):
pad_in = tts_pad_thinker
@@ -850,27 +839,27 @@ def talker_preprocess_prefill(self, input_ids: torch.Tensor, input_embeds: torch
if pad_in.ndim == 1:
pad_in = pad_in.view(1, 1, -1)
pad_proj = self.talker.text_projection(pad_in.to(self._module_device(self.talker)))
- update_dict.setdefault("embed", {})["tts_pad_projected"] = pad_proj.detach()
+ update_dict["tts_pad_embed_projected"] = pad_proj.detach()
except Exception:
pass
- update_dict.setdefault("meta", {})["prefill_consumed_text_tokens"] = 1
- self._talker_cache_thinker_decode_embeds(embed, update_dict)
+ update_dict["prefill_consumed_text_tokens"] = 1
+ self._talker_cache_thinker_decode_embeds(info_dict, update_dict)
return req_input_ids[start_index:end_index], req_embeds[start_index:end_index], update_dict
def _talker_cache_thinker_decode_embeds(
self,
- embed: Embeddings,
- update_dict: OmniPayload,
+ info_dict: dict[str, Any],
+ update_dict: dict[str, Any],
) -> None:
"""
Cache thinker embeds for decode stage.
"""
- thinker_decode_embeds = embed.get("decode", None)
+ thinker_decode_embeds = info_dict.get("thinker_decode_embeddings", None)
if thinker_decode_embeds is not None:
- cached_thinker_decode_embeds = embed.get("cached_decode", None)
+ cached_thinker_decode_embeds = info_dict.get("cached_thinker_decode_embeddings", None)
if cached_thinker_decode_embeds is None:
- update_dict.setdefault("embed", {})["cached_decode"] = thinker_decode_embeds
+ update_dict["cached_thinker_decode_embeddings"] = thinker_decode_embeds
else:
cached_thinker_decode_embeds = cached_thinker_decode_embeds.to(
device=self._module_device(self.talker), dtype=torch.bfloat16
@@ -878,10 +867,10 @@ def _talker_cache_thinker_decode_embeds(
thinker_decode_embeds = thinker_decode_embeds.to(
device=self._module_device(self.talker), dtype=torch.bfloat16
)
- update_dict.setdefault("embed", {})["cached_decode"] = torch.cat(
+ update_dict["cached_thinker_decode_embeddings"] = torch.cat(
[cached_thinker_decode_embeds, thinker_decode_embeds], dim=0
)
- update_dict.setdefault("embed", {})["decode"] = None
+ update_dict["thinker_decode_embeddings"] = None
def _thinker_to_talker_prefill(
self,
@@ -901,11 +890,10 @@ def _thinker_to_talker_prefill(
Returns:
(input_ids, input_embeds) for talker
"""
- target_len = thinker_result_ids.shape[-1]
im_start_indexes = torch.cat(
(
torch.nonzero(input_ids[0] == self.config.im_start_token_id).squeeze(),
- torch.tensor([target_len], device=input_ids.device, dtype=input_ids.dtype),
+ torch.tensor([thinker_result_ids.shape[-1]], device=input_ids.device, dtype=input_ids.dtype),
),
dim=-1,
) # Shape [n_starts + 1]; Take batch 0 since batched inference is not supported here.
@@ -968,27 +956,23 @@ def _thinker_to_talker_prefill(
def _thinker_decode_to_talker_decode(
self,
- payload: OmniPayload,
+ info_dict: dict,
device: torch.device,
update_dict,
):
"""
- Project thinker outputs to talker inputs during decode stage.
+ Project thinker outputs to talker inputs during prefill stage.
Returns:
- text_step embedding for talker
+ (input_ids, input_embeds) for talker
"""
- embed = payload.get("embed", {})
- meta = payload.get("meta", {})
- ids = payload.get("ids", {})
-
- cached_thinker_decode_embeds = embed.get("cached_decode", None)
- thinker_decode_embed = embed.get("decode", None)
- start_index = meta.get("num_processed_tokens", 0)
- thinker_output_token_ids = ids.get("output", [])
+ cached_thinker_decode_embeds = info_dict.get("cached_thinker_decode_embeddings", None)
+ thinker_decode_embed = info_dict.get("thinker_decode_embeddings", None)
+ start_index = info_dict.get("num_processed_tokens", 0)
+ thinker_output_token_ids = info_dict.get("thinker_output_token_ids", [])
if start_index >= len(thinker_output_token_ids) - 1:
- if meta.get("finished"):
+ if info_dict.get("finished_flag"):
return self.tts_pad_embed.to(device)
- update_dict.setdefault("meta", {})["finished"] = True
+ update_dict["finished_flag"] = True
return self.tts_eos_embed.to(device)
if cached_thinker_decode_embeds is not None and start_index < cached_thinker_decode_embeds.shape[0]:
@@ -997,26 +981,25 @@ def _thinker_decode_to_talker_decode(
if thinker_decode_embed is not None:
thinker_decode_embed = thinker_decode_embed.to(device)
cached_thinker_decode_embeds = torch.cat([cached_thinker_decode_embeds, thinker_decode_embed], dim=0)
- update_dict.setdefault("embed", {})["cached_decode"] = cached_thinker_decode_embeds
+ update_dict["cached_thinker_decode_embeddings"] = cached_thinker_decode_embeds
else:
thinker_embed = thinker_decode_embed
if thinker_embed.device != device:
thinker_embed = thinker_embed.to(device)
- update_dict.setdefault("embed", {})["decode"] = None
+
+ update_dict["thinker_decode_embeddings"] = None
return self.talker.text_projection(thinker_embed).to(device)
def talker_preprocess_decode(
- self, input_ids: torch.Tensor, input_embeds: torch.Tensor, update_dict: OmniPayload, payload: OmniPayload
+ self, input_ids: torch.Tensor, input_embeds: torch.Tensor, update_dict: dict, **info_dict: dict
):
- hs = payload.get("hidden_states", {})
-
last_talker_hidden = None
text_step = None
try:
if self.vllm_config.model_config.async_chunk:
- text_step = self._thinker_decode_to_talker_decode(payload, input_ids.device, update_dict)
+ text_step = self._thinker_decode_to_talker_decode(info_dict, input_ids.device, update_dict)
else:
- q_tail = hs.get("trailing_text", None)
+ q_tail = info_dict.get("trailing_text_hidden", None)
if isinstance(q_tail, torch.Tensor) and q_tail.numel() > 0:
use_vec = q_tail[0:1, :]
new_q_tail = (
@@ -1025,11 +1008,11 @@ def talker_preprocess_decode(
else self.tts_pad_embed.to(input_embeds.device, dtype=input_embeds.dtype)
)
text_step = use_vec.to(input_embeds.device, dtype=input_embeds.dtype)
- update_dict.setdefault("hidden_states", {})["trailing_text"] = new_q_tail
+ update_dict["trailing_text_hidden"] = new_q_tail
else:
text_step = self.tts_pad_embed.to(input_embeds.device, dtype=input_embeds.dtype)
- last_talker_hidden_tensor = hs.get("last")
+ last_talker_hidden_tensor = info_dict.get("last_talker_hidden")
if last_talker_hidden_tensor is not None:
last_talker_hidden = last_talker_hidden_tensor.to(input_embeds.device, dtype=input_embeds.dtype)
last_talker_hidden = last_talker_hidden.reshape(*last_talker_hidden.shape[-2:]) # [1, hidden_size]
@@ -1045,35 +1028,8 @@ def talker_preprocess_decode(
return last_talker_hidden, text_step, update_dict
def _get_talker_user_parts(self, im_start_index, segment_end_index, multimodal_mask, thinker_hidden, thinker_embed):
- clamped = min(
- segment_end_index,
- multimodal_mask.shape[0],
- thinker_hidden.shape[0],
- thinker_embed.shape[0],
- )
- if clamped < segment_end_index:
- logger.warning(
- "_get_talker_user_parts: segment_end_index %d clamped to %d "
- "(embed=%d, hidden=%d, mask=%d). "
- "This usually means _merge_pd_embeddings failed to merge "
- "prefill embeddings – check PD prefill_mm keys.",
- segment_end_index,
- clamped,
- thinker_embed.shape[0],
- thinker_hidden.shape[0],
- multimodal_mask.shape[0],
- )
- segment_end_index = clamped
- seg_len = segment_end_index - im_start_index
- if seg_len <= 0:
- return torch.empty(
- (0, self.config.talker_config.text_config.hidden_size),
- device=thinker_hidden.device,
- dtype=torch.bfloat16,
- )
-
user_talker_part = torch.empty(
- (seg_len, self.config.talker_config.text_config.hidden_size),
+ (segment_end_index - im_start_index, self.config.talker_config.text_config.hidden_size),
device=thinker_hidden.device,
dtype=torch.bfloat16,
)
diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py
index 73c7e41d26f..2ceaafdb670 100644
--- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py
+++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_code_predictor_mtp.py
@@ -1,29 +1,510 @@
-"""Qwen3-Omni Code Predictor -- thin wrapper over CodePredictorWrapper."""
+"""Qwen3-Omni Code Predictor -- optimized re-prefill, no KV cache.
+* SDPA attention (F.scaled_dot_product_attention) with native GQA support
+* HF-compatible numerics (float32 RMSNorm, float32 RoPE, separate linear layers)
+* Per-call embedding buffer to avoid cross-request aliasing
+* Pre-allocated position_ids (read-only, safe to persist)
+* torch.compile (epilogue_fusion=False) on inner transformer by default
+* Inline sampling (top-k + top-p) -- no custom op overhead
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-from vllm_omni.model_executor.models.common.qwen3_code_predictor import (
- CodePredictorWrapper,
- CodePredictorWrapperConfig,
-)
from vllm_omni.platforms import current_omni_platform
+logger = init_logger(__name__)
+
+
+# ===================================================================
+# HF-numerics-compatible layers for code predictor
+# ===================================================================
+#
+# These use plain PyTorch ops (nn.Linear, manual RMSNorm in float32,
+# rotate_half RoPE) to produce outputs numerically identical to the
+# HuggingFace reference. vLLM's fused kernels (RMSNorm, QKVParallel,
+# get_rope) introduce small precision differences that compound across
+# the autoregressive steps of the code predictor, causing severe
+# audio quality degradation.
+#
+# See: https://github.com/vllm-project/vllm-omni/issues/2274
+
+
+class _RMSNorm(nn.Module):
+ """RMSNorm matching HuggingFace's implementation exactly.
+
+ Computes variance in float32 to avoid bfloat16 precision loss.
+ """
+
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+def _rotate_half(x: torch.Tensor) -> torch.Tensor:
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+class _RotaryEmbedding(nn.Module):
+ """RoPE matching HuggingFace's implementation exactly.
+
+ Forces float32 computation for cos/sin, matching HF's torch.autocast(enabled=False).
+ """
+
+ def __init__(self, config) -> None:
+ super().__init__()
+ head_dim = getattr(
+ config,
+ "head_dim",
+ config.hidden_size // config.num_attention_heads,
+ )
+ rope_theta = getattr(config, "rope_theta", 10000.0)
+ inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ # position_ids: [batch, seq_len]
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ # Force float32 (matching HF)
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class Qwen3OmniCodePredictorAttention(nn.Module):
+ """Multi-head self-attention for code predictor.
+
+ Uses ``F.scaled_dot_product_attention`` with HF-compatible RoPE and RMSNorm.
+ No KV cache -- the code predictor always re-prefills the full (short)
+ sequence each AR step.
+
+ Input : [B, seq_len, hidden_size]
+ Output: [B, seq_len, hidden_size]
+ """
+
+ def __init__(
+ self,
+ config,
+ prefix: str = "",
+ ):
+ super().__init__()
+ cp_cfg = config.code_predictor_config
+ self.num_heads = cp_cfg.num_attention_heads
+ self.num_kv_heads = cp_cfg.num_key_value_heads
+ self.head_dim = getattr(
+ cp_cfg,
+ "head_dim",
+ cp_cfg.hidden_size // cp_cfg.num_attention_heads,
+ )
+ self.hidden_size = cp_cfg.hidden_size
+ self.scaling = self.head_dim**-0.5
+ self._use_gqa = self.num_kv_heads != self.num_heads
+
+ # Separate q/k/v projections matching HF (no fused packing)
+ self.q_proj = nn.Linear(
+ self.hidden_size,
+ self.num_heads * self.head_dim,
+ bias=False,
+ )
+ self.k_proj = nn.Linear(
+ self.hidden_size,
+ self.num_kv_heads * self.head_dim,
+ bias=False,
+ )
+ self.v_proj = nn.Linear(
+ self.hidden_size,
+ self.num_kv_heads * self.head_dim,
+ bias=False,
+ )
+ self.o_proj = nn.Linear(
+ self.num_heads * self.head_dim,
+ self.hidden_size,
+ bias=False,
+ )
+ self.q_norm = _RMSNorm(self.head_dim, eps=cp_cfg.rms_norm_eps)
+ self.k_norm = _RMSNorm(self.head_dim, eps=cp_cfg.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ ) -> torch.Tensor:
+ bsz, seq_len, _ = hidden_states.shape
+ hidden_shape_q = (bsz, seq_len, self.num_heads, self.head_dim)
+ hidden_shape_kv = (bsz, seq_len, self.num_kv_heads, self.head_dim)
+
+ q = self.q_norm(self.q_proj(hidden_states).view(hidden_shape_q)).transpose(1, 2)
+ k = self.k_norm(self.k_proj(hidden_states).view(hidden_shape_kv)).transpose(1, 2)
+ v = self.v_proj(hidden_states).view(hidden_shape_kv).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ # cos/sin are [batch, seq_len, head_dim], need unsqueeze at dim=1 for heads
+ cos = cos.unsqueeze(1) # [batch, 1, seq_len, head_dim]
+ sin = sin.unsqueeze(1)
+ q = (q * cos) + (_rotate_half(q) * sin)
+ k = (k * cos) + (_rotate_half(k) * sin)
+
+ attn_out = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ scale=self.scaling,
+ is_causal=True,
+ enable_gqa=self._use_gqa,
+ )
+
+ attn_out = attn_out.transpose(1, 2).reshape(bsz, seq_len, -1)
+ output = self.o_proj(attn_out)
+ return output
+
+
+# ===================================================================
+# MLP
+# ===================================================================
+
+
+class Qwen3OmniCodePredictorMLP(nn.Module):
+ """SiLU-gated MLP for code predictor, matching HF's implementation."""
+
+ def __init__(
+ self,
+ config,
+ prefix: str = "",
+ ):
+ super().__init__()
+ hidden_size = config.code_predictor_config.hidden_size
+ intermediate_size = config.code_predictor_config.intermediate_size
+
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
+
+
+# ===================================================================
+# Decoder Layer
+# ===================================================================
+
+
+class Qwen3OmniCodePredictorDecoderLayer(nn.Module):
+ """Transformer decoder layer (SDPA, no KV cache)."""
+
+ def __init__(
+ self,
+ config,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.self_attn = Qwen3OmniCodePredictorAttention(
+ config,
+ prefix=f"{prefix}.self_attn",
+ )
+ self.mlp = Qwen3OmniCodePredictorMLP(
+ config,
+ prefix=f"{prefix}.mlp",
+ )
+ cp_cfg = config.code_predictor_config
+ self.input_layernorm = _RMSNorm(cp_cfg.hidden_size, eps=cp_cfg.rms_norm_eps)
+ self.post_attention_layernorm = _RMSNorm(cp_cfg.hidden_size, eps=cp_cfg.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states = self.self_attn(hidden_states, position_embeddings)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
-class Qwen3OmniMoeTalkerCodePredictor(CodePredictorWrapper):
- """Qwen3-Omni code predictor (no CUDA graphs, VocabParallelEmbedding)."""
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
- cp_config = vllm_config.model_config.hf_config.code_predictor_config
- super().__init__(
+# ===================================================================
+# Base Transformer Model (re-prefill, no KV cache)
+# ===================================================================
+
+
+class Qwen3OmniCodePredictorBaseModel(nn.Module):
+ """Inner transformer for code predictor.
+
+ Signature: ``forward(inputs_embeds, position_ids) -> hidden_states``
+ -- plain Tensor in, plain Tensor out (no namedtuple).
+ """
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config.code_predictor_config
+ self.config = config
+
+ self.codec_embedding = nn.ModuleList(
+ [VocabParallelEmbedding(config.vocab_size, config.hidden_size) for _ in range(config.num_code_groups - 1)]
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ Qwen3OmniCodePredictorDecoderLayer(
+ vllm_config.model_config.hf_config,
+ prefix=f"{prefix}.layers.{idx}",
+ )
+ for idx in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = _RotaryEmbedding(config)
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+ for layer in self.layers:
+ hidden_states = layer(hidden_states, position_embeddings)
+ hidden_states = self.norm(hidden_states)
+ return hidden_states
+
+
+# ===================================================================
+# Code Predictor Wrapper (optimized re-prefill, persistent buffers)
+# ===================================================================
+
+
+class Qwen3OmniMoeTalkerCodePredictor(nn.Module):
+ """Optimized code predictor -- re-prefill approach, no KV cache.
+
+ Each AR step forwards the full growing sequence (len 2 -> num_code_groups+1)
+ through the transformer. The extra O(T^2) FLOPs are negligible for
+ short sequences, and this avoids all KV-cache management overhead.
+
+ Optimizations:
+ 1. Per-call embedding buffer -- avoids cross-request aliasing.
+ 2. Pre-allocated position_ids -- no torch.arange per step.
+ 3. Cached module references -- bypass ModuleList indexing.
+ 4. torch.compile on inner transformer.
+ 5. Inline sampling (top-k + top-p) -- no custom op overhead.
+ """
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ self.config = config
+ self.quant_config = vllm_config.quant_config
+ self.prefix = prefix
+
+ self.num_code_groups = config.code_predictor_config.num_code_groups
+ self._hidden_size = config.code_predictor_config.hidden_size
+
+ self.model = Qwen3OmniCodePredictorBaseModel(
vllm_config=vllm_config,
- cp_config=cp_config,
- wrapper_config=CodePredictorWrapperConfig(
- use_cuda_graphs=current_omni_platform.is_npu(),
- use_parallel_embedding=True,
- use_projection=False,
- return_proj_buf=True,
- sampling_mode="stored",
- ),
- talker_hidden_size=cp_config.hidden_size,
prefix=prefix,
)
+
+ # One lm_head per residual layer (layers 1 .. G-1)
+ self.lm_head = nn.ModuleList(
+ [
+ nn.Linear(
+ config.code_predictor_config.hidden_size,
+ config.code_predictor_config.vocab_size,
+ bias=False,
+ )
+ for _ in range(self.num_code_groups - 1)
+ ]
+ )
+
+ self.set_sampling_params()
+
+ # Lazily initialised position ids (read-only, safe to persist)
+ self._pos_ids: torch.Tensor | None = None
+
+ # Cached plain-list refs (set once)
+ self._lm_heads: list | None = None
+ self._codec_embeds: list | None = None
+
+ # Model forward (optionally compiled)
+ self._model_fwd: object | None = None
+
+ def set_sampling_params(self, top_k: int = 50, top_p: float = 0.8):
+ """Configure sampling parameters to maintain consistency with previous implementation."""
+ self._top_k = top_k
+ self._top_p = top_p
+ logger.debug(f"Sampling parameters updated: top_k={top_k}, top_p={top_p}s")
+
+ # ------------------------------------------------------------------
+ # Lazy-init helpers
+ # ------------------------------------------------------------------
+
+ def _ensure_pos_ids(self, device: torch.device) -> None:
+ if self._pos_ids is not None and self._pos_ids.device == device:
+ return
+ max_seq = self.num_code_groups + 1
+ # [1, max_seq] for HF-style RoPE (will be expanded to [bsz, seq_len] at use)
+ self._pos_ids = torch.arange(max_seq, dtype=torch.long, device=device).unsqueeze(0)
+
+ def _ensure_cached_refs(self) -> None:
+ if self._lm_heads is not None:
+ return
+ self._lm_heads = list(self.lm_head)
+ self._codec_embeds = list(self.model.codec_embedding)
+
+ def _ensure_model_fwd(self) -> None:
+ if self._model_fwd is not None:
+ return
+ if current_omni_platform.supports_torch_inductor():
+ # torch.compile fuses RMSNorm/RoPE in ways that lose float32
+ # precision, compounding across AR steps. Use epilogue_fusion=False
+ # to disable the problematic fusions while still getting kernel
+ # fusion benefits for the linear layers and SDPA.
+ self._model_fwd = torch.compile(
+ self.model.forward,
+ dynamic=True,
+ options={
+ "epilogue_fusion": False,
+ },
+ )
+ logger.info("code_predictor: torch.compile enabled (no epilogue fusion)")
+ else:
+ self._model_fwd = self.model.forward
+ logger.info("code_predictor: using eager mode (no torch.compile)")
+
+ # ------------------------------------------------------------------
+ # Forward -- re-prefill + inline sampling
+ # ------------------------------------------------------------------
+
+ @torch.inference_mode()
+ def forward(
+ self,
+ layer0_code: torch.Tensor,
+ layer0_embed: torch.Tensor,
+ last_talker_hidden: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Predict residual codebooks 1..G-1 autoregressively via re-prefill.
+
+ Args:
+ layer0_code: [bsz, 1] int64
+ layer0_embed: [bsz, 1, hidden_size]
+ last_talker_hidden: [bsz, 1, hidden_size]
+
+ Returns:
+ all_codes: [bsz, num_code_groups, 1]
+ proj_buf: [bsz, num_code_groups + 1, hidden_size]
+ pos 0 = last_talker_hidden (NOT a codec embed)
+ pos 1 = layer0_embed
+ pos 2.. = `codec_embedding[i](predicted_code_i)`
+ """
+ bsz = int(layer0_code.shape[0])
+ device = layer0_code.device
+ dtype = last_talker_hidden.dtype
+ num_groups = self.num_code_groups
+
+ # Lazy init (read-only caches only)
+ self._ensure_pos_ids(device)
+ self._ensure_model_fwd()
+ self._ensure_cached_refs()
+
+ # Allocate proj_buf locally each call to avoid cross-call aliasing
+ max_seq = num_groups + 1
+ proj_buf = torch.zeros(bsz, max_seq, self._hidden_size, dtype=dtype, device=device)
+ pos_ids = self._pos_ids
+ model_fwd = self._model_fwd
+ lm_heads = self._lm_heads
+ codec_embeds = self._codec_embeds
+
+ # Output codes
+ all_codes = torch.empty(bsz, num_groups, 1, dtype=torch.int64, device=device)
+ all_codes[:, 0] = layer0_code
+
+ # Fill buffer positions 0 & 1
+ proj_buf[:bsz, 0:1, :] = last_talker_hidden
+ proj_buf[:bsz, 1:2, :] = layer0_embed
+
+ # Autoregressive loop: predict layers 1..G-1
+ for step in range(1, num_groups):
+ seq_len = step + 1
+ projected = proj_buf[:bsz, :seq_len, :]
+ # position_ids: [batch, seq_len] for HF-style RoPE
+ step_pos_ids = pos_ids[:, :seq_len].expand(bsz, -1)
+
+ hidden_out = model_fwd(projected, step_pos_ids)
+
+ # Inline sampling: top-k -> top-p -> softmax -> multinomial
+ logits = lm_heads[step - 1](hidden_out[:, -1, :]) # [bsz, vocab]
+ if self._top_k > 0:
+ topk_vals, _ = logits.topk(self._top_k, dim=-1)
+ logits = logits.masked_fill(logits < topk_vals[:, -1:], float("-inf"))
+ if self._top_p < 1.0:
+ sorted_logits, sorted_idx = logits.sort(dim=-1, descending=True)
+ cumulative_probs = F.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
+ # Remove tokens with cumulative probability above top_p
+ remove_mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= self._top_p
+ sorted_logits[remove_mask] = float("-inf")
+ logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
+ probs = F.softmax(logits, dim=-1)
+ code = torch.multinomial(probs, num_samples=1) # [bsz, 1]
+
+ all_codes[:, step] = code
+
+ # Embed predicted code -> next buffer position
+ new_embed = codec_embeds[step - 1](code) # [batch, 1, hidden_size]
+ proj_buf[:bsz, step + 1 : step + 2, :] = new_embed
+
+ return all_codes, proj_buf[:bsz]
+
+ # ------------------------------------------------------------------
+ # Weight loading
+ # ------------------------------------------------------------------
+
+ def load_weights(self, weights: list[tuple[str, torch.Tensor]]) -> set[str]:
+ """Load weights directly (no fused projection remapping needed).
+
+ Since we use separate nn.Linear for q/k/v/o and gate/up/down,
+ weight names match the HF checkpoint directly.
+ """
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+
+ for name, loaded_weight in weights:
+ # Skip rotary embeddings
+ if "rotary_emb.inv_freq" in name:
+ continue
+
+ param = params_dict.get(name)
+ if param is None:
+ continue
+
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+
+ return loaded_params
diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py
index 28f49918f2c..bb491d01b61 100644
--- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py
+++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py
@@ -100,23 +100,11 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
talker_config: Qwen3OmniMoeTalkerConfig = vllm_config.model_config.hf_config
- rope_params = getattr(talker_config.text_config, "rope_scaling", None)
+ rope_params = talker_config.text_config.rope_scaling
if rope_params is None:
+ # Newer transformers use rope_parameters instead of rope_scaling
rope_params = getattr(talker_config.text_config, "rope_parameters", None) or {}
- rope_params = dict(rope_params)
- # In transformers <5.0.0, rope_theta is a top-level config attribute
- # (e.g. config.text_config.rope_theta = 1000000.0).
- # In transformers >=5.0.0 (PR #39847), rope_theta moved inside the
- # rope_parameters dict (e.g. config.text_config.rope_parameters =
- # {"rope_theta": 1000000.0, "rope_type": "default"}).
- # Use setdefault so we never overwrite a value already present.
- # Precedence: rope_params["rope_theta"] (already set)
- # > text_config.rope_theta (transformers <5.0.0 top-level attr)
- # > 1000000 (Qwen3 Omni default)
- rope_params.setdefault(
- "rope_theta",
- getattr(talker_config.text_config, "rope_theta", 1000000),
- )
+ rope_params["rope_theta"] = talker_config.text_config.rope_theta
talker_config.text_config.rope_parameters = rope_params
quant_config = vllm_config.quant_config
if isinstance(quant_config, ComponentQuantizationConfig):
diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py
index 0d6f4334208..671ffb6cb16 100644
--- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py
+++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py
@@ -119,10 +119,7 @@
from vllm_omni.model_executor.models.qwen2_5_omni.qwen2_5_omni_thinker import (
Qwen2_5OmniConditionalGenerationMixin,
)
-from vllm_omni.quantization.component_config import (
- PRE_QUANTIZED_METHODS,
- ComponentQuantizationConfig,
-)
+from vllm_omni.quantization.component_config import ComponentQuantizationConfig
try:
import flash_attn
@@ -545,9 +542,7 @@ def forward(
if captured_hidden_states is not None and capture_set is not None:
if layer_idx in capture_set:
- hs = captured_hidden_states.setdefault("hidden_states", {})
- layers = hs.setdefault("layers", {})
- layers[layer_idx] = hidden_states.clone().view(-1, hidden_states.shape[-1])
+ captured_hidden_states[str(layer_idx)] = hidden_states.clone().view(-1, hidden_states.shape[-1])
hidden_states, residual = layer(
positions,
@@ -1119,24 +1114,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.multimodal_config = multimodal_config
self.quant_config = quant_config
- # Pre-quantized checkpoints (modelopt NVFP4/FP8/MXFP8) only quantize
- # the Thinker LM (language model). Vision and audio encoder weights
- # remain in BF16 and have no corresponding scale tensors in the
- # checkpoint. Dynamic quantization methods (e.g. --quantization fp8)
- # should also only target the language model.
+ # Pre-quantized checkpoints (modelopt NVFP4/FP8/MXFP8) quantize the
+ # entire thinker — audio tower, visual encoder, and language model
+ # all share the same quant method. Dynamic quantization methods
+ # (e.g. --quantization fp8) should only target the language model.
+ _PRE_QUANTIZED_METHODS = {"modelopt", "modelopt_fp4", "modelopt_mxfp8"}
if isinstance(quant_config, ComponentQuantizationConfig):
audio_quant_config = quant_config.resolve("audio_tower")
visual_quant_config = quant_config.resolve("visual")
language_quant_config = quant_config.resolve("language_model")
elif quant_config is not None:
- if quant_config.get_name() in PRE_QUANTIZED_METHODS:
- # Pre-quantized: only the Thinker LM is quantized.
- # Vision/audio encoder weights are BF16 with no FP8 scales;
- # passing quant_config to them causes FP8 kernels to run on
- # BF16 weights (producing garbage embeddings). Keep None.
- audio_quant_config = None
- visual_quant_config = None
+ if quant_config.get_name() in _PRE_QUANTIZED_METHODS:
+ # Pre-quantized: pass quant_config to all subcomponents.
+ audio_quant_config = quant_config
+ visual_quant_config = quant_config
language_quant_config = quant_config
else:
# Dynamic quantization: scope to language_model only.
diff --git a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py
index 26d69fe22a4..96f8c799c13 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/cuda_graph_decoder_wrapper.py
@@ -10,7 +10,6 @@
import torch
from torch.cuda import CUDAGraph
from vllm.logger import init_logger
-from vllm.platforms import current_platform
logger = init_logger(__name__)
@@ -130,7 +129,7 @@ def _capture(self, size: int, device: torch.device, dtype: torch.dtype):
graph = CUDAGraph()
with torch.no_grad():
- with torch.cuda.graph(graph, pool=current_platform.get_global_graph_pool()):
+ with torch.cuda.graph(graph):
static_output = self.decoder(static_input)
self.graphs[size] = graph
@@ -141,15 +140,6 @@ def decode(self, codes: torch.Tensor) -> torch.Tensor:
if not self.enabled or not self._warmed_up or codes.shape[0] != 1:
return self.decoder(codes)
- # Inner CUDA graph replay is illegal while an outer stream capture is
- # active (e.g. vLLM's cudagraph_mode=FULL warmup on Stage 1). Fall back
- # to eager in that case so the outer capture can complete. The guard is
- # a no-op at runtime: is_current_stream_capturing() returns False
- # outside the startup capture window, so normal inference still hits
- # the graph fast path.
- if torch.cuda.is_current_stream_capturing():
- return self.decoder(codes)
-
actual_size = codes.shape[-1]
padded_size = self._get_padded_size(actual_size)
diff --git a/vllm_omni/model_executor/models/qwen3_tts/pipeline.py b/vllm_omni/model_executor/models/qwen3_tts/pipeline.py
deleted file mode 100644
index 5051715ceac..00000000000
--- a/vllm_omni/model_executor/models/qwen3_tts/pipeline.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Qwen3-TTS pipeline: Talker (text → RVQ codec) → Code2Wav (codec → audio).
-
-Chunked vs end-to-end mode is dispatched from ``deploy.async_chunk``.
-"""
-
-from vllm_omni.config.stage_config import (
- PipelineConfig,
- StageExecutionType,
- StagePipelineConfig,
-)
-
-_PROC = "vllm_omni.model_executor.stage_input_processors.qwen3_tts"
-
-QWEN3_TTS_PIPELINE = PipelineConfig(
- model_type="qwen3_tts",
- # Pipeline-level default; the code2wav stage overrides per-stage below.
- model_arch="Qwen3TTSTalkerForConditionalGeneration",
- stages=(
- StagePipelineConfig(
- stage_id=0,
- model_stage="qwen3_tts",
- execution_type=StageExecutionType.LLM_AR,
- input_sources=(),
- owns_tokenizer=True,
- engine_output_type="latent",
- async_chunk_process_next_stage_input_func=(f"{_PROC}.talker2code2wav_async_chunk"),
- sampling_constraints={
- "detokenize": False,
- "stop_token_ids": [2150],
- },
- ),
- StagePipelineConfig(
- stage_id=1,
- model_stage="code2wav",
- execution_type=StageExecutionType.LLM_GENERATION,
- input_sources=(0,),
- final_output=True,
- final_output_type="audio",
- engine_output_type="audio",
- model_arch="Qwen3TTSCode2Wav",
- sync_process_input_func=f"{_PROC}.talker2code2wav",
- sampling_constraints={"detokenize": True},
- extras={"tts_args": {"max_instructions_length": 500}},
- ),
- ),
-)
diff --git a/vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml b/vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml
new file mode 100644
index 00000000000..6e3c78ff934
--- /dev/null
+++ b/vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml
@@ -0,0 +1,92 @@
+model_type: qwen3_tts
+async_chunk: true
+
+stages:
+ - stage_id: 0
+ model_stage: qwen3_tts
+ stage_type: llm
+ is_comprehension: true
+ input_sources: []
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ runtime:
+ devices: "0"
+ engine_args:
+ max_num_seqs: 10
+ model_arch: Qwen3TTSTalkerForConditionalGeneration
+ hf_overrides:
+ architectures: [Qwen3TTSTalkerForConditionalGeneration]
+ enforce_eager: false
+ trust_remote_code: true
+ async_scheduling: true
+ enable_prefix_caching: false
+ engine_output_type: latent
+ gpu_memory_utilization: 0.08
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 512
+ max_model_len: 4096
+ custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
+ output_connectors:
+ to_stage_1: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ detokenize: false
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+
+ - stage_id: 1
+ model_stage: code2wav
+ stage_type: llm
+ input_sources: [0]
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ final_output: true
+ final_output_type: audio
+ runtime:
+ devices: "0"
+ engine_args:
+ max_num_seqs: 1
+ model_arch: Qwen3TTSCode2Wav
+ hf_overrides:
+ architectures: [Qwen3TTSCode2Wav]
+ enforce_eager: true
+ trust_remote_code: true
+ async_scheduling: true
+ enable_prefix_caching: false
+ engine_output_type: audio
+ gpu_memory_utilization: 0.08
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 65536
+ max_model_len: 65536
+ input_connectors:
+ from_stage_0: connector_of_shared_memory
+ tts_args:
+ max_instructions_length: 500
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ detokenize: true
+ repetition_penalty: 1.0
+
+connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536
+ codec_streaming: true
+ connector_get_sleep_s: 0.01
+ connector_get_max_wait_first_chunk: 3000
+ connector_get_max_wait: 300
+ codec_chunk_frames: 25
+ codec_left_context_frames: 25
+
+edges:
+ - from: 0
+ to: 1
+ window_size: -1
diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py
index b6c384881bf..f6ac91a994f 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code2wav.py
@@ -41,7 +41,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._num_quantizers: int | None = None
self._output_sample_rate: int | None = None
self._total_upsample: int | None = None
- self._decoder_sliding_window: int | None = None
self._logged_codec_stats = False
@staticmethod
@@ -107,7 +106,6 @@ def _ensure_speech_tokenizer_loaded(self) -> None:
self._num_quantizers = num_q
self._output_sample_rate = out_sr
self._total_upsample = int(decoder.total_upsample)
- self._decoder_sliding_window = int(getattr(dec_cfg, "sliding_window", 0) or 0)
# Precompute SnakeBeta exp caches (benefits both Triton and eager paths)
if hasattr(decoder, "precompute_snake_caches"):
@@ -130,20 +128,6 @@ def _ensure_speech_tokenizer_loaded(self) -> None:
if isinstance(extra_cfg, dict):
chunk_frames = int(extra_cfg.get("codec_chunk_frames") or 0)
left_frames = int(extra_cfg.get("codec_left_context_frames") or 0)
- if (
- chunk_frames > 0
- and left_frames > 0
- and self._decoder_sliding_window
- and left_frames < self._decoder_sliding_window
- ):
- logger.warning(
- "Qwen3-TTS streaming codec_left_context_frames=%d is smaller than "
- "decoder sliding_window=%d; chunk-boundary distortion may occur. "
- "Increase codec_left_context_frames to at least %d for streaming.",
- left_frames,
- self._decoder_sliding_window,
- self._decoder_sliding_window,
- )
decoder.enable_cudagraph(
device=device,
@@ -234,10 +218,9 @@ def forward(
for i, info in enumerate(runtime_additional_information):
if i >= len(left_context_size):
break
- meta = info.get("meta", {})
- if "left_context_size" in meta:
+ if "left_context_size" in info:
# left_context_size may come through serialization as an int, [int], or tensor([int]).
- value = meta["left_context_size"]
+ value = info["left_context_size"]
if isinstance(value, list):
value = value[0] if value else 0
if isinstance(value, torch.Tensor):
@@ -306,17 +289,21 @@ def forward(
for j, idx in enumerate(valid_indices):
ctx_frames, actual_frames = parsed[idx]
wav = wav_tensors[j]
- # Slice on exact codec-frame boundaries instead of proportionally.
- start = max(0, ctx_frames * upsample)
- end = max(start, actual_frames * upsample)
- if start >= wav.shape[0]:
- logger.warning(
- "Context trim start %d >= decoded length %d; returning empty audio.",
- start,
- wav.shape[0],
- )
- continue
- wav = wav[start : min(end, wav.shape[0])]
+ # Drop the ref_code prefix from the decoded waveform, keeping only newly generated audio.
+ if ctx_frames <= 0:
+ expected_len = actual_frames * upsample
+ if wav.shape[0] > expected_len:
+ wav = wav[:expected_len]
+ else:
+ cut = int(ctx_frames / max(actual_frames, 1) * wav.shape[0])
+ if cut >= wav.shape[0]:
+ logger.warning(
+ "Context trim %d >= decoded length %d; returning empty audio.",
+ cut,
+ wav.shape[0],
+ )
+ continue
+ wav = wav[cut:]
if wav.shape[0] > 0:
audios[idx] = wav.to(dtype=torch.float32).reshape(-1)
@@ -325,18 +312,12 @@ def forward(
multimodal_outputs={"model_outputs": audios, "sr": srs},
)
- def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput | tuple, **kwargs: Any) -> OmniOutput:
+ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: Any) -> OmniOutput:
if isinstance(model_outputs, OmniOutput):
return model_outputs
- if isinstance(model_outputs, tuple) and len(model_outputs) == len(OmniOutput._fields):
- return OmniOutput(*model_outputs)
-
if not (isinstance(model_outputs, tuple) and len(model_outputs) == 2):
- raise TypeError(
- "Qwen3TTSCode2Wav expected OmniOutput, OmniOutput tuple, "
- f"or (audio_tensor, sr) outputs, got {type(model_outputs)}"
- )
+ raise TypeError(f"Qwen3TTSCode2Wav expected (audio_tensor, sr) outputs, got {type(model_outputs)}")
audio_tensor, sr = model_outputs
return OmniOutput(
diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py
index 8d2f0686ae0..11c0369e820 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_code_predictor_vllm.py
@@ -1,27 +1,318 @@
-"""Qwen3-TTS Code Predictor -- thin wrapper over CodePredictorWrapper."""
-
from __future__ import annotations
from collections.abc import Iterable
import torch
+import torch.nn as nn
+import torch.nn.functional as F
from vllm.config import VllmConfig
from vllm.config.vllm import set_current_vllm_config
-
-from vllm_omni.model_executor.models.common.qwen3_code_predictor import (
- CodePredictorBaseModel,
- CodePredictorWrapper,
- CodePredictorWrapperConfig,
+from vllm.logger import init_logger
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader,
)
+from vllm_omni.platforms import current_omni_platform
+
from .configuration_qwen3_tts import Qwen3TTSTalkerCodePredictorConfig, Qwen3TTSTalkerConfig
-# Backward-compat alias used by tests
-Qwen3TTSTalkerCodePredictorModelVLLM = CodePredictorBaseModel
+logger = init_logger(__name__)
+
+
+# ===================================================================
+# HF-numerics-compatible layers for code predictor
+# ===================================================================
+#
+# These use plain PyTorch ops (nn.Linear, manual RMSNorm in float32,
+# rotate_half RoPE) to produce outputs numerically identical to the
+# HuggingFace reference. vLLM's fused kernels (RMSNorm, QKVParallel,
+# get_rope) introduce small precision differences that compound across
+# the 15 autoregressive steps of the code predictor, causing severe
+# audio quality degradation (UTMOS ~4.26 → ~2.66).
+#
+# See: https://github.com/vllm-project/vllm-omni/issues/2274
+
+
+class _RMSNorm(nn.Module):
+ """RMSNorm matching HuggingFace's Qwen3TTSRMSNorm exactly.
+
+ Computes variance in float32 to avoid bfloat16 precision loss.
+ """
+
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+def _rotate_half(x: torch.Tensor) -> torch.Tensor:
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+class _RotaryEmbedding(nn.Module):
+ """RoPE matching HuggingFace's Qwen3TTSRotaryEmbedding exactly.
+
+ Forces float32 computation for cos/sin, matching HF's torch.autocast(enabled=False).
+ """
+
+ def __init__(self, config: Qwen3TTSTalkerCodePredictorConfig) -> None:
+ super().__init__()
+ head_dim = getattr(
+ config,
+ "head_dim",
+ config.hidden_size // config.num_attention_heads,
+ )
+ # Standard default RoPE
+ rope_theta = getattr(config, "rope_theta", 10000.0)
+ inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ # position_ids: [batch, seq_len]
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ # Force float32 (matching HF)
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class _CodePredictorAttention(nn.Module):
+ """Standalone multi-head attention for code predictor.
+
+ Uses F.scaled_dot_product_attention with HF-compatible RoPE and RMSNorm.
+ Input: [B, seq_len, hidden_size], output: [B, seq_len, hidden_size].
+ """
+
+ def __init__(
+ self,
+ config: Qwen3TTSTalkerCodePredictorConfig,
+ *,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.num_kv_heads = config.num_key_value_heads
+ self.head_dim = getattr(
+ config,
+ "head_dim",
+ config.hidden_size // config.num_attention_heads,
+ )
+ self.scaling = self.head_dim**-0.5
+ self._use_gqa = self.num_kv_heads != self.num_heads
+
+ # Separate q/k/v projections matching HF (no fused packing)
+ self.q_proj = nn.Linear(
+ self.hidden_size,
+ self.num_heads * self.head_dim,
+ bias=getattr(config, "attention_bias", False),
+ )
+ self.k_proj = nn.Linear(
+ self.hidden_size,
+ self.num_kv_heads * self.head_dim,
+ bias=getattr(config, "attention_bias", False),
+ )
+ self.v_proj = nn.Linear(
+ self.hidden_size,
+ self.num_kv_heads * self.head_dim,
+ bias=getattr(config, "attention_bias", False),
+ )
+ self.o_proj = nn.Linear(
+ self.num_heads * self.head_dim,
+ self.hidden_size,
+ bias=False,
+ )
+ self.q_norm = _RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.k_norm = _RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ ) -> torch.Tensor:
+ bsz, seq_len, _ = hidden_states.shape
+ hidden_shape_q = (bsz, seq_len, self.num_heads, self.head_dim)
+ hidden_shape_kv = (bsz, seq_len, self.num_kv_heads, self.head_dim)
+
+ q = self.q_norm(self.q_proj(hidden_states).view(hidden_shape_q)).transpose(1, 2)
+ k = self.k_norm(self.k_proj(hidden_states).view(hidden_shape_kv)).transpose(1, 2)
+ v = self.v_proj(hidden_states).view(hidden_shape_kv).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ # cos/sin are [batch, seq_len, head_dim], need unsqueeze at dim=1 for heads
+ cos = cos.unsqueeze(1) # [batch, 1, seq_len, head_dim]
+ sin = sin.unsqueeze(1)
+ q = (q * cos) + (_rotate_half(q) * sin)
+ k = (k * cos) + (_rotate_half(k) * sin)
+
+ attn_out = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ scale=self.scaling,
+ is_causal=True,
+ enable_gqa=self._use_gqa,
+ )
+
+ attn_out = attn_out.transpose(1, 2).reshape(bsz, seq_len, -1)
+ output = self.o_proj(attn_out)
+ return output
+
+
+class _CodePredictorMLP(nn.Module):
+ """SiLU-gated MLP for code predictor, matching HF's Qwen3TTSTalkerTextMLP."""
+
+ def __init__(
+ self,
+ config: Qwen3TTSTalkerCodePredictorConfig,
+ *,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
+
+
+class _CodePredictorDecoderLayer(nn.Module):
+ """Transformer decoder layer for code predictor (SDPA, no KV cache)."""
+
+ def __init__(
+ self,
+ config: Qwen3TTSTalkerCodePredictorConfig,
+ *,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.self_attn = _CodePredictorAttention(config, prefix=f"{prefix}.self_attn")
+ self.mlp = _CodePredictorMLP(config, prefix=f"{prefix}.mlp")
+ self.input_layernorm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states = self.self_attn(hidden_states, position_embeddings)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+# ===================================================================
+# Code Predictor Transformer Model
+# ===================================================================
+
+
+class Qwen3TTSTalkerCodePredictorModelVLLM(nn.Module):
+ """Transformer model for the code predictor (re-prefill, no KV cache)."""
+
+ def __init__(
+ self,
+ config: Qwen3TTSTalkerCodePredictorConfig,
+ *,
+ talker_hidden_size: int | None = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.config = config
+
+ self.layers = nn.ModuleList(
+ [_CodePredictorDecoderLayer(config, prefix=f"{prefix}.layers.{i}") for i in range(config.num_hidden_layers)]
+ )
+ self.norm = _RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = _RotaryEmbedding(config)
+
+ # Codec embeddings: one per residual group. Stored in talker hidden dim
+ # (some checkpoints use talker_hidden_size != code_predictor hidden_size).
+ emb_dim = int(talker_hidden_size) if talker_hidden_size is not None else int(config.hidden_size)
+ self.codec_embedding = nn.ModuleList(
+ [nn.Embedding(config.vocab_size, emb_dim) for _ in range(config.num_code_groups - 1)]
+ )
+
+ def get_input_embeddings(self) -> nn.ModuleList:
+ return self.codec_embedding
+
+ def forward(
+ self,
+ inputs_embeds: torch.Tensor,
+ position_ids: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+ for layer in self.layers:
+ hidden_states = layer(hidden_states, position_embeddings)
+ hidden_states = self.norm(hidden_states)
+ return hidden_states
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ if "rotary_emb.inv_freq" in name:
+ continue
+ param = params_dict.get(name)
+ if param is None:
+ continue
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+# ===================================================================
+# Code Predictor Wrapper (optimized re-prefill + torch.compile)
+# ===================================================================
+
+
+class Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM(nn.Module):
+ """vLLM-native code_predictor for the AR talker (residual codebooks).
+ Re-prefill approach: each AR step forwards the full growing sequence
+ through the 5-layer transformer. No KV cache needed. This trades
+ ~O(T^2) extra attention FLOPs (negligible for T=16, 5 layers) for
+ zero KV cache management overhead and a simpler execution model.
-class Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM(CodePredictorWrapper):
- """Qwen3-TTS code predictor (CUDA graphs, per-call sampling, projection)."""
+ Uses HF-compatible layers (plain nn.Linear, float32 RMSNorm, rotate_half
+ RoPE) to ensure numerical fidelity with the reference implementation.
+ Precision matters here because small errors compound across 15 AR steps.
+
+ Optimizations preserved:
+ 1. torch.compile on model forward -- fuses small kernel launches.
+ 2. Pre-allocated embedding buffer [B, max_seq, H] -- no torch.cat per step.
+ 3. Projection caching -- each token projected once and cached.
+ 4. Pre-allocated position_ids -- no torch.arange per step.
+ 5. Inline sampling -- no custom op / forward_context overhead.
+ 6. Cached module references -- bypass nn.Module.__call__ overhead.
+ 7. CUDA graphs per batch-size bucket.
+ """
def __init__(
self,
@@ -31,24 +322,240 @@ def __init__(
talker_config: Qwen3TTSTalkerConfig,
prefix: str = "code_predictor",
) -> None:
- super().__init__(
- vllm_config=vllm_config,
- cp_config=config,
- wrapper_config=CodePredictorWrapperConfig(
- use_cuda_graphs=True,
- use_parallel_embedding=False,
- use_projection=(config.hidden_size != talker_config.hidden_size),
- return_proj_buf=False,
- sampling_mode="per_call",
- ),
+ super().__init__()
+ self._vllm_config = vllm_config
+ self.config = config
+ self.talker_config = talker_config
+
+ self.model = Qwen3TTSTalkerCodePredictorModelVLLM(
+ config,
talker_hidden_size=int(talker_config.hidden_size),
- prefix=prefix,
+ prefix=f"{prefix}.model",
)
- # Store talker_config for backward compat (accessed by some callers)
- self.talker_config = talker_config
- self._vllm_config = vllm_config
+
+ self.lm_head = nn.ModuleList(
+ [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_code_groups - 1)]
+ )
+
+ if config.hidden_size != talker_config.hidden_size:
+ self.small_to_mtp_projection = nn.Linear(talker_config.hidden_size, config.hidden_size, bias=True)
+ else:
+ self.small_to_mtp_projection = nn.Identity()
+
+ self._num_groups = int(config.num_code_groups)
+ self._talker_hidden = int(talker_config.hidden_size)
+ self._cp_hidden = int(config.hidden_size)
+
+ # Pre-allocated buffers (lazily initialized on first forward).
+ self._proj_buf: torch.Tensor | None = None
+
+ # torch.compile + warmup state (lazily initialized in _setup_compile).
+ self._compiled_model_fwd = None
+ self._bucket_sizes: list[int] = []
+ self._bucket_pos_ids: dict[int, torch.Tensor] = {}
+ self._lm_heads_list: list[nn.Module] | None = None
+ self._codec_embeds_list: list[nn.Module] | None = None
+ self._cuda_graphs: dict[int, tuple[torch.cuda.CUDAGraph, torch.Tensor]] = {}
+
+ def get_input_embeddings(self) -> nn.ModuleList:
+ return self.model.get_input_embeddings()
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- """Load weights with vllm config context (required for VocabParallelEmbedding)."""
with set_current_vllm_config(self._vllm_config):
- return super().load_weights(weights)
+ loaded: set[str] = set()
+ model_weights: list[tuple[str, torch.Tensor]] = []
+ other_weights: list[tuple[str, torch.Tensor]] = []
+ for name, w in weights:
+ if name.startswith("model."):
+ model_weights.append((name[len("model.") :], w))
+ else:
+ other_weights.append((name, w))
+
+ loaded_model = self.model.load_weights(model_weights)
+ loaded |= {f"model.{n}" for n in loaded_model}
+
+ params = dict(self.named_parameters(remove_duplicate=False))
+ for name, w in other_weights:
+ if name not in params:
+ continue
+ default_weight_loader(params[name], w)
+ loaded.add(name)
+
+ return loaded
+
+ # ------------------------------------------------------------------
+ # Pre-allocated buffer management
+ # ------------------------------------------------------------------
+
+ def _ensure_buffers(self, device: torch.device, dtype: torch.dtype) -> None:
+ max_seq = self._num_groups + 1
+ if self._proj_buf is not None and self._proj_buf.device == device and self._proj_buf.dtype == dtype:
+ return
+ max_bsz = self._vllm_config.scheduler_config.max_num_seqs
+ self._proj_buf = torch.zeros(
+ max_bsz,
+ max_seq,
+ self._cp_hidden,
+ dtype=dtype,
+ device=device,
+ )
+
+ def _setup_compile(self) -> None:
+ """Lazily set up torch.compile with manual CUDA graph capture."""
+ if self._compiled_model_fwd is not None:
+ return
+ self._lm_heads_list = list(self.lm_head)
+ self._codec_embeds_list = list(self.model.codec_embedding)
+ if not current_omni_platform.supports_torch_inductor():
+ logger.warning_once("code_predictor: torch.compile disabled")
+ self._compiled_model_fwd = self.model.forward
+ return
+
+ # torch.compile fuses RMSNorm/RoPE in ways that lose float32
+ # precision, compounding across 15 AR steps. Use torch.compile
+ # with options that disable the problematic fusions while still
+ # getting kernel fusion benefits for the linear layers and SDPA.
+ self._compiled_model_fwd = torch.compile(
+ self.model.forward,
+ dynamic=False,
+ options={
+ "epilogue_fusion": False,
+ },
+ )
+ self._warmup_buckets()
+ self._capture_cuda_graphs()
+ logger.info("code_predictor: torch.compile (no epilogue fusion) + CUDA graphs")
+
+ def _padded_bsz(self, bsz: int) -> int:
+ for bucket in self._bucket_sizes:
+ if bsz <= bucket:
+ return bucket
+ return bsz
+
+ def _warmup_buckets(self) -> None:
+ """Warmup power-of-2 batch-size buckets to front-load Inductor compilation."""
+ max_bsz = self._vllm_config.scheduler_config.max_num_seqs
+ bucket_sizes = [1 << i for i in range(max_bsz.bit_length()) if (1 << i) <= max_bsz]
+ if max_bsz not in bucket_sizes:
+ bucket_sizes.append(max_bsz)
+ self._bucket_sizes = sorted(bucket_sizes)
+
+ max_seq = self._num_groups + 1
+ device = next(self.model.parameters()).device
+
+ proj_buf = self._proj_buf
+ for bsz in self._bucket_sizes:
+ # position_ids: [batch, seq_len] for HF-style RoPE
+ pos_ids = torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(bsz, -1)
+ self._bucket_pos_ids[bsz] = pos_ids
+ for _ in range(3):
+ self._compiled_model_fwd(proj_buf[:bsz, :max_seq, :], pos_ids)
+ logger.info("code_predictor: warmup done for buckets %s", self._bucket_sizes)
+
+ def _capture_cuda_graphs(self) -> None:
+ """Capture a CUDA graph per bucket using vLLM's global graph pool."""
+ from vllm.platforms import current_platform
+
+ pool = current_platform.get_global_graph_pool()
+
+ max_seq = self._num_groups + 1
+ proj_buf = self._proj_buf
+
+ for bsz in self._bucket_sizes:
+ static_input = proj_buf[:bsz, :max_seq, :]
+ pos_ids = self._bucket_pos_ids[bsz]
+
+ g = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(g, pool=pool):
+ static_output = self._compiled_model_fwd(static_input, pos_ids)
+
+ self._cuda_graphs[bsz] = (g, static_output)
+
+ logger.info("code_predictor: captured CUDA graphs for buckets %s", self._bucket_sizes)
+
+ # ------------------------------------------------------------------
+ # Optimized forward: re-prefill + torch.compile + projection cache
+ # ------------------------------------------------------------------
+
+ @torch.inference_mode()
+ def forward(
+ self,
+ layer0_code: torch.Tensor,
+ layer0_embed: torch.Tensor,
+ last_talker_hidden: torch.Tensor,
+ do_sample: bool = True,
+ temperature: float = 0.9,
+ top_k: int = 50,
+ top_p: float = 1.0,
+ ) -> torch.Tensor:
+ """Predict residual codebooks 1..Q-1 autoregressively via re-prefill.
+
+ torch.compile fuses the ~60 small kernel launches per step into fewer
+ fused kernels, reducing kernel launch overhead by ~75%.
+
+ Projection caching: each token is projected once via small_to_mtp_projection
+ and cached in _proj_buf, avoiding redundant re-projection of past tokens.
+ """
+ bsz = int(layer0_code.shape[0])
+ num_groups = self._num_groups
+ device = layer0_code.device
+ dtype = layer0_embed.dtype
+
+ all_codes = torch.empty(bsz, num_groups, dtype=torch.long, device=device)
+ all_codes[:, 0] = layer0_code.reshape(bsz)
+
+ self._ensure_buffers(device, dtype)
+ self._setup_compile()
+
+ proj_buf = self._proj_buf
+ max_seq = self._num_groups + 1
+
+ projection = self.small_to_mtp_projection
+ model_fwd = self._compiled_model_fwd
+ lm_heads = self._lm_heads_list
+ codec_embeds = self._codec_embeds_list
+
+ use_sampling = do_sample and temperature > 0
+ inv_temperature = 1.0 / max(temperature, 1e-6) if use_sampling else 0.0
+ if use_sampling and top_p != 1.0:
+ raise NotImplementedError(
+ "top_p sampling is not implemented for the vLLM-native code predictor; please set top_p=1.0."
+ )
+
+ padded_bsz = self._padded_bsz(bsz)
+ proj_buf[:padded_bsz].zero_()
+
+ proj_buf[:bsz, 0, :] = projection(last_talker_hidden.reshape(bsz, 1, -1)).reshape(bsz, -1)
+ proj_buf[:bsz, 1, :] = projection(layer0_embed.reshape(bsz, 1, -1)).reshape(bsz, -1)
+ full_pos_ids = self._bucket_pos_ids.get(padded_bsz)
+ if full_pos_ids is None:
+ full_pos_ids = torch.arange(max_seq, device=device, dtype=torch.long).unsqueeze(0).expand(padded_bsz, -1)
+
+ # Use captured CUDA graph if available, otherwise call compiled fn.
+ cuda_graph_entry = self._cuda_graphs.get(padded_bsz)
+
+ for step in range(1, num_groups):
+ if cuda_graph_entry is not None:
+ cuda_graph_entry[0].replay()
+ hidden_out = cuda_graph_entry[1]
+ else:
+ hidden_out = model_fwd(proj_buf[:padded_bsz, :max_seq, :], full_pos_ids)
+ logits = lm_heads[step - 1](hidden_out[:bsz, step, :])
+
+ if use_sampling:
+ scaled = logits * inv_temperature
+ if top_k > 0:
+ topk_vals, _ = scaled.topk(top_k, dim=-1)
+ scaled = scaled.masked_fill(scaled < topk_vals[:, -1:], float("-inf"))
+ probs = F.softmax(scaled, dim=-1)
+ next_ids = torch.multinomial(probs, num_samples=1)
+ else:
+ next_ids = logits.argmax(dim=-1, keepdim=True)
+
+ all_codes[:, step] = next_ids.reshape(bsz)
+
+ if step < num_groups - 1:
+ new_embed = codec_embeds[step - 1](next_ids)
+ proj_buf[:bsz, step + 1, :] = projection(new_embed.reshape(bsz, 1, -1)).reshape(bsz, -1)
+
+ return all_codes
diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
index 53e2a5480e0..bc6222bbe2c 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
@@ -13,6 +13,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
+from librosa.filters import mel as librosa_mel_fn
from transformers import AutoTokenizer
from transformers.activations import ACT2FN
from transformers.utils.hub import cached_file
@@ -23,12 +24,9 @@
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.qwen3 import Qwen3Model
from vllm.model_executor.models.utils import AutoWeightsLoader, PPMissingLayer, WeightsMapper, maybe_prefix
-from vllm.multimodal.audio import AudioResampler
from vllm.sequence import IntermediateTensors
-from vllm_omni.data_entry_keys import OmniPayload
from vllm_omni.model_executor.models.output_templates import OmniOutput
-from vllm_omni.utils.audio import mel_filter_bank
from vllm_omni.utils.voice_cache import VoiceEmbeddingCache
from .configuration_qwen3_tts import Qwen3TTSConfig, Qwen3TTSSpeakerEncoderConfig, Qwen3TTSTalkerConfig
@@ -260,19 +258,14 @@ def mel_spectrogram(
fmax: int | None = None,
center: bool = False,
) -> torch.Tensor:
- """Calculate mel spectrogram of an input signal using torchaudio mel filterbank and torch STFT."""
+ """Calculate mel spectrogram of an input signal using librosa mel filterbank and torch STFT."""
if torch.min(y) < -1.0:
logger.warning("Min value of input waveform signal is %s", torch.min(y))
if torch.max(y) > 1.0:
logger.warning("Max value of input waveform signal is %s", torch.max(y))
device = y.device
- mel_basis = mel_filter_bank(
- sr=sampling_rate,
- n_fft=n_fft,
- n_mels=num_mels,
- fmin=fmin,
- fmax=fmax,
- ).to(device)
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
+ mel_basis = torch.from_numpy(mel).float().to(device)
hann_window = torch.hann_window(win_size).to(device)
padding = (n_fft - hop_size) // 2
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
@@ -344,7 +337,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.mtp_hidden_size = int(self.talker_config.hidden_size)
# OmniGPUModelRunner will store talker_mtp output under this key in
# per-request additional_information.
- self.talker_mtp_output_key = ("codes", "audio")
+ self.talker_mtp_output_key = "audio_codes"
self.model = Qwen3Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
@@ -403,11 +396,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Keys that should stay on GPU in model_intermediate_buffer to avoid
# CPU-to-GPU round-trips on every decode step.
- self.gpu_resident_buffer_keys: set[tuple[str, str]] = {
- ("codes", "audio"),
- ("hidden_states", "last"),
- ("embed", "tts_pad"),
- ("hidden_states", "trailing_text"),
+ self.gpu_resident_buffer_keys: set[str] = {
+ "audio_codes",
+ "last_talker_hidden",
+ "tts_pad_embed",
+ "tailing_text_hidden",
}
# Tokenizer for prompt building.
@@ -416,10 +409,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# In-memory LRU cache for voice extraction artifacts (Base voice clone).
self._voice_cache = VoiceEmbeddingCache()
- raw_subtalker_sampling = getattr(vllm_config.model_config, "subtalker_sampling_params", None)
- self._subtalker_sampling_params: dict[str, Any] = (
- dict(raw_subtalker_sampling) if isinstance(raw_subtalker_sampling, Mapping) else {}
- )
# -------------------- vLLM required hooks --------------------
@@ -476,20 +465,18 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: A
for info in info_dicts:
if not isinstance(info, dict):
continue
- codes = info.get("codes", {})
- meta = info.get("meta", {})
- ac = codes.get("audio")
+ ac = info.get("audio_codes")
if isinstance(ac, torch.Tensor):
audio_codes_list.append(ac)
- cs = meta.get("codec_streaming")
+ cs = info.get("codec_streaming")
if isinstance(cs, bool):
codec_streaming_list.append(
torch.full((int(ac.shape[0]),), int(cs), dtype=torch.int8, device=ac.device)
)
- ref_code = codes.get("ref")
+ ref_code = info.get("ref_code")
if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0:
ref_code_tensor = ref_code
- ref_len = meta.get("ref_code_len")
+ ref_len = info.get("ref_code_len")
if ref_len is None:
continue
if isinstance(ref_len, torch.Tensor):
@@ -514,13 +501,13 @@ def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: A
audio_codes = torch.cat(audio_codes_list, dim=0)
span_len = int(audio_codes.shape[0])
hidden = hidden[:span_len]
- mm: OmniPayload = {"codes": {"audio": audio_codes}}
+ mm: dict[str, torch.Tensor] = {"audio_codes": audio_codes}
if ref_code_len_list:
- mm.setdefault("meta", {})["ref_code_len"] = torch.cat(ref_code_len_list, dim=0)[:span_len]
+ mm["ref_code_len"] = torch.cat(ref_code_len_list, dim=0)[:span_len]
if ref_code_tensor is not None:
- mm.setdefault("codes", {})["ref"] = [ref_code_tensor]
+ mm["ref_code"] = [ref_code_tensor]
if codec_streaming_list:
- mm.setdefault("meta", {})["codec_streaming"] = torch.cat(codec_streaming_list, dim=0)[:span_len]
+ mm["codec_streaming"] = torch.cat(codec_streaming_list, dim=0)[:span_len]
return OmniOutput(text_hidden_states=hidden, multimodal_outputs=mm)
# -------------------- preprocess / postprocess --------------------
@@ -539,11 +526,6 @@ def preprocess(
merged.setdefault(k, v)
info_dict = merged
- payload: OmniPayload = info_dict
- embed = payload.get("embed", {})
- hs = payload.get("hidden_states", {})
- meta = payload.get("meta", {})
-
span_len = int(input_ids.shape[0])
if span_len <= 0:
return input_ids, input_embeds if input_embeds is not None else self.embed_input_ids(input_ids), {}
@@ -553,7 +535,7 @@ def preprocess(
raise ValueError("Missing additional_information.text for Qwen3-TTS AR talker.")
task_type = (info_dict.get("task_type") or ["CustomVoice"])[0]
- codec_streaming_val = meta.get("codec_streaming")
+ codec_streaming_val = info_dict.get("codec_streaming")
if isinstance(codec_streaming_val, list):
codec_streaming_raw = codec_streaming_val[0] if codec_streaming_val else None
else:
@@ -565,8 +547,8 @@ def preprocess(
if span_len > 1:
# Prefill (prompt embeddings)
- prompt_embeds_cpu = embed.get("prefill")
- tts_pad_embed_cpu = embed.get("tts_pad")
+ prompt_embeds_cpu = info_dict.get("talker_prompt_embeds")
+ tts_pad_embed_cpu = info_dict.get("tts_pad_embed")
tts_pad_embed = None
if isinstance(tts_pad_embed_cpu, torch.Tensor) and tts_pad_embed_cpu.numel() > 0:
tts_pad_embed = tts_pad_embed_cpu.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1)
@@ -581,18 +563,17 @@ def preprocess(
# Store full prompt embeddings on CPU (large, prefill-only).
# tailing_text_hidden and tts_pad_embed stay on GPU (gpu_resident_buffer_keys).
prompt_embeds_cpu = full_prompt_embeds.detach().to("cpu").contiguous()
- info_update: OmniPayload = {
- "embed": {
- "prefill": prompt_embeds_cpu,
- "tts_pad": tts_pad_embed.detach(),
- },
- "hidden_states": {"trailing_text": tailing_text_hidden.detach()},
- "meta": {"talker_prefill_offset": 0, "codec_streaming": codec_streaming},
+ info_update: dict[str, Any] = {
+ "talker_prompt_embeds": prompt_embeds_cpu,
+ "tailing_text_hidden": tailing_text_hidden.detach(),
+ "tts_pad_embed": tts_pad_embed.detach(),
+ "talker_prefill_offset": 0,
+ "codec_streaming": codec_streaming,
}
if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0:
- info_update.setdefault("codes", {})["ref"] = ref_code.detach().to("cpu").contiguous()
+ info_update["ref_code"] = ref_code.detach().to("cpu").contiguous()
if ref_code_len is not None:
- info_update["meta"]["ref_code_len"] = int(ref_code_len)
+ info_update["ref_code_len"] = int(ref_code_len)
# Always return a span_len slice; if the scheduled placeholder is longer, pad with tts_pad_embed.
# This preserves placeholder/embedding alignment.
offset = 0
@@ -604,12 +585,12 @@ def preprocess(
pad_rows = tts_pad_embed.reshape(1, -1).to("cpu").expand(pad_n, -1)
take = torch.cat([take, pad_rows], dim=0)
prompt_embeds = take.to(device=input_ids.device, dtype=torch.bfloat16)
- info_update["meta"]["talker_prefill_offset"] = int(offset + span_len)
+ info_update["talker_prefill_offset"] = int(offset + span_len)
else:
# Subsequent prefill chunk: slice from stored embeddings at running offset.
if tts_pad_embed is None:
raise RuntimeError("Missing `tts_pad_embed` in additional_information; prefill must initialize it.")
- offset = int(meta.get("talker_prefill_offset", 0) or 0)
+ offset = int(info_dict.get("talker_prefill_offset", 0) or 0)
if offset < 0:
offset = 0
s = max(0, min(offset, int(prompt_embeds_cpu.shape[0])))
@@ -620,9 +601,8 @@ def preprocess(
pad_rows = tts_pad_embed.reshape(1, -1).to("cpu").expand(pad_n, -1)
take = torch.cat([take, pad_rows], dim=0)
prompt_embeds = take.to(device=input_ids.device, dtype=torch.bfloat16)
- info_update = {
- "meta": {"talker_prefill_offset": int(offset + span_len), "codec_streaming": codec_streaming}
- }
+ info_update = {"talker_prefill_offset": int(offset + span_len)}
+ info_update["codec_streaming"] = codec_streaming
# When inputs_embeds is set, token ids are ignored by the model but must stay in-vocab for vLLM bookkeeping.
input_ids_out = input_ids.clone()
@@ -633,18 +613,18 @@ def preprocess(
device=input_ids.device,
dtype=torch.long,
)
- info_update.setdefault("codes", {})["audio"] = zeros
+ info_update["audio_codes"] = zeros
return input_ids_out, prompt_embeds, info_update
# Decode: span_len == 1
# Pop one text-step vector from tailing_text_hidden queue.
# These tensors stay on GPU via gpu_resident_buffer_keys - .to() is a no-op.
- tts_pad_embed_buf = embed.get("tts_pad")
+ tts_pad_embed_buf = info_dict.get("tts_pad_embed")
if not isinstance(tts_pad_embed_buf, torch.Tensor):
raise RuntimeError("Missing `tts_pad_embed` in additional_information; prefill must run first.")
tts_pad_embed = tts_pad_embed_buf.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1)
- tail = hs.get("trailing_text")
+ tail = info_dict.get("tailing_text_hidden")
if isinstance(tail, torch.Tensor) and tail.ndim == 2 and tail.shape[0] > 0:
text_step = tail[:1].to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1)
new_tail = tail[1:] if tail.shape[0] > 1 else tail[:0]
@@ -652,9 +632,9 @@ def preprocess(
text_step = tts_pad_embed
new_tail = tail if isinstance(tail, torch.Tensor) else torch.empty((0, tts_pad_embed.shape[-1]))
- last_hidden = hs.get("last")
+ last_hidden = info_dict.get("last_talker_hidden")
if not isinstance(last_hidden, torch.Tensor):
- raise RuntimeError("Missing hidden_states['last'] in additional_information; postprocess must run.")
+ raise RuntimeError("Missing `last_talker_hidden` in additional_information; postprocess must run.")
past_hidden = last_hidden.to(device=input_ids.device, dtype=torch.bfloat16).reshape(1, -1)
# Use OmniGPUModelRunner talker_mtp fast-path for residual codebooks and per-step inputs_embeds update.
@@ -664,9 +644,9 @@ def preprocess(
inputs_embeds_out = last_id_hidden.reshape(1, -1)
info_update = {
- "hidden_states": {"trailing_text": new_tail},
+ "tailing_text_hidden": new_tail,
"mtp_inputs": (past_hidden, text_step),
- "meta": {"codec_streaming": codec_streaming},
+ "codec_streaming": codec_streaming,
}
return input_ids, inputs_embeds_out, info_update
@@ -676,7 +656,7 @@ def postprocess(self, hidden_states: torch.Tensor, **_: Any) -> dict[str, Any]:
if hidden_states.numel() == 0:
return {}
last = hidden_states[-1, :].detach()
- return {"hidden_states": {"last": last}}
+ return {"last_talker_hidden": last}
# -------------------- prompt construction helpers --------------------
@@ -891,7 +871,7 @@ def _load_audio_to_np(self, x: str) -> tuple[np.ndarray, int]:
Uses upstream vLLM's MediaConnector for http(s) URLs and ``file:``
URIs, with unrestricted local access (offline inference is trusted).
"""
- from vllm.multimodal.media.audio import load_audio
+ import librosa
if self._is_url(x):
from vllm.multimodal.media import MediaConnector
@@ -903,7 +883,7 @@ def _load_audio_to_np(self, x: str) -> tuple[np.ndarray, int]:
with io.BytesIO(wav_bytes) as f:
audio, sr = sf.read(f, dtype="float32", always_2d=False)
else:
- audio, sr = load_audio(x, sr=None, mono=True)
+ audio, sr = librosa.load(x, sr=None, mono=True)
if isinstance(audio, np.ndarray) and audio.ndim > 1:
audio = np.mean(audio, axis=-1)
@@ -1109,8 +1089,9 @@ def _extract_speaker_embedding(self, wav: np.ndarray, sr: int) -> torch.Tensor:
# Resample to 24kHz for speaker encoder.
target_sr = int(getattr(self.config.speaker_encoder_config, "sample_rate", 24000))
if sr != target_sr:
- resampler = AudioResampler(target_sr=target_sr)
- wav = resampler.resample(wav.astype(np.float32), orig_sr=int(sr))
+ import librosa
+
+ wav = librosa.resample(y=wav.astype(np.float32), orig_sr=int(sr), target_sr=target_sr)
sr = target_sr
# Follow official implementation: mel_spectrogram expects 24kHz.
@@ -1143,19 +1124,14 @@ def _ensure_speech_tokenizer_loaded(self) -> Qwen3TTSTokenizer:
speech_tokenizer_dir,
torch_dtype=torch.bfloat16,
)
- # Only move encoder to GPU; the decoder is unused by Talker (which
- # only calls tok.encode()) and would otherwise waste bf16 VRAM.
- # NOTE: after this point the tokenizer instance is encode-only;
- # calling tok.decode() will fail because tok.model.decoder is None.
+ # Prefer GPU for encoder if available; otherwise keep CPU.
dev = next(self.parameters()).device
if dev.type != "cpu":
try:
- del tok.model.decoder
- tok.model.decoder = None
- tok.model.encoder.to(dev)
+ tok.model.to(dev)
tok.device = dev
except Exception as e:
- raise RuntimeError(f"Failed to move speech tokenizer encoder to {dev}: {e}") from e
+ raise RuntimeError(f"Failed to move speech tokenizer to {dev}: {e}") from e
else:
tok.device = dev
self._speech_tokenizer = tok
@@ -1453,16 +1429,11 @@ def _normalize_voice_clone_prompt(raw: object) -> dict[str, object] | None:
)
if ref_ids is None:
ref_text = _as_singleton(info_dict.get("ref_text"))
- if isinstance(ref_text, str) and ref_text.strip():
- ref_ids = tok(
- self._build_ref_text(ref_text),
- return_tensors="pt",
- padding=False,
- )["input_ids"].to(device=input_ids.device)
- else:
- logger.warning("Base ICL: ref_text/ref_ids missing, falling back to x-vector-only mode.")
- in_context_mode = False
- if in_context_mode:
+ if not isinstance(ref_text, str) or not ref_text.strip():
+ raise ValueError("Base in-context voice cloning requires `ref_text` or tokenized `ref_ids`.")
+ ref_ids = tok(self._build_ref_text(ref_text), return_tensors="pt", padding=False)["input_ids"].to(
+ device=input_ids.device
+ )
icl_input_embed, trailing_text_hidden = self._generate_icl_prompt(
text_id=input_ids[:, 3:-5],
ref_id=ref_ids[:, 3:-2],
@@ -1652,10 +1623,6 @@ def talker_mtp(
input_embeds: torch.Tensor,
last_talker_hidden: torch.Tensor,
text_step: torch.Tensor,
- do_sample: bool | None = None,
- temperature: float | None = None,
- top_k: int | None = None,
- top_p: float | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""GPU fast-path used by OmniGPUModelRunner to predict residual codebooks (1..Q-1).
Returns (inputs_embeds, audio_codes) for the current step."""
@@ -1674,24 +1641,15 @@ def talker_mtp(
audio_codes = input_ids.reshape(bsz, 1)
return (last_id_hidden + text_step).reshape(bsz, -1), audio_codes
- subtalker_params = self._subtalker_sampling_params
- if do_sample is None:
- do_sample = bool(subtalker_params.get("do_sample", True))
- if temperature is None:
- temperature = float(subtalker_params.get("temperature", 0.9))
- if top_k is None:
- top_k = int(subtalker_params.get("top_k", 50))
- if top_p is None:
- top_p = float(subtalker_params.get("top_p", 1.0))
-
+ # Predict residual codes (1..Q-1) with HF reference sampling params.
audio_codes = self.code_predictor(
layer0_code=input_ids.reshape(bsz, 1),
layer0_embed=last_id_hidden,
last_talker_hidden=past_hidden,
- do_sample=do_sample,
- temperature=temperature,
- top_k=top_k,
- top_p=top_p,
+ do_sample=True,
+ temperature=0.9,
+ top_k=50,
+ top_p=1.0,
) # [B, Q]
# Map invalid layer-0 ids (e.g. EOS) to PAD=0 so SpeechTokenizer sees only real codes.
diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py
index 14bfbc5eedf..503e6bbc83b 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py
@@ -17,13 +17,12 @@
import urllib.request
from urllib.parse import urlparse
+import librosa
import numpy as np
import soundfile as sf
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoConfig, AutoFeatureExtractor, AutoModel
-from vllm.multimodal.audio import AudioResampler
-from vllm.multimodal.media.audio import load_audio as _load_audio_file
from .tokenizer_12hz.configuration_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Config
from .tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import (
@@ -155,14 +154,13 @@ def load_audio(
with io.BytesIO(wav_bytes) as f:
audio, sr = sf.read(f, dtype="float32", always_2d=False)
else:
- audio, sr = _load_audio_file(x, sr=None, mono=True)
+ audio, sr = librosa.load(x, sr=None, mono=True)
if audio.ndim > 1:
audio = np.mean(audio, axis=-1)
if sr != target_sr:
- resampler = AudioResampler(target_sr=target_sr)
- audio = resampler.resample(audio, orig_sr=sr)
+ audio = librosa.resample(y=audio, orig_sr=sr, target_sr=target_sr)
return audio.astype(np.float32)
@@ -210,8 +208,7 @@ def _normalize_audio_inputs(
if a.ndim > 1:
a = np.mean(a, axis=-1)
if int(sr) != target_sr:
- resampler = AudioResampler(target_sr=target_sr)
- a = resampler.resample(a.astype(np.float32), orig_sr=int(sr))
+ a = librosa.resample(y=a.astype(np.float32), orig_sr=int(sr), target_sr=target_sr)
out.append(a.astype(np.float32))
return out
diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/assets/mel_filters.npz b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/assets/mel_filters.npz
new file mode 100644
index 00000000000..28ea26909db
Binary files /dev/null and b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/assets/mel_filters.npz differ
diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py
index f7e664c74d6..de2c69702c5 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py
@@ -17,17 +17,16 @@
from itertools import accumulate
import onnxruntime
+import sox
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.compliance.kaldi as kaldi
+from librosa.filters import mel as librosa_mel_fn
from torch import Tensor
-from vllm_omni.model_executor.models.whisper_utils import Conv1d, ConvTranspose1d
-from vllm_omni.utils.audio import mel_filter_bank, peak_normalize
-
from .core_vq import DistributedGroupResidualVectorQuantization
-from .whisper_encoder import WhisperEncoder
+from .whisper_encoder import Conv1d, ConvTranspose1d, WhisperEncoder
def dynamic_range_compression_torch(x, c=1, clip_val=1e-5):
@@ -104,14 +103,14 @@ def extract(self, audio, **kwargs):
y = audio
if len(list(self.mel_basis.keys())) == 0:
- mel = mel_filter_bank(
+ mel = librosa_mel_fn(
sr=self.sampling_rate,
n_fft=self.filter_length,
n_mels=self.n_mel_channels,
fmin=self.mel_fmin,
fmax=self.mel_fmax,
)
- self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)] = mel.to(y.device)
+ self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to(y.device)
y = torch.nn.functional.pad(
@@ -153,6 +152,9 @@ def __init__(self, audio_codec_with_xvector):
audio_codec_with_xvector, sess_options=option, providers=providers
)
+ self.tfm = sox.Transformer()
+ self.tfm.norm(db_level=-6)
+
self.mel_ext = MelSpectrogramFeatures(
filter_length=1024,
hop_length=160,
@@ -180,7 +182,8 @@ def extract_code(self, audio):
return norm_embedding.numpy(), ref_mel.permute(0, 2, 1).squeeze(0).numpy()
def sox_norm(self, audio):
- return peak_normalize(audio, db_level=-6)
+ wav_norm = self.tfm.build_array(input_array=audio, sample_rate_in=16000)
+ return wav_norm
class WhisperEncoderVQ(WhisperEncoder):
diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py
index 7756720b2ba..e3bd6e1c3a3 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py
@@ -14,6 +14,7 @@
# limitations under the License.
import math
import operator
+import os
from functools import cache
from itertools import accumulate
@@ -23,8 +24,6 @@
from torch import Tensor, nn
from vllm_omni.diffusion.attention.backends.utils.fa import HAS_FLASH_ATTN, flash_attn_varlen_func
-from vllm_omni.model_executor.models.whisper_utils import Conv1d, Linear, sinusoids
-from vllm_omni.utils.audio import mel_filter_bank
N_FFT = 400
HOP_LENGTH = 160
@@ -32,8 +31,21 @@
@cache
def mel_filters(device, n_mels: int) -> torch.Tensor:
- """Compute mel filterbank matrix for projecting STFT into a Mel spectrogram."""
- return mel_filter_bank(sr=16000, n_fft=N_FFT, n_mels=n_mels).to(device)
+ """
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
+ Allows decoupling librosa dependency; saved using:
+
+ np.savez_compressed(
+ "mel_filters.npz",
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
+ mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
+ )
+ """
+ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
+
+ filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
+ with np.load(filters_path, allow_pickle=False) as f:
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(
@@ -103,6 +115,30 @@ def get_mel_audio(audio, padding=False, audio_vq_ds_rate=1, n_mels=128):
return mel
+def sinusoids(length, channels, max_timescale=10000):
+ """Returns sinusoids for positional embedding"""
+ assert channels % 2 == 0
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+
+
+class Conv1d(nn.Conv1d):
+ def _conv_forward(self, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor:
+ return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
+
+
+class ConvTranspose1d(nn.ConvTranspose1d):
+ def _conv_forward(self, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor:
+ return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
+
+
+class Linear(nn.Linear):
+ def forward(self, x: Tensor) -> Tensor:
+ return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))
+
+
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int, use_flash_attention: bool = True):
super().__init__()
diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py
index 0cbe60610ac..1398923458b 100644
--- a/vllm_omni/model_executor/models/registry.py
+++ b/vllm_omni/model_executor/models/registry.py
@@ -108,11 +108,6 @@
"mimo_audio",
"MiMoAudioForConditionalGeneration",
),
- "MiMoV2ASRForCausalLM": (
- "mimo_audio",
- "mimo_audio",
- "MiMoAudioForConditionalGeneration",
- ),
"MiMoAudioLLMModel": (
"mimo_audio",
"mimo_audio_llm",
@@ -150,18 +145,6 @@
"fish_speech_dac_decoder",
"FishSpeechDACDecoder",
),
- ## VoxCPM
- "VoxCPMForConditionalGeneration": (
- "voxcpm",
- "voxcpm",
- "VoxCPMForConditionalGeneration",
- ),
- ## VoxCPM2
- "VoxCPM2TalkerForConditionalGeneration": (
- "voxcpm2",
- "voxcpm2_talker",
- "VoxCPM2TalkerForConditionalGeneration",
- ),
## Voxtral TTS
"VoxtralTTSForConditionalGeneration": (
"voxtral_tts",
@@ -174,33 +157,6 @@
"VoxtralTTSAudioGenerationForConditionalGeneration",
),
"VoxtralTTSAudioTokenizer": ("voxtral_tts", "voxtral_tts_audio_tokenizer", "VoxtralTTSAudioTokenizer"),
- "DyninOmniForConditionalGeneration": (
- "dynin_omni",
- "dynin_omni",
- "DyninOmniForConditionalGeneration",
- ),
- ## Ming-flash-omni-2.0
- "MingFlashOmniForConditionalGeneration": (
- "ming_flash_omni",
- "ming_flash_omni",
- "MingFlashOmniForConditionalGeneration",
- ),
- "MingFlashOmniThinkerForConditionalGeneration": (
- "ming_flash_omni",
- "ming_flash_omni_thinker",
- "MingFlashOmniThinkerForConditionalGeneration",
- ),
- "MingFlashOmniTalkerForConditionalGeneration": (
- "ming_flash_omni",
- "ming_flash_omni_talker",
- "MingFlashOmniTalkerForConditionalGeneration",
- ),
- # Alias: HF repo currently ships this architecture name in config.json
- "BailingMM2NativeForConditionalGeneration": (
- "ming_flash_omni",
- "ming_flash_omni",
- "MingFlashOmniForConditionalGeneration",
- ),
}
diff --git a/vllm_omni/model_executor/models/voxcpm/__init__.py b/vllm_omni/model_executor/models/voxcpm/__init__.py
deleted file mode 100644
index 3b064c0f683..00000000000
--- a/vllm_omni/model_executor/models/voxcpm/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from .configuration_voxcpm import VoxCPMConfig
-from .voxcpm import VoxCPMForConditionalGeneration
-
-__all__ = [
- "VoxCPMConfig",
- "VoxCPMForConditionalGeneration",
-]
diff --git a/vllm_omni/model_executor/models/voxcpm/configuration_voxcpm.py b/vllm_omni/model_executor/models/voxcpm/configuration_voxcpm.py
deleted file mode 100644
index ce1d809bd38..00000000000
--- a/vllm_omni/model_executor/models/voxcpm/configuration_voxcpm.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from vllm_omni.transformers_utils.configs.voxcpm import VoxCPMConfig
-
-__all__ = ["VoxCPMConfig"]
diff --git a/vllm_omni/model_executor/models/voxcpm/voxcpm.py b/vllm_omni/model_executor/models/voxcpm/voxcpm.py
deleted file mode 100644
index 6fa36fc4200..00000000000
--- a/vllm_omni/model_executor/models/voxcpm/voxcpm.py
+++ /dev/null
@@ -1,886 +0,0 @@
-from __future__ import annotations
-
-import json
-import os
-import sys
-import tempfile
-import warnings
-import wave
-from collections.abc import Callable, Generator, Iterable
-from pathlib import Path
-from typing import Any
-
-import numpy as np
-import torch
-import torch.nn as nn
-from einops import rearrange
-from tqdm import tqdm
-from vllm.config import VllmConfig
-from vllm.logger import init_logger
-from vllm.sequence import IntermediateTensors
-
-from vllm_omni.model_executor.models.output_templates import OmniOutput
-
-from .voxcpm_loader import (
- _build_prompt_cache_with_soundfile,
- _device_to_string,
- _force_cuda_available_for_npu,
- _import_voxcpm_audio_vae_classes,
- _import_voxcpm_base_model_class,
- _is_torchcodec_load_error,
- _normalize_dtype_name,
- _prepare_runtime_model_dir,
- _resolve_runtime_device,
-)
-from .voxcpm_runtime_utils import resolve_voxcpm_model_dir
-from .voxcpm_stage_wrappers import _DirectVoxCPMAudioVAE, _DirectVoxCPMLatentGenerator
-
-logger = init_logger(__name__)
-_VOXCPM_LATENT_MAGIC = 131071
-
-
-def _make_voxcpm_model_for_omni(base: type[Any]) -> type[Any]:
- """Subclass upstream VoxCPMModel: local ``_inference`` + ``latents_only`` prompt-cache generation."""
-
- from voxcpm.model.utils import get_dtype
-
- class VoxCPMModelForOmni(base):
- @torch.inference_mode()
- def build_prompt_cache(self, *args: Any, **kwargs: Any):
- try:
- return super().build_prompt_cache(*args, **kwargs)
- except (ImportError, ModuleNotFoundError, RuntimeError) as exc:
- if not _is_torchcodec_load_error(exc):
- raise
- return _build_prompt_cache_with_soundfile(self, *args, **kwargs)
-
- @torch.inference_mode()
- def _inference(
- self,
- text: torch.Tensor,
- text_mask: torch.Tensor,
- feat: torch.Tensor,
- feat_mask: torch.Tensor,
- min_len: int = 2,
- max_len: int = 2000,
- inference_timesteps: int = 10,
- cfg_value: float = 2.0,
- streaming: bool = False,
- streaming_prefix_len: int = 3,
- ) -> Generator[tuple[torch.Tensor, torch.Tensor | list[torch.Tensor]], None, None]:
- B, _, _, _ = feat.shape
-
- feat_embed = self.feat_encoder(feat)
- feat_embed = self.enc_to_lm_proj(feat_embed)
-
- scale_emb = self.config.lm_config.scale_emb if self.config.lm_config.use_mup else 1.0
- text_embed = self.base_lm.embed_tokens(text) * scale_emb
- combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
-
- prefix_feat_cond = feat[:, -1, ...]
- pred_feat_seq: list[torch.Tensor] = []
-
- audio_patch_count = int(feat_mask.sum().item())
- if audio_patch_count > 0:
- context_len = min(streaming_prefix_len - 1, audio_patch_count)
- prompt_context_patches = list(feat[:, -context_len:, :, :].split(1, dim=1))
- pred_feat_seq = prompt_context_patches + pred_feat_seq
-
- enc_outputs, kv_cache_tuple = self.base_lm(
- inputs_embeds=combined_embed,
- is_causal=True,
- )
- self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
-
- enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
- lm_hidden = enc_outputs[:, -1, :]
-
- residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
- inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
- is_causal=True,
- )
- self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
- residual_hidden = residual_enc_outputs[:, -1, :]
-
- for step_idx in tqdm(range(max_len)):
- dit_hidden = self.lm_to_dit_proj(lm_hidden) + self.res_to_dit_proj(residual_hidden)
- pred_feat = self.feat_decoder(
- mu=dit_hidden,
- patch_size=self.patch_size,
- cond=prefix_feat_cond.transpose(1, 2).contiguous(),
- n_timesteps=inference_timesteps,
- cfg_value=cfg_value,
- ).transpose(1, 2)
-
- curr_embed = self.enc_to_lm_proj(self.feat_encoder(pred_feat.unsqueeze(1)))
- pred_feat_seq.append(pred_feat.unsqueeze(1))
- prefix_feat_cond = pred_feat
-
- if streaming:
- pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
- feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
- yield feat_pred, pred_feat_seq
-
- stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
- if step_idx > min_len and stop_flag == 1:
- break
-
- lm_hidden = self.base_lm.forward_step(
- curr_embed[:, 0, :],
- torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device),
- ).clone()
- lm_hidden = self.fsq_layer(lm_hidden)
- residual_hidden = self.residual_lm.forward_step(
- lm_hidden + curr_embed[:, 0, :],
- torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device),
- ).clone()
-
- if not streaming:
- pred_feat_seq_cat = torch.cat(pred_feat_seq, dim=1)
- feat_pred = rearrange(pred_feat_seq_cat, "b t p d -> b d (t p)", b=B, p=self.patch_size)
- yield feat_pred, pred_feat_seq_cat.squeeze(0).cpu()
-
- @torch.inference_mode()
- def generate_latents_with_prompt_cache(
- self,
- target_text: str,
- prompt_cache: dict,
- min_len: int = 2,
- max_len: int = 2000,
- inference_timesteps: int = 10,
- cfg_value: float = 2.0,
- retry_badcase: bool = False,
- retry_badcase_max_times: int = 3,
- retry_badcase_ratio_threshold: float = 6.0,
- streaming_prefix_len: int = 3,
- ) -> tuple[None, torch.Tensor, torch.Tensor]:
- return next(
- self._generate_with_prompt_cache(
- target_text=target_text,
- prompt_cache=prompt_cache,
- min_len=min_len,
- max_len=max_len,
- inference_timesteps=inference_timesteps,
- cfg_value=cfg_value,
- retry_badcase=retry_badcase,
- retry_badcase_max_times=retry_badcase_max_times,
- retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
- streaming=False,
- streaming_prefix_len=streaming_prefix_len,
- latents_only=True,
- )
- )
-
- @torch.inference_mode()
- def generate_latents_with_prompt_cache_streaming(
- self,
- target_text: str,
- prompt_cache: dict,
- min_len: int = 2,
- max_len: int = 2000,
- inference_timesteps: int = 10,
- cfg_value: float = 2.0,
- retry_badcase: bool = False,
- retry_badcase_max_times: int = 3,
- retry_badcase_ratio_threshold: float = 6.0,
- streaming_prefix_len: int = 3,
- ) -> Generator[tuple[None, torch.Tensor, torch.Tensor], None, None]:
- return self._generate_with_prompt_cache(
- target_text=target_text,
- prompt_cache=prompt_cache,
- min_len=min_len,
- max_len=max_len,
- inference_timesteps=inference_timesteps,
- cfg_value=cfg_value,
- retry_badcase=retry_badcase,
- retry_badcase_max_times=retry_badcase_max_times,
- retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
- streaming=True,
- streaming_prefix_len=streaming_prefix_len,
- latents_only=True,
- )
-
- @torch.inference_mode()
- def _generate_with_prompt_cache(
- self,
- target_text: str,
- prompt_cache: dict,
- min_len: int = 2,
- max_len: int = 2000,
- inference_timesteps: int = 10,
- cfg_value: float = 2.0,
- retry_badcase: bool = False,
- retry_badcase_max_times: int = 3,
- retry_badcase_ratio_threshold: float = 6.0,
- streaming: bool = False,
- streaming_prefix_len: int = 3,
- latents_only: bool = False,
- ) -> Generator[tuple[torch.Tensor | None, torch.Tensor, torch.Tensor | list[torch.Tensor]], None, None]:
- if retry_badcase and streaming:
- warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
- retry_badcase = False
- if prompt_cache is None:
- prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
- text = target_text
- else:
- prompt_audio_feat = prompt_cache["audio_feat"]
- prompt_text = prompt_cache["prompt_text"]
- text = prompt_text + target_text
-
- text_token = torch.LongTensor(self.text_tokenizer(text))
- text_token = torch.cat(
- [
- text_token,
- torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
- ],
- dim=-1,
- )
- target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
-
- audio_length = prompt_audio_feat.size(0)
- text_length = text_token.shape[0]
- text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
- audio_pad_feat = torch.zeros(
- (text_token.shape[0], self.patch_size, self.audio_vae.latent_dim),
- dtype=torch.float32,
- device=text_token.device,
- )
- text_token = torch.cat([text_token, text_pad_token])
- audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
- text_mask = (
- torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
- )
- audio_mask = (
- torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
- )
-
- text_token = text_token.unsqueeze(0).to(self.device)
- text_mask = text_mask.unsqueeze(0).to(self.device)
- audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
- audio_mask = audio_mask.unsqueeze(0).to(self.device)
-
- target_text_length = len(self.text_tokenizer(target_text))
- retry_badcase_times = 0
- while retry_badcase_times < retry_badcase_max_times:
- inference_result = self._inference(
- text_token,
- text_mask,
- audio_feat,
- audio_mask,
- min_len=min_len,
- max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len),
- inference_timesteps=inference_timesteps,
- cfg_value=cfg_value,
- streaming=streaming,
- streaming_prefix_len=streaming_prefix_len,
- )
- if streaming:
- patch_len = self.patch_size * self.chunk_size
- for latent_pred, pred_audio_feat in inference_result:
- if latents_only:
- decode_audio = None
- yield (decode_audio, target_text_token, latent_pred)
- else:
- decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
- decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
- yield (decode_audio, target_text_token, pred_audio_feat)
- break
-
- latent_pred, pred_audio_feat = next(inference_result)
- if retry_badcase and pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
- ratio = pred_audio_feat.shape[0] / target_text_length
- print(f" Badcase detected, audio_text_ratio={ratio}, retrying...", file=sys.stderr)
- retry_badcase_times += 1
- continue
- break
-
- if not streaming:
- if latents_only:
- decode_audio = None
- else:
- decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
- patch_len = self.patch_size * self.chunk_size
- if audio_mask.sum().item() > 0:
- decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
- else:
- decode_audio = decode_audio[..., :].squeeze(1).cpu()
- yield (decode_audio, target_text_token, pred_audio_feat)
-
- VoxCPMModelForOmni.__name__ = "VoxCPMModelForOmni"
- VoxCPMModelForOmni.__qualname__ = "VoxCPMModelForOmni"
- return VoxCPMModelForOmni
-
-
-def _import_voxcpm_model_class() -> type[Any]:
- base = _import_voxcpm_base_model_class()
- return _make_voxcpm_model_for_omni(base)
-
-
-def _load_native_voxcpm_model(
- model_path: str,
- *,
- device: torch.device,
- dtype: str | None,
-):
- VoxCPMModel = _import_voxcpm_model_class()
- model_dir = resolve_voxcpm_model_dir(model_path)
- runtime_model_path = _prepare_runtime_model_dir(model_dir, target_device=device, target_dtype=dtype)
-
- if device.type == "npu" and hasattr(torch, "npu"):
- torch.npu.set_device(device)
-
- with _force_cuda_available_for_npu(device):
- return VoxCPMModel.from_local(
- runtime_model_path,
- optimize=device.type == "cuda",
- )
-
-
-def _load_native_voxcpm_latent_generator(
- model_path: str,
- *,
- device: torch.device,
- dtype: str | None,
-) -> _DirectVoxCPMLatentGenerator:
- return _DirectVoxCPMLatentGenerator(_load_native_voxcpm_model(model_path, device=device, dtype=dtype))
-
-
-def _load_native_voxcpm_audio_vae(
- model_path: str,
- *,
- device: torch.device,
-) -> _DirectVoxCPMAudioVAE:
- AudioVAE, AudioVAEConfig = _import_voxcpm_audio_vae_classes()
- model_dir = resolve_voxcpm_model_dir(model_path)
- runtime_model_path = _prepare_runtime_model_dir(model_dir, target_device=device, target_dtype="float32")
- config_dict = json.loads((Path(runtime_model_path) / "config.json").read_text())
- audio_vae_config = config_dict.get("audio_vae_config")
- audio_vae = AudioVAE(config=AudioVAEConfig(**audio_vae_config)) if audio_vae_config is not None else AudioVAE()
-
- state_dict = torch.load(
- Path(runtime_model_path) / "audiovae.pth",
- map_location="cpu",
- weights_only=True,
- )["state_dict"]
- audio_vae.load_state_dict(state_dict, strict=True)
- audio_vae = audio_vae.to(device=device, dtype=torch.float32).eval()
- if device.type == "npu" and hasattr(torch, "npu"):
- torch.npu.set_device(device)
- patch_size = int(config_dict.get("patch_size", 2))
- return _DirectVoxCPMAudioVAE(audio_vae, patch_size=patch_size)
-
-
-class VoxCPMForConditionalGeneration(nn.Module):
- input_modalities = "audio"
- _LATENT_STAGES = {"latent_generator", "latent", "ar_dit"}
- _VAE_STAGES = {"vae", "audio_vae"}
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
- del prefix
- self.vllm_config = vllm_config
- self.model_path = vllm_config.model_config.model
- self.model_stage = getattr(vllm_config.model_config, "model_stage", "latent_generator")
- self.have_multimodal_outputs = True
- self.has_preprocess = False
- self.has_postprocess = False
- self.enable_update_additional_information = True
- self.requires_raw_input_tokens = True
- self.inject_omni_request_id_into_runtime_info = True
- self._pipeline = None
- self._latent_stream_gens: dict[str, Any] = {}
- self._latent_stream_terminal_pending: dict[str, int] = {}
- self._latent_stream_completed: set[str] = set()
- self._next_local_stream_key = 0
- self._ar_emit_stop_token = True
-
- def _runner_hidden_device_dtype(self) -> tuple[torch.device, torch.dtype]:
- device = _resolve_runtime_device(self.vllm_config)
- model_config = getattr(self.vllm_config, "model_config", None)
- dtype = getattr(model_config, "dtype", torch.float32) if model_config is not None else torch.float32
- return device, dtype
-
- def _ensure_model_loaded(self):
- if self._pipeline is not None:
- return
-
- target_device = _resolve_runtime_device(self.vllm_config)
- model_dtype = getattr(self.vllm_config.model_config, "dtype", None)
- normalized_dtype = _normalize_dtype_name(model_dtype)
- if self.model_stage in self._LATENT_STAGES:
- self._pipeline = _load_native_voxcpm_latent_generator(
- self.model_path,
- device=target_device,
- dtype=normalized_dtype,
- )
- elif self.model_stage in self._VAE_STAGES:
- self._pipeline = _load_native_voxcpm_audio_vae(
- self.model_path,
- device=target_device,
- )
- else:
- raise ValueError(
- f"Unsupported VoxCPM model_stage: {self.model_stage}. "
- "pure_voxcpm only supports split-stage latent_generator/vae inference."
- )
-
- logger.info("Loaded VoxCPM stage '%s' on %s", self.model_stage, _device_to_string(target_device))
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- del weights
- self._ensure_model_loaded()
- return set()
-
- @staticmethod
- def _extract_val(info: dict[str, Any], key: str, default: Any) -> Any:
- value = info.get(key, default)
- if isinstance(value, list):
- return value[0] if value else default
- return value
-
- def _resolve_stream_request_key(self, info: dict[str, Any]) -> str:
- request_key = info.get("__voxcpm_stream_key")
- if request_key is not None:
- return str(request_key)
-
- request_key = info.get("_omni_req_id")
- if request_key is not None:
- request_key = str(request_key)
- info["__voxcpm_stream_key"] = request_key
- return request_key
-
- request_key = f"voxcpm-local-{self._next_local_stream_key}"
- self._next_local_stream_key += 1
- info["__voxcpm_stream_key"] = request_key
- return str(request_key)
-
- def _recover_latent_from_input_ids(self, input_ids: torch.Tensor | None) -> torch.Tensor | None:
- if input_ids is None or input_ids.numel() == 0:
- return None
- flat_ids = input_ids.detach().reshape(-1).to("cpu")
- if flat_ids.numel() < 4 or int(flat_ids[0].item()) != _VOXCPM_LATENT_MAGIC:
- return None
- latent_dim = int(flat_ids[1].item())
- time_dim = int(flat_ids[2].item())
- payload = flat_ids[3:]
- expected = latent_dim * time_dim
- if latent_dim <= 0 or time_dim <= 0:
- raise ValueError(f"Invalid VoxCPM latent header: latent_dim={latent_dim}, time_dim={time_dim}")
- if int(payload.numel()) != expected:
- raise ValueError(
- "Invalid VoxCPM latent payload size: "
- f"expected={expected}, actual={int(payload.numel())}, "
- f"latent_dim={latent_dim}, time_dim={time_dim}"
- )
- packed = payload.to(dtype=torch.int32).to(torch.uint16)
- return packed.view(torch.bfloat16).to(torch.float32).reshape(1, latent_dim, time_dim)
-
- def _maybe_recover_vae_infos(
- self,
- infos: list[dict[str, Any]],
- input_ids: torch.Tensor | None,
- *,
- async_chunk: bool,
- ) -> list[dict[str, Any]]:
- if not async_chunk:
- return infos
- if any(self._extract_val(info, "latent_audio_feat", None) is not None for info in infos):
- return infos
- recovered = self._recover_latent_from_input_ids(input_ids)
- if recovered is None:
- return infos
- return [{"latent_audio_feat": recovered}]
-
- @staticmethod
- def _normalize_audio_samples(samples: Any) -> np.ndarray:
- if isinstance(samples, torch.Tensor):
- return samples.detach().cpu().float().reshape(-1).numpy()
- return np.asarray(samples, dtype=np.float32).reshape(-1)
-
- @classmethod
- def _normalize_ref_audio(cls, ref_audio: Any) -> tuple[np.ndarray, int]:
- if isinstance(ref_audio, str):
- raise TypeError("String ref_audio should be handled as a path before waveform normalization.")
-
- if isinstance(ref_audio, dict):
- sample_rate = ref_audio.get("sample_rate") or ref_audio.get("sampling_rate") or ref_audio.get("sr")
- samples = None
- for key in ("audio", "wav", "samples", "array", "waveform"):
- if key in ref_audio and ref_audio[key] is not None:
- samples = ref_audio[key]
- break
- if sample_rate is None or samples is None:
- raise ValueError("ref_audio dict must contain waveform data and sample rate.")
- return cls._normalize_audio_samples(samples), int(sample_rate)
-
- if isinstance(ref_audio, (list, tuple)):
- if len(ref_audio) == 1:
- return cls._normalize_ref_audio(ref_audio[0])
- if len(ref_audio) == 2 and np.isscalar(ref_audio[1]):
- return cls._normalize_audio_samples(ref_audio[0]), int(ref_audio[1])
-
- raise TypeError(f"Unsupported ref_audio format: {type(ref_audio)!r}")
-
- @staticmethod
- def _write_temp_prompt_wav(waveform: np.ndarray, sample_rate: int) -> str:
- prompt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
- prompt_file.close()
-
- wav = np.asarray(waveform, dtype=np.float32).reshape(-1)
- wav = np.clip(wav, -1.0, 1.0)
- pcm16 = (wav * 32767.0).astype(np.int16)
- with wave.open(prompt_file.name, "wb") as wav_file:
- wav_file.setnchannels(1)
- wav_file.setsampwidth(2)
- wav_file.setframerate(int(sample_rate))
- wav_file.writeframes(pcm16.tobytes())
-
- return prompt_file.name
-
- @classmethod
- def _resolve_prompt_inputs(cls, info: dict[str, Any]) -> tuple[str | None, str | None, str | None]:
- prompt_text = cls._extract_val(info, "prompt_text", None)
- prompt_wav_path = cls._extract_val(info, "prompt_wav_path", None)
- if prompt_wav_path:
- if prompt_text is None:
- prompt_text = cls._extract_val(info, "ref_text", None)
- return prompt_wav_path, prompt_text, None
-
- ref_audio = cls._extract_val(info, "ref_audio", None)
- ref_text = cls._extract_val(info, "ref_text", None)
- if ref_audio is None or ref_text is None:
- return None, None, None
- if isinstance(ref_audio, str):
- return ref_audio, ref_text, None
-
- waveform, sample_rate = cls._normalize_ref_audio(ref_audio)
- temp_prompt_wav = cls._write_temp_prompt_wav(waveform, sample_rate)
- return temp_prompt_wav, ref_text, temp_prompt_wav
-
- def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor:
- if input_ids.numel() == 0:
- return torch.empty((0, 1), device=input_ids.device, dtype=torch.float32)
- return torch.zeros((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.float32)
-
- def _get_vocab_size(self) -> int:
- model_config = getattr(self.vllm_config, "model_config", None)
- if model_config is not None:
- getter = getattr(model_config, "get_vocab_size", None)
- if callable(getter):
- try:
- return int(getter())
- except Exception:
- pass
- hf_config = getattr(model_config, "hf_text_config", None)
- if hf_config is not None and hasattr(hf_config, "vocab_size"):
- return int(hf_config.vocab_size)
- return 32000
-
- def _make_empty_output(
- self,
- *,
- output_key: str,
- payload_factory: Callable[[], torch.Tensor],
- infos: list[dict[str, Any]],
- sample_rate: int,
- out_device: torch.device,
- out_dtype: torch.dtype,
- hidden_rows: int | None = None,
- ) -> OmniOutput:
- if hidden_rows is None:
- hidden_rows = len(infos)
- return OmniOutput(
- text_hidden_states=torch.zeros((hidden_rows, 1), device=out_device, dtype=out_dtype),
- multimodal_outputs={
- output_key: [payload_factory() for _ in infos],
- "sr": [torch.tensor(sample_rate, dtype=torch.int32) for _ in infos],
- },
- )
-
- def _finalize_stage_output(
- self,
- *,
- output_key: str,
- outputs: list[torch.Tensor],
- sample_rates: list[torch.Tensor],
- out_device: torch.device,
- out_dtype: torch.dtype,
- hidden_rows: int | None = None,
- ) -> OmniOutput:
- multimodal_outputs: dict[str, Any] = {output_key: outputs, "sr": sample_rates}
- if hidden_rows is not None:
- text_hidden_states = torch.zeros((hidden_rows, 1), device=out_device, dtype=out_dtype)
- elif outputs:
- outputs_tensor = torch.stack(outputs)
- text_hidden_states = (
- outputs_tensor.unsqueeze(-1)
- if outputs_tensor.ndim == 1
- else outputs_tensor.reshape(-1, outputs_tensor.shape[-1])
- )
- else:
- text_hidden_states = torch.zeros((0, 1), device=out_device, dtype=out_dtype)
- text_hidden_states = text_hidden_states.to(device=out_device, dtype=out_dtype)
- return OmniOutput(
- text_hidden_states=text_hidden_states,
- multimodal_outputs=multimodal_outputs,
- )
-
- def _forward_vae_stage(
- self,
- infos: list[dict[str, Any]],
- *,
- sample_rate: int,
- async_chunk: bool,
- out_device: torch.device,
- out_dtype: torch.dtype,
- ) -> OmniOutput:
- if all(self._extract_val(info, "latent_audio_feat", None) is None for info in infos):
- self._ar_emit_stop_token = True
- return self._make_empty_output(
- output_key="model_outputs",
- payload_factory=lambda: torch.zeros((0,), dtype=torch.float32),
- infos=infos,
- sample_rate=sample_rate,
- out_device=out_device,
- out_dtype=out_dtype,
- )
-
- outputs: list[torch.Tensor] = []
- sample_rates: list[torch.Tensor] = []
- for info in infos:
- latent_audio_feat = self._extract_val(info, "latent_audio_feat", None)
- audio_tensor = self._pipeline.decode(latent_audio_feat, trim_streaming_patch=async_chunk)
- outputs.append(audio_tensor.float().cpu())
- sample_rates.append(torch.tensor(sample_rate, dtype=torch.int32))
-
- self._ar_emit_stop_token = True
- return self._finalize_stage_output(
- output_key="model_outputs",
- outputs=outputs,
- sample_rates=sample_rates,
- out_device=out_device,
- out_dtype=out_dtype,
- )
-
- def _forward_latent_stage(
- self,
- infos: list[dict[str, Any]],
- *,
- sample_rate: int,
- async_chunk: bool,
- out_device: torch.device,
- out_dtype: torch.dtype,
- hidden_rows: int,
- ) -> OmniOutput:
- texts = [self._extract_val(info, "text", "") for info in infos]
- if all(not text for text in texts):
- self._ar_emit_stop_token = True
- return self._make_empty_output(
- output_key="latent_audio_feat",
- payload_factory=lambda: torch.zeros((0,), dtype=torch.float32),
- infos=infos,
- sample_rate=sample_rate,
- out_device=out_device,
- out_dtype=out_dtype,
- hidden_rows=hidden_rows,
- )
-
- outputs: list[torch.Tensor] = []
- sample_rates: list[torch.Tensor] = []
- last_chunk_flags: list[bool] | None = [] if async_chunk else None
- payload_finished_flags: list[bool] | None = [] if async_chunk else None
- for info in infos:
- text = self._extract_val(info, "text", "")
- cfg_value = float(self._extract_val(info, "cfg_value", 2.0))
- inference_timesteps = int(self._extract_val(info, "inference_timesteps", 10))
- min_len = int(self._extract_val(info, "min_len", 2))
- max_len = int(self._extract_val(info, "max_len", self._extract_val(info, "max_new_tokens", 4096)))
- retry_badcase = bool(self._extract_val(info, "retry_badcase", True))
- retry_badcase_max_times = int(self._extract_val(info, "retry_badcase_max_times", 3))
- retry_badcase_ratio_threshold = float(self._extract_val(info, "retry_badcase_ratio_threshold", 6.0))
- streaming_prefix_len = int(self._extract_val(info, "streaming_prefix_len", 3))
-
- request_key = self._resolve_stream_request_key(info)
- created_temp: str | None = None
-
- if async_chunk:
- terminal_pending = self._latent_stream_terminal_pending.get(request_key, 0)
- if terminal_pending > 0:
- outputs.append(torch.zeros((0,), dtype=torch.float32))
- assert last_chunk_flags is not None
- last_chunk_flags.append(True)
- assert payload_finished_flags is not None
- payload_finished_flags.append(terminal_pending == 1)
- if terminal_pending == 1:
- self._latent_stream_terminal_pending.pop(request_key, None)
- else:
- self._latent_stream_terminal_pending[request_key] = terminal_pending - 1
- sample_rates.append(torch.tensor(sample_rate, dtype=torch.int32))
- continue
-
- if request_key in self._latent_stream_completed:
- outputs.append(torch.zeros((0,), dtype=torch.float32))
- assert last_chunk_flags is not None
- last_chunk_flags.append(True)
- assert payload_finished_flags is not None
- payload_finished_flags.append(False)
- sample_rates.append(torch.tensor(sample_rate, dtype=torch.int32))
- continue
-
- if request_key not in self._latent_stream_gens:
- prompt_wav_path, prompt_text, temp_prompt_wav = self._resolve_prompt_inputs(info)
- created_temp = temp_prompt_wav
- self._latent_stream_gens[request_key] = self._pipeline.iter_latent_chunks_streaming(
- text=text,
- prompt_wav_path=prompt_wav_path,
- prompt_text=prompt_text,
- cfg_value=cfg_value,
- inference_timesteps=inference_timesteps,
- min_len=min_len,
- max_len=max_len,
- streaming_prefix_len=streaming_prefix_len,
- retry_badcase=False,
- retry_badcase_max_times=retry_badcase_max_times,
- retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
- )
- generator = self._latent_stream_gens[request_key]
- try:
- chunk_latent, is_last = next(generator)
- except StopIteration:
- self._latent_stream_gens.pop(request_key, None)
- self._latent_stream_terminal_pending[request_key] = 1
- self._latent_stream_completed.add(request_key)
- outputs.append(torch.zeros((0,), dtype=torch.float32))
- assert last_chunk_flags is not None
- last_chunk_flags.append(True)
- assert payload_finished_flags is not None
- payload_finished_flags.append(True)
- else:
- if is_last:
- self._latent_stream_gens.pop(request_key, None)
- self._latent_stream_terminal_pending[request_key] = 1
- self._latent_stream_completed.add(request_key)
- outputs.append(chunk_latent.detach().float().cpu())
- assert last_chunk_flags is not None
- last_chunk_flags.append(bool(is_last))
- assert payload_finished_flags is not None
- payload_finished_flags.append(False)
- finally:
- if created_temp is not None and os.path.exists(created_temp):
- os.unlink(created_temp)
- sample_rates.append(torch.tensor(sample_rate, dtype=torch.int32))
- continue
-
- prompt_wav_path, prompt_text, temp_prompt_wav = self._resolve_prompt_inputs(info)
- try:
- latent_audio_feat = self._pipeline.generate_latents(
- text=text,
- prompt_wav_path=prompt_wav_path,
- prompt_text=prompt_text,
- cfg_value=cfg_value,
- inference_timesteps=inference_timesteps,
- min_len=min_len,
- max_len=max_len,
- retry_badcase=retry_badcase,
- retry_badcase_max_times=retry_badcase_max_times,
- retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
- )
- outputs.append(latent_audio_feat.float().cpu())
- finally:
- if temp_prompt_wav is not None and os.path.exists(temp_prompt_wav):
- os.unlink(temp_prompt_wav)
-
- sample_rates.append(torch.tensor(sample_rate, dtype=torch.int32))
-
- self._ar_emit_stop_token = all(last_chunk_flags) if async_chunk and last_chunk_flags else True
- output = self._finalize_stage_output(
- output_key="latent_audio_feat",
- outputs=outputs,
- sample_rates=sample_rates,
- out_device=out_device,
- out_dtype=out_dtype,
- hidden_rows=hidden_rows,
- )
- if async_chunk and payload_finished_flags is not None:
- output.multimodal_outputs["finished"] = [
- torch.tensor(flag, dtype=torch.bool) for flag in payload_finished_flags
- ]
- return output
-
- def compute_logits(self, hidden_states: torch.Tensor | OmniOutput, sampling_metadata: Any = None) -> torch.Tensor:
- del sampling_metadata
- if isinstance(hidden_states, OmniOutput):
- hidden_states = hidden_states.text_hidden_states
- if hidden_states is None:
- device, dtype = self._runner_hidden_device_dtype()
- hidden_states = torch.zeros((0, 1), device=device, dtype=dtype)
- if hidden_states.ndim == 1:
- hidden_states = hidden_states.unsqueeze(-1)
- elif hidden_states.ndim > 2:
- hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
-
- vocab_size = self._get_vocab_size()
- num_rows = int(hidden_states.shape[0])
- logits = torch.zeros((num_rows, vocab_size), dtype=torch.float32, device=hidden_states.device)
- eos_id = 2 if vocab_size > 2 else 0
- safe_id = 1 if vocab_size > 1 and 1 != eos_id else 0
- emit_stop = getattr(self, "_ar_emit_stop_token", True)
- if num_rows > 0:
- if emit_stop:
- logits[:, eos_id] = 1.0e6
- else:
- logits[:, eos_id] = -1.0e9
- logits[:, safe_id] = 1.0e6
- return logits
-
- @torch.no_grad()
- def forward(
- self,
- input_ids: torch.Tensor | None = None,
- positions: torch.Tensor | None = None,
- intermediate_tensors: Any = None,
- inputs_embeds: torch.Tensor | None = None,
- runtime_additional_information: list[dict[str, Any]] | None = None,
- model_intermediate_buffer: list[dict[str, Any]] | None = None,
- **kwargs: Any,
- ) -> OmniOutput:
- del positions, intermediate_tensors, inputs_embeds, kwargs
- self._ensure_model_loaded()
- out_device, out_dtype = self._runner_hidden_device_dtype()
- if input_ids is not None and input_ids.device.type == out_device.type:
- out_device = input_ids.device
-
- infos = model_intermediate_buffer or runtime_additional_information or [{}]
- hidden_rows = len(infos)
- if input_ids is not None and len(input_ids.shape) > 0:
- hidden_rows = max(hidden_rows, int(input_ids.shape[0]))
- sample_rate = int(getattr(self._pipeline, "sample_rate", 24000))
- async_chunk = bool(getattr(self.vllm_config.model_config, "async_chunk", False))
- if self.model_stage in self._VAE_STAGES:
- infos = self._maybe_recover_vae_infos(infos, input_ids, async_chunk=async_chunk)
- return self._forward_vae_stage(
- infos,
- sample_rate=sample_rate,
- async_chunk=async_chunk,
- out_device=out_device,
- out_dtype=out_dtype,
- )
- if self.model_stage in self._LATENT_STAGES:
- return self._forward_latent_stage(
- infos,
- sample_rate=sample_rate,
- async_chunk=async_chunk,
- out_device=out_device,
- out_dtype=out_dtype,
- hidden_rows=hidden_rows,
- )
- raise ValueError(f"Unsupported VoxCPM model_stage at runtime: {self.model_stage}")
-
- def make_empty_intermediate_tensors(
- self, batch_size: int, dtype: torch.dtype, device: torch.device
- ) -> IntermediateTensors:
- del batch_size, dtype, device
- return {}
-
-
-__all__ = ["VoxCPMForConditionalGeneration"]
diff --git a/vllm_omni/model_executor/models/voxcpm/voxcpm_loader.py b/vllm_omni/model_executor/models/voxcpm/voxcpm_loader.py
deleted file mode 100644
index dac7117cad8..00000000000
--- a/vllm_omni/model_executor/models/voxcpm/voxcpm_loader.py
+++ /dev/null
@@ -1,247 +0,0 @@
-from __future__ import annotations
-
-import importlib
-import json
-import os
-import shutil
-import sys
-import tempfile
-from contextlib import contextmanager
-from hashlib import sha256
-from pathlib import Path
-from typing import Any
-from unittest.mock import patch
-
-import numpy as np
-import torch
-from vllm.config import VllmConfig
-from vllm.logger import init_logger
-
-logger = init_logger(__name__)
-
-
-def _iter_voxcpm_src_candidates() -> list[Path]:
- candidates: list[Path] = []
- env_path = os.environ.get("VLLM_OMNI_VOXCPM_CODE_PATH")
- if env_path:
- candidates.append(Path(env_path).expanduser())
-
- repo_root = Path(__file__).resolve().parents[4]
- candidates.append(repo_root.parent / "VoxCPM" / "src")
-
- unique_candidates: list[Path] = []
- seen: set[str] = set()
- for candidate in candidates:
- candidate_key = str(candidate)
- if candidate_key in seen:
- continue
- seen.add(candidate_key)
- unique_candidates.append(candidate)
- return unique_candidates
-
-
-def _prepend_voxcpm_src(candidate: Path) -> None:
- candidate_str = str(candidate)
- if candidate_str not in sys.path:
- sys.path.insert(0, candidate_str)
-
-
-def _import_voxcpm_attrs(module_name: str, *attr_names: str) -> tuple[Any, ...]:
- last_exc: ImportError | None = None
- for candidate in _iter_voxcpm_src_candidates():
- if not candidate.exists():
- continue
- _prepend_voxcpm_src(candidate)
- try:
- module = importlib.import_module(module_name)
- return tuple(getattr(module, attr_name) for attr_name in attr_names)
- except ImportError as exc:
- last_exc = exc
-
- try:
- module = importlib.import_module(module_name)
- return tuple(getattr(module, attr_name) for attr_name in attr_names)
- except ImportError as exc:
- last_exc = exc
-
- raise ImportError(f"Failed to import {module_name}.") from last_exc
-
-
-def _import_voxcpm_base_model_class():
- """Import upstream ``VoxCPMModel`` from ``VoxCPM/src/voxcpm`` (env, sibling tree, or pip)."""
- try:
- (VoxCPMModel,) = _import_voxcpm_attrs("voxcpm.model.voxcpm", "VoxCPMModel")
- return VoxCPMModel
- except ImportError as exc:
- raise ImportError(
- "Failed to import VoxCPMModel. Install the `voxcpm` package or set "
- "`VLLM_OMNI_VOXCPM_CODE_PATH` to the VoxCPM repository `src` directory "
- "(the parent of the `voxcpm` package that contains `model/` and `modules/`)."
- ) from exc
-
-
-def _import_voxcpm_audio_vae_classes():
- try:
- return _import_voxcpm_attrs("voxcpm.modules.audiovae", "AudioVAE", "AudioVAEConfig")
- except ImportError as exc:
- raise ImportError(
- "Failed to import VoxCPM AudioVAE. Install the `voxcpm` package or set "
- "`VLLM_OMNI_VOXCPM_CODE_PATH` to the VoxCPM repository `src` directory."
- ) from exc
-
-
-def _device_to_string(device: torch.device) -> str:
- if device.index is None:
- return device.type
- return f"{device.type}:{device.index}"
-
-
-def _normalize_dtype_name(dtype: Any) -> str | None:
- if dtype is None:
- return None
- if isinstance(dtype, torch.dtype):
- mapping = {
- torch.bfloat16: "bfloat16",
- torch.float16: "float16",
- torch.float32: "float32",
- }
- return mapping.get(dtype, str(dtype).removeprefix("torch."))
- dtype_str = str(dtype)
- return dtype_str.removeprefix("torch.")
-
-
-def _resolve_runtime_device(vllm_config: VllmConfig) -> torch.device:
- try:
- from vllm_omni.platforms import current_omni_platform
-
- return current_omni_platform.get_torch_device()
- except Exception:
- pass
-
- device = getattr(getattr(vllm_config, "device_config", None), "device", None)
- if isinstance(device, torch.device):
- return device
- if device:
- return torch.device(device)
- return torch.device("cpu")
-
-
-def _prepare_runtime_model_dir(
- model_path: str | Path,
- *,
- target_device: torch.device,
- target_dtype: str | None,
-) -> str:
- source_dir = Path(model_path)
- config_path = source_dir / "config.json"
- if not config_path.exists():
- return str(source_dir)
-
- config_text = config_path.read_text()
- config_dict = json.loads(config_text)
- desired_device = target_device.type
- desired_dtype = target_dtype or config_dict.get("dtype")
-
- if config_dict.get("device") == desired_device and config_dict.get("dtype") == desired_dtype:
- return str(source_dir)
-
- digest = sha256(f"{source_dir.resolve()}:{config_text}:{desired_device}:{desired_dtype}".encode()).hexdigest()[:16]
- runtime_dir = Path(tempfile.gettempdir()) / "vllm_omni_voxcpm_runtime" / digest
- runtime_dir.mkdir(parents=True, exist_ok=True)
-
- for entry in source_dir.iterdir():
- target = runtime_dir / entry.name
- if entry.name == "config.json" or target.exists():
- continue
- try:
- target.symlink_to(entry, target_is_directory=entry.is_dir())
- except OSError as exc:
- logger.warning(
- "Falling back to copying VoxCPM runtime artifact %s into %s because symlink creation failed: %s",
- entry,
- runtime_dir,
- exc,
- )
- if entry.is_dir():
- shutil.copytree(entry, target, dirs_exist_ok=True)
- else:
- shutil.copy2(entry, target)
-
- patched_config = dict(config_dict)
- patched_config["device"] = desired_device
- if desired_dtype is not None:
- patched_config["dtype"] = desired_dtype
- (runtime_dir / "config.json").write_text(json.dumps(patched_config, indent=2, sort_keys=True))
- return str(runtime_dir)
-
-
-@contextmanager
-def _force_cuda_available_for_npu(device: torch.device):
- if device.type != "npu":
- yield
- return
-
- with patch("torch.cuda.is_available", return_value=True):
- yield
-
-
-def _is_torchcodec_load_error(exc: BaseException) -> bool:
- message = str(exc).lower()
- return "torchcodec" in message or "load_with_torchcodec" in message
-
-
-def _load_audio_with_soundfile(
- prompt_wav_path: str,
- *,
- sample_rate: int,
-) -> torch.Tensor:
- try:
- import soundfile as sf
- except ImportError:
- raise
-
- audio_np, source_sr = sf.read(prompt_wav_path, dtype="float32", always_2d=True)
- audio = torch.from_numpy(np.ascontiguousarray(audio_np.T))
-
- if audio.size(0) > 1:
- audio = audio.mean(dim=0, keepdim=True)
-
- if int(source_sr) != int(sample_rate):
- try:
- import torchaudio
- except ImportError as exc:
- raise ImportError("torchaudio is required for resampling prompt audio.") from exc
- audio = torchaudio.functional.resample(audio, int(source_sr), int(sample_rate))
-
- return audio
-
-
-def _build_prompt_cache_with_soundfile(model: Any, *args: Any, **kwargs: Any) -> dict[str, Any]:
- if args:
- prompt_text = args[0]
- prompt_wav_path = args[1] if len(args) > 1 else kwargs.get("prompt_wav_path")
- else:
- prompt_text = kwargs.get("prompt_text")
- prompt_wav_path = kwargs.get("prompt_wav_path")
-
- if not prompt_text or not prompt_wav_path:
- raise ValueError("prompt_text and prompt_wav_path are required")
-
- audio = _load_audio_with_soundfile(prompt_wav_path, sample_rate=int(model.sample_rate))
-
- patch_len = model.patch_size * model.chunk_size
- if audio.size(1) % patch_len != 0:
- padding_size = patch_len - audio.size(1) % patch_len
- audio = torch.nn.functional.pad(audio, (padding_size, 0))
-
- audio_feat = model.audio_vae.encode(audio.to(model.device), model.sample_rate).cpu()
- audio_feat = audio_feat.view(
- model.audio_vae.latent_dim,
- -1,
- model.patch_size,
- ).permute(1, 2, 0)
-
- return {
- "prompt_text": prompt_text,
- "audio_feat": audio_feat,
- }
diff --git a/vllm_omni/model_executor/models/voxcpm/voxcpm_runtime_utils.py b/vllm_omni/model_executor/models/voxcpm/voxcpm_runtime_utils.py
deleted file mode 100644
index 36b4282c2d7..00000000000
--- a/vllm_omni/model_executor/models/voxcpm/voxcpm_runtime_utils.py
+++ /dev/null
@@ -1,44 +0,0 @@
-from __future__ import annotations
-
-import json
-import shutil
-from pathlib import Path
-
-
-def resolve_voxcpm_model_dir(model: str) -> Path:
- model_path = Path(model).expanduser()
- if model_path.exists():
- return model_path
-
- from huggingface_hub import snapshot_download
-
- return Path(snapshot_download(repo_id=model))
-
-
-def prepare_voxcpm_hf_config_dir(model_dir: str | Path, hf_config_dir: str | Path) -> Path:
- model_dir = Path(model_dir).expanduser()
- hf_config_dir = Path(hf_config_dir).expanduser()
- hf_config_dir.mkdir(parents=True, exist_ok=True)
-
- source_config_path = model_dir / "config.json"
- if not source_config_path.exists():
- raise FileNotFoundError(f"VoxCPM config.json not found under {model_dir}")
-
- config_path = hf_config_dir / "config.json"
- shutil.copy2(source_config_path, config_path)
-
- source_generation_config_path = model_dir / "generation_config.json"
- if source_generation_config_path.exists():
- shutil.copy2(source_generation_config_path, hf_config_dir / "generation_config.json")
-
- config_dict = json.loads(config_path.read_text(encoding="utf-8"))
- config_dict["model_type"] = "voxcpm"
- config_dict.setdefault("architectures", ["VoxCPMForConditionalGeneration"])
- config_path.write_text(json.dumps(config_dict, indent=2, ensure_ascii=False), encoding="utf-8")
- return hf_config_dir
-
-
-__all__ = [
- "prepare_voxcpm_hf_config_dir",
- "resolve_voxcpm_model_dir",
-]
diff --git a/vllm_omni/model_executor/models/voxcpm/voxcpm_stage_wrappers.py b/vllm_omni/model_executor/models/voxcpm/voxcpm_stage_wrappers.py
deleted file mode 100644
index f4446c796e4..00000000000
--- a/vllm_omni/model_executor/models/voxcpm/voxcpm_stage_wrappers.py
+++ /dev/null
@@ -1,185 +0,0 @@
-from __future__ import annotations
-
-import os
-from collections.abc import Generator
-from typing import Any
-
-import torch
-import torch.nn as nn
-from einops import rearrange
-
-
-class _DirectVoxCPMLatentGenerator:
- def __init__(self, tts_model: Any):
- self.tts_model = tts_model
- self.sample_rate = int(getattr(tts_model, "sample_rate", 24000))
-
- def generate_latents(
- self,
- *,
- text: str,
- prompt_wav_path: str | None = None,
- prompt_text: str | None = None,
- cfg_value: float = 2.0,
- inference_timesteps: int = 10,
- min_len: int = 2,
- max_len: int = 4096,
- retry_badcase: bool = True,
- retry_badcase_max_times: int = 3,
- retry_badcase_ratio_threshold: float = 6.0,
- ) -> torch.Tensor:
- if not isinstance(text, str) or not text.strip():
- raise ValueError("target text must be a non-empty string")
- if (prompt_wav_path is None) != (prompt_text is None):
- raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
- if prompt_wav_path is not None and not os.path.exists(prompt_wav_path):
- raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
-
- prompt_cache = None
- if prompt_wav_path is not None and prompt_text is not None:
- prompt_cache = self.tts_model.build_prompt_cache(
- prompt_text=prompt_text,
- prompt_wav_path=prompt_wav_path,
- )
-
- gen_kw = dict(
- target_text=" ".join(text.split()),
- prompt_cache=prompt_cache,
- min_len=min_len,
- max_len=max_len,
- inference_timesteps=inference_timesteps,
- cfg_value=cfg_value,
- retry_badcase=retry_badcase,
- retry_badcase_max_times=retry_badcase_max_times,
- retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
- )
- latent_entry = getattr(self.tts_model, "generate_latents_with_prompt_cache", None)
- if latent_entry is not None:
- _, _, pred_audio_feat = latent_entry(**gen_kw)
- else:
- try:
- _, _, pred_audio_feat = self.tts_model.generate_with_prompt_cache(
- **gen_kw,
- latents_only=True,
- )
- except TypeError:
- _, _, pred_audio_feat = self.tts_model.generate_with_prompt_cache(**gen_kw)
- return pred_audio_feat.detach().cpu().to(torch.float32)
-
- def iter_latent_chunks_streaming(
- self,
- *,
- text: str,
- prompt_wav_path: str | None = None,
- prompt_text: str | None = None,
- cfg_value: float = 2.0,
- inference_timesteps: int = 10,
- min_len: int = 2,
- max_len: int = 4096,
- streaming_prefix_len: int = 3,
- retry_badcase: bool = False,
- retry_badcase_max_times: int = 3,
- retry_badcase_ratio_threshold: float = 6.0,
- ) -> Generator[tuple[torch.Tensor, bool], None, None]:
- """Yield ``(latent_window, is_last_chunk)`` for Omni async_chunk latent to VAE."""
- if not isinstance(text, str) or not text.strip():
- raise ValueError("target text must be a non-empty string")
- if (prompt_wav_path is None) != (prompt_text is None):
- raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
- if prompt_wav_path is not None and not os.path.exists(prompt_wav_path):
- raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
-
- prompt_cache = None
- if prompt_wav_path is not None and prompt_text is not None:
- prompt_cache = self.tts_model.build_prompt_cache(
- prompt_text=prompt_text,
- prompt_wav_path=prompt_wav_path,
- )
-
- gen_kw = dict(
- target_text=" ".join(text.split()),
- prompt_cache=prompt_cache,
- min_len=min_len,
- max_len=max_len,
- inference_timesteps=inference_timesteps,
- cfg_value=cfg_value,
- retry_badcase=retry_badcase,
- retry_badcase_max_times=retry_badcase_max_times,
- retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
- streaming_prefix_len=streaming_prefix_len,
- )
- stream_entry = getattr(self.tts_model, "generate_latents_with_prompt_cache_streaming", None)
- if stream_entry is not None:
- gen = stream_entry(**gen_kw)
- else:
- fallback_stream_entry = getattr(self.tts_model, "generate_with_prompt_cache_streaming", None)
- if fallback_stream_entry is not None:
- gen = fallback_stream_entry(**gen_kw, latents_only=True)
- else:
- gen = self.tts_model._generate_with_prompt_cache(streaming=True, latents_only=True, **gen_kw)
-
- iterator = iter(gen)
- previous = next(iterator, None)
- while previous is not None:
- current = next(iterator, None)
- _, _target_tok, chunk_latent = previous
- if not isinstance(chunk_latent, torch.Tensor):
- chunk_latent = torch.as_tensor(chunk_latent)
- yield chunk_latent, current is None
- previous = current
-
-
-class _DirectVoxCPMAudioVAE:
- def __init__(self, audio_vae: nn.Module, *, patch_size: int = 2):
- self.audio_vae = audio_vae
- self.sample_rate = int(getattr(audio_vae, "sample_rate", 24000))
- self.latent_dim = int(getattr(audio_vae, "latent_dim", 64))
- self.patch_size = int(patch_size)
- self._chunk_size = int(getattr(audio_vae, "chunk_size", 1))
- self._stream_audio_patch_samples = max(1, self.patch_size * self._chunk_size)
-
- def _prepare_latents_for_decode(self, latent_audio_feat: Any) -> torch.Tensor:
- latents = latent_audio_feat
- if not isinstance(latents, torch.Tensor):
- latents = torch.tensor(latents, dtype=torch.float32)
- latents = latents.detach().to(torch.float32)
-
- if latents.ndim == 3:
- if latents.shape[-1] == self.latent_dim:
- latents = rearrange(latents, "t p d -> 1 d (t p)")
- elif latents.shape[1] == self.latent_dim:
- latents = latents.contiguous()
- else:
- raise ValueError(f"Unsupported latent_audio_feat shape: {tuple(latents.shape)}")
- elif latents.ndim == 2:
- if latents.shape[0] == self.latent_dim:
- latents = latents.unsqueeze(0)
- elif latents.shape[1] == self.latent_dim:
- latents = rearrange(latents, "t d -> 1 d t")
- else:
- raise ValueError(f"Unsupported latent_audio_feat shape: {tuple(latents.shape)}")
- else:
- raise ValueError(f"Unsupported latent_audio_feat ndim: {latents.ndim}")
-
- return latents
-
- @torch.no_grad()
- def decode(self, latent_audio_feat: Any, *, trim_streaming_patch: bool = False) -> torch.Tensor:
- latents = self._prepare_latents_for_decode(latent_audio_feat)
- device = next(self.audio_vae.parameters()).device
- raw = self.audio_vae.decode(latents.to(device=device, dtype=torch.float32))
- if isinstance(raw, dict):
- audio = raw.get("audio")
- if audio is None:
- audio = next(v for v in raw.values() if isinstance(v, torch.Tensor))
- else:
- audio = raw
- if audio.dim() == 3:
- stream = audio.squeeze(1)
- elif audio.dim() == 2:
- stream = audio
- else:
- stream = audio.reshape(audio.shape[0], -1)
- if trim_streaming_patch:
- stream = stream[..., -self._stream_audio_patch_samples :]
- return stream.reshape(-1).detach().cpu().to(torch.float32)
diff --git a/vllm_omni/model_executor/models/voxcpm2/__init__.py b/vllm_omni/model_executor/models/voxcpm2/__init__.py
deleted file mode 100644
index 77bd8dfb518..00000000000
--- a/vllm_omni/model_executor/models/voxcpm2/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from .voxcpm2_talker import VoxCPM2TalkerForConditionalGeneration
-
-__all__ = ["VoxCPM2TalkerForConditionalGeneration"]
diff --git a/vllm_omni/model_executor/models/voxcpm2/minicpm4_hf_compat.py b/vllm_omni/model_executor/models/voxcpm2/minicpm4_hf_compat.py
deleted file mode 100644
index cb3101b16ac..00000000000
--- a/vllm_omni/model_executor/models/voxcpm2/minicpm4_hf_compat.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""fp32 RoPE + MLP matching native VoxCPM2 numerics.
-
-Exports: _MiniCPMLongRoPE, _MiniCPMMLP, _apply_rotary_pos_emb
-"""
-
-from __future__ import annotations
-
-import math
-from typing import Any
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-# ===================================================================
-# Primitives
-# ===================================================================
-
-
-def _rotate_half(x: torch.Tensor) -> torch.Tensor:
- x1, x2 = x.chunk(2, dim=-1)
- return torch.cat((-x2, x1), dim=-1)
-
-
-def _apply_rotary_pos_emb(
- q: torch.Tensor,
- k: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
-) -> tuple[torch.Tensor, torch.Tensor]:
- """Apply rotary embeddings in float32."""
- orig_dtype = q.dtype
- q, k = q.to(torch.float32), k.to(torch.float32)
- q_embed = (q * cos) + (_rotate_half(q) * sin)
- k_embed = (k * cos) + (_rotate_half(k) * sin)
- return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
-
-
-# ===================================================================
-# LongRoPE — must match native computation order exactly
-# ===================================================================
-
-
-class _MiniCPMLongRoPE(nn.Module):
- """LongRoPE matching native computation order."""
-
- def __init__(
- self,
- hidden_size: int,
- num_attention_heads: int,
- kv_channels: int | None,
- rope_theta: float,
- max_position_embeddings: int,
- rope_scaling: dict[str, Any],
- ) -> None:
- super().__init__()
- self.dim = kv_channels if kv_channels else hidden_size // num_attention_heads
- self.base = rope_theta
- self.max_position_embeddings = max_position_embeddings
- self.short_factor = rope_scaling["short_factor"]
- self.long_factor = rope_scaling["long_factor"]
- self.original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
-
- scale = self.max_position_embeddings / self.original_max_position_embeddings
- self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
-
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
- self.register_buffer("inv_freq", inv_freq, persistent=False)
-
- self.max_seq_len_cached = 0
- self.register_buffer("cos_cached", torch.empty(0), persistent=False)
- self.register_buffer("sin_cached", torch.empty(0), persistent=False)
- self._set_cos_sin_cache(self.max_position_embeddings, self.inv_freq.device, torch.float32)
-
- def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
- self.max_seq_len_cached = seq_len
- t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
-
- ext_factors = torch.tensor(
- self.long_factor if seq_len > self.original_max_position_embeddings else self.short_factor,
- dtype=torch.float32,
- device=device,
- )
-
- freqs = torch.mul(
- torch.outer(t, 1.0 / ext_factors).to(device=device),
- self.inv_freq.to(device=device).to(dtype),
- )
- emb = torch.cat((freqs, freqs), dim=-1)
- self.cos_cached = emb.cos().to(dtype) * self.scaling_factor
- self.sin_cached = emb.sin().to(dtype) * self.scaling_factor
-
- def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
- return self.cos_cached[position_ids], self.sin_cached[position_ids]
-
-
-# ===================================================================
-# MLP
-# ===================================================================
-
-
-class _MiniCPMMLP(nn.Module):
- """SiLU-gated MLP matching native MiniCPMMLP."""
-
- def __init__(self, hidden_size: int, intermediate_size: int) -> None:
- super().__init__()
- self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
diff --git a/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py b/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py
deleted file mode 100644
index b87ec5aafef..00000000000
--- a/vllm_omni/model_executor/models/voxcpm2/minicpm4_paged.py
+++ /dev/null
@@ -1,457 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""MiniCPM4 with PagedAttention + fp32 RoPE/RMSNorm for VoxCPM2.
-
-Uses vllm Attention for KV cache, keeps fp32 precision ops from
-minicpm4_hf_compat.py to match native VoxCPM2 numerics.
-"""
-
-from __future__ import annotations
-
-import math
-from collections.abc import Iterable
-from typing import Any
-
-import torch
-import torch.nn as nn
-from vllm.config import CacheConfig, VllmConfig
-from vllm.logger import init_logger
-from vllm.model_executor.layers.attention import Attention
-from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-from vllm.model_executor.models.utils import make_empty_intermediate_tensors_factory
-from vllm.sequence import IntermediateTensors
-
-from .minicpm4_hf_compat import (
- _apply_rotary_pos_emb,
- _MiniCPMLongRoPE,
- _MiniCPMMLP,
-)
-
-logger = init_logger(__name__)
-
-
-def _resolve_lm_cfg(config: Any) -> Any:
- """Extract lm_config from VoxCPM2Config, converting dict to namespace if needed."""
- lm_cfg = getattr(config, "lm_config", config)
- if isinstance(lm_cfg, dict):
-
- class _Cfg:
- pass
-
- c = _Cfg()
- for k, v in lm_cfg.items():
- setattr(c, k, v)
- return c
- return lm_cfg
-
-
-# ===================================================================
-# Attention with vllm PagedAttention backend
-# ===================================================================
-
-
-class _PagedMiniCPM4Attention(nn.Module):
- """PagedAttention + fp32 RoPE with separate q/k/v projections."""
-
- def __init__(
- self,
- hidden_size: int,
- num_attention_heads: int,
- num_key_value_heads: int,
- kv_channels: int | None,
- layer_idx: int,
- cache_config: CacheConfig | None = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.layer_idx = layer_idx
- self.hidden_size = hidden_size
- self.num_heads = num_attention_heads
- self.head_dim = kv_channels if kv_channels else hidden_size // num_attention_heads
- self.num_kv_heads = num_key_value_heads
- self.q_size = self.num_heads * self.head_dim
- self.kv_size = self.num_kv_heads * self.head_dim
-
- self.q_proj = nn.Linear(hidden_size, self.q_size, bias=False)
- self.k_proj = nn.Linear(hidden_size, self.kv_size, bias=False)
- self.v_proj = nn.Linear(hidden_size, self.kv_size, bias=False)
- self.o_proj = nn.Linear(self.q_size, hidden_size, bias=False)
- self._fused_qkv_weight: torch.Tensor | None = None
-
- self.attn = Attention(
- self.num_heads,
- self.head_dim,
- scale=self.head_dim**-0.5,
- num_kv_heads=self.num_kv_heads,
- cache_config=cache_config,
- prefix=f"{prefix}.attn",
- )
-
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- rope_emb: _MiniCPMLongRoPE | None = None,
- ) -> torch.Tensor:
- """Forward: fused QKV → fp32 RoPE → PagedAttention → o_proj."""
- if self._fused_qkv_weight is None:
- self._fused_qkv_weight = torch.cat(
- [
- self.q_proj.weight,
- self.k_proj.weight,
- self.v_proj.weight,
- ],
- dim=0,
- ).detach()
- qkv = nn.functional.linear(hidden_states, self._fused_qkv_weight)
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
-
- if rope_emb is not None:
- cos, sin = rope_emb(positions)
- bsz = q.shape[0]
- q_r = q.view(bsz, self.num_heads, self.head_dim)
- k_r = k.view(bsz, self.num_kv_heads, self.head_dim)
- q_r = q_r.unsqueeze(0).transpose(1, 2) # [1, heads, n_tokens, dim]
- k_r = k_r.unsqueeze(0).transpose(1, 2) # [1, kv_heads, n_tokens, dim]
- q_r, k_r = _apply_rotary_pos_emb(q_r, k_r, cos, sin)
- q = q_r.transpose(1, 2).squeeze(0).reshape(bsz, -1) # [n_tokens, q_size]
- k = k_r.transpose(1, 2).squeeze(0).reshape(bsz, -1) # [n_tokens, kv_size]
-
- attn_output = self.attn(q, k, v)
-
- output = self.o_proj(attn_output)
- return output
-
-
-# ===================================================================
-# Decoder Layer
-# ===================================================================
-
-
-class _PagedMiniCPM4DecoderLayer(nn.Module):
- """Decoder layer: PagedAttention + fp32 RMSNorm + muP scale_depth."""
-
- def __init__(
- self,
- hidden_size: int,
- intermediate_size: int,
- num_attention_heads: int,
- num_key_value_heads: int,
- kv_channels: int | None,
- rms_norm_eps: float,
- layer_idx: int,
- num_hidden_layers: int,
- use_mup: bool,
- scale_depth: float,
- cache_config: CacheConfig | None = None,
- prefix: str = "",
- ) -> None:
- super().__init__()
- self.self_attn = _PagedMiniCPM4Attention(
- hidden_size=hidden_size,
- num_attention_heads=num_attention_heads,
- num_key_value_heads=num_key_value_heads,
- kv_channels=kv_channels,
- layer_idx=layer_idx,
- cache_config=cache_config,
- prefix=f"{prefix}.self_attn",
- )
- self.mlp = _MiniCPMMLP(hidden_size, intermediate_size)
- self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
- self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
-
- self.use_mup = use_mup
- self.scale_depth = scale_depth
- self.num_hidden_layers = num_hidden_layers
-
- def _residual_scale(self) -> float:
- if self.use_mup:
- return self.scale_depth / math.sqrt(self.num_hidden_layers)
- return 1.0
-
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- residual: torch.Tensor | None,
- rope_emb: _MiniCPMLongRoPE | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- # Pre-norm + attention
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- hidden_states = self.self_attn(positions, hidden_states, rope_emb)
-
- scale = self._residual_scale()
- if scale != 1.0:
- hidden_states = residual + hidden_states * scale
- else:
- hidden_states = residual + hidden_states
-
- # Pre-norm + FFN
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
-
- if scale != 1.0:
- hidden_states = residual + hidden_states * scale
- else:
- hidden_states = residual + hidden_states
-
- return hidden_states, None
-
-
-# ===================================================================
-# Full Model
-# ===================================================================
-
-
-class MiniCPM4PagedForVoxCPM2(nn.Module):
- """PagedAttention base_lm (28 layers) for VoxCPM2 scaffold."""
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
- super().__init__()
- config = vllm_config.model_config.hf_config
- cache_config = vllm_config.cache_config
- self.config = config
-
- lm_cfg = _resolve_lm_cfg(config)
-
- hidden_size = lm_cfg.hidden_size
- num_hidden_layers = lm_cfg.num_hidden_layers
- kv_channels = getattr(lm_cfg, "kv_channels", None)
-
- self.vocab_size = lm_cfg.vocab_size
- self.embed_tokens = nn.Embedding(self.vocab_size, hidden_size)
-
- rope_scaling = getattr(lm_cfg, "rope_scaling", None)
- if isinstance(rope_scaling, dict):
- rope_scaling_dict = rope_scaling
- elif hasattr(rope_scaling, "__dict__"):
- rope_scaling_dict = {
- "short_factor": rope_scaling.short_factor,
- "long_factor": rope_scaling.long_factor,
- "original_max_position_embeddings": rope_scaling.original_max_position_embeddings,
- }
- else:
- rope_scaling_dict = {}
-
- no_rope = getattr(lm_cfg, "no_rope", False)
- if not no_rope:
- self.rope_emb = _MiniCPMLongRoPE(
- hidden_size=hidden_size,
- num_attention_heads=lm_cfg.num_attention_heads,
- kv_channels=kv_channels,
- rope_theta=getattr(lm_cfg, "rope_theta", 10000.0),
- max_position_embeddings=getattr(lm_cfg, "max_position_embeddings", 32768),
- rope_scaling=rope_scaling_dict,
- )
- else:
- self.rope_emb = None
-
- self.layers = nn.ModuleList(
- [
- _PagedMiniCPM4DecoderLayer(
- hidden_size=hidden_size,
- intermediate_size=lm_cfg.intermediate_size,
- num_attention_heads=lm_cfg.num_attention_heads,
- num_key_value_heads=lm_cfg.num_key_value_heads,
- kv_channels=kv_channels,
- rms_norm_eps=lm_cfg.rms_norm_eps,
- layer_idx=i,
- num_hidden_layers=num_hidden_layers,
- use_mup=getattr(lm_cfg, "use_mup", False),
- scale_depth=getattr(lm_cfg, "scale_depth", 1.0),
- cache_config=cache_config,
- prefix=f"{prefix}.layers.{i}",
- )
- for i in range(num_hidden_layers)
- ]
- )
-
- self.norm = RMSNorm(hidden_size, eps=lm_cfg.rms_norm_eps)
-
- self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
- ["hidden_states", "residual"], hidden_size
- )
-
- use_mup = getattr(lm_cfg, "use_mup", False)
- self._scale_emb = getattr(lm_cfg, "scale_emb", 1.0) if use_mup else 1.0
- self._compiled_layers: set[int] = set()
-
- def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor:
- return self.embed_tokens(input_ids) * self._scale_emb
-
- def forward(
- self,
- input_ids: torch.Tensor | None,
- positions: torch.Tensor,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **kwargs: Any,
- ) -> torch.Tensor | IntermediateTensors:
- if inputs_embeds is not None:
- hidden_states = inputs_embeds
- else:
- hidden_states = self.embed_input_ids(input_ids)
-
- residual = None
- for layer in self.layers:
- hidden_states, residual = layer(
- positions,
- hidden_states,
- residual,
- self.rope_emb,
- )
-
- hidden_states = self.norm(hidden_states)
- return hidden_states
-
- def precompute_fused_qkv(self) -> None:
- """Materialize fused QKV weights before CUDA Graph capture."""
- for layer in self.layers:
- attn = layer.self_attn
- if attn._fused_qkv_weight is None:
- attn._fused_qkv_weight = torch.cat(
- [attn.q_proj.weight, attn.k_proj.weight, attn.v_proj.weight],
- dim=0,
- ).detach()
-
- def compile_selective(self) -> list[str]:
- """Compile the full model forward as one graph.
-
- Earlier versions compiled ``layer.mlp`` + ``layer.self_attn.o_proj``
- (PR #2690) and then the whole ``layer`` (perf/voxcpm2-streaming-vae).
- Both still paid one Dynamo dispatch per layer per decode step.
- V3 profiling showed 1,332 per-layer dispatches (~28 layers × ~47
- decode steps) costing ~726 ms of CPU self-time for a long prompt.
-
- Compiling ``forward`` at the model level lets Dynamo unroll the
- 28-layer Python loop inside the graph. Graph breaks at
- PagedAttention produce sub-graphs but Dynamo memoises the whole
- trace once, so the per-step dispatch drops from 28 to just a few.
- """
- if self._compiled_layers:
- return []
- # Null the fused-qkv caches so the compile sees the real weight layout.
- for layer in self.layers:
- layer.self_attn._fused_qkv_weight = None
- self.forward = torch.compile(self.forward, mode="default", fullgraph=False)
- # Mark every layer as compiled so idempotent callers don't double-wrap.
- self._compiled_layers.update(range(len(self.layers)))
- return ["forward (whole model)"]
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- """Load weights from native checkpoint (base_lm. prefix pre-stripped)."""
- params_dict = dict(self.named_parameters(remove_duplicate=False))
- loaded: set[str] = set()
-
- for name, loaded_weight in weights:
- if "rotary_emb.inv_freq" in name:
- continue
- param = params_dict.get(name)
- if param is None:
- continue
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- weight_loader(param, loaded_weight)
- loaded.add(name)
-
- return loaded
-
-
-# ===================================================================
-# Residual LM with PagedAttention (no RoPE, 8 layers)
-# ===================================================================
-
-
-class MiniCPM4PagedResidualLM(nn.Module):
- """PagedAttention residual LM (8 layers, no RoPE) for VoxCPM2."""
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
- super().__init__()
- config = vllm_config.model_config.hf_config
- cache_config = vllm_config.cache_config
- self.config = config
-
- lm_cfg = _resolve_lm_cfg(config)
-
- hidden_size = lm_cfg.hidden_size
- num_hidden_layers = getattr(config, "residual_lm_num_layers", 8)
- kv_channels = getattr(lm_cfg, "kv_channels", None)
-
- self.rope_emb = None
-
- self.layers = nn.ModuleList(
- [
- _PagedMiniCPM4DecoderLayer(
- hidden_size=hidden_size,
- intermediate_size=lm_cfg.intermediate_size,
- num_attention_heads=lm_cfg.num_attention_heads,
- num_key_value_heads=lm_cfg.num_key_value_heads,
- kv_channels=kv_channels,
- rms_norm_eps=lm_cfg.rms_norm_eps,
- layer_idx=i,
- num_hidden_layers=num_hidden_layers,
- use_mup=getattr(lm_cfg, "use_mup", False),
- scale_depth=getattr(lm_cfg, "scale_depth", 1.0),
- cache_config=cache_config,
- prefix=f"{prefix}.layers.{i}",
- )
- for i in range(num_hidden_layers)
- ]
- )
-
- self.norm = RMSNorm(hidden_size, eps=lm_cfg.rms_norm_eps)
- self._compiled_layers: set[int] = set()
-
- def forward(
- self,
- positions: torch.Tensor,
- inputs_embeds: torch.Tensor,
- ) -> torch.Tensor:
- hidden_states = inputs_embeds
- residual = None
- for layer in self.layers:
- hidden_states, residual = layer(
- positions,
- hidden_states,
- residual,
- self.rope_emb,
- )
- hidden_states = self.norm(hidden_states)
- return hidden_states
-
- def precompute_fused_qkv(self) -> None:
- """Materialize fused QKV weights before CUDA Graph capture."""
- for layer in self.layers:
- attn = layer.self_attn
- if attn._fused_qkv_weight is None:
- attn._fused_qkv_weight = torch.cat(
- [attn.q_proj.weight, attn.k_proj.weight, attn.v_proj.weight],
- dim=0,
- ).detach()
-
- def compile_selective(self) -> list[str]:
- """Compile the full residual model forward as one graph (same strategy as base_lm)."""
- if self._compiled_layers:
- return []
- for layer in self.layers:
- layer.self_attn._fused_qkv_weight = None
- self.forward = torch.compile(self.forward, mode="default", fullgraph=False)
- self._compiled_layers.update(range(len(self.layers)))
- return ["forward (whole residual)"]
-
- def load_weights_from_native(self, native_residual_lm: nn.Module) -> int:
- """Load weights from native residual_lm. Returns param count."""
- params_dict = dict(self.named_parameters(remove_duplicate=False))
- loaded = 0
- for name, param in native_residual_lm.named_parameters():
- if "rotary_emb" in name:
- continue
- target = params_dict.get(name)
- if target is None:
- continue
- weight_loader = getattr(target, "weight_loader", default_weight_loader)
- weight_loader(target, param.data)
- loaded += 1
- return loaded
diff --git a/vllm_omni/model_executor/models/voxcpm2/pipeline.py b/vllm_omni/model_executor/models/voxcpm2/pipeline.py
deleted file mode 100644
index 347fce17707..00000000000
--- a/vllm_omni/model_executor/models/voxcpm2/pipeline.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""VoxCPM2 pipeline topology (frozen).
-
-Single-stage AR TTS: text → speech waveform in one pass.
-Uses the native MiniCPM4 base_lm with a per-request StaticKVCache that the
-talker restores into the paged attention layer at step boundaries.
-"""
-
-from vllm_omni.config.stage_config import (
- PipelineConfig,
- StageExecutionType,
- StagePipelineConfig,
-)
-
-VOXCPM2_PIPELINE = PipelineConfig(
- model_type="voxcpm2",
- model_arch="VoxCPM2TalkerForConditionalGeneration",
- stages=(
- StagePipelineConfig(
- stage_id=0,
- model_stage="latent_generator",
- execution_type=StageExecutionType.LLM_AR,
- input_sources=(),
- final_output=True,
- final_output_type="audio",
- owns_tokenizer=True,
- engine_output_type="audio",
- sampling_constraints={
- "detokenize": False,
- "stop_token_ids": [1],
- },
- ),
- ),
-)
diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_import_utils.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_import_utils.py
deleted file mode 100644
index 231a51bbca4..00000000000
--- a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_import_utils.py
+++ /dev/null
@@ -1,82 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Dynamic import utilities for the native VoxCPM2 package.
-
-Supports three discovery modes (first match wins):
-1. ``VLLM_OMNI_VOXCPM_CODE_PATH`` env var (explicit source tree)
-2. Sibling ``../VoxCPM/src`` relative to the vllm-omni repo root
-3. pip-installed ``voxcpm`` package (>= 2.0)
-"""
-
-from __future__ import annotations
-
-import importlib
-import os
-import sys
-from pathlib import Path
-from typing import Any
-
-from vllm.logger import init_logger
-
-logger = init_logger(__name__)
-
-
-def _iter_voxcpm2_src_candidates() -> list[Path]:
- """Yield candidate source directories for VoxCPM2."""
- candidates: list[Path] = []
- env_path = os.environ.get("VLLM_OMNI_VOXCPM_CODE_PATH")
- if env_path:
- candidates.append(Path(env_path).expanduser())
-
- repo_root = Path(__file__).resolve().parents[4]
- candidates.append(repo_root.parent / "VoxCPM" / "src")
-
- seen: set[str] = set()
- unique: list[Path] = []
- for c in candidates:
- key = str(c)
- if key not in seen:
- seen.add(key)
- unique.append(c)
- return unique
-
-
-def _prepend_src(candidate: Path) -> None:
- candidate_str = str(candidate)
- if candidate_str not in sys.path:
- sys.path.insert(0, candidate_str)
-
-
-def _import_voxcpm2_attrs(module_name: str, *attr_names: str) -> tuple[Any, ...]:
- """Import attributes from the voxcpm package, trying source tree first."""
- last_exc: ImportError | None = None
-
- for candidate in _iter_voxcpm2_src_candidates():
- if not candidate.exists():
- continue
- _prepend_src(candidate)
- try:
- mod = importlib.import_module(module_name)
- return tuple(getattr(mod, name) for name in attr_names)
- except (ImportError, AttributeError) as exc:
- last_exc = ImportError(str(exc))
- continue
-
- try:
- mod = importlib.import_module(module_name)
- return tuple(getattr(mod, name) for name in attr_names)
- except (ImportError, AttributeError) as exc:
- last_exc = ImportError(str(exc))
-
- raise ImportError(
- f"Could not import {attr_names} from {module_name}. "
- f"Install voxcpm>=2.0: pip install voxcpm. "
- f"Or set VLLM_OMNI_VOXCPM_CODE_PATH to the VoxCPM source tree. "
- f"Last error: {last_exc}"
- )
-
-
-def import_voxcpm2_core():
- """Import the VoxCPM core class used to load the native TTS model."""
- (VoxCPM,) = _import_voxcpm2_attrs("voxcpm.core", "VoxCPM")
- return VoxCPM
diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
deleted file mode 100644
index 0a9246251b0..00000000000
--- a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
+++ /dev/null
@@ -1,1290 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""VoxCPM2 AR talker — PagedAttention pipeline with per-request state.
-
-Architecture:
- MiniCPM4PagedForVoxCPM2 (base_lm, 28 layers, PagedAttention + fp32 RoPE)
- → FSQ → MiniCPM4PagedResidualLM (8 layers, PagedAttention, no RoPE)
- → LocDiT (CFM solver) → AudioVAE → 48kHz waveform
-"""
-
-from __future__ import annotations
-
-import copy
-import dataclasses
-import logging
-import math
-import os
-import time
-from collections.abc import Iterable
-from typing import Any
-
-import torch
-import torch.nn as nn
-from vllm.config import VllmConfig
-from vllm.forward_context import get_forward_context, override_forward_context
-from vllm.logger import init_logger
-from vllm.model_executor.models.utils import (
- AutoWeightsLoader,
- WeightsMapper,
- maybe_prefix,
-)
-from vllm.multimodal.audio import AudioResampler
-from vllm.sequence import IntermediateTensors
-
-from vllm_omni.model_executor.models.output_templates import OmniOutput
-
-from .minicpm4_paged import MiniCPM4PagedForVoxCPM2, MiniCPM4PagedResidualLM
-from .voxcpm2_import_utils import import_voxcpm2_core
-
-logger = init_logger(__name__)
-
-_ENABLE_PROFILING = os.environ.get("VOXCPM2_PROFILE", "0") == "1"
-
-# Lower bound for the _active_states leak-warn threshold. The effective
-# threshold is max(_ACTIVE_STATE_LEAK_WARN_MIN, 4 * max_batch_size) so small
-# deployments still get a usable floor instead of a tiny noisy one.
-_ACTIVE_STATE_LEAK_WARN_MIN = 512
-
-
-def is_cjk_char(c: str) -> bool:
- """Check if a character is a CJK ideograph."""
- cp = ord(c)
- return (
- 0x4E00 <= cp <= 0x9FFF # CJK Unified Ideographs
- or 0x3400 <= cp <= 0x4DBF # Extension A
- or 0xF900 <= cp <= 0xFAFF # Compatibility Ideographs
- or 0x20000 <= cp <= 0x2A6DF # Extension B
- or 0x2A700 <= cp <= 0x2B73F # Extension C
- or 0x2B740 <= cp <= 0x2B81F # Extension D
- or 0x2F800 <= cp <= 0x2FA1F # Compatibility Supplement
- )
-
-
-def build_cjk_split_map(tokenizer: Any) -> dict[int, list[int]]:
- """Build {multichar_cjk_token_id: [single_char_ids]} from tokenizer vocab."""
- vocab = tokenizer.get_vocab()
- split_map: dict[int, list[int]] = {}
- for token, token_id in vocab.items():
- clean = token.replace("\u2581", "")
- if len(clean) >= 2 and all(is_cjk_char(c) for c in clean):
- char_ids = tokenizer.convert_tokens_to_ids(list(clean))
- if all(cid != tokenizer.unk_token_id for cid in char_ids):
- split_map[token_id] = char_ids
- return split_map
-
-
-def split_multichar_chinese(token_ids: list[int], split_map: dict[int, list[int]]) -> list[int]:
- """Replace multichar Chinese token IDs with single-char IDs (idempotent)."""
- result: list[int] = []
- for tid in token_ids:
- expansion = split_map.get(tid)
- if expansion is not None:
- result.extend(expansion)
- else:
- result.append(tid)
- return result
-
-
-def build_voxcpm2_prompt(
- hf_config: Any,
- tokenizer: Any,
- split_map: dict[int, list[int]],
- text: str,
- ref_audio: Any | None = None,
- ref_sr: int | None = None,
- ref_text: str | None = None,
-) -> dict[str, Any]:
- """Build a VoxCPM2 prefill prompt whose ``prompt_token_ids`` length matches
- the talker-side prefill length.
-
- Used by both online serving (``serving_speech._build_voxcpm2_prompt``) and
- the offline example, so the talker-side length assertion never fires.
- """
- ids = split_multichar_chinese(tokenizer.encode(text, add_special_tokens=True), split_map)
- bos = tokenizer.bos_token_id
- if ids and ids[0] == bos:
- ids = ids[1:]
- prefill_len = len(ids) + 1 # + audio_start
- additional: dict[str, Any] = {"text_token_ids": [ids]}
- if ref_audio is not None:
- vae = hf_config.audio_vae_config
- patch_samples = hf_config.patch_size * math.prod(vae["encoder_rates"])
- ref_len = math.ceil(math.ceil(len(ref_audio) * vae["sample_rate"] / ref_sr) / patch_samples)
- if ref_text is not None:
- additional["prompt_audio"] = [[ref_audio, ref_sr]]
- additional["prompt_text"] = [ref_text]
- ref_ids = split_multichar_chinese(tokenizer.encode(ref_text, add_special_tokens=True), split_map)
- if ref_ids and ref_ids[0] == bos:
- ref_ids = ref_ids[1:]
- prefill_len += ref_len + len(ref_ids)
- else:
- additional["reference_audio"] = [[ref_audio, ref_sr]]
- prefill_len += ref_len + 2 # ref_start / ref_end
- return {"prompt_token_ids": [1] * prefill_len, "additional_information": additional}
-
-
-def _encode_raw_audio(
- tts: nn.Module,
- samples: list[float] | torch.Tensor,
- sr: int,
- padding_mode: str = "right",
-) -> torch.Tensor:
- """Encode raw audio samples using the native VoxCPM2 AudioVAE.
-
- Mirrors ``VoxCPM2Model._encode_wav`` but accepts in-memory samples
- instead of a file path (needed for the OpenAI speech API).
- """
- if isinstance(samples, list):
- audio = torch.tensor(samples, dtype=torch.float32)
- else:
- audio = samples.float()
- if audio.ndim == 1:
- audio = audio.unsqueeze(0)
-
- encode_sr = tts._encode_sample_rate
- if sr != encode_sr:
- audio_np = audio.squeeze(0).numpy()
- resampler = AudioResampler(target_sr=encode_sr)
- audio_np = resampler.resample(audio_np, orig_sr=sr)
- audio = torch.from_numpy(audio_np).unsqueeze(0)
-
- patch_len = tts.patch_size * tts.chunk_size
- if audio.size(1) % patch_len != 0:
- padding_size = patch_len - audio.size(1) % patch_len
- pad = (padding_size, 0) if padding_mode == "left" else (0, padding_size)
- audio = torch.nn.functional.pad(audio, pad)
-
- feat = tts.audio_vae.encode(audio.to(tts.device), encode_sr).cpu()
- return feat.view(tts.audio_vae.latent_dim, -1, tts.patch_size).permute(1, 2, 0)
-
-
-# ===================================================================
-# Per-request state
-# ===================================================================
-
-
-@dataclasses.dataclass
-class _RequestState:
- request_id: str
- curr_embed_for_next: torch.Tensor | None = None
- prev_feat_embed: torch.Tensor | None = None
- curr_prefix_feat_cond: torch.Tensor | None = None
- last_audio_patch_gpu: torch.Tensor | None = None
- precomputed_stop_logits: torch.Tensor | None = None
- # Rolling tail of previously-decoded latents used as VAE receptive-field context.
- # Shape (n_pad_frames, feat_dim) on GPU. None before first decode.
- decode_pad: torch.Tensor | None = None
- # Audio chunks already emitted (CPU float32), concatenated for cumulative output.
- audio_chunks: list[torch.Tensor] = dataclasses.field(default_factory=list)
- decode_step_count: int = 0
- request_start_time: float = 0.0
- prefill_completed: bool = False
- prefill_text: str = ""
- prompt_cache: dict | None = None
- prefill_masks: tuple | None = None
- is_stopping: bool = False
- last_decoded_audio: torch.Tensor | None = None
-
-
-@dataclasses.dataclass
-class _CapturedGraph:
- graph: torch.cuda.CUDAGraph
- input_embeds: torch.Tensor
- positions: torch.Tensor
- output: torch.Tensor
-
-
-# ===================================================================
-# Profiling timer
-# ===================================================================
-
-
-class _PerfTimer:
- __slots__ = ("_enabled", "_timers", "_counts", "_starts", "_pairs")
-
- def __init__(self, enabled: bool = False):
- self._enabled = enabled
- self._timers: dict[str, float] = {}
- self._counts: dict[str, int] = {}
- self._starts: dict[str, torch.cuda.Event] = {}
- self._pairs: list[tuple[str, torch.cuda.Event, torch.cuda.Event]] = []
-
- def start(self, name: str) -> None:
- if not self._enabled:
- return
- evt = torch.cuda.Event(enable_timing=True)
- evt.record()
- self._starts[name] = evt
-
- def stop(self, name: str) -> None:
- if not self._enabled or name not in self._starts:
- return
- start_evt = self._starts.pop(name)
- end_evt = torch.cuda.Event(enable_timing=True)
- end_evt.record()
- self._pairs.append((name, start_evt, end_evt))
-
- def _resolve(self) -> None:
- if not self._pairs:
- return
- torch.cuda.synchronize()
- for name, s, e in self._pairs:
- self._timers[name] = self._timers.get(name, 0.0) + s.elapsed_time(e)
- self._counts[name] = self._counts.get(name, 0) + 1
- self._pairs.clear()
-
- def breakdown(self) -> str:
- if not self._enabled:
- return ""
- self._resolve()
- if not self._timers:
- return ""
- total = self._timers.get("decode_step", sum(self._timers.values()))
- lines = [
- "=== VoxCPM2 Decode Step Breakdown ===",
- f"{'Component':<30} | {'ms':>10} | {'%':>6} | {'N':>5} | {'avg':>8}",
- "-" * 70,
- ]
- for name in sorted(self._timers):
- t, c = self._timers[name], self._counts[name]
- lines.append(f"{name:<30} | {t:>10.2f} | {t / total * 100:>5.1f}% | {c:>5} | {t / c:>8.3f}")
- lines.append(f"{'TOTAL':<30} | {total:>10.2f} |")
- return "\n".join(lines)
-
- def reset(self) -> None:
- self._timers.clear()
- self._counts.clear()
- self._starts.clear()
- self._pairs.clear()
-
-
-# ===================================================================
-# CFM pre-allocated buffers + optimized Euler solver
-# ===================================================================
-
-
-class _CFMBufferManager:
- def __init__(
- self,
- device: torch.device,
- dtype: torch.dtype,
- feat_dim: int,
- patch_size: int,
- dit_hidden_size: int,
- max_batch_size: int = 1,
- sway_sampling_coef: float = 1.0,
- ):
- n = 2 * max_batch_size # CFG doubles the batch
- self.x_in = torch.zeros(n, feat_dim, patch_size, device=device, dtype=dtype)
- self.mu_in = torch.zeros(n, dit_hidden_size, device=device, dtype=dtype)
- self.t_in = torch.zeros(n, device=device, dtype=dtype)
- self.dt_in = torch.zeros(n, device=device, dtype=dtype)
- self.cond_in = torch.zeros(n, feat_dim, patch_size, device=device, dtype=dtype)
- self.noise = torch.zeros(max_batch_size, feat_dim, patch_size, device=device, dtype=dtype)
- self._sway_coef = sway_sampling_coef
- self._device = device
- self._dtype = dtype
- self.t_span_10 = self._make_t_span(10)
-
- def _make_t_span(self, n: int) -> torch.Tensor:
- t = torch.linspace(1, 0, n + 1, device=self._device, dtype=self._dtype)
- return t + self._sway_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
-
- def get_t_span(self, n: int) -> torch.Tensor:
- return self.t_span_10 if n == 10 else self._make_t_span(n)
-
-
-def _optimized_solve_euler(
- cfm_module: nn.Module,
- mu: torch.Tensor,
- patch_size: int,
- cond: torch.Tensor,
- n_timesteps: int,
- cfg_value: float,
- buffers: _CFMBufferManager,
- use_cfg_zero_star: bool = True,
- cfg_cutoff_ratio: float = 1.0,
- perf: _PerfTimer | None = None,
-) -> torch.Tensor:
- estimator = cfm_module.estimator
- mean_mode = getattr(cfm_module, "mean_mode", False)
- b = mu.size(0)
-
- buffers.noise[:b].normal_()
- x = buffers.noise[:b].clone()
-
- t_span = buffers.get_t_span(n_timesteps)
- t, dt = t_span[0], t_span[0] - t_span[1]
- zero_init_steps = max(1, int(len(t_span) * 0.04))
- cfg_cutoff_step = max(zero_init_steps + 1, int(len(t_span) * cfg_cutoff_ratio))
-
- for step in range(1, len(t_span)):
- if use_cfg_zero_star and step <= zero_init_steps:
- dphi_dt = torch.zeros_like(x)
- elif step <= cfg_cutoff_step:
- buffers.x_in[:b].copy_(x)
- buffers.x_in[b : 2 * b].copy_(x)
- buffers.mu_in[:b].copy_(mu)
- buffers.mu_in[b : 2 * b].zero_()
- # Broadcast the 0-dim GPU scalar directly instead of
- # ``.fill_(t.item())`` — ``.item()`` forces a GPU->CPU sync.
- buffers.t_in[: 2 * b].copy_(t)
- if mean_mode:
- buffers.dt_in[: 2 * b].copy_(dt)
- else:
- buffers.dt_in.zero_()
- buffers.cond_in[:b].copy_(cond[:b])
- buffers.cond_in[b : 2 * b].copy_(cond[:b])
-
- if perf:
- perf.start(" cfm.estimator_cfg")
- raw_out = estimator(
- buffers.x_in[: 2 * b],
- buffers.mu_in[: 2 * b],
- buffers.t_in[: 2 * b],
- buffers.cond_in[: 2 * b],
- buffers.dt_in[: 2 * b],
- )
- if perf:
- perf.stop(" cfm.estimator_cfg")
-
- dphi_dt, cfg_dphi_dt = raw_out[:b], raw_out[b : 2 * b]
- if use_cfg_zero_star:
- pos = dphi_dt.reshape(b, -1)
- neg = cfg_dphi_dt.reshape(b, -1)
- st = torch.sum(pos * neg, 1, keepdim=True) / (torch.sum(neg**2, 1, keepdim=True) + 1e-8)
- st = st.view(b, *([1] * (len(dphi_dt.shape) - 1)))
- else:
- st = 1.0
- dphi_dt = cfg_dphi_dt * st + cfg_value * (dphi_dt - cfg_dphi_dt * st)
- else:
- buffers.x_in[:b].copy_(x)
- buffers.mu_in[:b].copy_(mu)
- # Broadcast the 0-dim GPU scalar; ``.fill_(t.item())`` would sync.
- buffers.t_in[:b].copy_(t)
- if mean_mode:
- buffers.dt_in[:b].copy_(dt)
- else:
- buffers.dt_in[:b].zero_()
- buffers.cond_in[:b].copy_(cond[:b])
- if perf:
- perf.start(" cfm.estimator_nocfg")
- dphi_dt = estimator(
- buffers.x_in[:b], buffers.mu_in[:b], buffers.t_in[:b], buffers.cond_in[:b], buffers.dt_in[:b]
- )
- if perf:
- perf.stop(" cfm.estimator_nocfg")
-
- x = x - dt * dphi_dt
- t = t - dt
- if step < len(t_span) - 1:
- dt = t - t_span[step + 1]
- return x
-
-
-# ===================================================================
-# Main talker model
-# ===================================================================
-
-
-class VoxCPM2TalkerForConditionalGeneration(nn.Module):
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
- self.vllm_config = vllm_config
- self.config = vllm_config.model_config.hf_config
-
- self.have_multimodal_outputs = True
- self.has_preprocess = True
- self.has_postprocess = True
-
- self.model = MiniCPM4PagedForVoxCPM2(
- vllm_config=vllm_config,
- prefix=maybe_prefix(prefix, "model"),
- )
- self.residual_model = MiniCPM4PagedResidualLM(
- vllm_config=vllm_config,
- prefix=maybe_prefix(prefix, "residual_model"),
- )
- self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors
-
- self._tts: nn.Module | None = None
- self._device = "cuda"
- self._side_dtype = torch.bfloat16
-
- self._patch_size = getattr(self.config, "patch_size", 4)
- self._feat_dim = getattr(self.config, "feat_dim", 64)
- self._sample_rate = getattr(self.config, "sample_rate", 48000)
-
- self._inference_timesteps = 10
- self._cfg_value = 2.0
- self._cfg_cutoff_ratio = 1.0
- # Number of trailing latent frames to keep as VAE receptive-field context
- # for sliding-window streaming decode. 12 matches the nanovllm reference
- # implementation and covers the longest VAE decoder receptive field.
- self._n_decode_pad_frames = 12
- self._enable_torch_compile = True
- self._compile_vae = True
- self._max_decode_steps = 2000
- self._max_batch_size = getattr(vllm_config.scheduler_config, "max_num_seqs", 4)
-
- self._perf = _PerfTimer(enabled=_ENABLE_PROFILING)
- self._cfm_buffers: _CFMBufferManager | None = None
- self._enable_cuda_graph = True
- self._scaffold_graphs: dict[int, _CapturedGraph] = {}
- self._residual_graphs: dict[int, _CapturedGraph] = {}
- self._max_cached_graphs = self._max_batch_size
- self._cuda_graph_pool: tuple | None = None
- self._cuda_graph_warmup_steps = 0
- self._cuda_graph_warmup_threshold = 3
-
- self._multichar_zh_split: dict[int, list[int]] | None = None
-
- self._active_states: dict[str, _RequestState] = {}
- self._current_request_id: str | None = None
- self._pending_requests: list[tuple[str, bool, torch.Tensor | None, int]] = []
- self._results_queue: list[tuple[str, torch.Tensor | None]] = []
- self._audio_queue: list[tuple[str, Any]] = []
- self._deferred_cleanup_ids: set[str] = set()
- self._active_state_warn_threshold = max(_ACTIVE_STATE_LEAK_WARN_MIN, 4 * self._max_batch_size)
- # one-shot by design: fires at most once per process to avoid log spam.
- self._active_state_warned = False
-
- @property
- def tts(self) -> nn.Module:
- assert self._tts is not None, "Model not loaded yet"
- return self._tts
-
- # -------------------- request state management --------------------
-
- def _get_or_create_state(self, request_id: str) -> _RequestState:
- state = self._active_states.get(request_id)
- if state is None:
- state = _RequestState(request_id=request_id)
- self._active_states[request_id] = state
- if len(self._active_states) > self._active_state_warn_threshold and not self._active_state_warned:
- logger.warning(
- "VoxCPM2: _active_states size=%d exceeds threshold %d "
- "(max_batch_size=%d); possible cleanup path leak",
- len(self._active_states),
- self._active_state_warn_threshold,
- self._max_batch_size,
- )
- self._active_state_warned = True
- return state
-
- def _switch_to_request(self, request_id: str) -> _RequestState:
- if request_id != self._current_request_id:
- self._current_request_id = request_id
- return self._get_or_create_state(request_id)
-
- def _cleanup_request(self, request_id: str) -> None:
- self._active_states.pop(request_id, None)
- if self._current_request_id == request_id:
- self._current_request_id = None
-
- def on_requests_finished(self, finished_req_ids: set[str] | list[str]) -> None:
- # Defer cleanup: on_requests_finished is called before forward(),
- # so we must not delete state that the current step may still need.
- self._deferred_cleanup_ids.update(finished_req_ids)
-
- def _flush_deferred_cleanup(self) -> None:
- for req_id in self._deferred_cleanup_ids:
- self._cleanup_request(req_id)
- self._deferred_cleanup_ids.clear()
-
- def _build_prompt_cache(
- self,
- ref_audio: Any = None,
- prompt_audio: Any = None,
- prompt_text: str | None = None,
- ) -> dict | None:
- """Build prompt cache, handling both file paths and raw audio data.
-
- The OpenAI speech API sends decoded audio as [samples_list, sr]
- via ``_resolve_ref_audio``, while offline usage sends file paths.
- """
- tts = self.tts
-
- def _is_raw_audio(v: Any) -> bool:
- import numbers
-
- return (
- isinstance(v, (list, tuple))
- and len(v) == 2
- and isinstance(v[1], numbers.Integral)
- and isinstance(v[0], (list, torch.Tensor))
- )
-
- if not _is_raw_audio(ref_audio) and not _is_raw_audio(prompt_audio):
- return tts.build_prompt_cache(
- prompt_text=prompt_text,
- prompt_wav_path=prompt_audio,
- reference_wav_path=ref_audio,
- )
-
- cache: dict[str, Any] = {}
- if ref_audio is not None:
- if _is_raw_audio(ref_audio):
- samples, sr = ref_audio
- cache["ref_audio_feat"] = _encode_raw_audio(tts, samples, sr)
- else:
- cache["ref_audio_feat"] = tts._encode_wav(ref_audio, padding_mode="right")
-
- if prompt_audio is not None and prompt_text is not None:
- cache["prompt_text"] = prompt_text
- if _is_raw_audio(prompt_audio):
- samples, sr = prompt_audio
- cache["audio_feat"] = _encode_raw_audio(tts, samples, sr, padding_mode="left")
- else:
- cache["audio_feat"] = tts._encode_wav(prompt_audio, padding_mode="left")
-
- has_ref = "ref_audio_feat" in cache
- has_prompt = "audio_feat" in cache
- if has_ref and has_prompt:
- cache["mode"] = "ref_continuation"
- elif has_ref:
- cache["mode"] = "reference"
- else:
- cache["mode"] = "continuation"
-
- return cache
-
- # -------------------- compile setup --------------------
-
- def _setup_cfm_buffers(self) -> None:
- if self._cfm_buffers is not None:
- return
- tts = self.tts
- dit_hidden = tts.lm_to_dit_proj.out_features + tts.res_to_dit_proj.out_features
- self._cfm_buffers = _CFMBufferManager(
- device=torch.device(self._device),
- dtype=self._side_dtype,
- feat_dim=self._feat_dim,
- patch_size=self._patch_size,
- dit_hidden_size=dit_hidden,
- max_batch_size=self._max_batch_size,
- )
-
- def _setup_torch_compile(self) -> None:
- if not self._enable_torch_compile:
- return
- tts = self.tts
- estimator = tts.feat_decoder.estimator
- if hasattr(estimator, "_compiled"):
- return
-
- targets: list[str] = []
-
- try:
- tts.feat_decoder.estimator = torch.compile(estimator, mode="reduce-overhead", fullgraph=False)
- tts.feat_decoder.estimator._compiled = True
- targets.append("LocDiT")
- except Exception as e:
- logger.warning("torch.compile LocDiT failed: %s", e)
-
- try:
- if not hasattr(tts.feat_encoder, "_compiled"):
- tts.feat_encoder = torch.compile(tts.feat_encoder, mode="reduce-overhead", fullgraph=False)
- tts.feat_encoder._compiled = True
- targets.append("feat_encoder")
- except Exception as e:
- logger.warning("torch.compile feat_encoder failed: %s", e)
-
- if self._compile_vae:
- try:
- if not hasattr(tts.audio_vae, "_compiled"):
- tts.audio_vae.decode = torch.compile(tts.audio_vae.decode, mode="reduce-overhead", fullgraph=False)
- tts.audio_vae._compiled = True
- targets.append("AudioVAE")
- except Exception as e:
- logger.warning("torch.compile AudioVAE failed: %s", e)
-
- if not self._enable_cuda_graph:
- if not getattr(self.model, "_selective_compiled", False):
- try:
- targets.extend(f"scaffold.{t}" for t in self.model.compile_selective())
- self.model._selective_compiled = True
- except Exception as e:
- logger.warning("scaffold compile failed: %s", e)
-
- if not getattr(self.residual_model, "_selective_compiled", False):
- try:
- targets.extend(f"residual.{t}" for t in self.residual_model.compile_selective())
- self.residual_model._selective_compiled = True
- except Exception as e:
- logger.warning("residual compile failed: %s", e)
- else:
- self.model.precompute_fused_qkv()
- self.residual_model.precompute_fused_qkv()
- targets.append("scaffold+residual (CUDA Graph, skipping compile)")
-
- if not getattr(self, "_projections_compiled", False):
- try:
- self._compiled_dit_proj = torch.compile(self._dit_proj_fn, mode="default", fullgraph=True)
- self._compiled_stop_fn = torch.compile(self._stop_fn, mode="default", fullgraph=True)
- self._projections_compiled = True
- targets.append("projections")
- except Exception as e:
- self._compiled_dit_proj = self._compiled_stop_fn = None
- logger.warning("projections compile failed: %s", e)
-
- if targets:
- logger.info("VoxCPM2: torch.compile applied to: %s", ", ".join(targets))
-
- def _dit_proj_fn(self, lm_h: torch.Tensor, res_h: torch.Tensor) -> torch.Tensor:
- tts = self.tts
- return torch.cat([tts.lm_to_dit_proj(lm_h), tts.res_to_dit_proj(res_h)], dim=-1)
-
- def _stop_fn(self, lm_h: torch.Tensor) -> torch.Tensor:
- tts = self.tts
- return tts.stop_head(tts.stop_actn(tts.stop_proj(lm_h)))
-
- def _get_cuda_graph_pool(self) -> tuple:
- if self._cuda_graph_pool is None:
- self._cuda_graph_pool = torch.cuda.graph_pool_handle()
- return self._cuda_graph_pool
-
- @staticmethod
- def _nullify_volatile_metadata(ctx: Any) -> Any:
- """Set ``scheduler_metadata`` to None on all attention layers.
-
- This is the only tensor FA3 reallocates each step (variable shape).
- All other metadata tensors are persistent model-runner buffers.
- Setting it to None makes FA3 use default scheduling (~0.1ms cost).
- """
- if not isinstance(ctx.attn_metadata, dict):
- return ctx
-
- ctx = copy.copy(ctx)
- new_meta: dict[str, Any] = {}
- for layer_name, meta in ctx.attn_metadata.items():
- if getattr(meta, "scheduler_metadata", None) is not None:
- meta = copy.copy(meta)
- meta.scheduler_metadata = None
- new_meta[layer_name] = meta
- ctx.attn_metadata = new_meta
- return ctx
-
- def _capture_graph(
- self,
- model: nn.Module,
- batch_size: int,
- label: str,
- is_residual: bool = False,
- ) -> _CapturedGraph:
- """Capture a CUDA Graph for *model* at *batch_size*."""
- hidden_size = self.config.hidden_size
- dtype = self._side_dtype
- dev = torch.device(self._device)
- pool = self._get_cuda_graph_pool()
-
- model.precompute_fused_qkv()
-
- g = _CapturedGraph(
- graph=torch.cuda.CUDAGraph(),
- input_embeds=torch.zeros(batch_size, hidden_size, device=dev, dtype=dtype),
- positions=torch.zeros(batch_size, device=dev, dtype=torch.long),
- output=torch.zeros(batch_size, hidden_size, device=dev, dtype=dtype),
- )
-
- if is_residual:
- call_kwargs = dict(positions=g.positions, inputs_embeds=g.input_embeds)
- else:
- call_kwargs = dict(input_ids=None, positions=g.positions, inputs_embeds=g.input_embeds)
-
- ctx = get_forward_context()
- patched_ctx = self._nullify_volatile_metadata(ctx)
-
- with override_forward_context(patched_ctx):
- for _ in range(3):
- _ = model(**call_kwargs)
-
- with torch.cuda.graph(g.graph, pool=pool):
- g.output = model(**call_kwargs)
-
- logger.info("CUDA Graph captured for %s (batch_size=%d)", label, batch_size)
- return g
-
- def _replay_graph(
- self,
- g: _CapturedGraph,
- inputs_embeds: torch.Tensor,
- positions: torch.Tensor,
- batch_size: int,
- ) -> torch.Tensor:
- """Copy fresh inputs into static buffers, then replay.
-
- No metadata copy needed: persistent buffers (seq_lens, slot_mapping,
- etc.) are updated in-place by the model runner. scheduler_metadata
- was nullified at capture time so no kernel references it.
- """
- g.input_embeds[:batch_size].copy_(inputs_embeds[:batch_size])
- g.positions[:batch_size].copy_(positions[:batch_size])
- g.graph.replay()
- return g.output[:batch_size].clone()
-
- # -------------------- vllm hooks --------------------
-
- def embed_input_ids(self, input_ids: torch.Tensor, **_: Any) -> torch.Tensor:
- return self.model.embed_input_ids(input_ids)
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- **kwargs: Any,
- ) -> torch.Tensor | IntermediateTensors:
- self._perf.start("forward_total")
- dev = input_ids.device
-
- num_reqs = len(self._pending_requests)
- num_decode = sum(1 for _, is_p, _, n in self._pending_requests if not is_p and n == 1)
- is_all_decode = num_decode == num_reqs and num_reqs > 0
-
- tts_compiled = getattr(self.tts.feat_decoder.estimator, "_compiled", False) if self._tts is not None else False
- graph_ready = tts_compiled and self._cuda_graph_warmup_steps >= self._cuda_graph_warmup_threshold
- if num_decode > 0:
- self._cuda_graph_warmup_steps += 1
-
- can_use_graph = (
- self._enable_cuda_graph and graph_ready and intermediate_tensors is None and inputs_embeds is not None
- )
-
- if can_use_graph and is_all_decode and num_reqs <= self._max_cached_graphs:
- self._perf.start("scaffold_fwd")
- if num_reqs not in self._scaffold_graphs:
- self._scaffold_graphs[num_reqs] = self._capture_graph(self.model, num_reqs, "scaffold")
- scaffold_hidden = self._replay_graph(self._scaffold_graphs[num_reqs], inputs_embeds, positions, num_reqs)
- self._perf.stop("scaffold_fwd")
-
- else:
- self._perf.start("scaffold_fwd")
- model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds)
- self._perf.stop("scaffold_fwd")
- if isinstance(model_output, IntermediateTensors):
- return model_output
- scaffold_hidden = model_output
- if isinstance(scaffold_hidden, tuple):
- scaffold_hidden = scaffold_hidden[0]
-
- # Phase 1: per-request FSQ + residual input
- token_offset = 0
- residual_inputs: list[torch.Tensor] = []
- residual_positions: list[torch.Tensor] = []
- req_metas: list[tuple] = []
-
- for req_id, is_prefill, _req_embeds, n in self._pending_requests:
- state = self._switch_to_request(req_id)
- req_hidden = scaffold_hidden[token_offset : token_offset + n]
- req_pos = positions[token_offset : token_offset + n]
-
- if is_prefill:
- res_input, meta = self._prepare_residual_prefill(state, req_hidden, dev)
- elif state.prefill_completed:
- res_input, meta = self._prepare_residual_decode(state, req_hidden, dev)
- else:
- token_offset += n
- self._results_queue.append((req_id, None))
- self._audio_queue.append((req_id, None))
- continue
-
- residual_inputs.append(res_input)
- residual_positions.append(req_pos)
- req_metas.append((state, is_prefill, meta))
- token_offset += n
-
- # Phase 2: batch residual_model forward
- if residual_inputs:
- batch_in = torch.cat(residual_inputs, dim=0)
- batch_pos = torch.cat(residual_positions, dim=0)
-
- residual_batch_size = batch_in.shape[0]
- use_residual_graph = (
- self._enable_cuda_graph
- and is_all_decode
- and graph_ready
- and residual_batch_size == num_reqs # 1 token per request
- and residual_batch_size <= self._max_cached_graphs
- )
-
- self._perf.start("residual_fwd")
- if use_residual_graph:
- if residual_batch_size not in self._residual_graphs:
- self._residual_graphs[residual_batch_size] = self._capture_graph(
- self.residual_model, residual_batch_size, "residual", is_residual=True
- )
- batch_out = self._replay_graph(
- self._residual_graphs[residual_batch_size], batch_in, batch_pos, residual_batch_size
- )
- else:
- batch_out = self.residual_model(batch_pos, batch_in)
- self._perf.stop("residual_fwd")
-
- # Phase 3: per-request LocDiT + update
- offset = 0
- for idx, (state, is_prefill, meta) in enumerate(req_metas):
- n = residual_inputs[idx].shape[0]
- res_out = batch_out[offset : offset + n]
- offset += n
-
- if is_prefill:
- self._finish_prefill(state, meta, res_out, dev)
- else:
- self._finish_decode(state, meta, res_out, dev)
-
- self._results_queue.append((state.request_id, state.precomputed_stop_logits))
- self._audio_queue.append((state.request_id, self._collect_audio(state)))
-
- self._pending_requests.clear()
- self._flush_deferred_cleanup()
- self._perf.stop("forward_total")
- return scaffold_hidden
-
- # -------------------- prefill / decode helpers --------------------
-
- def _prepare_residual_prefill(self, state: _RequestState, base_lm_out: torch.Tensor, dev: Any):
- tts = self.tts
- text_mask, feat_mask, feat, feat_embed = state.prefill_masks
- state.prefill_masks = None
-
- tts_len = text_mask.shape[1]
- scaffold_len = base_lm_out.shape[0]
- assert scaffold_len == tts_len, (
- f"voxcpm2 prefill length mismatch: scaffold_len={scaffold_len} tts_len={tts_len}; "
- "caller must pad prompt_token_ids to the full prefill length "
- "(see serving_speech._build_voxcpm2_prompt or the offline example)."
- )
- enc_out = base_lm_out.unsqueeze(0)
-
- prefix_feat_cond = (
- feat[:, -1, ...]
- if feat.shape[1] > 0
- else torch.zeros(1, self._patch_size, self._feat_dim, device=dev, dtype=self._side_dtype)
- )
- enc_outputs = tts.fsq_layer(enc_out) * feat_mask.unsqueeze(-1) + enc_out * text_mask.unsqueeze(-1)
- lm_hidden = enc_outputs[:, -1, :]
-
- residual_input = tts.fusion_concat_proj(torch.cat([enc_outputs, feat_mask.unsqueeze(-1) * feat_embed], dim=-1))
- meta = {"lm_hidden": lm_hidden, "prefix_feat_cond": prefix_feat_cond}
- return residual_input.squeeze(0), meta
-
- def _prepare_residual_decode(self, state: _RequestState, base_lm_out: torch.Tensor, dev: Any):
- tts = self.tts
- state.decode_step_count += 1
-
- if state.decode_step_count >= self._max_decode_steps:
- logger.warning("MAX_DECODE_STEPS for %s (%d), forcing stop", state.request_id, state.decode_step_count)
- state.is_stopping = True
-
- h = base_lm_out.unsqueeze(0) if base_lm_out.ndim == 1 else base_lm_out
- lm_h = tts.fsq_layer(h)
- if lm_h.ndim == 1:
- lm_h = lm_h.unsqueeze(0)
-
- prev = state.prev_feat_embed.to(self._side_dtype)
- if prev.ndim == 1:
- prev = prev.unsqueeze(0)
- res_input = tts.fusion_concat_proj(torch.cat([lm_h, prev], dim=-1))
- return res_input, {"new_lm_hidden": lm_h}
-
- def _run_cfm(self, dit_h: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
- if self._cfm_buffers is not None:
- return _optimized_solve_euler(
- self.tts.feat_decoder,
- dit_h,
- self._patch_size,
- cond,
- self._inference_timesteps,
- self._cfg_value,
- self._cfm_buffers,
- cfg_cutoff_ratio=self._cfg_cutoff_ratio,
- perf=self._perf,
- ).transpose(1, 2)
- return self.tts.feat_decoder(
- mu=dit_h,
- patch_size=self._patch_size,
- cond=cond,
- n_timesteps=self._inference_timesteps,
- cfg_value=self._cfg_value,
- ).transpose(1, 2)
-
- def _finish_prefill(self, state: _RequestState, meta: dict, res_out: torch.Tensor, dev: Any):
- tts = self.tts
- lm_hidden = meta["lm_hidden"]
- prefix_feat_cond = meta["prefix_feat_cond"]
- residual_hidden = res_out[-1:, :]
-
- state.precomputed_stop_logits = tts.stop_head(tts.stop_actn(tts.stop_proj(lm_hidden))).detach()
- dit_h = torch.cat([tts.lm_to_dit_proj(lm_hidden), tts.res_to_dit_proj(residual_hidden)], dim=-1)
-
- self._setup_cfm_buffers()
- if self._enable_torch_compile:
- self._setup_torch_compile()
-
- pred_feat = self._run_cfm(dit_h, prefix_feat_cond.transpose(1, 2).contiguous())
-
- with torch.no_grad():
- curr_embed = tts.enc_to_lm_proj(tts.feat_encoder(pred_feat.unsqueeze(1))).squeeze(1)
-
- state.curr_embed_for_next = curr_embed.detach()
- state.prev_feat_embed = curr_embed.detach()
- state.curr_prefix_feat_cond = pred_feat[0].detach()
- state.last_audio_patch_gpu = pred_feat.detach()
- state.decode_step_count = 0
- state.request_start_time = time.perf_counter()
- state.prefill_completed = True
-
- if logger.isEnabledFor(logging.DEBUG):
- # Only compute the norm (which forces a GPU->CPU sync) if we will log it.
- logger.debug("PREFILL[%s]: patch norm=%.4f", state.request_id, pred_feat.norm().item())
- self._perf.reset()
-
- def _finish_decode(self, state: _RequestState, meta: dict, res_out: torch.Tensor, dev: Any):
- self._perf.start("decode_step")
- tts = self.tts
-
- lm_h = meta["new_lm_hidden"]
- res_h = res_out.unsqueeze(0) if res_out.ndim == 1 else res_out
-
- dit_proj = getattr(self, "_compiled_dit_proj", None) or self._dit_proj_fn
- stop_fn = getattr(self, "_compiled_stop_fn", None) or self._stop_fn
-
- dit_h = dit_proj(lm_h, res_h)
- pfc = state.curr_prefix_feat_cond.to(self._side_dtype)
- if pfc.ndim == 2:
- pfc = pfc.unsqueeze(0)
-
- pred_feat = self._run_cfm(dit_h, pfc.transpose(1, 2).contiguous())
- next_embed = tts.enc_to_lm_proj(tts.feat_encoder(pred_feat.unsqueeze(1))).squeeze(1)
-
- state.precomputed_stop_logits = stop_fn(lm_h).detach()
- state.curr_embed_for_next = next_embed.detach()
- state.prev_feat_embed = next_embed.detach()
- state.curr_prefix_feat_cond = pred_feat[0].detach()
- state.last_audio_patch_gpu = pred_feat.detach()
-
- self._perf.stop("decode_step")
- if _ENABLE_PROFILING and state.decode_step_count % 20 == 0:
- logger.info("Step %d[%s]:\n%s", state.decode_step_count, state.request_id, self._perf.breakdown())
-
- # -------------------- audio collection --------------------
-
- def _collect_audio(self, state: _RequestState) -> torch.Tensor | None:
- """Per-step sliding-window VAE decode (nanovllm pattern).
-
- Each decode step feeds ``[decode_pad, new_patch]`` through the VAE
- and slices out only the audio region corresponding to the new patch.
- The pad buffer (last ``_n_decode_pad_frames`` latent frames) provides
- the receptive-field context needed by the VAE's transposed convolutions,
- eliminating boundary artifacts between chunks.
-
- Returns the delta audio chunk (not cumulative) so the output processor
- can stream each chunk to the client independently.
- """
- patch = state.last_audio_patch_gpu
- if patch is None:
- return None
- state.last_audio_patch_gpu = None
-
- # patch shape: (patch_size, feat_dim) or (1, patch_size, feat_dim)
- new_latent = patch.reshape(-1, self._feat_dim).to(torch.float32)
- n_new = new_latent.shape[0] # = patch_size (typically 4)
-
- self._perf.start("vae_decode")
-
- # Build VAE input: [pad_frames | new_latent]
- if state.decode_pad is not None:
- vae_input = torch.cat([state.decode_pad, new_latent], dim=0)
- pad_frames = state.decode_pad.shape[0]
- else:
- vae_input = new_latent
- pad_frames = 0
-
- # VAE decode: (1, feat_dim, T_frames) -> (1, 1, T_samples)
- feat = vae_input.unsqueeze(0).transpose(1, 2).contiguous()
- with torch.no_grad():
- audio = self.tts.audio_vae.decode(feat.to(self._device)).reshape(-1)
-
- # Slice out only the new audio (after the pad region).
- # Each latent frame maps to decoder_chunk_size audio samples.
- dcs = int(getattr(self.tts.audio_vae, "decode_chunk_size", audio.numel() // vae_input.shape[0]))
- new_audio = audio[pad_frames * dcs : (pad_frames + n_new) * dcs].detach().cpu().float()
-
- # Roll the pad buffer: keep last N latent frames as context for next step.
- all_latents = vae_input # [pad + new]
- state.decode_pad = all_latents[-self._n_decode_pad_frames :].detach()
-
- state.audio_chunks.append(new_audio)
- state.last_decoded_audio = new_audio
- self._perf.stop("vae_decode")
- return new_audio
-
- # -------------------- compute_logits --------------------
-
- def compute_logits(
- self, hidden_states: torch.Tensor | OmniOutput, sampling_metadata: Any = None
- ) -> torch.Tensor | None:
- if isinstance(hidden_states, OmniOutput):
- hidden_states = hidden_states.text_hidden_states
- if hidden_states is None:
- return None
-
- bsz = hidden_states.shape[0]
- logits = torch.full(
- (bsz, self.config.vocab_size), float("-inf"), device=hidden_states.device, dtype=hidden_states.dtype
- )
-
- if self._results_queue:
- for i, (req_id, stop_logits) in enumerate(self._results_queue):
- if i >= bsz:
- break
- state = self._active_states.get(req_id)
- if stop_logits is not None:
- if state is not None and state.is_stopping:
- logits[i, 0] = 0.0
- logits[i, 1] = 1.0
- state.precomputed_stop_logits = None
- else:
- logits[i, 0] = stop_logits[0, 0]
- logits[i, 1] = stop_logits[0, 1]
- if state is not None:
- state.is_stopping = bool(stop_logits[0, 1] > stop_logits[0, 0])
- state.precomputed_stop_logits = None
- elif state and state.prefill_completed:
- logits[i, 1] = 1.0
- else:
- logits[i, 0] = 1.0
- self._results_queue.clear()
- else:
- logits[:, 0] = 1.0
- return logits
-
- # -------------------- omni output --------------------
-
- def make_omni_output(self, model_outputs: torch.Tensor | OmniOutput, **kwargs: Any) -> OmniOutput:
- if isinstance(model_outputs, OmniOutput):
- return model_outputs
-
- mm: dict[str, Any] = {}
- if self._audio_queue:
- audio_by_req = {rid: audio for rid, audio in self._audio_queue}
- order = [r for r, _ in self._audio_queue]
- mm["model_outputs"] = [audio_by_req.get(r) for r in order]
- mm["sr"] = [torch.tensor(self._sample_rate, dtype=torch.int32) for _ in order]
- self._audio_queue.clear()
-
- return OmniOutput(text_hidden_states=model_outputs, multimodal_outputs=mm)
-
- # -------------------- Chinese token splitting --------------------
-
- def _get_multichar_zh_split(self) -> dict[int, list[int]]:
- """Lazy-build {multichar_chinese_token_id: [char_id, ...]} map."""
- if self._multichar_zh_split is not None:
- return self._multichar_zh_split
- base_tokenizer = self.tts.text_tokenizer.tokenizer
- self._multichar_zh_split = build_cjk_split_map(base_tokenizer)
- logger.info("VoxCPM2: built multichar Chinese split map (%d entries)", len(self._multichar_zh_split))
- return self._multichar_zh_split
-
- # -------------------- preprocess / postprocess --------------------
-
- def preprocess(
- self, input_ids: torch.Tensor, input_embeds: torch.Tensor | None, **info_dict: Any
- ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]:
- additional = info_dict.get("additional_information")
- if isinstance(additional, dict):
- merged = {k: v for k, v in info_dict.items() if k != "additional_information"}
- for k, v in additional.items():
- merged.setdefault(k, v)
- info_dict = merged
-
- span_len = int(input_ids.shape[0])
- dev = input_ids.device
- req_id = info_dict.get("request_id", "default")
- is_prefill = span_len > 1
-
- if is_prefill:
- # Do not evict state here: _pending_requests is a per-step prefix,
- # not the full batch. Cleanup is driven by on_requests_finished ->
- # _flush_deferred_cleanup (fed by vLLM scheduler._free_request via
- # gpu_ar_model_runner.py).
- real = info_dict.get("text_token_ids")
- token_ids = input_ids.tolist() if real is None else real[0]
- # Fail-fast: unsplit multichar Chinese IDs in input_ids means the
- # serving layer didn't pre-split. Silent fixup here would cause
- # input_ids/embeds length mismatch (scheduler slot count is fixed).
- split_map = self._get_multichar_zh_split()
- if split_map and any(tid in split_map for tid in token_ids):
- raise ValueError(
- "VoxCPM2 preprocess received unsplit multichar Chinese "
- "token IDs. The serving layer must send prompt_token_ids "
- "with single-char CJK IDs (see _voxcpm2_encode)."
- )
- if token_ids and token_ids[0] == self.config.bos_token_id:
- token_ids = token_ids[1:]
-
- state = self._get_or_create_state(req_id)
- state.prefill_text = ""
- state.decode_pad = None
- state.audio_chunks = []
- state.prefill_completed = False
- state.decode_step_count = 0
- state.precomputed_stop_logits = None
- state.last_audio_patch_gpu = None
- state.curr_embed_for_next = None
- state.prev_feat_embed = None
- state.curr_prefix_feat_cond = None
- state.is_stopping = False
- state.last_decoded_audio = None
-
- # Voice clone / continuation
- ref_audio = info_dict.get("reference_audio") or info_dict.get("ref_audio")
- prompt_audio = info_dict.get("prompt_audio")
- prompt_text = info_dict.get("prompt_text")
- if isinstance(ref_audio, list):
- ref_audio = ref_audio[0] if ref_audio else None
- if isinstance(prompt_audio, list):
- prompt_audio = prompt_audio[0] if prompt_audio else None
- if isinstance(prompt_text, list):
- prompt_text = prompt_text[0] if prompt_text else None
-
- state.prompt_cache = None
- if ref_audio or (prompt_audio and prompt_text):
- try:
- state.prompt_cache = self._build_prompt_cache(
- ref_audio=ref_audio,
- prompt_audio=prompt_audio,
- prompt_text=prompt_text,
- )
- except Exception as e:
- logger.warning("build_prompt_cache failed: %s", e)
-
- inputs = self._build_prefill_inputs(token_ids, dev, req_id)
- tts = self.tts
- feat_embed = tts.enc_to_lm_proj(tts.feat_encoder(inputs["audio_feat"]))
- text_embed = self.model.embed_input_ids(inputs["text_token"].to(dev))
- text_mask, feat_mask = inputs["text_mask"], inputs["audio_mask"]
- embeds = (text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed).squeeze(0)
- state.prefill_masks = (text_mask, feat_mask, inputs["audio_feat"], feat_embed)
- else:
- state = self._active_states.get(req_id)
- curr = state.curr_embed_for_next if state else None
- if curr is not None:
- embeds = curr.to(dev, dtype=self._side_dtype).reshape(1, -1)
- else:
- embeds = torch.zeros(1, self.config.hidden_size, device=dev, dtype=self._side_dtype)
-
- self._pending_requests.append((req_id, is_prefill, embeds, span_len))
- return input_ids, embeds, {}
-
- def postprocess(self, hidden_states: torch.Tensor, **info: Any) -> dict[str, Any]:
- req_id = info.get("request_id", self._current_request_id or "default")
- if _ENABLE_PROFILING:
- state = self._active_states.get(req_id)
- if state and state.decode_step_count > 0:
- logger.info(
- "REQUEST DONE[%s]: %d steps, %.2fs\n%s",
- req_id,
- state.decode_step_count,
- time.perf_counter() - state.request_start_time,
- self._perf.breakdown(),
- )
- return {}
-
- # -------------------- build prefill inputs --------------------
-
- def _build_prefill_inputs(self, token_ids: list[int], dev: Any, req_id: str = "default") -> dict:
- tts = self.tts
- dtype = self._side_dtype
- state = self._active_states.get(req_id)
- cache = state.prompt_cache if state else None
- mode = cache.get("mode", "continuation") if cache else "zero_shot"
-
- if cache and mode in ("continuation", "ref_continuation"):
- prompt_text = cache.get("prompt_text", "")
- prompt_ids = list(tts.text_tokenizer(prompt_text)) if prompt_text else []
- all_ids = prompt_ids + token_ids
- else:
- all_ids = token_ids
-
- text_token = torch.tensor(all_ids, dtype=torch.int32)
- text_token = torch.cat([text_token, torch.tensor([tts.audio_start_token], dtype=torch.int32)], dim=-1)
- text_len = text_token.shape[0]
- latent_dim = tts.audio_vae.latent_dim
- ps = self._patch_size
-
- if mode in ("zero_shot", "continuation"):
- audio_feat = cache["audio_feat"] if cache else torch.empty((0, ps, latent_dim), dtype=torch.float32)
- a_len = audio_feat.size(0)
- text_token = torch.cat([text_token, torch.zeros(a_len, dtype=torch.int32)])
- audio_feat = torch.cat([torch.zeros((text_len, ps, latent_dim), dtype=torch.float32), audio_feat])
- text_mask = torch.cat([torch.ones(text_len, dtype=torch.int32), torch.zeros(a_len, dtype=torch.int32)])
- audio_mask = torch.cat([torch.zeros(text_len, dtype=torch.int32), torch.ones(a_len, dtype=torch.int32)])
- elif mode == "reference":
- ref = cache["ref_audio_feat"]
- rt, rf, rtm, ram = tts._make_ref_prefix(ref, text_token.device)
- text_token = torch.cat([rt.cpu(), text_token])
- audio_feat = torch.cat([rf.cpu(), torch.zeros((text_len, ps, latent_dim), dtype=torch.float32)])
- text_mask = torch.cat([rtm.cpu(), torch.ones(text_len, dtype=torch.int32)])
- audio_mask = torch.cat([ram.cpu(), torch.zeros(text_len, dtype=torch.int32)])
- else: # ref_continuation
- ref = cache["ref_audio_feat"]
- prompt = cache["audio_feat"]
- p_len = prompt.size(0)
- rt, rf, rtm, ram = tts._make_ref_prefix(ref, text_token.device)
- text_token = torch.cat([rt.cpu(), text_token, torch.zeros(p_len, dtype=torch.int32)])
- audio_feat = torch.cat([rf.cpu(), torch.zeros((text_len, ps, latent_dim), dtype=torch.float32), prompt])
- ones_t = torch.ones(text_len, dtype=torch.int32)
- zeros_p = torch.zeros(p_len, dtype=torch.int32)
- zeros_t = torch.zeros(text_len, dtype=torch.int32)
- ones_p = torch.ones(p_len, dtype=torch.int32)
- text_mask = torch.cat([rtm.cpu(), ones_t, zeros_p])
- audio_mask = torch.cat([ram.cpu(), zeros_t, ones_p])
-
- return {
- "text_token": text_token.unsqueeze(0).to(dev),
- "audio_feat": audio_feat.unsqueeze(0).to(dev).to(dtype),
- "text_mask": text_mask.unsqueeze(0).to(dev),
- "audio_mask": audio_mask.unsqueeze(0).to(dev),
- }
-
- # -------------------- weight loading --------------------
-
- hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"base_lm.": "model."})
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- def _base_lm_only(ws):
- for name, tensor in ws:
- if name.startswith("base_lm."):
- yield name, tensor
-
- loader = AutoWeightsLoader(self)
- loaded = loader.load_weights(_base_lm_only(weights), mapper=self.hf_to_vllm_mapper)
-
- model_path = self.vllm_config.model_config.model
- VoxCPM = import_voxcpm2_core()
- native = VoxCPM.from_pretrained(model_path, load_denoiser=False, optimize=False)
- self._tts = native.tts_model.to("cuda")
- self._side_dtype = self._tts.fusion_concat_proj.weight.dtype
- self._device = "cuda"
- self._patch_size = self._tts.patch_size
- self._feat_dim = self._tts.feat_dim
-
- n = self.residual_model.load_weights_from_native(self._tts.residual_lm)
- for name, _ in self.residual_model.named_parameters():
- loaded.add(f"residual_model.{name}")
- logger.info("VoxCPM2: loaded %d params into paged residual_model", n)
-
- del self._tts.base_lm
- self._tts.base_lm = None
- del self._tts.residual_lm
- self._tts.residual_lm = None
- torch.cuda.empty_cache()
-
- logger.info(
- "Loaded VoxCPM2 (patch=%d, feat_dim=%d, dtype=%s)", self._patch_size, self._feat_dim, self._side_dtype
- )
- return loaded
diff --git a/vllm_omni/model_executor/models/voxtral_tts/__init__.py b/vllm_omni/model_executor/models/voxtral_tts/__init__.py
deleted file mode 100644
index e69de29bb2d..00000000000
diff --git a/vllm_omni/model_executor/models/voxtral_tts/configuration_voxtral_tts.py b/vllm_omni/model_executor/models/voxtral_tts/configuration_voxtral_tts.py
new file mode 100644
index 00000000000..d32a882e786
--- /dev/null
+++ b/vllm_omni/model_executor/models/voxtral_tts/configuration_voxtral_tts.py
@@ -0,0 +1,99 @@
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Any
+
+from transformers import PretrainedConfig
+from vllm.logger import init_logger
+from vllm.transformers_utils.config import MistralConfigParser, register_config_parser
+
+logger = init_logger(__name__)
+
+
+class VoxtralTTSConfig(PretrainedConfig):
+ """HuggingFace-style config for Voxtral TTS models."""
+
+ model_type = "voxtral_tts"
+
+ def __init__(
+ self,
+ text_config: PretrainedConfig | dict | None = None,
+ audio_config: dict[str, Any] | None = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ if isinstance(text_config, PretrainedConfig):
+ self.text_config = text_config
+ elif isinstance(text_config, dict):
+ self.text_config = PretrainedConfig.from_dict(text_config)
+ else:
+ self.text_config = PretrainedConfig()
+
+ self.audio_config = audio_config or {}
+
+ def get_text_config(self, **kwargs: Any) -> PretrainedConfig:
+ return self.text_config
+
+
+@register_config_parser("mistral")
+class VoxtralTTSConfigParser(MistralConfigParser):
+ """Config parser that extends the base Mistral parser with TTS support.
+
+ This only support voxtral_tts for now.
+ """
+
+ def _remap_mistral_audio_args(self, config_dict: dict) -> dict:
+ encoder_args = config_dict["multimodal"].pop("audio_model_args")
+ audio_tokenizer_args = config_dict["multimodal"].pop("audio_tokenizer_args", None)
+ audio_config = {}
+ if encoder_args is not None:
+ audio_config = {
+ "sampling_rate": encoder_args["audio_encoding_args"]["sampling_rate"],
+ "codec_args": audio_tokenizer_args,
+ "audio_model_args": encoder_args,
+ "speaker_id": audio_tokenizer_args.get("voice", {}),
+ }
+ return audio_config
+
+ def parse(
+ self,
+ model: str | Path,
+ trust_remote_code: bool,
+ revision: str | None = None,
+ code_revision: str | None = None,
+ **kwargs: Any,
+ ) -> tuple[dict, PretrainedConfig]:
+ from vllm.transformers_utils.config import (
+ _download_mistral_config_file,
+ )
+
+ config_dict = _download_mistral_config_file(model, revision)
+
+ from vllm.transformers_utils.configs.mistral import (
+ _remap_general_mistral_args,
+ _remap_mistral_quantization_args,
+ )
+
+ # Extract audio config before building text config
+ audio_config = {}
+ if (config_dict.get("multimodal") or {}).get("audio_model_args"):
+ audio_config = self._remap_mistral_audio_args(config_dict)
+
+ # Build text_config from the top-level keys
+ non_text_keys = {"multimodal"}
+ text_config = {k: v for k, v in config_dict.items() if k not in non_text_keys}
+ text_config = _remap_general_mistral_args(text_config)
+ if text_config.get("quantization"):
+ text_config = _remap_mistral_quantization_args(text_config)
+
+ # The text sub-model is a plain MistralForCausalLM
+ text_config.setdefault("architectures", ["MistralForCausalLM"])
+
+ config = VoxtralTTSConfig(
+ text_config=PretrainedConfig.from_dict(text_config),
+ audio_config=audio_config,
+ architectures=config_dict.get("architectures", ["VoxtralTTSForConditionalGeneration"]),
+ )
+
+ return config_dict, config
diff --git a/vllm_omni/model_executor/models/voxtral_tts/cuda_graph_acoustic_transformer_wrapper.py b/vllm_omni/model_executor/models/voxtral_tts/cuda_graph_acoustic_transformer_wrapper.py
index d7407afe561..395c0d1130d 100644
--- a/vllm_omni/model_executor/models/voxtral_tts/cuda_graph_acoustic_transformer_wrapper.py
+++ b/vllm_omni/model_executor/models/voxtral_tts/cuda_graph_acoustic_transformer_wrapper.py
@@ -11,7 +11,6 @@
import torch
from torch.cuda import CUDAGraph
from vllm.logger import init_logger
-from vllm.platforms import current_platform
from vllm_omni.model_executor.models.voxtral_tts.voxtral_tts_audio_generation import (
AudioSpecialTokens,
@@ -48,13 +47,13 @@ def __init__(
self.n_acoustic_codebook = self.acoustic_transformer.model_args.n_acoustic_codebook
self.acoustic_embeddings_levels = self.acoustic_transformer.acoustic_embeddings_levels
- self.n_steps = self.acoustic_transformer.acoustic_transformer_args.n_decoding_steps
+ self.cfg_alpha = 1.2
+ self.n_steps = 8
# Graph storage
self.graphs: dict[int, CUDAGraph] = {}
self.static_inputs: dict[int, torch.Tensor] = {}
self.static_noise: dict[int, torch.Tensor] = {}
- self.static_cfg_alpha: dict[int, torch.Tensor] = {}
self.static_fake_eos: dict[int, torch.Tensor] = {}
self.static_audio_codes: dict[int, torch.Tensor] = {}
@@ -73,17 +72,15 @@ def _warmup_and_capture(self, device: torch.device, dtype: torch.dtype, hidden_d
)
# Pre-create persistent buffers
- self.timesteps = torch.linspace(0, 1, self.n_steps + 1, device=device, dtype=dtype)
+ self.timesteps = torch.linspace(0, 1, self.n_steps, device=device, dtype=dtype)
self.fake_eos_one = torch.tensor(1.0, dtype=dtype, device=device)
self.fake_eos_zero = torch.tensor(0.0, dtype=dtype, device=device)
# Phase 1: Eager warmup for ALL capture sizes
for size in self.capture_sizes:
dummy = torch.zeros(size, hidden_dim, device=device, dtype=dtype)
- dummy_cfg_alpha = torch.full((size, 1), 1.2, device=device, dtype=dtype)
- dummy_noise = torch.randn(size, self.n_acoustic_codebook, device=device, dtype=dtype)
with torch.no_grad():
- self._forward_cudagraph_compatible(dummy, cfg_alpha=dummy_cfg_alpha, noise=dummy_noise)
+ self._forward_cudagraph_compatible(dummy)
torch.cuda.synchronize(device)
@@ -107,12 +104,7 @@ def _warmup_and_capture(self, device: torch.device, dtype: torch.dtype, hidden_d
len(self.capture_sizes),
)
- def _forward_cudagraph_compatible(
- self,
- hidden_states: torch.Tensor,
- cfg_alpha: torch.Tensor,
- noise: torch.Tensor,
- ):
+ def _forward_cudagraph_compatible(self, hidden_states: torch.Tensor, noise: torch.Tensor | None = None):
"""
The actual computation captured by the CUDA graph.
@@ -124,7 +116,6 @@ def _forward_cudagraph_compatible(
- Calls _predict_velocity directly
- Uses a pre-allocated noise buffer to avoid baking random state
into the CUDA graph
- - Uses a pre-allocated cfg_alpha buffer for per-request CFG strength
"""
at = self.acoustic_transformer
B = hidden_states.shape[0]
@@ -140,7 +131,10 @@ def _forward_cudagraph_compatible(
# --- Flow matching: Euler ODE ---
should_decode = semantic_code.squeeze(1) != self.end_audio_token_id
- x = noise
+ if noise is not None:
+ x = noise
+ else:
+ x = torch.randn(B, self.n_acoustic_codebook, device=hidden_states.device, dtype=hidden_states.dtype)
# Pre-compute zero hidden states for unconditional CFG branch
hidden_states_zero = torch.zeros_like(hidden_states)
@@ -159,8 +153,8 @@ def _forward_cudagraph_compatible(
v_all = at._predict_velocity(x_t=x_batched, llm_output=llm_batched, t_emb=t_emb_batched)
v_t, uncond_v_t = v_all[:B], v_all[B:]
- # CFG combination (cfg_alpha is (B, 1), v_t is (B, C))
- v_t = cfg_alpha * v_t + (1 - cfg_alpha) * uncond_v_t
+ # CFG combination
+ v_t = self.cfg_alpha * v_t + (1 - self.cfg_alpha) * uncond_v_t
x = x + v_t * dt
@@ -193,25 +187,23 @@ def _capture_graph_for_size(
"""Capture a CUDA graph for a specific batch size."""
static_input = torch.zeros(size, hidden_dim, device=device, dtype=dtype)
static_noise = torch.randn(size, self.n_acoustic_codebook, device=device, dtype=dtype)
- static_cfg_alpha = torch.full((size, 1), 1.2, device=device, dtype=dtype)
# Stabilizing eager run
with torch.no_grad():
- _ = self._forward_cudagraph_compatible(static_input, cfg_alpha=static_cfg_alpha, noise=static_noise)
+ _ = self._forward_cudagraph_compatible(static_input, noise=static_noise)
torch.cuda.synchronize(device)
graph = CUDAGraph()
with torch.no_grad():
- with torch.cuda.graph(graph, pool=current_platform.get_global_graph_pool()):
+ with torch.cuda.graph(graph):
static_fake_eos, static_audio_codes = self._forward_cudagraph_compatible(
- static_input, cfg_alpha=static_cfg_alpha, noise=static_noise
+ static_input, noise=static_noise
)
self.graphs[size] = graph
self.static_inputs[size] = static_input
self.static_noise[size] = static_noise
- self.static_cfg_alpha[size] = static_cfg_alpha
self.static_fake_eos[size] = static_fake_eos
self.static_audio_codes[size] = static_audio_codes
@@ -225,7 +217,6 @@ def _get_padded_size(self, actual_size: int) -> int | None:
def __call__(
self,
hidden_states: torch.Tensor,
- cfg_alpha: torch.Tensor,
) -> tuple[torch.Tensor, dict[str, list[torch.Tensor]] | None]:
"""
Drop-in replacement for model.compute_mm_logits().
@@ -237,20 +228,16 @@ def __call__(
actual_size = hidden_states.shape[0]
if not self.enabled or not self._warmed_up:
- return self.model.compute_mm_logits(hidden_states, cfg_alpha=cfg_alpha)
+ return self.model.compute_mm_logits(hidden_states)
padded_size = self._get_padded_size(actual_size)
if padded_size is None or padded_size not in self.graphs:
- return self.model.compute_mm_logits(hidden_states, cfg_alpha=cfg_alpha)
+ return self.model.compute_mm_logits(hidden_states)
# Zero static input, then copy actual data
self.static_inputs[padded_size].zero_()
self.static_inputs[padded_size][:actual_size] = hidden_states
- # Copy per-request cfg_alpha into static buffer (pad with 1.2 default)
- self.static_cfg_alpha[padded_size].fill_(1.2)
- self.static_cfg_alpha[padded_size][:actual_size, 0] = cfg_alpha
-
# Fill noise buffer with fresh random values before replay so the
# flow-matching ODE starts from different initial noise each time.
self.static_noise[padded_size].normal_()
diff --git a/vllm_omni/model_executor/models/voxtral_tts/pipeline.py b/vllm_omni/model_executor/models/voxtral_tts/pipeline.py
deleted file mode 100644
index 3bc2743c6cd..00000000000
--- a/vllm_omni/model_executor/models/voxtral_tts/pipeline.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Voxtral TTS pipeline topology (frozen).
-
-Stage 0: audio_generation — text → acoustic latents (LLM_AR, tokenizer owner).
-Stage 1: audio_tokenizer — acoustic latents → waveform (LLM_GENERATION).
-"""
-
-from vllm_omni.config.stage_config import (
- PipelineConfig,
- StageExecutionType,
- StagePipelineConfig,
-)
-
-_PROC = "vllm_omni.model_executor.stage_input_processors.voxtral_tts"
-
-VOXTRAL_TTS_PIPELINE = PipelineConfig(
- model_type="voxtral_tts",
- model_arch="VoxtralTTSForConditionalGeneration",
- stages=(
- StagePipelineConfig(
- stage_id=0,
- model_stage="audio_generation",
- execution_type=StageExecutionType.LLM_AR,
- input_sources=(),
- final_output=False,
- final_output_type="text",
- owns_tokenizer=True,
- engine_output_type="latent",
- async_chunk_process_next_stage_input_func=(f"{_PROC}.generator2tokenizer_async_chunk"),
- sampling_constraints={"detokenize": True},
- ),
- StagePipelineConfig(
- stage_id=1,
- model_stage="audio_tokenizer",
- execution_type=StageExecutionType.LLM_GENERATION,
- input_sources=(0,),
- final_output=True,
- final_output_type="audio",
- engine_output_type="audio",
- sampling_constraints={"detokenize": True},
- extras={"tts_args": {"max_instructions_length": 500}},
- ),
- ),
-)
diff --git a/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts.py b/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts.py
index c7808915098..127171067d6 100644
--- a/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts.py
+++ b/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts.py
@@ -283,30 +283,6 @@ def forward(
multimodal_outputs={"audio": batch_audio_arrays},
)
- _DEFAULT_CFG_ALPHA = 1.2
-
- def _extract_cfg_alpha(self, input_hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
- """Extract per-request cfg_alpha from sampling_extra_args.
-
- Returns a 1-D tensor of shape (B,) with per-request cfg_alpha values.
- Falls back to default if sampling_extra_args is missing or incomplete.
- """
- B = input_hidden_states.shape[0]
- sampling_extra_args = kwargs.get("sampling_extra_args")
- if sampling_extra_args is None:
- return torch.full(
- (B,),
- self._DEFAULT_CFG_ALPHA,
- device=input_hidden_states.device,
- dtype=input_hidden_states.dtype,
- )
- cfg_alpha_values = [ea.get("cfg_alpha", self._DEFAULT_CFG_ALPHA) for ea in sampling_extra_args]
- return torch.tensor(
- cfg_alpha_values,
- device=input_hidden_states.device,
- dtype=input_hidden_states.dtype,
- )
-
def make_omni_output(
self, model_outputs: torch.Tensor | OmniOutput | tuple, logits_index: int | None = None, **kwargs
) -> OmniOutput:
@@ -315,15 +291,10 @@ def make_omni_output(
hidden_states = model_outputs
assert logits_index is not None
input_hidden_states = hidden_states[logits_index]
- cfg_alpha = self._extract_cfg_alpha(input_hidden_states, **kwargs)
if self._cudagraph_acoustic_transformer is not None:
- fake_eos, multimodal_outputs = self._cudagraph_acoustic_transformer(
- input_hidden_states, cfg_alpha=cfg_alpha
- )
+ fake_eos, multimodal_outputs = self._cudagraph_acoustic_transformer(input_hidden_states)
else:
- fake_eos, multimodal_outputs = self.model.compute_mm_logits(
- input_hidden_states, cfg_alpha=cfg_alpha
- )
+ fake_eos, multimodal_outputs = self.model.compute_mm_logits(input_hidden_states)
hidden_states[logits_index, 0] = fake_eos
return OmniOutput(
text_hidden_states=hidden_states,
diff --git a/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts_audio_generation.py b/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts_audio_generation.py
index 8b7dd7d1370..b5d11617337 100644
--- a/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts_audio_generation.py
+++ b/vllm_omni/model_executor/models/voxtral_tts/voxtral_tts_audio_generation.py
@@ -108,7 +108,6 @@ class AcousticTransformerArgs:
use_biases: bool = False
norm_eps: float = 1e-5
sigma: float = 1e-5 # was 0.01 in beta version
- n_decoding_steps: int | None = None # Number of Euler ODE steps for flow matching
@dataclass
@@ -437,11 +436,14 @@ def __init__(
self._empty_audio_token_id = AudioSpecialTokens.id(AudioSpecialTokens.empty_audio)
# Flow matching constants
- self._n_steps = args.n_decoding_steps
+ # TODO(chenyo): hardcoded, need to fix
+ self._acoustic_decode_iters = 8
+ # TODO(chenyo): hardcoded, need to fix
+ self._cfg_alpha = 1.2
self._noise_scale = 1.0
self.register_buffer(
"_timesteps",
- torch.linspace(0, 1, self._n_steps + 1),
+ torch.linspace(0, 1, self._acoustic_decode_iters),
persistent=False,
)
@@ -510,7 +512,6 @@ def decode_one_frame(
self,
semantic_code: torch.Tensor,
llm_hidden: torch.Tensor,
- cfg_alpha: torch.Tensor,
) -> torch.Tensor:
B = semantic_code.shape[0]
@@ -524,10 +525,6 @@ def decode_one_frame(
timesteps = self._timesteps.to(dtype=llm_hidden.dtype)
llm_hidden_zero = torch.zeros_like(llm_hidden)
- # Reshape cfg_alpha for broadcasting: (B,) -> (B, 1)
- cfg_alpha = cfg_alpha.to(dtype=llm_hidden.dtype, device=llm_hidden.device)
- cfg_alpha = cfg_alpha.unsqueeze(1) # (B, 1) for broadcasting with (B, C)
-
# Euler integration with batched conditional + unconditional velocity
sampled = x_0
for i in range(len(timesteps) - 1):
@@ -547,7 +544,7 @@ def decode_one_frame(
t_emb=t_emb_batched,
)
v_t, uncond_v_t = v_all[:B], v_all[B:]
- v_t = cfg_alpha * v_t + (1 - cfg_alpha) * uncond_v_t
+ v_t = self._cfg_alpha * v_t + (1 - self._cfg_alpha) * uncond_v_t
sampled = sampled + v_t * dt
@@ -588,7 +585,6 @@ def _predict_velocity(
def forward(
self,
llm_hidden: torch.Tensor,
- cfg_alpha: torch.Tensor,
) -> torch.Tensor:
# llm_hidden: BxD
semantic_logit = self.semantic_codebook_output(llm_hidden).float()
@@ -598,10 +594,10 @@ def forward(
# semantic_logit: Bx1
semantic_code = semantic_logit.argmax(dim=-1, keepdim=True)
+ # acoustic codes, TODO(@chenyo): config sampling
acoustic_codes = self.decode_one_frame(
semantic_code.squeeze(1),
llm_hidden,
- cfg_alpha=cfg_alpha,
)
audio_codes = torch.concatenate(
@@ -868,29 +864,6 @@ def get_replacement(item_idx: int):
),
]
- def _apply_hf_processor_mm_only(
- self,
- mm_items: MultiModalDataItems,
- hf_processor_mm_kwargs: Mapping[str, object],
- tokenization_kwargs: Mapping[str, object],
- ) -> BatchFeature:
- """
- Apply the HF processor on the multi-modal data only.
-
- Issue: Voxtral TTS use Mistral Tokenizer with custom audio encoder. It doesn't
- inherit Transformers ProcessorMixin and can't use call_hf_processor_mm_only.
-
- Solution: Override this method to call _apply_hf_processor_text_mm directly.
- """
- mm_counts = mm_items.get_all_counts()
- _, mm_processed_data, _ = self._apply_hf_processor_text_mm(
- prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
- mm_items=mm_items,
- hf_processor_mm_kwargs=hf_processor_mm_kwargs,
- tokenization_kwargs=tokenization_kwargs,
- )
- return mm_processed_data
-
def _cached_apply_hf_processor(
self,
inputs: ProcessorInputs,
@@ -1039,13 +1012,11 @@ def compute_logits(
def compute_mm_logits(
self,
hidden_states: torch.Tensor,
- cfg_alpha: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]:
audio_codes = None
mm_tokens = None
audio_codes = self.acoustic_transformer(
llm_hidden=hidden_states,
- cfg_alpha=cfg_alpha,
)
fake_eos = torch.where(
audio_codes[:, 0] == AudioSpecialTokens.id(AudioSpecialTokens.end_audio),
diff --git a/vllm_omni/model_executor/models/whisper_utils.py b/vllm_omni/model_executor/models/whisper_utils.py
deleted file mode 100644
index 5aa2fc8a3ad..00000000000
--- a/vllm_omni/model_executor/models/whisper_utils.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-# Copyright (c) 2022 OpenAI
-#
-# Shared Whisper encoder primitives used by multiple model implementations.
-# Originally from the OpenAI Whisper codebase.
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-def sinusoids(length, channels, max_timescale=10000):
- """Returns sinusoids for positional embedding."""
- assert channels % 2 == 0
- log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
- inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
- scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
- return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
-
-
-class Conv1d(nn.Conv1d):
- """Conv1d with automatic dtype casting for mixed precision inference."""
-
- def _conv_forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None) -> torch.Tensor:
- return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
-
-
-class ConvTranspose1d(nn.ConvTranspose1d):
- def _conv_forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None) -> torch.Tensor:
- return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
-
-
-class Linear(nn.Linear):
- """Linear layer with automatic dtype casting for mixed precision inference."""
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))
diff --git a/vllm_omni/model_executor/stage_configs/bagel.yaml b/vllm_omni/model_executor/stage_configs/bagel.yaml
new file mode 100644
index 00000000000..b0c1b048034
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/bagel.yaml
@@ -0,0 +1,85 @@
+# Stage 0: Thinker (multimodal understanding + text generation)
+
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts
+ runtime:
+ devices: "0"
+ # 3 = 1 user prompt + 2 CFG companions (text-unconditional + image-unconditional).
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 3
+ model_arch: OmniBagelForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.45
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: text
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 1
+ omni_kv_config:
+ need_send_cache: true
+ kv_transfer_criteria:
+ type: prefill_finished #or special token generated
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 52
+ detokenize: True
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ stage_type: diffusion
+ cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: dit
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.45
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: image
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 1
+ omni_kv_config:
+ need_recv_cache: true
+ engine_input_source: [0]
+
+ final_output: true
+ final_output_type: image
+ is_comprehension: false
+ default_sampling_params:
+ seed: 52
+
+# Runtime edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
+
+ # Distributed connectors configuration (optional)
+ # More connectors will be supported in the future.
+ connectors:
+ shared_memory_connector:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536 # 64KB threshold
+
+
+ edges:
+ - from: 0
+ to: 1
+ window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml
new file mode 100644
index 00000000000..4919395cad7
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml
@@ -0,0 +1,112 @@
+# Stage 0: Thinker (multimodal understanding + text generation)
+
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 1
+ model_arch: BagelForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.45
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: text
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 1
+ omni_kv_config:
+ need_send_cache: true
+ kv_transfer_criteria:
+ type: prefill_finished #or special token generated
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 52
+ detokenize: True
+ repetition_penalty: 1.05
+ output_connectors:
+ to_stage_1: mooncake_connector
+
+
+ - stage_id: 1
+ stage_type: diffusion
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: dit
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.45
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: image
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 1
+ omni_kv_config:
+ need_recv_cache: true
+ engine_input_source: [0]
+
+ final_output: true
+ final_output_type: image
+ is_comprehension: false
+ default_sampling_params:
+ seed: 52
+ input_connectors:
+ from_stage_0: mooncake_connector
+
+
+# Runtime edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
+
+ # Distributed connectors configuration (optional)
+ # More connectors will be supported in the future.
+ connectors:
+ # Mooncake connector for cross-node/intra-node communication
+ mooncake_connector:
+ name: MooncakeStoreConnector
+ extra:
+ host: "127.0.0.1"
+ metadata_server: "http://10.90.67.86:8080/metadata"
+ master: "10.90.67.86:50051"
+ segment: 512000000 # 512MB
+ localbuf: 64000000 # 64MB
+ proto: "tcp"
+
+ # PR1 (#1019) note:
+ # - Keep this transfer-engine connector config as a ready-to-use template.
+ # - Bagel does NOT consume this connector in PR1(#1019).
+ # - output_connectors/input_connectors above still point to mooncake_connector.
+ # - We will switch Bagel to this connector in the next PR.
+ rdma_connector:
+ name: MooncakeTransferEngineConnector
+ extra:
+ # NOTE:
+ # - role/sender_host/sender_zmq_port are internal fields resolved by
+ # orchestration logic and should not be set in user YAML.
+ host: "auto" # Auto-detect local IP for RDMA
+ zmq_port: 50051 # ZMQ base port (actual port uses runtime offsets)
+ protocol: "rdma"
+ device_name: "" # e.g. "mlx5_0"; empty for auto-detect
+ memory_pool_size: 2147483648 # 2GB
+ memory_pool_device: "cpu" # "cuda" for GPUDirect RDMA, "cpu" for pinned memory
+
+
+ edges:
+ - from: 0
+ to: 1
+ window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/bagel_single_stage.yaml b/vllm_omni/model_executor/stage_configs/bagel_single_stage.yaml
new file mode 100644
index 00000000000..2c1d84af493
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/bagel_single_stage.yaml
@@ -0,0 +1,32 @@
+# Stage 0: Thinker (multimodal understanding + text generation)
+
+stage_args:
+
+ - stage_id: 0
+ stage_type: diffusion
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: dit
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.45
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: image
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 1
+
+ final_output: true
+ final_output_type: image
+ is_comprehension: false
+ default_sampling_params:
+ seed: 52
+
+# Runtime edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
diff --git a/vllm_omni/model_executor/stage_configs/bagel_think.yaml b/vllm_omni/model_executor/stage_configs/bagel_think.yaml
new file mode 100644
index 00000000000..c4cf32c707e
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/bagel_think.yaml
@@ -0,0 +1,86 @@
+# BAGEL Think Model: AR stage decodes thinking tokens before KV transfer to DiT.
+#
+# Differences from bagel.yaml:
+# - No kv_transfer_criteria: AR stage decodes until EOS, then transfers full
+# KV cache (including thinking tokens) via _free_request path.
+# - prompt_expand_func: uses expand_cfg_prompts_think which sets max_tokens=1
+# on companion requests so they stop immediately after prefill.
+# - max_tokens: 2048 for thinking text generation.
+
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts_think
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 3
+ model_arch: OmniBagelForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.45
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: text
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 1
+ omni_kv_config:
+ need_send_cache: true
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.3
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 52
+ detokenize: True
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ stage_type: diffusion
+ cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: dit
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.45
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: image
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 1
+ omni_kv_config:
+ need_recv_cache: true
+ engine_input_source: [0]
+
+ final_output: true
+ final_output_type: image
+ is_comprehension: false
+ default_sampling_params:
+ seed: 52
+
+# Runtime edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
+
+ connectors:
+ shared_memory_connector:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536
+
+ edges:
+ - from: 0
+ to: 1
+ window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml b/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml
new file mode 100644
index 00000000000..632c227f360
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml
@@ -0,0 +1,81 @@
+# Stage config for BAGEL SP: ulysses=2 (2 GPUs)
+
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts
+ runtime:
+ devices: "0"
+ max_batch_size: 1
+ engine_args:
+ model_stage: thinker
+ model_arch: OmniBagelForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.45
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: text
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 1
+ omni_kv_config:
+ need_send_cache: true
+ kv_transfer_criteria:
+ type: prefill_finished
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 52
+ detokenize: True
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ stage_type: diffusion
+ cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches
+ runtime:
+ # devices: "0,1,2,3"
+ devices: "0,1"
+ max_batch_size: 1
+ engine_args:
+ model_stage: dit
+ gpu_memory_utilization: 0.45
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: image
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 1
+ parallel_config:
+ ulysses_degree: 2
+ # ring_degree: 2
+ omni_kv_config:
+ need_recv_cache: true
+ engine_input_source: [0]
+ final_output: true
+ final_output_type: image
+ is_comprehension: false
+ default_sampling_params:
+ seed: 52
+
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
+ connectors:
+ shared_memory_connector:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536
+ edges:
+ - from: 0
+ to: 1
+ window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/bailingmm_moe_v2_lite.yaml b/vllm_omni/model_executor/stage_configs/bailingmm_moe_v2_lite.yaml
deleted file mode 100644
index b7d0aeeb742..00000000000
--- a/vllm_omni/model_executor/stage_configs/bailingmm_moe_v2_lite.yaml
+++ /dev/null
@@ -1,46 +0,0 @@
-# Stage config for Ming-flash-omni-2.0
-# Stage 0: Thinker (Multimodal understanding + text generation)
-# Stage 1a: Image Generator (Text embeddings -> PIL image)
-# Stage 1b: Talker (Text embeddings -> audio waveform)
-
-async_chunk: false
-stage_args:
- - stage_id: 0
- stage_type: llm
- runtime:
- devices: "0,1,2,3"
- max_batch_size: 1
- engine_args:
- model_stage: thinker
- model_arch: MingFlashOmniForConditionalGeneration
- # tokenizer_subdir: talker/llm
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- tensor_parallel_size: 4 # Use 4 GPUs for MoE model
- # pipeline_parallel_size: 4
- hf_config_name: llm_config
- compilation_config:
- pass_config:
- # there's a version mismatch regarding vllm and flashinfer
- # disable fuse allreduce for now
- fuse_allreduce_rms: false
- final_output: true # Can output text directly
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- max_tokens: 2048
- repetition_penalty: 1.05
- seed: 42
- detokenize: true
-
- # Future Stage 1a: Image Generator (Optional - not yet implemented)
- # Future Stage 1b: Talker/TTS (Optional - not yet implemented)
diff --git a/vllm_omni/model_executor/stage_configs/cosyvoice3.yaml b/vllm_omni/model_executor/stage_configs/cosyvoice3.yaml
new file mode 100644
index 00000000000..e215f51428a
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/cosyvoice3.yaml
@@ -0,0 +1,45 @@
+# Stage config for running CosyVoice3 with 2 stage architecture
+# Stage 0: Talker (text prompt → speech tokens)
+# Stage 1: Code2Wav (flow matching → acoustic features → waveform)
+# Right now I have coded up stage 2 in stage 1 only; can be split in future
+
+stage_args:
+ - stage_id: 0
+ is_comprehension: true
+ runtime:
+ devices: 0
+ engine_args:
+ model_stage: cosyvoice3_talker
+ worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ model_arch: CosyVoice3Model
+ trust_remote_code: true
+ gpu_memory_utilization: 0.2
+ engine_output_type: latent # Output speech tokens for chunk aware flow matching
+ disable_hybrid_kv_cache_manager: true
+ enable_prefix_caching: false
+ enforce_eager: true
+ mm_processor_cache_gb: 0
+ skip_mm_profiling: true
+ dtype: "float32"
+
+ - stage_id: 1
+ runtime:
+ devices: 0
+ engine_args:
+ model_stage: cosyvoice3_code2wav
+ model_arch: CosyVoice3Model
+ trust_remote_code: true
+ worker_cls: vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ engine_output_type: latent
+ gpu_memory_utilization: 0.1
+ enforce_eager: true # CUDA graphs don't work with dynamic conv shapes in code2wav
+ disable_hybrid_kv_cache_manager: true
+ enable_prefix_caching: false
+ skip_mm_profiling: true
+ dtype: "float32"
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.cosyvoice3.text2flow
+ final_output: true
+ final_output_type: audio
diff --git a/vllm_omni/model_executor/stage_configs/dynin_omni.yaml b/vllm_omni/model_executor/stage_configs/dynin_omni.yaml
deleted file mode 100644
index 131a0d1cd70..00000000000
--- a/vllm_omni/model_executor/stage_configs/dynin_omni.yaml
+++ /dev/null
@@ -1,75 +0,0 @@
-stage_args:
- - stage_id: 0
- stage_type: llm
- runtime:
- devices: "0"
- max_batch_size: 1
- engine_args:
- model_stage: token2text
- model_arch: DyninOmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- engine_output_type: latent
- trust_remote_code: true
- gpu_memory_utilization: 0.5
- enforce_eager: true
- enable_prefix_caching: false
- async_scheduling: false
- max_num_batched_tokens: 32768
- is_comprehension: true
- final_output: true
- final_output_type: text
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "0"
- max_batch_size: 1
- engine_args:
- model_stage: token2image
- model_arch: DyninOmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- engine_output_type: latent
- trust_remote_code: true
- gpu_memory_utilization: 0.1
- enforce_eager: true
- enable_prefix_caching: false
- async_scheduling: false
- max_num_batched_tokens: 32768
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image
- final_output: true
- final_output_type: image
-
- - stage_id: 2
- stage_type: llm
- runtime:
- devices: "0"
- max_batch_size: 1
- engine_args:
- model_stage: token2audio
- model_arch: DyninOmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- engine_output_type: latent
- trust_remote_code: true
- gpu_memory_utilization: 0.1
- enforce_eager: true
- enable_prefix_caching: false
- async_scheduling: false
- max_num_batched_tokens: 32768
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2image_to_token2audio
- final_output: true
- final_output_type: audio
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
-
- edges:
- - from: 0
- to: 1
- - from: 1
- to: 2
diff --git a/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml
deleted file mode 100644
index 4a54f8188aa..00000000000
--- a/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml
+++ /dev/null
@@ -1,109 +0,0 @@
-stage_args:
- - stage_id: 0
- stage_type: llm
- runtime:
- devices: "0"
- max_batch_size: 1
- engine_args:
- model_stage: token2text
- model_arch: DyninOmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- engine_output_type: latent
- trust_remote_code: false
- enforce_eager: true
- enable_prefix_caching: false
- async_scheduling: false
- max_num_batched_tokens: 32768
- output_connectors:
- to_stage_1: mooncake_connector
- final_output: true
- final_output_type: text
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "0"
- max_batch_size: 1
- engine_args:
- model_stage: token2image
- model_arch: DyninOmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- engine_output_type: latent
- trust_remote_code: false
- enforce_eager: true
- enable_prefix_caching: false
- async_scheduling: false
- max_num_batched_tokens: 32768
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2text_to_token2image
- final_output: true
- final_output_type: image
- input_connectors:
- from_stage_0: mooncake_connector
- output_connectors:
- to_stage_2: mooncake_connector
-
- - stage_id: 2
- stage_type: llm
- runtime:
- devices: "0"
- max_batch_size: 1
- engine_args:
- model_stage: token2audio
- model_arch: DyninOmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- engine_output_type: latent
- trust_remote_code: false
- enforce_eager: true
- enable_prefix_caching: false
- async_scheduling: false
- max_num_batched_tokens: 32768
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.dynin_omni.token2image_to_token2audio
- final_output: true
- final_output_type: audio
- input_connectors:
- from_stage_1: mooncake_connector
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- ####
- # same as Qwen2.5_omni version
- # Distributed connectors configuration (optional)
- # More connectors will be supported in the future.
- connectors:
- # Mooncake connector for cross-node/intra-node communication
- mooncake_connector:
- name: MooncakeConnector
- extra:
- host: "127.0.0.1"
- metadata_server: "http://10.90.67.86:8080/metadata"
- master: "10.90.67.86:50051"
- segment: 512000000 # 512MB
- localbuf: 64000000 # 64MB
- proto: "tcp"
-
- # Yuanrong connector for cross-node/intra-node communication
- yuanrong_connector:
- name: YuanrongConnector
- extra:
- host: "127.0.0.1"
- port: "35000"
-
- # SharedMemory connector for intra-node communication
- # Alternative SHM connector with different threshold
- shared_memory_connector:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536 # 64KB threshold
- ####
-
- edges:
- - from: 0
- to: 1
- - from: 1
- to: 2
diff --git a/vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml b/vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml
new file mode 100644
index 00000000000..0b0b2785928
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml
@@ -0,0 +1,96 @@
+async_chunk: true
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ is_comprehension: true
+ runtime:
+ devices: "0"
+ max_batch_size: 16
+ engine_args:
+ max_num_seqs: 4
+ model_stage: fish_speech_slow_ar
+ model_arch: FishSpeechSlowARForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ enforce_eager: false
+ trust_remote_code: true
+ async_scheduling: false
+ enable_prefix_caching: false
+ engine_output_type: latent
+ gpu_memory_utilization: 0.6
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 3072
+ max_model_len: 16384
+ custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.fish_speech.slow_ar_to_dac_decoder_async_chunk
+ output_connectors:
+ to_stage_1: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.8
+ top_k: 30
+ top_p: 0.9
+ max_tokens: 2048
+ seed: 42
+ detokenize: false
+ repetition_penalty: 1.0
+ # <|im_end|> token -- stop when model emits end-of-turn.
+ stop_token_ids: [151645]
+
+ - stage_id: 1
+ stage_type: llm
+ runtime:
+ devices: "0"
+ max_batch_size: 16
+ engine_args:
+ max_num_seqs: 1
+ model_stage: dac_decoder
+ model_arch: FishSpeechDACDecoder
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ async_scheduling: false
+ enable_prefix_caching: false
+ engine_output_type: audio
+ gpu_memory_utilization: 0.1
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 8192
+ max_model_len: 16384
+ engine_input_source: [0]
+ final_output: true
+ final_output_type: audio
+ input_connectors:
+ from_stage_0: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ detokenize: true
+ repetition_penalty: 1.0
+
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 16
+
+ connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536
+ codec_streaming: true
+ connector_get_sleep_s: 0.01
+ connector_get_max_wait_first_chunk: 3000
+ connector_get_max_wait: 300
+ # Chunk sizes for streaming -- ~21 Hz codec.
+ # 25 frames ≈ 1.16s of audio at 21.5 Hz.
+ codec_chunk_frames: 25
+ codec_left_context_frames: 25
+ initial_codec_chunk_frames: 4
+
+ edges:
+ - from: 0
+ to: 1
+ window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/glm_image.yaml b/vllm_omni/model_executor/stage_configs/glm_image.yaml
new file mode 100644
index 00000000000..3cc23e1e251
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/glm_image.yaml
@@ -0,0 +1,80 @@
+# Stage config for running GLM-Image with 2-stage architecture
+# Stage 0: AR Model (vLLM implementation) - generates prior_token_ids
+# Stage 1: Diffusion (DiT + VAE) - denoising and image decoding
+
+stage_args:
+ # Stage 0: AR Model (GlmImageForConditionalGeneration)
+ # This stage uses the vLLM-optimized AR model to generate prior tokens
+ # for conditioning the diffusion process.
+ - stage_id: 0
+ stage_type: llm
+ runtime:
+ process: true
+ devices: "0"
+ requires_multimodal_data: true # Required for i2i mode to receive source images
+ engine_args:
+ model_stage: ar
+ max_num_seqs: 1
+ model_arch: GlmImageForConditionalGeneration
+ model_subdir: vision_language_encoder # AR model config.json is in this subdirectory
+ tokenizer_subdir: processor # Use processor's tokenizer (not ByT5 from tokenizer/)
+ worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.6
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: token_ids # Output prior_token_ids for diffusion stage
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ final_output: false # AR is not the final output
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.9 # From model's generation_config.json
+ top_p: 0.75 # From model's generation_config.json
+ top_k: 16512 # vision_vocab_size from generation_config.json
+ max_tokens: 1281 # For 1024x1024: small(16x16=256) + large(32x32=1024) + EOS(1)
+ stop_token_ids: [16385] # eos_token_id from generation_config.json
+ seed: 42
+ detokenize: false
+
+ # Stage 1: Diffusion (DiT + VAE)
+ # This stage receives prior_token_ids from AR and performs denoising + VAE decode
+ - stage_id: 1
+ stage_type: diffusion
+ runtime:
+ process: true
+ devices: "1" # Can use different GPU, or same GPU if memory allows
+ requires_multimodal_data: true # Required for i2i mode to pass condition images
+ engine_args:
+ model_stage: dit
+ max_num_seqs: 1
+ model_arch: GlmImagePipeline # Required for diffusion model class resolution
+ # Diffusion-specific parameters
+ num_gpus: 1
+ enforce_eager: true
+ trust_remote_code: true
+ distributed_executor_backend: "mp"
+ engine_input_source: [0] # Input from AR stage
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.glm_image.ar2diffusion
+ final_output: true
+ final_output_type: image
+ default_sampling_params:
+ # Diffusion-specific parameters only (no LLM params like temperature/top_p/top_k)
+ seed: 42
+ num_inference_steps: 50
+ guidance_scale: 1.5
+ height: 1024
+ width: 1024
+
+# Top-level runtime config
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1 # Trigger downstream only after full upstream completion
+ max_inflight: 1 # Process serially within each stage
+
+ edges:
+ - from: 0 # AR → Diffusion: trigger after AR completes
+ to: 1
+ window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml b/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml
new file mode 100644
index 00000000000..719c73a9fc0
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml
@@ -0,0 +1,93 @@
+# Stage config for running GLM-Image with 2-stage architecture (MultiConnector version)
+# Stage 0: AR Model (vLLM implementation) - generates prior_token_ids
+# Stage 1: Diffusion (DiT + VAE) - denoising and image decoding
+#
+# This config uses OmniConnectors for inter-stage communication,
+# enabling efficient tensor transfer between stages on different processes/nodes.
+
+stage_args:
+ # Stage 0: AR Model (GlmImageForConditionalGeneration)
+ # This stage uses the vLLM-optimized AR model to generate prior tokens
+ # for conditioning the diffusion process.
+ - stage_id: 0
+ stage_type: llm
+ runtime:
+ process: true
+ devices: "0"
+ requires_multimodal_data: true # Required for i2i mode to receive source images
+ engine_args:
+ model_stage: ar
+ max_num_seqs: 1
+ model_arch: GlmImageForConditionalGeneration
+ worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.6
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: token_ids # Output prior_token_ids for diffusion stage
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ hf_config_name: vision_language_encoder # Subfolder in model path
+ final_output: false # AR is not the final output
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.9 # From model's generation_config.json
+ top_p: 0.75 # From model's generation_config.json
+ top_k: 16512 # vision_vocab_size from generation_config.json
+ max_tokens: 1281 # For 1024x1024: small(16x16=256) + large(32x32=1024) + EOS(1)
+ stop_token_ids: [16385] # eos_token_id from generation_config.json
+ seed: 42
+ detokenize: false
+
+ # Stage 1: Diffusion (DiT + VAE)
+ # This stage receives prior_token_ids from AR and performs denoising + VAE decode
+ - stage_id: 1
+ stage_type: diffusion
+ runtime:
+ process: true
+ devices: "1" # Use separate GPU for diffusion
+ requires_multimodal_data: true # Required for i2i mode to pass condition images
+ engine_args:
+ model_stage: dit
+ max_num_seqs: 1
+ # Diffusion-specific parameters
+ num_gpus: 1
+ enforce_eager: true
+ trust_remote_code: true
+ distributed_executor_backend: "mp"
+ engine_input_source: [0] # Input from AR stage
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.glm_image.ar2diffusion
+ final_output: true
+ final_output_type: image
+ default_sampling_params:
+ seed: 42
+ num_inference_steps: 50
+ guidance_scale: 1.5
+ height: 1024
+ width: 1024
+
+# Top-level runtime config with MultiConnector support
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1 # Trigger downstream only after full upstream completion
+ max_inflight: 1 # Process serially within each stage
+
+ edges:
+ - from: 0 # AR → Diffusion
+ to: 1
+ window_size: -1
+
+# OmniConnector configuration for efficient inter-stage tensor transfer
+connectors:
+ - type: tensor_transfer
+ source_stage: 0
+ target_stage: 1
+ # Transfer prior_token_ids efficiently between stages
+ fields:
+ - name: prior_token_ids
+ dtype: int64
+ - name: prior_token_image_ids
+ dtype: int64
+ optional: true
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml
deleted file mode 100644
index b68b184ec31..00000000000
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml
+++ /dev/null
@@ -1,41 +0,0 @@
-# Stage config for HunyuanImage-3.0 Image-to-Text (I2T / image understanding).
-# Single LLM stage: AR model reads image + text prompt, generates text output.
-
-stage_args:
- - stage_id: 0
- stage_type: llm
- runtime:
- process: true
- devices: "0,1,2,3"
- max_batch_size: 1
- requires_multimodal_data: true
- engine_args:
- model_stage: AR
- max_num_seqs: 1
- model_arch: HunyuanImage3ForCausalMM
- worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.95
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- tensor_parallel_size: 4
- pipeline_parallel_size: 1
- hf_overrides:
- rope_parameters:
- mrope_section: [0, 32, 32]
- rope_type: default
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 0.95
- top_k: 1024
- max_tokens: 2048
- stop_token_ids: [127957, 128026] # <|endoftext|>,
- detokenize: True
-
-runtime:
- enabled: true
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml
deleted file mode 100644
index 413e0f09cbe..00000000000
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml
+++ /dev/null
@@ -1,74 +0,0 @@
-# Stage config for HunyuanImage-3.0 Image+Text-to-Image (image editing).
-# Stage 0: AR (HunyuanImage3ForConditionalGeneration) — reads (image, text), emits latent tokens
-# Stage 1: Diffusion (HunyuanImage3Pipeline / DiT + VAE) — denoise + decode latents → image
-
-stage_args:
- # Stage 0: AR Model
- - stage_id: 0
- stage_type: llm
- runtime:
- process: true
- devices: "0,1,2,3"
- max_batch_size: 1
- requires_multimodal_data: true # AR needs the original image
- engine_args:
- model_stage: AR
- model_arch: HunyuanImage3ForCausalMM
- worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.95
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # AR outputs latent for DiT
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- tensor_parallel_size: 4
- pipeline_parallel_size: 1
- hf_overrides:
- rope_parameters:
- mrope_section: [0, 32, 32]
- rope_type: default
- is_comprehension: false # Generation task, not comprehension
- final_output: false # AR is not the final output
- default_sampling_params:
- temperature: 0.6
- top_p: 0.95
- top_k: 1024
- max_tokens: 4096
- stop_token_ids: [127957] # <|endoftext|>
- detokenize: false
-
- # Stage 1: Diffusion (DiT + VAE)
- # Receives latents from AR stage, performs denoising + VAE decode
- - stage_id: 1
- stage_type: diffusion
- runtime:
- process: true
- devices: "4,5,6,7"
- max_batch_size: 1
- requires_multimodal_data: true # May need condition images
- engine_args:
- model_stage: dit
- model_arch: HunyuanImage3ForCausalMM
- enforce_eager: true
- trust_remote_code: true
- distributed_executor_backend: "mp"
- parallel_config:
- tensor_parallel_size: 4
- enable_expert_parallel: true
- omni_kv_config:
- need_recv_cache: true
- engine_input_source: [0] # Input from AR stage
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.hunyuan_image3.ar2diffusion
- final_output: true
- final_output_type: image
- default_sampling_params:
- num_inference_steps: 50
- guidance_scale: 2.5
-
-# Top-level runtime config
-runtime:
- enabled: true
- edges:
- - from: 0 # AR → Diffusion
- to: 1
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit.yaml
similarity index 80%
rename from vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i.yaml
rename to vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit.yaml
index 1d8c7f4812d..0b812ff376b 100644
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i.yaml
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit.yaml
@@ -11,9 +11,13 @@ stage_args:
engine_args:
max_num_seqs: 1
model_stage: dit
+ gpu_memory_utilization: 0.65
enforce_eager: true
trust_remote_code: true
+ engine_output_type: image
distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
parallel_config:
tensor_parallel_size: 4
enable_expert_parallel: true
@@ -29,3 +33,6 @@ stage_args:
# Runtime edges
runtime:
enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit_2gpu_fp8.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit_2gpu_fp8.yaml
index 586b601bc5a..51110c28587 100644
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit_2gpu_fp8.yaml
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit_2gpu_fp8.yaml
@@ -11,9 +11,13 @@ stage_args:
max_batch_size: 1
engine_args:
model_stage: dit
+ gpu_memory_utilization: 0.9
enforce_eager: true
trust_remote_code: true
+ engine_output_type: image
distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
quantization: "fp8"
parallel_config:
tensor_parallel_size: 2
@@ -30,3 +34,6 @@ stage_args:
# Runtime edges
runtime:
enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml
deleted file mode 100644
index a0a1a0dc1c4..00000000000
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml
+++ /dev/null
@@ -1,42 +0,0 @@
-# Stage config for HunyuanImage-3.0 Text-to-Text (T2T / pure text generation).
-# Single LLM stage: AR model reads text prompt only, generates text output.
-# Sampling params aligned with official generation_config.json.
-
-stage_args:
- - stage_id: 0
- stage_type: llm
- runtime:
- process: true
- devices: "0,1,2,3"
- max_batch_size: 1
- requires_multimodal_data: false
- engine_args:
- model_stage: AR
- max_num_seqs: 1
- model_arch: HunyuanImage3ForCausalMM
- worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.95
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- tensor_parallel_size: 4
- pipeline_parallel_size: 1
- hf_overrides:
- rope_parameters:
- mrope_section: [0, 32, 32]
- rope_type: default
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 0.95
- top_k: 1024
- max_tokens: 2048
- stop_token_ids: [127957, 128026] # <|endoftext|>,
- detokenize: True
-
-runtime:
- enabled: true
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml
similarity index 60%
rename from vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml
rename to vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml
index f0797c63270..6f4ba306a50 100644
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe.yaml
@@ -1,43 +1,36 @@
-# Stage config for running Hunyuan-Image3.0 with AR→DiT KV reuse.
+# Stage config for running Hunyuan-Image3.0 for multi-stage omni runtime.
# Stage 0: AR Model (vLLM implementation)
-# Stage 1: DiT Model (diffusion)
-#
-# text-to-image flow: AR (stage 0) → KV transfer → DiT (stage 1)
-# image-to-text flow: AR (stage 0) only
-#
-# Compared to hunyuan_image3_t2i.yaml, this config:
-# 1. Enables both stages [0, 1] for text-to-image (AR prefill + DiT denoising)
-# 2. Adds omni_kv_config to send/receive KV cache between stages
-# The following config has been verified on 8x L40S-48G GPU (4 for AR + 4 for DiT).
+# The following config has been verified on 8x L40S-48G GPU.
+modes:
+ - mode: text-to-image
+ stages: [1]
+ - mode: image-to-text
+ stages: [0]
stage_args:
- stage_id: 0
stage_type: llm # Use llm stage type for AR stages
runtime:
process: true # Run this stage in a separate process
- devices: "0,1,2,3" # AR stage uses GPU 0-3
+ devices: "0,1,2,3,4,5,6,7" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
engine_args:
model_stage: AR
max_num_seqs: 1
model_arch: HunyuanImage3ForCausalMM
worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
+ gpu_memory_utilization: 0.3
enforce_eager: true # Now we only support eager mode
trust_remote_code: true
engine_output_type: latent
enable_prefix_caching: false
max_num_batched_tokens: 32768
- tensor_parallel_size: 4
+ tensor_parallel_size: 8
pipeline_parallel_size: 1
hf_overrides:
rope_parameters:
mrope_section: [0, 32, 32]
rope_type: default
- omni_kv_config:
- need_send_cache: true
- kv_transfer_criteria:
- type: prefill_finished # Send KV cache after AR prefill completes
is_comprehension: true
final_output: true
final_output_type: text
@@ -53,23 +46,25 @@ stage_args:
stage_type: diffusion
runtime:
process: true
- devices: "4,5,6,7" # DiT stage uses GPU 4-7
+ devices: "0,1,2,3,4,5,6,7"
max_batch_size: 1
engine_args:
model_stage: diffusion
+ gpu_memory_utilization: 0.9
enforce_eager: true
+ engine_output_type: image
distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
vae_use_slicing: false
vae_use_tiling: false
cache_backend: null
cache_config: null
enable_cache_dit_summary: false
- omni_kv_config:
- need_recv_cache: true # Receive AR KV cache from stage 0
parallel_config:
pipeline_parallel_size: 1
data_parallel_size: 1
- tensor_parallel_size: 4
+ tensor_parallel_size: 8
enable_expert_parallel: false
sequence_parallel_size: 1
ulysses_degree: 1
@@ -79,18 +74,12 @@ stage_args:
use_hsdp: false
hsdp_shard_size: -1
hsdp_replicate_size: 1
- engine_input_source: [0] # Receive input (including KV) from stage 0
final_output: true
final_output_type: image
-# Top-level runtime config: windows, edges, and connectors
+# Top-level runtime config (concise): default windows and stage edges
runtime:
enabled: true
defaults:
- window_size: -1 # Trigger downstream only after full upstream completion
- max_inflight: 1 # Process serially within each stage
-
- edges:
- - from: 0
- to: 1
- window_size: -1
+ window_size: -1 # Simplified: trigger downstream only after full upstream completion
+ max_inflight: 1 # Simplified: process serially within each stage
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i_2gpu.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe_2gpu.yaml
similarity index 95%
rename from vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i_2gpu.yaml
rename to vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe_2gpu.yaml
index 41ed74ba62a..e029c383623 100644
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i_2gpu.yaml
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image_3_moe_2gpu.yaml
@@ -39,3 +39,6 @@ stage_args:
runtime:
enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
diff --git a/vllm_omni/model_executor/stage_configs/mimo_audio.yaml b/vllm_omni/model_executor/stage_configs/mimo_audio.yaml
new file mode 100644
index 00000000000..123e5f0cf4f
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/mimo_audio.yaml
@@ -0,0 +1,69 @@
+# stage config for running mimo-audio for multi-stage omni runtime.
+
+# The following config has been verified on 1x H20-96G GPU.
+async_chunk: false
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
+ engine_args:
+ dtype: bfloat16
+ max_num_seqs: 1
+ model_stage: fused_thinker_talker
+ model_arch: MiMoAudioForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ tensor_parallel_size: 1 # Change to desired TP size for multi-GPU inference (e.g., 4 for 4 GPUs)
+ gpu_memory_utilization: 0.3
+ enforce_eager: true # need to discuss
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: latent # change the param name,such as pooling_output
+ max_model_len: 8192
+ max_num_batched_tokens: 8192
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.6
+ top_p: 0.95
+ top_k: 50
+ max_tokens: 18192
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+
+ - stage_id: 1
+ stage_type: llm
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "1" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: MiMoAudioForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ tensor_parallel_size: 1 # Change to desired TP size for multi-GPU inference (e.g., 4 for 4 GPUs)
+ gpu_memory_utilization: 0.2
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: audio
+ max_model_len: 18192
+ max_num_batched_tokens: 18192
+ async_scheduling: false
+ engine_input_source: [ 0 ]
+ is_comprehension: false
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.mimo_audio.llm2code2wav
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 18192
+ seed: 42
+ detokenize: false
diff --git a/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml
new file mode 100644
index 00000000000..b3c6bbbaf04
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml
@@ -0,0 +1,96 @@
+# stage config for running mimo-audio for multi-stage omni runtime.
+
+# The following config has been verified on 1x H20-96G GPU.
+async_chunk: true
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
+ engine_args:
+ dtype: bfloat16
+ max_num_seqs: 1
+ model_stage: fused_thinker_talker
+ model_arch: MiMoAudioForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ tensor_parallel_size: 1 # Change to desired TP size for multi-GPU inference (e.g., 4 for 4 GPUs)
+ gpu_memory_utilization: 0.3
+ enforce_eager: true # need to discuss
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: latent # change the param name,such as pooling_output
+ max_model_len: 8192
+ max_num_batched_tokens: 8192
+ custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.mimo_audio.llm2code2wav_async_chunk
+ output_connectors:
+ to_stage_1: connector_of_shared_memory
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.6
+ top_p: 0.95
+ top_k: 50
+ max_tokens: 18192
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+
+ - stage_id: 1
+ stage_type: llm
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: MiMoAudioForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ tensor_parallel_size: 1 # Change to desired TP size for multi-GPU inference (e.g., 4 for 4 GPUs)
+ gpu_memory_utilization: 0.2
+ enforce_eager: true
+ trust_remote_code: true
+ async_scheduling: false
+ enable_prefix_caching: false
+ engine_output_type: audio
+ max_model_len: 8192
+ max_num_batched_tokens: 8192
+ engine_input_source: [ 0 ]
+ input_connectors:
+ from_stage_0: connector_of_shared_memory
+ is_comprehension: false
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 18192
+ seed: 42
+ detokenize: false
+
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
+
+ connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 365536
+ codec_streaming: true
+ connector_get_sleep_s: 0.001
+ connector_get_max_wait_first_chunk: 3000
+ connector_get_max_wait: 300
+ codec_chunk_frames: 3
+ codec_left_context_frames: 3
+
+ edges:
+ - from: 0
+ to: 1
+ window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/ming_flash_omni.yaml b/vllm_omni/model_executor/stage_configs/ming_flash_omni.yaml
deleted file mode 100644
index 1e07d16a0b5..00000000000
--- a/vllm_omni/model_executor/stage_configs/ming_flash_omni.yaml
+++ /dev/null
@@ -1,70 +0,0 @@
-# Multi-stage config for Ming-flash-omni-2.0: Thinker + Talker
-# Stage 0: Thinker (multimodal understanding → text generation)
-# Stage 1: Talker (text → audio waveform via CFM + AudioVAE)
-
-stage_args:
- - stage_id: 0
- stage_type: llm
- runtime:
- devices: "0,1,2,3"
- max_batch_size: 1
- engine_args:
- model_stage: thinker
- model_arch: MingFlashOmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.74
- enforce_eager: false
- trust_remote_code: true
- # Ming Thinker -> talker bridging reads the detokenized text via
- # source_output.outputs[0].text rather than hidden states
- engine_output_type: text
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- tensor_parallel_size: 4
- hf_config_name: llm_config
- compilation_config:
- pass_config:
- # there's a version mismatch regarding vllm and flashinfer
- # disable fuse allreduce for now
- fuse_allreduce_rms: false
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- max_tokens: 2048
- repetition_penalty: 1.05
- detokenize: true
- ignore_eos: false
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "3"
- max_batch_size: 1
- engine_args:
- # Use the standalone talker class
- model_stage: ming_tts
- model_arch: MingFlashOmniTalkerForConditionalGeneration
- worker_cls: vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.18
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: audio
- enable_prefix_caching: false
- max_num_batched_tokens: 1000000
- tokenizer_subdir: talker/llm
- # The HF repo ships BailingMM2Config (thinker-only) at root,
- # OmniModelConfig treats that as "stage does not share outer mrope"
- hf_config_name: talker_config
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.ming_flash_omni.thinker2talker
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- max_tokens: 1
diff --git a/vllm_omni/model_executor/stage_configs/ming_flash_omni_tts.yaml b/vllm_omni/model_executor/stage_configs/ming_flash_omni_tts.yaml
deleted file mode 100644
index 311f79ec5aa..00000000000
--- a/vllm_omni/model_executor/stage_configs/ming_flash_omni_tts.yaml
+++ /dev/null
@@ -1,32 +0,0 @@
-# Single-stage TTS config for Ming-flash-omni-2.0 Talker
-# Stage 0: Talker only (text → audio waveform via CFM + AudioVAE)
-# Use this config for standalone TTS deployment without thinker.
-
-stage_args:
- - stage_id: 0
- stage_type: llm
- is_comprehension: true
- runtime:
- devices: "0"
- max_batch_size: 1
- engine_args:
- model_stage: ming_tts
- model_arch: MingFlashOmniTalkerForConditionalGeneration
- worker_cls: vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: audio
- enable_prefix_caching: false
- max_num_batched_tokens: 100000
- tokenizer_subdir: talker/llm
- # NOTE: `hf_config_name` for Ming talker acts as a placeholder
- # The HF repo ships BailingMM2Config (thinker-only) at root,
- # OmniModelConfig treats that as "stage does not share outer mrope"
- hf_config_name: talker_config
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- max_tokens: 1
diff --git a/vllm_omni/model_executor/stage_configs/omnivoice.yaml b/vllm_omni/model_executor/stage_configs/omnivoice.yaml
index 546e3b3dc2a..49f11e9674d 100644
--- a/vllm_omni/model_executor/stage_configs/omnivoice.yaml
+++ b/vllm_omni/model_executor/stage_configs/omnivoice.yaml
@@ -10,8 +10,10 @@ stage_args:
engine_args:
model_stage: dit
model_class_name: "OmniVoicePipeline"
+ gpu_memory_utilization: 0.5
enforce_eager: true
trust_remote_code: true
+ engine_output_type: audio
distributed_executor_backend: "mp"
dtype: "float32"
final_output: true
diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml
new file mode 100644
index 00000000000..0a307b44778
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml
@@ -0,0 +1,107 @@
+# stage config for running qwen2.5-omni for multi-stage omni runtime.
+
+# The following config has been verified on 2x H100-80G GPU.
+stage_args:
+ - stage_id: 0
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.8
+ enforce_eager: true # Now we only support eager mode
+ trust_remote_code: true
+ engine_output_type: latent
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ mm_processor_cache_gb: 0
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+
+ - stage_id: 1
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ process: true
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.8
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ engine_output_type: latent
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
+ default_sampling_params:
+ temperature: 0.9
+ top_p: 0.8
+ top_k: 40
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+ stop_token_ids: [8294]
+
+ - stage_id: 2
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ process: true
+ devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ gpu_memory_utilization: 0.15
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ async_scheduling: false
+ engine_output_type: audio
+ engine_input_source: [1]
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+
+# Top-level runtime config (concise): default windows and stage edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1 # Simplified: trigger downstream only after full upstream completion
+ max_inflight: 1 # Simplified: process serially within each stage
+
+ edges:
+ - from: 0 # thinker → talker: trigger only after receiving full input (-1)
+ to: 1
+ window_size: -1
+ - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
+ to: 2
+ window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml
new file mode 100644
index 00000000000..6e4f871e38d
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml
@@ -0,0 +1,141 @@
+# stage config for running qwen2.5-omni for multi-stage omni runtime.
+
+# The following config has been verified on 1x H100-80G GPU.
+stage_args:
+ - stage_id: 0
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.8
+ enforce_eager: true # Now we only support eager mode
+ trust_remote_code: true
+ engine_output_type: latent
+ enable_prefix_caching: false
+ mm_processor_cache_gb: 0
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+ # Distributed connector configuration (optional)
+ output_connectors:
+ to_stage_1: mooncake_connector
+ - stage_id: 1
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ process: true
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.8
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: latent
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
+ default_sampling_params:
+ temperature: 0.9
+ top_p: 0.8
+ top_k: 40
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+ stop_token_ids: [8294]
+ # Distributed connector configuration (optional)
+ input_connectors:
+ from_stage_0: mooncake_connector
+ output_connectors:
+ to_stage_2: mooncake_connector
+ - stage_id: 2
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ process: true
+ devices: "2" # Example: use a different GPU than the previous stage; use "0" if single GPU
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ gpu_memory_utilization: 0.3
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ engine_output_type: audio
+ engine_input_source: [1]
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+ # Distributed connector configuration (optional)
+ input_connectors:
+ from_stage_1: mooncake_connector
+
+# Top-level runtime config (concise): default windows and stage edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1 # Simplified: trigger downstream only after full upstream completion
+ max_inflight: 1 # Simplified: process serially within each stage
+
+ # Distributed connectors configuration (optional)
+ # More connectors will be supported in the future.
+ connectors:
+ # Mooncake connector for cross-node/intra-node communication
+ mooncake_connector:
+ name: MooncakeStoreConnector
+ extra:
+ host: "127.0.0.1"
+ metadata_server: "http://10.90.67.86:8080/metadata"
+ master: "10.90.67.86:50051"
+ segment: 512000000 # 512MB
+ localbuf: 64000000 # 64MB
+ proto: "tcp"
+
+ # Yuanrong connector for cross-node/intra-node communication
+ yuanrong_connector:
+ name: YuanrongConnector
+ extra:
+ host: "127.0.0.1"
+ port: "35000"
+
+ # SharedMemory connector for intra-node communication
+ # Alternative SHM connector with different threshold
+ shared_memory_connector:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536 # 64KB threshold
+
+ edges:
+ - from: 0 # thinker → talker: trigger only after receiving full input (-1)
+ to: 1
+ window_size: -1
+ - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
+ to: 2
+ window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml
new file mode 100644
index 00000000000..0ce4f0c94fd
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml
@@ -0,0 +1,101 @@
+# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
+# Stage 0: Thinker (multimodal understanding + text generation)
+# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
+# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
+
+# The following config has been verified on 2x H100-80G GPUs.
+async_chunk: false
+stage_args:
+ - stage_id: 0
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 64
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent # Output hidden states for talker
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ hf_config_name: thinker_config
+ tensor_parallel_size: 1
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ max_num_seqs: 64
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.6
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent # Output codec codes for code2wav
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ distributed_executor_backend: "mp"
+ hf_config_name: talker_config
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
+ # final_output: true
+ # final_output_type: text
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ detokenize: False
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+
+ - stage_id: 2
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "1"
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 32
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ async_scheduling: false
+ enable_prefix_caching: false
+ engine_output_type: audio # Final output: audio waveform
+ gpu_memory_utilization: 0.1
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 1000000
+ hf_config_name: thinker_config
+ engine_input_source: [1]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml
new file mode 100644
index 00000000000..38626fc081e
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml
@@ -0,0 +1,117 @@
+# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
+# Stage 0: Thinker (multimodal understanding + text generation)
+# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
+# Stage 2: Code2Wav (16-layer RVQ codes → audio waveform)
+
+# The following config has been verified on 2x H100-80G GPUs.
+async_chunk: true
+stage_args:
+ - stage_id: 0
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 64
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent # Output hidden states for talker
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ hf_config_name: thinker_config
+ tensor_parallel_size: 1
+ custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ # Use named connector to apply runtime.connectors.extra.
+ output_connectors:
+ to_stage_1: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ max_num_seqs: 64
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.6
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent # Output codec codes for code2wav
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ distributed_executor_backend: "mp"
+ hf_config_name: talker_config
+ custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk
+ engine_input_source: [0]
+ # final_output: true
+ # final_output_type: text
+ # Distributed connector configuration
+ input_connectors:
+ from_stage_0: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ detokenize: False
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+
+ - stage_id: 2
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "1"
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 64
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ async_scheduling: false
+ enable_prefix_caching: false
+ engine_output_type: audio # Final output: audio waveform
+ gpu_memory_utilization: 0.1
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 51200 # [TODO] if max_num_batch_tokens < max_num_seqs * 800, there will be precision problem.
+ hf_config_name: thinker_config
+ engine_input_source: [1]
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+
+runtime:
+
+ connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ # Align with Omni: small chunks with sufficient context overlap.
+ codec_chunk_frames: 25 # code2wav decode chunk size
+ codec_left_context_frames: 25 # code2wav left context size
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml
new file mode 100644
index 00000000000..6c2d2a7669d
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml
@@ -0,0 +1,143 @@
+# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
+# Stage 0: Thinker (multimodal understanding + text generation)
+# Stage 1: Talker (text embeddings -> 8-layer RVQ codec codes)
+# Stage 2: Code2Wav (8-layer RVQ codes -> audio waveform)
+
+# The following config has been verified on 2x H100-80G GPUs.
+stage_args:
+ - stage_id: 0
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 1
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: latent # Output hidden states for talker
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ hf_config_name: thinker_config
+ tensor_parallel_size: 1
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+ # Distributed connector configuration
+ output_connectors:
+ to_stage_1: mooncake_connector
+
+ - stage_id: 1
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ max_num_seqs: 1
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.6
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: latent # Output codec codes for code2wav
+ # tensor_parallel_size: 2
+ enable_prefix_caching: false
+ distributed_executor_backend: "mp"
+ hf_config_name: talker_config
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
+ # final_output: true
+ # final_output_type: text
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ detokenize: False
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+ # Distributed connector configuration
+ input_connectors:
+ from_stage_0: mooncake_connector
+ output_connectors:
+ to_stage_2: mooncake_connector
+
+ - stage_id: 2
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "1"
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 64
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: audio # Final output: audio waveform
+ gpu_memory_utilization: 0.1
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 1000000
+ hf_config_name: thinker_config
+ engine_input_source: [1]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+ # Distributed connector configuration
+ input_connectors:
+ from_stage_1: mooncake_connector
+
+# Top-level runtime config: default windows and stage edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
+
+ # Distributed connectors configuration
+ connectors:
+ # Mooncake connector for cross-node/intra-node communication
+ mooncake_connector:
+ name: MooncakeStoreConnector
+ extra:
+ host: "127.0.0.1"
+ metadata_server: "http://10.90.67.86:8080/metadata"
+ master: "10.90.67.86:50051"
+ segment: 512000000 # 512MB
+ localbuf: 64000000 # 64MB
+ proto: "tcp"
+
+ # SharedMemory connector for intra-node communication
+ shared_memory_connector:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536 # 64KB threshold
+
+ edges:
+ - from: 0
+ to: 1
+ window_size: -1
+ - from: 1
+ to: 2
+ window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml
similarity index 91%
rename from vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml
rename to vllm_omni/model_executor/stage_configs/qwen3_tts.yaml
index 4ca8d11ad77..2c5f0a54744 100644
--- a/vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml
+++ b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml
@@ -17,6 +17,7 @@ stage_args:
enable_prefix_caching: false
engine_output_type: latent
gpu_memory_utilization: 0.3
+ distributed_executor_backend: "mp"
max_num_batched_tokens: 512
max_model_len: 4096
custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
@@ -48,6 +49,7 @@ stage_args:
enable_prefix_caching: false
engine_output_type: audio
gpu_memory_utilization: 0.3
+ distributed_executor_backend: "mp"
# Must be divisible by num_code_groups and cover (left_context + chunk).
# Prefill length is Q * num_frames (e.g. 16 * 2148 = 34368); keep headroom past 32k.
max_num_batched_tokens: 65536
@@ -72,6 +74,9 @@ stage_args:
runtime:
enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
connectors:
connector_of_shared_memory:
@@ -84,10 +89,11 @@ runtime:
connector_get_sleep_s: 0.01
connector_get_max_wait_first_chunk: 3000
connector_get_max_wait: 300
- # Match the decoder sliding attention window to avoid chunk-boundary noise.
+ # Align with Omni: small chunks with sufficient context overlap.
codec_chunk_frames: 25
- codec_left_context_frames: 72
+ codec_left_context_frames: 25
edges:
- from: 0
to: 1
+ window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml
new file mode 100644
index 00000000000..a3509bb3305
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml
@@ -0,0 +1,100 @@
+# Same as qwen3_tts.yaml with batched talker and code2wav.
+# Stage 0: max_num_seqs 4, stage 1: max_num_seqs 4.
+# max_num_seqs must be a power of two to align with CUDA graph capture sizes
+# (stage 0) and must match --batch-size in end2end.py / benchmark scripts.
+async_chunk: true
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ is_comprehension: true
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: qwen3_tts
+ max_num_seqs: 4
+ model_arch: Qwen3TTSTalkerForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ enforce_eager: false
+ trust_remote_code: true
+ async_scheduling: true
+ enable_prefix_caching: false
+ engine_output_type: latent
+ gpu_memory_utilization: 0.3
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 512
+ max_model_len: 4096
+ custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
+ # Use named connector to apply runtime.connectors.extra.
+ output_connectors:
+ to_stage_1: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ detokenize: false
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+
+ - stage_id: 1
+ stage_type: llm
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 4
+ model_arch: Qwen3TTSCode2Wav
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ async_scheduling: true
+ enable_prefix_caching: false
+ engine_output_type: audio
+ gpu_memory_utilization: 0.2
+ distributed_executor_backend: "mp"
+ # Must be divisible by num_code_groups and cover (left_context + chunk).
+ max_num_batched_tokens: 65536
+ # Flat codec prompt can exceed 32k tokens (Q * frames); align with max_tokens below.
+ max_model_len: 65536
+ engine_input_source: [0]
+ final_output: true
+ final_output_type: audio
+ # Distributed connector configuration
+ input_connectors:
+ from_stage_0: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ detokenize: true
+ repetition_penalty: 1.0
+
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
+
+ connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536
+ # Frame-aligned codec streaming transport.
+ codec_streaming: true
+ # Connector polling / timeout (unit: loop count, sleep interval in seconds).
+ connector_get_sleep_s: 0.01
+ connector_get_max_wait_first_chunk: 3000
+ connector_get_max_wait: 300
+ # Align with Omni: small chunks with sufficient context overlap.
+ codec_chunk_frames: 25
+ codec_left_context_frames: 25
+
+ edges:
+ - from: 0
+ to: 1
+ window_size: -1
diff --git a/vllm_omni/platforms/npu/stage_configs/voxcpm.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml
similarity index 59%
rename from vllm_omni/platforms/npu/stage_configs/voxcpm.yaml
rename to vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml
index dcd1f40517b..3f412fc4dca 100644
--- a/vllm_omni/platforms/npu/stage_configs/voxcpm.yaml
+++ b/vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml
@@ -1,47 +1,42 @@
+async_chunk: false
stage_args:
- stage_id: 0
stage_type: llm
is_comprehension: true
runtime:
devices: "0"
- max_batch_size: 1
engine_args:
- dtype: bfloat16
- model_stage: latent_generator
- model_arch: VoxCPMForConditionalGeneration
- # Optional persistent HF-compatible config dir for native VoxCPM models.
- hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,}
+ model_stage: qwen3_tts
+ max_num_seqs: 1
+ model_arch: Qwen3TTSTalkerForConditionalGeneration
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: true
+ enforce_eager: false
trust_remote_code: true
async_scheduling: false
enable_prefix_caching: false
engine_output_type: latent
- gpu_memory_utilization: 0.75
+ gpu_memory_utilization: 0.3
distributed_executor_backend: "mp"
- max_num_batched_tokens: 4096
+ max_num_batched_tokens: 512
max_model_len: 4096
default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
+ temperature: 0.9
+ top_k: 50
max_tokens: 4096
seed: 42
detokenize: false
- repetition_penalty: 1.0
- final_output: false
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
- stage_id: 1
stage_type: llm
runtime:
devices: "0"
- max_batch_size: 1
engine_args:
- dtype: float32
- model_stage: vae
- model_arch: VoxCPMForConditionalGeneration
- hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,}
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: Qwen3TTSCode2Wav
worker_type: generation
scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager: true
@@ -49,19 +44,21 @@ stage_args:
async_scheduling: false
enable_prefix_caching: false
engine_output_type: audio
- gpu_memory_utilization: 0.1
+ gpu_memory_utilization: 0.2
distributed_executor_backend: "mp"
- max_num_batched_tokens: 8192
- max_model_len: 4096
+ max_num_batched_tokens: 65536
+ max_model_len: 65536
engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.voxcpm.latent2vae
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav
final_output: true
final_output_type: audio
+ tts_args:
+ max_instructions_length: 500
default_sampling_params:
temperature: 0.0
top_p: 1.0
top_k: -1
- max_tokens: 1
+ max_tokens: 65536
seed: 42
detokenize: true
repetition_penalty: 1.0
diff --git a/vllm_omni/model_executor/stage_configs/voxcpm.yaml b/vllm_omni/model_executor/stage_configs/voxcpm.yaml
deleted file mode 100644
index a5f324f6602..00000000000
--- a/vllm_omni/model_executor/stage_configs/voxcpm.yaml
+++ /dev/null
@@ -1,69 +0,0 @@
-# VoxCPM two-stage (latent → VAE) without async_chunk: one-shot latent then decode.
-stage_args:
- - stage_id: 0
- stage_type: llm
- is_comprehension: true
- runtime:
- devices: "0"
- max_batch_size: 1
- engine_args:
- dtype: bfloat16
- model_stage: latent_generator
- model_arch: VoxCPMForConditionalGeneration
- # Optional persistent HF-compatible config dir for native VoxCPM models.
- hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,}
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: latent
- gpu_memory_utilization: 0.7
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 4096
- max_model_len: 4096
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 4096
- stop_token_ids: [2]
- seed: 42
- detokenize: false
- repetition_penalty: 1.0
- final_output: false
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "0"
- max_batch_size: 1
- engine_args:
- dtype: float32
- model_stage: vae
- model_arch: VoxCPMForConditionalGeneration
- hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,}
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio
- gpu_memory_utilization: 0.15
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 8192
- max_model_len: 4096
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.voxcpm.latent2vae
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 1
- seed: 42
- detokenize: true
- repetition_penalty: 1.0
diff --git a/vllm_omni/model_executor/stage_configs/voxtral_tts.yaml b/vllm_omni/model_executor/stage_configs/voxtral_tts.yaml
new file mode 100644
index 00000000000..31cccb9ccfd
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/voxtral_tts.yaml
@@ -0,0 +1,105 @@
+async_chunk: true
+stage_args:
+ - stage_id: 0
+ stage_type: llm # Use llm stage type to launch OmniLLM
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
+ engine_args:
+ max_num_seqs: 32
+ model_stage: audio_generation
+ model_arch: VoxtralTTSForConditionalGeneration
+ worker_type: ar
+ worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.8
+ enforce_eager: false
+ trust_remote_code: true
+ async_scheduling: true
+ engine_output_type: latent
+ enable_prefix_caching: false
+ tokenizer_mode: mistral
+ config_format: mistral
+ load_format: mistral
+ skip_mm_profiling: true
+ enable_chunked_prefill: false
+ max_model_len: 4096
+ custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.voxtral_tts.generator2tokenizer_async_chunk
+ output_connectors:
+ to_stage_1: connector_of_shared_memory
+ is_comprehension: true
+ final_output: false
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+ - stage_id: 1
+ stage_type: llm # Use llm stage type to launch OmniLLM
+ runtime:
+ process: true
+ devices: "0"
+ engine_args:
+ max_num_seqs: 32
+ model_stage: audio_tokenizer
+ model_arch: VoxtralTTSForConditionalGeneration
+ worker_type: generation
+ worker_cls: vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ async_scheduling: false
+ gpu_memory_utilization: 0.1
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ skip_mm_profiling: true
+ engine_output_type: audio
+ tokenizer_mode: mistral
+ config_format: mistral
+ load_format: mistral
+ max_num_batched_tokens: 65536
+ max_model_len: 65536
+ engine_input_source: [0]
+ is_comprehension: false
+ final_output: true
+ final_output_type: audio
+ input_connectors:
+ from_stage_0: connector_of_shared_memory
+ tts_args:
+ max_instructions_length: 500
+ default_sampling_params:
+ temperature: 0.9
+ top_p: 0.8
+ top_k: 40
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+
+# Top-level runtime config (concise): default windows and stage edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1 # Simplified: trigger downstream only after full upstream completion
+ max_inflight: 1 # Simplified: process serially within each stage
+
+ connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536
+ codec_streaming: true
+ connector_get_sleep_s: 0.01
+ connector_get_max_wait_first_chunk: 3000
+ connector_get_max_wait: 300
+ codec_chunk_frames: 25
+ codec_chunk_frames_at_begin: 5
+ codec_left_context_frames: 25
+
+ edges:
+ - from: 0 # language_model → acoustic_transformer: trigger only after receiving full input (-1)
+ to: 1
+ window_size: -1
diff --git a/vllm_omni/model_executor/stage_input_processors/bagel.py b/vllm_omni/model_executor/stage_input_processors/bagel.py
index 52cc14d3aa2..6b88fcd4a18 100644
--- a/vllm_omni/model_executor/stage_input_processors/bagel.py
+++ b/vllm_omni/model_executor/stage_input_processors/bagel.py
@@ -82,8 +82,6 @@ def expand_cfg_prompts(
neg_prompt = _get_negative_prompt(prompt, sampling_params)
if "image" in modalities:
- if not neg_prompt:
- return []
neg_prompt_dict = {
"prompt": neg_prompt,
"modalities": prompt.get("modalities", []),
@@ -137,13 +135,6 @@ def expand_cfg_prompts(
"i.e. planning process here image here"
)
-VLM_THINK_SYSTEM_PROMPT = (
- "You should first think about the reasoning process in the mind "
- "and then provide the user with the answer. \n"
- "The reasoning process is enclosed within tags, "
- "i.e. reasoning process here answer here"
-)
-
def expand_cfg_prompts_think(
prompt: dict[str, Any] | str,
@@ -168,8 +159,6 @@ def expand_cfg_prompts_think(
companion_params = {"max_tokens": 1}
if "image" in modalities:
- if not neg_prompt:
- return []
neg_prompt_dict = {
"prompt": neg_prompt,
"modalities": prompt.get("modalities", []),
@@ -291,10 +280,9 @@ def _get_negative_prompt(
) -> str:
"""Resolve the negative prompt for CFG from prompt or sampling params.
- Returns the negative prompt string when one is supplied, otherwise an
- empty string. Callers decide how to treat the empty case: text2img
- skips the cfg_text companion entirely, while img2img substitutes it
- into the cfg_text prompt template.
+ An empty string is treated the same as absent (falls through to
+ the Bagel default token pair), because an empty negative prompt is
+ not meaningful for CFG guidance.
"""
neg = prompt.get("negative_prompt")
if neg:
@@ -305,4 +293,4 @@ def _get_negative_prompt(
if neg:
return neg
- return ""
+ return "<|im_start|><|im_end|>"
diff --git a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py
index dc1e12dfea8..b7f21eca8fd 100644
--- a/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py
+++ b/vllm_omni/model_executor/stage_input_processors/cosyvoice3.py
@@ -1,67 +1,10 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from collections import defaultdict
-from contextlib import nullcontext
from typing import Any
-import numpy as np
-import torch
from vllm.inputs import TextPrompt
from vllm_omni.inputs.data import OmniTokensPrompt
-def _ensure_list(x: Any) -> list[Any]:
- if hasattr(x, "_x"):
- return list(x._x)
- if isinstance(x, list):
- return list(x)
- if isinstance(x, tuple):
- return list(x)
- if x is None:
- return []
- try:
- return list(x)
- except TypeError:
- return [x]
-
-
-def _to_cpu_tensor(x: Any) -> torch.Tensor | None:
- if isinstance(x, list):
- if not x:
- return None
- x = x[0]
- if isinstance(x, torch.Tensor):
- return x.detach().cpu()
- return None
-
-
-def _decode_additional_information(raw_info: Any) -> dict[str, Any]:
- if raw_info is None:
- return {}
- if isinstance(raw_info, dict):
- return raw_info
-
- entries = getattr(raw_info, "entries", None)
- if not isinstance(entries, dict):
- return {}
-
- decoded: dict[str, Any] = {}
- for key, entry in entries.items():
- tensor_data = getattr(entry, "tensor_data", None)
- if tensor_data is not None:
- dtype_name = getattr(entry, "tensor_dtype", "float32")
- tensor_shape = getattr(entry, "tensor_shape", None)
- if tensor_shape is None:
- continue
- dt = np.dtype(dtype_name)
- arr = np.frombuffer(tensor_data, dtype=dt).reshape(tensor_shape)
- decoded[key] = torch.from_numpy(arr.copy())
- else:
- decoded[key] = getattr(entry, "list_data", None)
- return decoded
-
-
def text2flow(
stage_list: list[Any],
engine_input_source: list[int],
@@ -72,178 +15,18 @@ def text2flow(
source_stage_id = engine_input_source[0]
source_outputs = stage_list[source_stage_id].engine_outputs
- engine_inputs: list[OmniTokensPrompt] = []
- for source_output in source_outputs:
- output = source_output.outputs[0]
- multi_modal_data = output.multimodal_output
- if multi_modal_data is None:
- raise RuntimeError(f"Missing multimodal_output for request {source_output.request_id}")
-
- output_ids = _ensure_list(output.cumulative_token_ids)
- prefix_ids = _ensure_list(source_output.prompt_token_ids)
- additional_info = dict(multi_modal_data)
- additional_info.setdefault("ids", {})["prompt"] = prefix_ids
- engine_inputs.append(OmniTokensPrompt(prompt_token_ids=output_ids, additional_information=additional_info))
- return engine_inputs
-
-
-def talker2code2wav_async_chunk(
- transfer_manager: Any,
- pooling_output: dict[str, Any] | None,
- request: Any,
- is_finished: bool = False,
-) -> dict[str, Any] | None:
- """CosyVoice3 async_chunk processor: talker token stream -> code2wav chunks."""
- with nullcontext():
- request_id = request.external_req_id
- finished = bool(is_finished or request.is_finished())
-
- connector = getattr(transfer_manager, "connector", None)
- raw_cfg = getattr(connector, "config", {}) or {}
- cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {}
- chunk_size = int(cfg.get("codec_chunk_frames", 25))
- code_vocab_size = int(cfg.get("codec_vocab_size", 6561))
- pre_lookahead_len = int(cfg.get("codec_pre_lookahead_frames", 3))
- max_chunk_size = int(cfg.get("codec_max_chunk_frames", 4 * chunk_size))
- stream_scale_factor = int(cfg.get("codec_stream_scale_factor", 2))
- if chunk_size <= 0 or pre_lookahead_len < 0 or max_chunk_size <= 0 or stream_scale_factor <= 0:
- raise ValueError(
- f"Invalid codec chunk config: codec_chunk_frames={chunk_size}, "
- f"codec_pre_lookahead_frames={pre_lookahead_len}, "
- f"codec_max_chunk_frames={max_chunk_size}, "
- f"codec_stream_scale_factor={stream_scale_factor}"
- )
-
- request_state = transfer_manager.request_payload.get(request_id)
- if not isinstance(request_state, dict) or "_cosyvoice3_async_state" not in request_state:
- with nullcontext():
- info = _decode_additional_information(getattr(request, "additional_information", None))
- prompt_payload = {}
- for key in ("speech_token", "speech_feat", "embedding"):
- value = _to_cpu_tensor(info.get(key))
- if value is not None:
- prompt_payload[key] = value
- if isinstance(pooling_output, dict):
- for key in ("speech_token", "speech_feat", "embedding"):
- if key in prompt_payload:
- continue
- value = _to_cpu_tensor(pooling_output.get(key))
- if value is not None:
- prompt_payload[key] = value
- prompt_token = prompt_payload.get("speech_token")
- prompt_token_len = (
- int(prompt_token.shape[1])
- if isinstance(prompt_token, torch.Tensor) and prompt_token.ndim >= 2
- else 0
- )
- prompt_token_pad = (
- ((prompt_token_len + chunk_size - 1) // chunk_size) * chunk_size - prompt_token_len
- if prompt_token_len > 0
- else 0
- )
- request_state = {
- "_cosyvoice3_async_state": {
- "seen_len": 0,
- "sent_prompt": False,
- "emitted_chunks": 0,
- "emitted_token_len": 0,
- "token_hop_len": chunk_size,
- "prompt_token_pad": prompt_token_pad,
- "pre_lookahead_len": pre_lookahead_len,
- "token_max_hop_len": max(chunk_size, max_chunk_size),
- "stream_scale_factor": stream_scale_factor,
- "terminal_sent": False,
- "prompt_payload": prompt_payload,
- }
- }
- transfer_manager.request_payload[request_id] = request_state
-
- state = request_state["_cosyvoice3_async_state"]
- if bool(state.get("terminal_sent", False)):
- return None
-
- with nullcontext():
- output_token_ids = _ensure_list(getattr(request, "output_token_ids", []))
- seen_len = int(state.get("seen_len", 0))
- new_tokens = output_token_ids[seen_len:] if seen_len < len(output_token_ids) else []
- state["seen_len"] = len(output_token_ids)
-
- if not hasattr(transfer_manager, "code_prompt_token_ids"):
- transfer_manager.code_prompt_token_ids = defaultdict(list)
- token_frames = transfer_manager.code_prompt_token_ids[request_id]
- for tok in new_tokens:
- tok_int = int(tok)
- if 0 <= tok_int < code_vocab_size:
- token_frames.append([tok_int])
-
- length = len(token_frames)
- if length <= 0:
- if not finished:
- return None
- payload: dict[str, Any] = {
- "codes": {"audio": []},
- "meta": {"finished": torch.tensor(True, dtype=torch.bool)},
- }
- if not state.get("sent_prompt", False):
- payload.update(state.get("prompt_payload", {}))
- state["sent_prompt"] = True
- state["terminal_sent"] = True
- return payload
-
- emitted_token_len = int(state.get("emitted_token_len", 0))
- if finished and length <= emitted_token_len:
- payload = {
- "codes": {"audio": []},
- "meta": {"finished": torch.tensor(True, dtype=torch.bool)},
- }
- if not state.get("sent_prompt", False):
- payload.update(state.get("prompt_payload", {}))
- state["sent_prompt"] = True
- state["terminal_sent"] = True
- return payload
-
- with nullcontext():
- token_hop_len = max(1, int(state.get("token_hop_len", chunk_size)))
- prompt_token_pad = max(0, int(state.get("prompt_token_pad", 0)))
- pre_lookahead_len = max(0, int(state.get("pre_lookahead_len", pre_lookahead_len)))
- available = max(0, length - emitted_token_len)
- this_token_hop_len = token_hop_len + prompt_token_pad if emitted_token_len == 0 else token_hop_len
- required = this_token_hop_len + pre_lookahead_len
-
- if not finished:
- if available < required:
- return None
- prefix_len = emitted_token_len + required
- token_offset = emitted_token_len
- else:
- if available <= 0:
- return None
- prefix_len = length
- token_offset = emitted_token_len
-
- with nullcontext():
- code_predictor_codes = [int(frame[0]) for frame in token_frames[:prefix_len]]
+ if not isinstance(prompt, list):
+ prompt = [prompt]
- payload = {
- "codes": {"audio": code_predictor_codes},
- "meta": {"finished": torch.tensor(finished, dtype=torch.bool)},
- "token_offset": token_offset,
- "left_context_size": token_offset,
- "req_id": [request_id],
- "stream_finished": torch.tensor(finished, dtype=torch.bool),
- }
- if not state.get("sent_prompt", False):
- payload.update(state.get("prompt_payload", {}))
- state["sent_prompt"] = True
+ source_output = source_outputs[0]
+ output = source_output.outputs[0]
- if not finished:
- state["emitted_token_len"] = emitted_token_len + this_token_hop_len
- state["token_hop_len"] = min(
- int(state.get("token_max_hop_len", chunk_size)),
- max(chunk_size, token_hop_len * int(state.get("stream_scale_factor", 1))),
- )
- else:
- state["terminal_sent"] = True
+ multi_modal_data = output.multimodal_output
+ if multi_modal_data is None:
+ raise RuntimeError(f"Missing multimodal_output for request {source_output.request_id}")
- state["emitted_chunks"] = int(state.get("emitted_chunks", 0)) + 1
- return payload
+ output_ids = output.token_ids
+ prefix_ids = source_output.prompt_token_ids
+ multi_modal_data["prefix_ids"] = prefix_ids
+ engine_input = OmniTokensPrompt(prompt_token_ids=output_ids, additional_information=multi_modal_data)
+ return [engine_input]
diff --git a/vllm_omni/model_executor/stage_input_processors/dynin_omni.py b/vllm_omni/model_executor/stage_input_processors/dynin_omni.py
deleted file mode 100644
index 6ce8881a93b..00000000000
--- a/vllm_omni/model_executor/stage_input_processors/dynin_omni.py
+++ /dev/null
@@ -1,164 +0,0 @@
-from __future__ import annotations
-
-import json
-from typing import Any
-
-import torch
-from vllm.inputs import TextPrompt
-
-from vllm_omni.inputs.data import OmniTokensPrompt
-
-
-def _to_prompt_dict(prompt_item: OmniTokensPrompt | TextPrompt | str | None) -> dict[str, Any]:
- if isinstance(prompt_item, dict):
- return prompt_item
- return {}
-
-
-def _to_token_id_list(value: Any) -> list[int]:
- if isinstance(value, torch.Tensor):
- value = value.detach().to("cpu")
- if value.ndim == 0:
- return [int(value.item())]
- if value.ndim > 1:
- value = value[0]
- return [int(x) for x in value.tolist()]
- if isinstance(value, list):
- if not value:
- return []
- if isinstance(value[0], list):
- return [int(x) for x in value[0]]
- return [int(x) for x in value]
- if value is None:
- return []
- return [int(value)]
-
-
-def _to_int(value: Any, default: int = 0) -> int:
- if isinstance(value, torch.Tensor):
- if value.numel() == 0:
- return default
- return int(value.view(-1)[0].item())
- if isinstance(value, list):
- if not value:
- return default
- return int(value[0])
- if value is None:
- return default
- return int(value)
-
-
-def _normalize_additional_info(value: Any) -> dict[str, Any]:
- if not isinstance(value, dict):
- return {}
- normalized: dict[str, Any] = {}
- for key, val in value.items():
- if isinstance(val, list):
- normalized[key] = val
- else:
- normalized[key] = [val]
- return normalized
-
-
-def _decode_runtime_bridge_info(value: Any) -> dict[str, Any]:
- if isinstance(value, torch.Tensor):
- tensor = value.detach().to("cpu").reshape(-1).to(torch.uint8)
- raw = bytes(tensor.tolist())
- elif isinstance(value, (bytes, bytearray)):
- raw = bytes(value)
- elif isinstance(value, list):
- try:
- raw = bytes(int(item) for item in value)
- except Exception:
- return {}
- elif value is None:
- return {}
- else:
- return value if isinstance(value, dict) else {}
-
- if not raw:
- return {}
-
- try:
- decoded = json.loads(raw.decode("utf-8"))
- except Exception:
- return {}
- return decoded if isinstance(decoded, dict) else {}
-
-
-def _bridge_tokens(
- stage_list,
- engine_input_source,
- prompt: OmniTokensPrompt | TextPrompt = None,
- requires_multimodal_data: bool = False,
-):
- if not engine_input_source:
- raise ValueError("engine_input_source cannot be empty")
-
- source_stage_id = engine_input_source[0]
- if source_stage_id >= len(stage_list):
- raise IndexError(f"Invalid stage_id: {source_stage_id}")
-
- if stage_list[source_stage_id].engine_outputs is None:
- raise RuntimeError(f"Stage {source_stage_id} has no outputs yet")
-
- source_outputs = stage_list[source_stage_id].engine_outputs
- next_inputs = []
- if not isinstance(prompt, list):
- prompt = [prompt]
-
- prompt_meta_by_reqid = {src_out.request_id: _to_prompt_dict(p) for src_out, p in zip(source_outputs, prompt)}
-
- for source_output in source_outputs:
- output = source_output.outputs[0]
- mm_out = getattr(output, "multimodal_output", None) or {}
-
- token_ids = _to_token_id_list(mm_out.get("token_ids"))
- if not token_ids:
- token_ids = _to_token_id_list(mm_out.get("text_tokens"))
- if not token_ids:
- token_ids = list(output.cumulative_token_ids or [])
- if not token_ids:
- raise RuntimeError(
- f"Stage {source_stage_id} output for request {source_output.request_id} has no token_ids"
- )
-
- detok_id = _to_int(mm_out.get("detok_id"), default=0)
- src_prompt = prompt_meta_by_reqid.get(source_output.request_id, {})
- src_additional_info = src_prompt.get("additional_information", {}) or {}
- runtime_bridge_info = _decode_runtime_bridge_info(mm_out.get("runtime_info_json"))
- if not runtime_bridge_info:
- runtime_bridge_info = mm_out.get("runtime_info", {}) or {}
-
- additional_information: dict[str, Any] = _normalize_additional_info(src_additional_info)
- additional_information.update(_normalize_additional_info(runtime_bridge_info))
- additional_information["detok_id"] = [detok_id]
-
- next_inputs.append(
- OmniTokensPrompt(
- prompt_token_ids=token_ids,
- additional_information=additional_information,
- multi_modal_data=(src_prompt.get("multi_modal_data") if requires_multimodal_data else None),
- mm_processor_kwargs=None,
- )
- )
-
- return next_inputs
-
-
-def token2text_to_token2image(
- stage_list,
- engine_input_source,
- prompt: OmniTokensPrompt | TextPrompt = None,
- requires_multimodal_data: bool = False,
-):
- return _bridge_tokens(stage_list, engine_input_source, prompt, requires_multimodal_data)
-
-
-def token2image_to_token2audio(
- stage_list,
- engine_input_source,
- prompt: OmniTokensPrompt | TextPrompt = None,
- requires_multimodal_data: bool = False,
-):
- return _bridge_tokens(stage_list, engine_input_source, prompt, requires_multimodal_data)
diff --git a/vllm_omni/model_executor/stage_input_processors/fish_speech.py b/vllm_omni/model_executor/stage_input_processors/fish_speech.py
index 365b303be2b..d857c9123af 100644
--- a/vllm_omni/model_executor/stage_input_processors/fish_speech.py
+++ b/vllm_omni/model_executor/stage_input_processors/fish_speech.py
@@ -110,8 +110,8 @@ def slow_ar_to_dac_decoder_async_chunk(
if length <= 0:
if finished:
return {
- "codes": {"audio": []},
- "meta": {"finished": torch.tensor(True, dtype=torch.bool)},
+ "code_predictor_codes": [],
+ "finished": True,
}
return None
@@ -143,6 +143,7 @@ def slow_ar_to_dac_decoder_async_chunk(
code_predictor_codes = stacked_frames.transpose(0, 1).reshape(-1).tolist()
return {
- "codes": {"audio": code_predictor_codes},
- "meta": {"left_context_size": left_context_size, "finished": torch.tensor(finished, dtype=torch.bool)},
+ "code_predictor_codes": code_predictor_codes,
+ "left_context_size": left_context_size,
+ "finished": finished,
}
diff --git a/vllm_omni/model_executor/stage_input_processors/glm_image.py b/vllm_omni/model_executor/stage_input_processors/glm_image.py
index 828f7f2182c..3063620bf8f 100644
--- a/vllm_omni/model_executor/stage_input_processors/glm_image.py
+++ b/vllm_omni/model_executor/stage_input_processors/glm_image.py
@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Stage input processor for GLM-Image: AR → Diffusion transition."""
-import math
-import time
from typing import Any
import torch
@@ -15,86 +13,6 @@
logger = init_logger(__name__)
-def _has_source_image(mm_data: Any) -> bool:
- """Return whether prompt multi_modal_data contains a source image.
-
- Normalizes legacy/new keys used across omni pipelines:
- - `image`: single PIL image or list
- - `img2img`: legacy single-image key
- - `images`: list or single image
- """
- if not isinstance(mm_data, dict):
- return False
- if mm_data.get("image") is not None:
- return True
- if mm_data.get("img2img") is not None:
- return True
- images = mm_data.get("images")
- return bool(images)
-
-
-def _first_source_image(mm_data: Any) -> Any:
- """Get first source image from normalized multimodal keys."""
- if not isinstance(mm_data, dict):
- return None
-
- image = mm_data.get("image")
- if image is not None:
- if isinstance(image, list):
- return image[0] if image else None
- return image
-
- image = mm_data.get("img2img")
- if image is not None:
- if isinstance(image, list):
- return image[0] if image else None
- return image
-
- images = mm_data.get("images")
- if isinstance(images, list):
- return images[0] if images else None
- return images
-
-
-def compute_max_tokens(height: int, width: int, factor: int = 32, is_i2i: bool = False) -> int:
- """
- Compute max_new_tokens for GLM-Image AR generation.
-
- GLM-Image generation differs by mode:
-
- - text-to-image (t2i): small preview + large target + EOS
- - image-to-image (i2i): large target + EOS
-
- Args:
- height: Target image height in pixels
- width: Target image width in pixels
- factor: Downsampling factor (32 for GLM-Image AR output)
- is_i2i: Whether the request is image-to-image mode
-
- Returns:
- Total number of tokens to generate for the specified mode
- """
- # Large image tokens (target resolution)
- token_h = height // factor
- token_w = width // factor
- large_tokens = token_h * token_w
-
- # Small preview tokens (half resolution in each dimension)
- import math
-
- ratio = token_h / token_w if token_w > 0 else 1.0
- small_token_h = max(1, int(math.sqrt(ratio) * (factor // 2)))
- small_token_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2)))
- small_tokens = small_token_h * small_token_w
-
- # Mode-dependent totals:
- # - t2i: small + large + EOS
- # - i2i: large + EOS
- if is_i2i:
- return large_tokens + 1
- return small_tokens + large_tokens + 1
-
-
def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor:
"""Upsample token IDs by 2x using nearest neighbor interpolation.
@@ -138,49 +56,39 @@ def _parse_generated_tokens(
large_image_tokens = token_h * token_w
# Calculate small preview image dimensions (used in text-to-image)
- ratio = token_h / token_w if token_w > 0 else 1.0
- small_token_h = max(1, int(math.sqrt(ratio) * (factor // 2)))
- small_token_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2)))
+ small_token_h = token_h // 2
+ small_token_w = token_w // 2
small_image_tokens = small_token_h * small_token_w
token_tensor = torch.tensor(token_ids, dtype=torch.long)
# Remove EOS token (16385) from the end if present
eos_token_id = 16385
- has_terminal_eos = len(token_ids) > 0 and token_ids[-1] == eos_token_id
- if has_terminal_eos:
+ if len(token_ids) > 0 and token_ids[-1] == eos_token_id:
token_tensor = token_tensor[:-1]
actual_tokens = len(token_tensor)
+ logger.debug(
+ f"[_parse_generated_tokens] height={height}, width={width}, "
+ f"token_h={token_h}, token_w={token_w}, "
+ f"large_image_tokens={large_image_tokens}, small_image_tokens={small_image_tokens}, "
+ f"actual_tokens={actual_tokens}"
+ )
+
if is_i2i:
+ # Image-to-image mode: check if AR generated small+large tokens (like t2i) or just large tokens
+ # Some AR models output small+large even in i2i mode
if actual_tokens >= small_image_tokens + large_image_tokens:
+ # AR generated full t2i-style output, extract large tokens after small
large_start = small_image_tokens
large_end = large_start + large_image_tokens
prior_token_ids_d32 = token_tensor[large_start:large_end]
actual_h, actual_w = token_h, token_w
- logger.warning(
- "[_parse_generated_tokens] i2i detected t2i-style token layout; "
- "using small-offset extraction: large_start=%s large_end=%s",
- large_start,
- large_end,
- )
- elif actual_tokens >= large_image_tokens:
+ else:
+ # AR generated only large tokens (pure i2i output)
prior_token_ids_d32 = token_tensor[:large_image_tokens]
actual_h, actual_w = token_h, token_w
- logger.info(
- "[_parse_generated_tokens] i2i using offset-0 extraction: large_tokens=%s",
- large_image_tokens,
- )
- else:
- logger.warning(
- "[_parse_generated_tokens] i2i token parse failed: actual_tokens=%s < expected_large_tokens=%s",
- actual_tokens,
- large_image_tokens,
- )
- raise ValueError(
- f"i2i token parse failed: actual_tokens={actual_tokens} < expected_large_tokens={large_image_tokens}"
- )
elif actual_tokens >= small_image_tokens + large_image_tokens:
# Text-to-image: extract large image tokens after small image tokens
large_start = small_image_tokens
@@ -188,22 +96,43 @@ def _parse_generated_tokens(
prior_token_ids_d32 = token_tensor[large_start:large_end]
actual_h, actual_w = token_h, token_w
elif actual_tokens >= large_image_tokens:
- logger.warning(
- "[_parse_generated_tokens] t2i token parse failed: got only large tokens without small preview "
- "(actual_tokens=%s, expected_small_plus_large=%s)",
- actual_tokens,
- small_image_tokens + large_image_tokens,
- )
- raise ValueError("t2i token parse failed: missing small-preview tokens; refusing low-quality fallback")
+ # Image-to-image: large image tokens are at the beginning
+ prior_token_ids_d32 = token_tensor[:large_image_tokens]
+ actual_h, actual_w = token_h, token_w
else:
- logger.warning(
- "[_parse_generated_tokens] token parse failed: insufficient tokens "
- "(actual_tokens=%s, expected=%s, mode=%s)",
- actual_tokens,
- large_image_tokens if is_i2i else (small_image_tokens + large_image_tokens),
- "i2i" if is_i2i else "t2i",
- )
- raise ValueError(f"token parse failed: actual_tokens={actual_tokens}, mode={'i2i' if is_i2i else 't2i'}")
+ # Insufficient tokens - try to infer the actual grid size
+ import math
+
+ for scale in [1, 2, 4]:
+ test_h = token_h // scale
+ test_w = token_w // scale
+ test_small_h = test_h // 2
+ test_small_w = test_w // 2
+ test_large = test_h * test_w
+ test_small = test_small_h * test_small_w
+
+ if actual_tokens >= test_small + test_large:
+ prior_token_ids_d32 = token_tensor[test_small : test_small + test_large]
+ actual_h, actual_w = test_h, test_w
+ height = test_h * factor
+ width = test_w * factor
+ logger.warning(f"Adjusted grid to {test_h}x{test_w}, output will be {height}x{width}")
+ break
+ elif actual_tokens >= test_large:
+ prior_token_ids_d32 = token_tensor[:test_large]
+ actual_h, actual_w = test_h, test_w
+ height = test_h * factor
+ width = test_w * factor
+ logger.warning(f"Adjusted grid to {test_h}x{test_w}, output will be {height}x{width}")
+ break
+ else:
+ sqrt_tokens = int(math.sqrt(actual_tokens))
+ actual_h = actual_w = sqrt_tokens
+ usable_tokens = sqrt_tokens * sqrt_tokens
+ prior_token_ids_d32 = token_tensor[:usable_tokens]
+ height = sqrt_tokens * factor
+ width = sqrt_tokens * factor
+ logger.error(f"Grid pattern mismatch. Using {sqrt_tokens}x{sqrt_tokens}, output: {height}x{width}")
# Upsample from 32x to 16x
prior_token_ids = _upsample_token_ids(prior_token_ids_d32, actual_h, actual_w)
@@ -218,8 +147,6 @@ def ar2diffusion(
requires_multimodal_data: bool = False,
) -> list[dict[str, Any]]:
"""Process AR stage outputs to create Diffusion stage inputs."""
- _t_total = time.perf_counter()
-
if not engine_input_source:
raise ValueError("engine_input_source cannot be empty")
@@ -238,9 +165,8 @@ def ar2diffusion(
prompt = [prompt] if prompt is not None else [{}]
for i, ar_output in enumerate(ar_outputs):
- _t_req = time.perf_counter()
output = ar_output.outputs[0]
- generated_token_ids = output.cumulative_token_ids
+ generated_token_ids = output.token_ids
# Get original prompt info
original_prompt = prompt[i] if i < len(prompt) else {}
@@ -253,82 +179,30 @@ def ar2diffusion(
else:
original_prompt = {}
- mm_processor_kwargs = original_prompt.get("mm_processor_kwargs")
-
- def _coerce_dim(v: Any, default: int) -> int:
- try:
- iv = int(v)
- return iv if iv > 0 else default
- except (TypeError, ValueError):
- return default
-
- # Prefer GLM-Image target size from mm_processor_kwargs (set by serving layer),
- # then fall back to top-level fields for backward compatibility.
- height = _coerce_dim(
- mm_processor_kwargs.get("target_h") if isinstance(mm_processor_kwargs, dict) else None,
- _coerce_dim(original_prompt.get("height"), 1024),
- )
- width = _coerce_dim(
- mm_processor_kwargs.get("target_w") if isinstance(mm_processor_kwargs, dict) else None,
- _coerce_dim(original_prompt.get("width"), 1024),
- )
+ height = original_prompt.get("height", 1024)
+ width = original_prompt.get("width", 1024)
text_prompt = original_prompt.get("prompt", "")
- # Detect i2i mode.
- # Prefer normalized prompt multi_modal_data source-image presence, with
- # multimodal output as secondary signal.
- _t_mode = time.perf_counter()
+ # Detect i2i mode first by checking if multimodal_output contains prior_token_image_ids
is_i2i = False
-
- prompt_modalities = original_prompt.get("modalities")
- if isinstance(prompt_modalities, list) and "img2img" in prompt_modalities:
- is_i2i = True
-
- prompt_mm_data = original_prompt.get("multi_modal_data")
- if _has_source_image(prompt_mm_data):
- is_i2i = True
-
if hasattr(ar_output, "multimodal_output") and ar_output.multimodal_output:
mm_output = ar_output.multimodal_output
- if isinstance(mm_output, dict) and mm_output.get("ids", {}).get("prior_image") is not None:
+ if isinstance(mm_output, dict) and mm_output.get("prior_token_image_ids") is not None:
is_i2i = True
- _dt_mode = (time.perf_counter() - _t_mode) * 1000
# Parse and upsample prior tokens
- _t_parse = time.perf_counter()
- try:
- prior_token_ids, pixel_h, pixel_w = _parse_generated_tokens(
- generated_token_ids,
- height,
- width,
- is_i2i=is_i2i,
- )
- except ValueError as e:
- logger.warning(
- "[ar2diffusion] Request %s: skip due to token parse failure: %s "
- "(target=%sx%s, mode=%s, raw_tokens=%s, tail=%s)",
- i,
- e,
- height,
- width,
- "i2i" if is_i2i else "t2i",
- len(generated_token_ids),
- generated_token_ids[-8:] if len(generated_token_ids) >= 8 else generated_token_ids,
- )
- continue
- _dt_parse = (time.perf_counter() - _t_parse) * 1000
+ prior_token_ids, pixel_h, pixel_w = _parse_generated_tokens(generated_token_ids, height, width, is_i2i=is_i2i)
# Get prior_token_image_ids from AR model output (for i2i mode)
# This contains VQ-VAE tokens from input image, used for KV cache conditioning
# NOTE: multimodal_output is attached to ar_output (RequestOutput), NOT output (CompletionOutput)
- _t_prior_img = time.perf_counter()
prior_token_image_ids = None
# Check ar_output (RequestOutput) for multimodal_output - this is the correct location
if hasattr(ar_output, "multimodal_output") and ar_output.multimodal_output:
mm_output = ar_output.multimodal_output
if isinstance(mm_output, dict):
- raw_prior_image_ids = mm_output.get("ids", {}).get("prior_image")
+ raw_prior_image_ids = mm_output.get("prior_token_image_ids")
if raw_prior_image_ids is not None:
# Handle different formats:
# 1. Single tensor -> wrap in list
@@ -354,13 +228,12 @@ def _coerce_dim(v: Any, default: int) -> int:
mm_output = output.multimodal_output
logger.debug(f"[ar2diffusion] Request {i}: found multimodal_output on CompletionOutput (fallback)")
if isinstance(mm_output, dict):
- raw_prior_image_ids = mm_output.get("ids", {}).get("prior_image")
+ raw_prior_image_ids = mm_output.get("prior_token_image_ids")
if raw_prior_image_ids is not None:
if isinstance(raw_prior_image_ids, torch.Tensor):
prior_token_image_ids = [raw_prior_image_ids]
elif isinstance(raw_prior_image_ids, list):
prior_token_image_ids = raw_prior_image_ids
- _dt_prior_img = (time.perf_counter() - _t_prior_img) * 1000
diffusion_input = {
"prompt": text_prompt,
@@ -375,38 +248,18 @@ def _coerce_dim(v: Any, default: int) -> int:
if requires_multimodal_data:
mm_data = original_prompt.get("multi_modal_data")
if mm_data:
- pil_image = _first_source_image(mm_data)
+ pil_image = mm_data.get("image")
+ if pil_image is None:
+ # Try "images" (plural) as fallback
+ images = mm_data.get("images")
+ if images:
+ pil_image = images[0] if isinstance(images, list) else images
diffusion_input["pil_image"] = pil_image
for key in ["seed", "num_inference_steps", "guidance_scale", "negative_prompt"]:
if key in original_prompt:
diffusion_input[key] = original_prompt[key]
- _dt_req = (time.perf_counter() - _t_req) * 1000
- logger.info(
- "[ar2diffusion] req=%d mode=%s target=%dx%d "
- "raw_tokens=%d prior_tokens=%d prior_image_ids=%s "
- "timing: mode_detect=%.3fms parse+upsample=%.3fms "
- "prior_image_ids_extract=%.3fms req_total=%.3fms",
- i,
- "i2i" if is_i2i else "t2i",
- pixel_h,
- pixel_w,
- len(generated_token_ids),
- len(prior_token_ids),
- "yes" if prior_token_image_ids is not None else "no",
- _dt_mode,
- _dt_parse,
- _dt_prior_img,
- _dt_req,
- )
diffusion_inputs.append(diffusion_input)
- _dt_total = (time.perf_counter() - _t_total) * 1000
- logger.info(
- "[ar2diffusion] batch done: %d reqs, total=%.3fms",
- len(diffusion_inputs),
- _dt_total,
- )
-
return diffusion_inputs
diff --git a/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py b/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py
deleted file mode 100644
index 0c0e6d7b37f..00000000000
--- a/vllm_omni/model_executor/stage_input_processors/hunyuan_image3.py
+++ /dev/null
@@ -1,123 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Stage input processor for HunyuanImage3: AR → Diffusion transition.
-
-In IT2I (image editing) mode:
- - Stage 0 (AR) receives (image + edit instruction), generates CoT/latent tokens
- - Stage 1 (DiT) receives the AR output + original image, denoises → edited image
-
-The ar2diffusion function bridges these two stages, following the same
-signature pattern as glm_image.ar2diffusion.
-"""
-
-from typing import Any
-
-import torch
-from vllm.inputs import TextPrompt
-from vllm.logger import init_logger
-
-from vllm_omni.inputs.data import OmniTokensPrompt
-
-logger = init_logger(__name__)
-
-
-def ar2diffusion(
- stage_list: list[Any],
- engine_input_source: list[int],
- prompt: OmniTokensPrompt | TextPrompt | list | None = None,
- requires_multimodal_data: bool = False,
-) -> list[dict[str, Any]]:
- """Process AR stage outputs to create Diffusion stage inputs.
-
- Args:
- stage_list: List of stage clients (set by orchestrator).
- engine_input_source: List of source stage IDs (from YAML).
- prompt: Original user prompt (may contain multimodal data).
- requires_multimodal_data: Whether to forward multimodal data.
-
- Returns:
- List of dicts, each consumable by the HunyuanImage3 diffusion pipeline.
- """
- if not engine_input_source:
- raise ValueError("engine_input_source cannot be empty")
-
- source_stage_id = engine_input_source[0]
- if source_stage_id >= len(stage_list):
- raise IndexError(f"Invalid source stage_id: {source_stage_id}")
-
- if stage_list[source_stage_id].engine_outputs is None:
- raise RuntimeError(f"Stage {source_stage_id} has no outputs yet")
-
- ar_outputs = stage_list[source_stage_id].engine_outputs
- diffusion_inputs = []
-
- # Normalize prompt to list
- if not isinstance(prompt, list):
- prompt = [prompt] if prompt is not None else [{}]
-
- for i, ar_output in enumerate(ar_outputs):
- output = ar_output.outputs[0]
- generated_token_ids = output.cumulative_token_ids
- generated_text = getattr(output, "text", "") or ""
-
- # Get original prompt info
- original_prompt = prompt[i] if i < len(prompt) else {}
- if isinstance(original_prompt, dict):
- pass
- elif hasattr(original_prompt, "_asdict"):
- original_prompt = original_prompt._asdict()
- elif hasattr(original_prompt, "__dict__"):
- original_prompt = vars(original_prompt)
- else:
- original_prompt = {}
-
- height = original_prompt.get("height", 1024)
- width = original_prompt.get("width", 1024)
- text_prompt = original_prompt.get("prompt", "")
-
- logger.info(
- "[ar2diffusion] Request %d: AR generated %d tokens, text length=%d, target size=%dx%d",
- i,
- len(generated_token_ids),
- len(generated_text),
- height,
- width,
- )
-
- token_tensor = torch.tensor(generated_token_ids, dtype=torch.long)
-
- diffusion_input: dict[str, Any] = {
- "prompt": text_prompt,
- "height": height,
- "width": width,
- "extra": {
- "ar_token_ids": token_tensor,
- "ar_generated_text": generated_text,
- },
- }
-
- # Forward multimodal data (original image for IT2I conditioning)
- mm_data = original_prompt.get("multi_modal_data")
- if mm_data:
- pil_image = mm_data.get("image")
- if pil_image is None:
- images = mm_data.get("images")
- if images:
- pil_image = images[0] if isinstance(images, list) else images
- if pil_image is not None:
- diffusion_input["pil_image"] = pil_image
-
- # Forward multimodal output from AR (if any)
- if hasattr(ar_output, "multimodal_output") and ar_output.multimodal_output:
- mm_output = ar_output.multimodal_output
- if isinstance(mm_output, dict):
- diffusion_input["extra"]["ar_multimodal_output"] = mm_output
-
- # Forward sampling params
- for key in ["seed", "num_inference_steps", "guidance_scale", "negative_prompt"]:
- if key in original_prompt:
- diffusion_input[key] = original_prompt[key]
-
- diffusion_inputs.append(diffusion_input)
-
- return diffusion_inputs
diff --git a/vllm_omni/model_executor/stage_input_processors/mammoth_moda2.py b/vllm_omni/model_executor/stage_input_processors/mammoth_moda2.py
index 5eec3d0453e..1239aafafd8 100644
--- a/vllm_omni/model_executor/stage_input_processors/mammoth_moda2.py
+++ b/vllm_omni/model_executor/stage_input_processors/mammoth_moda2.py
@@ -34,7 +34,7 @@ def ar2dit(
prompt_token_ids = ar_output.prompt_token_ids
# exclude the last token because it has no corresponding hidden state
completion_output = ar_output.outputs[0]
- gen_token_ids = completion_output.cumulative_token_ids[:-1]
+ gen_token_ids = completion_output.token_ids[:-1]
full_token_ids = prompt_token_ids + gen_token_ids
mm_output = getattr(completion_output, "multimodal_output", None)
diff --git a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py
index 96680b2dd94..71573c0e275 100644
--- a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py
+++ b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py
@@ -9,14 +9,6 @@
logger = init_logger(__name__)
-# Maximum tokens supported by the code2wav stage. The flattened talker codec
-# sequence fed to stage-1 must not exceed this, otherwise gpu_input_batch
-# add_request will fail with a broadcast error when copying prompt_token_ids
-# into token_ids_cpu. Keep in sync with the stage-1 ``max_model_len`` in
-# ``vllm_omni/model_executor/stage_configs/mimo_audio.yaml`` and the offline
-# example ``examples/offline_inference/mimo_audio/end2end.py``.
-MAX_CODE2WAV_TOKENS = 18192
-
def prepend_and_flatten_colmajor(x: torch.Tensor, pad_vec: torch.Tensor) -> torch.Tensor:
"""
@@ -55,7 +47,7 @@ def prepend_and_flatten_colmajor(x: torch.Tensor, pad_vec: torch.Tensor) -> torc
def _make_finished_sentinel() -> dict[str, Any]:
"""Return a minimal payload with finished=True so Stage-1 can end the request."""
- return {"codes": {"audio": []}, "meta": {"finished": torch.tensor(True, dtype=torch.bool)}}
+ return {"code_predictor_codes": [], "finished": torch.tensor(True, dtype=torch.bool)}
def _flush_remaining_codes(
@@ -79,14 +71,12 @@ def _flush_remaining_codes(
flat_codes = torch.tensor(accumulated[-end_index:]).reshape(-1).tolist()
return {
- "codes": {"audio": flat_codes},
- "meta": {
- "left_context_size": left_ctx_frames,
- "codec_chunk_frames": chunk_size,
- "codec_left_context_frames": left_context_size,
- "code_flat_numel": len(flat_codes),
- "finished": torch.tensor(True, dtype=torch.bool),
- },
+ "code_predictor_codes": flat_codes,
+ "left_context_size": left_ctx_frames,
+ "codec_chunk_frames": chunk_size,
+ "codec_left_context_frames": left_context_size,
+ "code_flat_numel": len(flat_codes),
+ "finished": torch.tensor(True, dtype=torch.bool),
}
@@ -124,6 +114,7 @@ def llm2code2wav_async_chunk(
Accumulates codes in connector per request_id,
returns payload only when chunk_size is full or request is finished; returns None when waiting.
"""
+
connector = getattr(transfer_manager, "connector", None)
raw_cfg = getattr(connector, "config", {}) or {}
cfg = raw_cfg.get("extra", raw_cfg) if isinstance(raw_cfg, dict) else {}
@@ -132,14 +123,14 @@ def llm2code2wav_async_chunk(
request_id = getattr(request, "external_req_id", None)
- po_codes = pooling_output.get("codes", {})
- if "audio" not in po_codes:
+ codes = pooling_output.get("code_predictor_codes")
+
+ if _is_codes_empty(codes):
if is_finished:
return _flush_remaining_codes(transfer_manager, request_id, chunk_size, left_context_size)
return None
- code_predictor_codes = po_codes["audio"]
- code_tensor = _to_code_tensor(code_predictor_codes)
+ code_tensor = _to_code_tensor(codes)
if code_tensor is None:
if is_finished:
return _flush_remaining_codes(transfer_manager, request_id, chunk_size, left_context_size)
@@ -168,16 +159,12 @@ def llm2code2wav_async_chunk(
flat_codes = torch.tensor(transfer_manager.code_prompt_token_ids[request_id][-end_index:]).reshape(-1).tolist()
return {
- "codes": {
- "audio": flat_codes,
- },
- "meta": {
- "left_context_size": left_ctx_frames,
- "codec_chunk_frames": chunk_size,
- "codec_left_context_frames": left_context_size,
- "code_flat_numel": len(flat_codes),
- "finished": torch.tensor(is_finished, dtype=torch.bool),
- },
+ "code_predictor_codes": flat_codes,
+ "left_context_size": left_ctx_frames,
+ "codec_chunk_frames": chunk_size,
+ "codec_left_context_frames": left_context_size,
+ "code_flat_numel": len(flat_codes),
+ "finished": torch.tensor(is_finished, dtype=torch.bool),
}
@@ -223,11 +210,8 @@ def llm2code2wav(
# Extract codec codes from talker output
# Expected shape: [8, seq_len] (8-layer RVQ codes)
- mm = output.multimodal_output
- mm_codes = mm.get("codes", {})
- mm_hs = mm.get("hidden_states", {})
- if "audio" in mm_codes:
- codec_codes = mm_codes["audio"].to(torch.long) # [seq_batch_size, 1, 8, 4]
+ if "code_predictor_codes" in output.multimodal_output:
+ codec_codes = output.multimodal_output["code_predictor_codes"].to(torch.long) # [seq_batch_size, 1, 8, 4]
is_all_zero = (codec_codes == 0).all(dim=(1, 2, 3))
non_zero_indices = (~is_all_zero).nonzero(as_tuple=True)[0]
if len(non_zero_indices) == 0:
@@ -241,7 +225,7 @@ def llm2code2wav(
else:
if len(non_zero_indices) < codec_codes.shape[0]:
codec_codes = codec_codes[non_zero_indices]
- elif "output" in mm_hs and "audio" not in mm_codes:
+ elif "latent" in output.multimodal_output and "code_predictor_codes" not in output.multimodal_output:
codec_codes = torch.zeros(1, 1, 8, 4, dtype=torch.long)
else:
raise ValueError(f"Invalid multimodal_output: {output.multimodal_output}")
@@ -251,20 +235,6 @@ def llm2code2wav(
code_final = prepend_and_flatten_colmajor(codec_codes, pad_vec)
code_final = code_final.tolist()
- # Guard against flattened sequences longer than code2wav's max_model_len.
- # Without this, add_request raises ``could not broadcast input array
- # from shape (N,) into shape (max_model_len,)`` and kills the engine
- # core (see issue #2683). Mirrors the offline end2end.py safeguard.
- if len(code_final) > MAX_CODE2WAV_TOKENS:
- request_id = getattr(talker_output, "request_id", f"unknown_{i}")
- logger.warning(
- "Request %s: code_final len=%d > MAX_CODE2WAV_TOKENS=%d, truncating.",
- request_id,
- len(code_final),
- MAX_CODE2WAV_TOKENS,
- )
- code_final = code_final[:MAX_CODE2WAV_TOKENS]
-
code2wav_inputs.append(
OmniTokensPrompt(
prompt_token_ids=code_final,
diff --git a/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py b/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py
deleted file mode 100644
index dddca3a9e2d..00000000000
--- a/vllm_omni/model_executor/stage_input_processors/ming_flash_omni.py
+++ /dev/null
@@ -1,104 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-"""Stage input processors for Ming-flash-omni-2.0 multi-stage pipeline."""
-
-from __future__ import annotations
-
-from typing import Any
-
-from vllm.inputs import TextPrompt
-
-from vllm_omni.inputs.data import OmniTokensPrompt
-
-
-def _validate_stage_inputs(stage_list, engine_input_source):
- """Validate stage inputs and return the source engine outputs."""
- if not engine_input_source:
- raise ValueError("engine_input_source cannot be empty")
-
- stage_id = engine_input_source[0]
- if stage_id >= len(stage_list):
- raise IndexError(f"Invalid stage_id: {stage_id}")
-
- stage = stage_list[stage_id]
- if stage.engine_outputs is None:
- raise RuntimeError(f"Stage {stage_id} has no outputs yet")
-
- return stage.engine_outputs
-
-
-def thinker2talker(
- stage_list: list[Any],
- engine_input_source: list[int],
- prompt: OmniTokensPrompt | TextPrompt | None = None,
- requires_multimodal_data: bool = False,
-) -> list[OmniTokensPrompt]:
- """Build talker stage inputs from thinker stage outputs.
-
- Extracts the generated text from thinker output and constructs
- a talker input prompt with the text and any speaker/instruction info
- from the original request.
- """
- source_outputs = _validate_stage_inputs(stage_list, engine_input_source)
-
- if not isinstance(prompt, list):
- prompt = [prompt]
-
- talker_inputs: list[OmniTokensPrompt] = []
- for i, source_output in enumerate(source_outputs):
- output = source_output.outputs[0]
-
- # Get the generated text from thinker
- generated_text = output.text if hasattr(output, "text") and output.text else ""
-
- # Extract additional information from the original prompt
- original_prompt = prompt[i] if i < len(prompt) else None
- additional_info = {}
- if original_prompt is not None and hasattr(original_prompt, "additional_information"):
- additional_info = original_prompt.additional_information or {}
-
- # spk_emb can arrive serialised as a plain list from JSON requests;
- # the talker's spk_head wants a torch tensor.
- spk_emb = additional_info.get("spk_emb", None)
- if isinstance(spk_emb, list) and spk_emb and not hasattr(spk_emb[0], "device"):
- import torch
-
- spk_emb = torch.tensor(spk_emb, dtype=torch.float32).unsqueeze(0)
-
- # Omni speech path mirrors upstream `omni_audio_generation`:
- # - `prompt` is hardcoded, `instruction` is forced to None,
- # cfg/sigma/temperature inherit the `tts_job` defaults (the
- # upstream API does NOT expose these knobs).
- # - Voice cloning is preset-only via `voice_name` (default
- # 'DB30'); `get_prompt_emb` is called with
- # `use_spk_emb=True, use_zero_spk_emb=False`, so when no
- # preset resolves upstream simply passes `spk_emb=None`
- # through to `tts_job` rather than substituting a zero
- # vector.
- # The bridge only plumbs the request-specific fields; the
- # talker `forward()` enforces the per-task defaults from
- # `ming_task="omni"` so any stray caller overrides are ignored.
- # Voice presets are resolved by voice_name in the talker's
- # forward() from its registered_prompts cache.
- talker_info = {
- "ming_task": "omni",
- "text": generated_text,
- "spk_emb": spk_emb,
- "voice_name": additional_info.get("voice_name", "DB30"),
- "prompt_text": additional_info.get("prompt_text", None),
- "prompt_wav_lat": additional_info.get("prompt_wav_lat", None),
- "prompt_wav_emb": additional_info.get("prompt_wav_emb", None),
- "max_text_length": additional_info.get("max_text_length", 50),
- }
-
- # Use dummy token IDs (talker builds its own embeddings from text)
- talker_inputs.append(
- OmniTokensPrompt(
- prompt_token_ids=[0],
- additional_information=talker_info,
- multi_modal_data=None,
- mm_processor_kwargs=None,
- )
- )
-
- return talker_inputs
diff --git a/vllm_omni/model_executor/stage_input_processors/omnivoice.py b/vllm_omni/model_executor/stage_input_processors/omnivoice.py
index b866e977913..b7f5c102e40 100644
--- a/vllm_omni/model_executor/stage_input_processors/omnivoice.py
+++ b/vllm_omni/model_executor/stage_input_processors/omnivoice.py
@@ -35,7 +35,7 @@ def tokens2audio(
# Pass audio_tokens from generator to decoder
engine_input = OmniTokensPrompt(
- prompt_token_ids=output.cumulative_token_ids,
+ prompt_token_ids=output.token_ids,
additional_information=multi_modal_data,
)
return [engine_input]
diff --git a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py
index 072869eafe5..e994589c4dd 100644
--- a/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py
+++ b/vllm_omni/model_executor/stage_input_processors/qwen2_5_omni.py
@@ -1,7 +1,6 @@
import torch
from vllm.inputs import TextPrompt
-from vllm_omni.data_entry_keys import OmniPayload
from vllm_omni.inputs.data import OmniTokensPrompt
TALKER_CODEC_PAD_TOKEN_ID = 8292
@@ -33,21 +32,17 @@ def thinker2talker(
for i, thinker_output in enumerate(thinker_outputs):
output = thinker_output.outputs[0]
prompt_token_ids = thinker_output.prompt_token_ids
- thinker_output_ids = output.cumulative_token_ids
+ thinker_output_ids = output.token_ids
prompt_token_ids_len = len(prompt_token_ids)
- mm: OmniPayload = output.multimodal_output
- latent = mm["latent"]
+ latent = output.multimodal_output["latent"]
thinker_hidden_states = latent.clone().detach().to(latent.device)
additional_information = {
- "hidden_states": {
- "output": thinker_hidden_states[prompt_token_ids_len:].to(torch.float32),
- "output_shape": list(thinker_hidden_states[prompt_token_ids_len:].shape),
- },
- "embed": {
- "prefill": thinker_hidden_states[:prompt_token_ids_len].to(torch.float32),
- "prefill_shape": list(thinker_hidden_states[:prompt_token_ids_len].shape),
- },
- "ids": {"prompt": prompt_token_ids, "output": thinker_output_ids},
+ "thinker_result": thinker_hidden_states[prompt_token_ids_len:].to(torch.float32),
+ "prompt_embeds": thinker_hidden_states[:prompt_token_ids_len].to(torch.float32),
+ "prompt_token_ids": prompt_token_ids,
+ "thinker_output_token_ids": thinker_output_ids,
+ "thinker_result_shape": list(thinker_hidden_states[prompt_token_ids_len:].shape),
+ "prompt_embeds_shape": list(thinker_hidden_states[:prompt_token_ids_len].shape),
}
talker_inputs.append(
OmniTokensPrompt(
diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
index 8018df343ee..f4828fddaa5 100644
--- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
+++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
@@ -3,15 +3,12 @@
# Copyright 2025 The Qwen team.
"""Stage input processor for Qwen3 Omni MoE: Thinker → Talker transition."""
-import logging
-from dataclasses import dataclass, field
from typing import Any
import torch
from vllm.inputs import TextPrompt
from vllm.platforms import current_platform
-from vllm_omni.data_entry_keys import OmniPayload
from vllm_omni.engine import OmniEngineCoreRequest
from vllm_omni.inputs.data import OmniTokensPrompt
from vllm_omni.model_executor.stage_input_processors.tts_utils import (
@@ -21,23 +18,16 @@
extract_speaker_from_request,
)
-logger = logging.getLogger(__name__)
-# Pooling output layer keys: "0" = word embedding, "24" = accept_hidden_layer
-_EMBED_LAYER_KEY = "0"
-_HIDDEN_LAYER_KEY = "24"
-
-
-def _compute_talker_prompt_ids_length(info: OmniPayload, device: torch.device | str = "cuda") -> int:
+def _compute_talker_prompt_ids_length(info, device: torch.device | str = "cuda") -> int:
im_start_token_id = 151644
system_token_id = 8948
user_token_id = 872
assistant_token_id = 77091
- ids = info.get("ids", {})
- thinker_sequences = torch.tensor(ids["all"], dtype=torch.long, device=device).unsqueeze(0) # [1, T]
+ thinker_sequences = torch.tensor(info["thinker_sequences"], dtype=torch.long, device=device).unsqueeze(0) # [1, T]
- input_ids = torch.tensor(ids["prompt"], dtype=torch.long, device=device).unsqueeze(0) # [1, T]
+ input_ids = torch.tensor(info["thinker_input_ids"], dtype=torch.long, device=device).unsqueeze(0) # [1, T]
im_start_indexes = torch.cat(
[
@@ -94,192 +84,6 @@ def _validate_stage_inputs(stage_list, engine_input_source):
return stage.engine_outputs
-# =========================
-# PD disaggregation helpers
-# =========================
-
-
-def _get_prefill_stage(stage_list: list[Any], source_stage_id: int) -> Any | None:
- if source_stage_id <= 0:
- return None
- source_stage = stage_list[source_stage_id]
- if not getattr(source_stage, "is_decode_only", False):
- return None
- prev_stage = stage_list[source_stage_id - 1]
- if getattr(prev_stage, "is_prefill_only", False) and prev_stage.engine_outputs is not None:
- return prev_stage
- return None
-
-
-def _merge_pd_embeddings(
- decode_emb: torch.Tensor,
- decode_hid: torch.Tensor,
- prefill_mm: dict[str, Any],
- device: torch.device,
- expected_total: int | None = None,
-) -> tuple[torch.Tensor, torch.Tensor]:
- """Merge prefill prompt embeddings with decode generated embeddings.
-
- In PD mode the prefill engine processes the prompt and the decode engine
- generates tokens starting from position 1. This function concatenates
- them, removing the overlapping token(s):
-
- merged = prefill[:P] + decode[overlap:]
-
- where overlap = P + D - expected_total.
- """
- try:
- p_layers = prefill_mm.get("hidden_states", {}).get("layers", {})
- p_emb = p_layers[int(_EMBED_LAYER_KEY)].detach().to(device=device, dtype=torch.float)
- p_hid = p_layers[int(_HIDDEN_LAYER_KEY)].detach().to(device=device, dtype=torch.float)
- except (KeyError, AttributeError, TypeError) as exc:
- available_keys = list(prefill_mm.keys()) if isinstance(prefill_mm, dict) else type(prefill_mm).__name__
- logger.error(
- "_merge_pd_embeddings: failed to extract prefill embeddings (%s). "
- "Expected keys %r and %r, got: %s. "
- "Falling back to decode-only embeddings – talker user-segment will be degraded.",
- exc,
- _EMBED_LAYER_KEY,
- _HIDDEN_LAYER_KEY,
- available_keys,
- )
- return decode_emb, decode_hid
-
- if p_emb.shape[0] == 0 or decode_emb.shape[0] == 0:
- return decode_emb, decode_hid
-
- raw_total = p_emb.shape[0] + decode_emb.shape[0]
- overlap = max(0, raw_total - expected_total) if expected_total is not None else 0
-
- merged_emb = torch.cat([p_emb, decode_emb[overlap:]], dim=0)
- merged_hid = torch.cat([p_hid, decode_hid[overlap:]], dim=0)
- return merged_emb, merged_hid
-
-
-def _get_prefill_multimodal_output(prefill_stage: Any, output_index: int) -> dict[str, Any] | None:
- """Return multimodal_output dict from the PD prefill stage for a given batch index."""
- try:
- prefill_eos = prefill_stage.engine_outputs
- prefill_eo = prefill_eos[min(output_index, len(prefill_eos) - 1)]
- return prefill_eo.outputs[0].multimodal_output
- except Exception:
- return None
-
-
-def _resolve_tts_token_embedding(
- key: str,
- *,
- thinker_mm: dict[str, Any],
- prefill_mm: dict[str, Any] | None,
- device: torch.device,
-) -> torch.Tensor | None:
- """Return TTS BOS/EOS/PAD embedding tensors for the talker projection path.
-
- Values are taken from the current thinker (decode) ``multimodal_output``; in
- PD mode, missing keys may be filled from the paired prefill stage output.
- """
- val = thinker_mm.get("embed", {}).get(key)
- if val is None and prefill_mm is not None:
- val = prefill_mm.get("embed", {}).get(key)
- return val.detach().to(device=device, dtype=torch.float) if val is not None else None
-
-
-# =========================
-# Streaming input helpers
-# =========================
-
-
-@dataclass
-class _Thinker2TalkerStreamingState:
- last_prompt_len: int = 0
- last_output_len: int = 0
- merged_sequences: list[int] = field(default_factory=list)
-
-
-@dataclass
-class _Qwen3OmniStreamingState:
- thinker2talker: _Thinker2TalkerStreamingState = field(default_factory=_Thinker2TalkerStreamingState)
- talker2code2wav_last_seq_len: int = 0
-
-
-def _get_qwen3_streaming_state(
- request_id: str,
- streaming_context: Any | None,
-) -> _Qwen3OmniStreamingState:
- bridge_states = getattr(streaming_context, "bridge_states", None)
- per_model_state = bridge_states.setdefault("qwen3_omni", {})
- state = per_model_state.get(request_id)
- if state is None:
- state = _Qwen3OmniStreamingState()
- per_model_state[request_id] = state
- return state
-
-
-def _get_streaming_talker_tokens(
- request_id: str,
- prompt_token_ids: list[int],
- output_token_ids: list[int],
- new_prompt_len_snapshot: int | None = None,
- streaming_context: Any | None = None,
- *,
- clear_state: bool = False,
-) -> tuple[list[int], list[int], list[int], list[int]]:
- """Return streaming token slices and merged token views for thinker->talker.
- e.g. For the second streaming input request:
- merged_sequences: [input_prompt 1, output_tokens 1[:-1], input_prompt 2, output_tokens 2]
- thinker_input_ids: [input_prompt 1, output_tokens 1[:-1], input_prompt 2]
- Returns:
- inc_prompt: prompt token delta for this segment.
- inc_output: output token delta for this segment.
- merged_sequences: full thinker_sequences to send downstream.
- thinker_input_ids: full thinker_input_ids paired with merged_sequences.
- """
- state = _get_qwen3_streaming_state(request_id, streaming_context).thinker2talker
- if new_prompt_len_snapshot:
- prompt_token_ids = prompt_token_ids[:-new_prompt_len_snapshot]
- cur_prompt_len = len(prompt_token_ids)
- cur_output_len = len(output_token_ids)
-
- inc_prompt = prompt_token_ids[state.last_prompt_len :]
- inc_output = output_token_ids[state.last_output_len :]
- delta_sequences = inc_prompt + inc_output
- cached_sequences = state.merged_sequences
-
- merged_sequences = cached_sequences + delta_sequences
- thinker_input_ids = cached_sequences + inc_prompt
-
- # Persist history for next segment. Drop the latest sampled token to keep
- # thinker_input_ids / thinker_sequences alignment with next-step append.
- cached_sequences.extend(delta_sequences[:-1])
-
- state.last_prompt_len = cur_prompt_len
- state.last_output_len = cur_output_len
-
- if clear_state:
- state.last_prompt_len = 0
- state.last_output_len = 0
- state.merged_sequences.clear()
-
- return inc_prompt, inc_output, merged_sequences, thinker_input_ids
-
-
-def _get_streaming_codec_delta_len(
- cur_seq_len: int,
- request_id: str,
- talker_output: Any,
- streaming_context: Any | None = None,
-) -> int:
- """Return newly added seq_len for talker->code2wav in streaming mode."""
- state = _get_qwen3_streaming_state(request_id, streaming_context)
- prev_seq_len = state.talker2code2wav_last_seq_len
- seq_len = cur_seq_len - prev_seq_len
- state.talker2code2wav_last_seq_len = cur_seq_len + 1
- if bool(getattr(talker_output, "finished", False)):
- # Final segment: clear history to avoid cross-session carry-over.
- state.talker2code2wav_last_seq_len = 0
- return seq_len
-
-
# =========================
# Thinker -> Talker
# =========================
@@ -300,29 +104,23 @@ def thinker2talker_async_chunk(
request_id = request.external_req_id
chunk_id = transfer_manager.put_req_chunk[request_id]
- thinker_hs = pooling_output.get("hidden_states", {})
- thinker_layers = thinker_hs.get("layers", {})
- thinker_embed = pooling_output.get("embed", {})
-
if chunk_id == 0:
all_token_ids = request.all_token_ids # prefill + decode
prompt_token_ids = request.prompt_token_ids
# Convert ConstantList to regular list for OmniSerializer serialization
all_token_ids = _ensure_list(all_token_ids)
prompt_token_ids = _ensure_list(prompt_token_ids)
- payload: OmniPayload = {
- "embed": {
- "prefill": thinker_layers[int(_EMBED_LAYER_KEY)].detach().cpu(),
- # Provide thinker-side TTS token embeddings for talker projection
- "tts_bos": thinker_embed["tts_bos"].detach().cpu(),
- "tts_eos": thinker_embed["tts_eos"].detach().cpu(),
- "tts_pad": thinker_embed["tts_pad"].detach().cpu(),
- },
- "hidden_states": {"output": thinker_layers[int(_HIDDEN_LAYER_KEY)].detach().cpu()},
- "ids": {"all": all_token_ids, "prompt": prompt_token_ids},
- "meta": {"finished": torch.tensor(is_finished, dtype=torch.bool)},
+ talker_additional_info = {
+ "thinker_prefill_embeddings": pooling_output.get("0").detach().cpu(),
+ "thinker_hidden_states": pooling_output.get("24").detach().cpu(),
+ "thinker_sequences": all_token_ids,
+ "thinker_input_ids": prompt_token_ids,
+ # Provide thinker-side TTS token embeddings for talker projection
+ "tts_bos_embed": pooling_output.get("tts_bos_embed").detach().cpu(),
+ "tts_eos_embed": pooling_output.get("tts_eos_embed").detach().cpu(),
+ "tts_pad_embed": pooling_output.get("tts_pad_embed").detach().cpu(),
+ "finished": torch.tensor(is_finished, dtype=torch.bool),
}
- talker_additional_info = payload
speaker = extract_speaker_from_request(request)
if speaker is not None:
talker_additional_info["speaker"] = speaker
@@ -335,18 +133,15 @@ def thinker2talker_async_chunk(
return None
else:
save_payload = transfer_manager.request_payload.pop(request_id)
- talker_additional_info["embed"]["prefill"] = torch.cat(
+ talker_additional_info["thinker_prefill_embeddings"] = torch.cat(
(
- save_payload.get("embed", {}).get("prefill"),
- talker_additional_info.get("embed", {}).get("prefill"),
+ save_payload.get("thinker_prefill_embeddings"),
+ talker_additional_info.get("thinker_prefill_embeddings"),
),
dim=0,
)
- talker_additional_info["hidden_states"]["output"] = torch.cat(
- (
- save_payload.get("hidden_states", {}).get("output"),
- talker_additional_info.get("hidden_states", {}).get("output"),
- ),
+ talker_additional_info["thinker_hidden_states"] = torch.cat(
+ (save_payload.get("thinker_hidden_states"), talker_additional_info.get("thinker_hidden_states")),
dim=0,
)
else:
@@ -354,8 +149,8 @@ def thinker2talker_async_chunk(
# Convert ConstantList to regular list for OmniSerializer serialization
output_token_ids = _ensure_list(output_token_ids)
- talker_additional_info: OmniPayload = {
- "meta": {"finished": torch.tensor(is_finished, dtype=torch.bool)},
+ talker_additional_info = {
+ "finished": torch.tensor(is_finished, dtype=torch.bool),
}
speaker = extract_speaker_from_request(request)
if speaker is not None:
@@ -365,13 +160,14 @@ def thinker2talker_async_chunk(
talker_additional_info["language"] = language
if output_token_ids:
- talker_additional_info["meta"]["override_keys"] = [("embed", "decode"), ("ids", "output")]
- talker_additional_info["embed"] = {"decode": thinker_layers[int(_EMBED_LAYER_KEY)].detach().cpu()}
- talker_additional_info["ids"] = {"output": output_token_ids}
+ talker_additional_info["override_keys"] = ["thinker_decode_embeddings", "thinker_output_token_ids"]
+ talker_additional_info["thinker_decode_embeddings"] = pooling_output.get("0").detach().cpu()
+ talker_additional_info["thinker_output_token_ids"] = output_token_ids
else:
# When prefilling a chunked thinker, thinker_hidden_states needs to be updated.
- talker_additional_info["embed"] = {"prefill": thinker_layers[0].detach().cpu()}
- talker_additional_info["hidden_states"] = {"output": thinker_layers[24].detach().cpu()}
+ talker_additional_info["thinker_prefill_embeddings"] = pooling_output.get("0").detach().cpu()
+ talker_additional_info["thinker_hidden_states"] = pooling_output.get("24").detach().cpu()
+
return talker_additional_info
@@ -380,7 +176,6 @@ def thinker2talker(
engine_input_source: list[int],
prompt: OmniTokensPrompt | TextPrompt | None = None,
requires_multimodal_data: bool = False,
- streaming_context: Any | None = None,
) -> list[OmniTokensPrompt]:
"""
Process thinker outputs to create talker inputs.
@@ -390,9 +185,6 @@ def thinker2talker(
2. Split hidden states into: prompt embeddings + generated embeddings
3. Package for talker with additional information
- In PD disaggregation mode, merges prefill-stage prompt embeddings with
- decode-stage generated embeddings before handing off to the talker.
-
Args:
stage_list: List of stage objects
engine_input_source: Source stage IDs (typically [0] for thinker)
@@ -407,71 +199,22 @@ def thinker2talker(
device = torch.device(current_platform.device_type)
- # PD disaggregation: look up the preceding prefill stage (if any)
- source_stage_id = engine_input_source[0]
- prefill_stage = _get_prefill_stage(stage_list, source_stage_id)
-
# Process each thinker output
for i, thinker_output in enumerate(thinker_outputs):
output = thinker_output.outputs[0]
- req_id = str(getattr(thinker_output, "request_id", f"idx-{i}"))
- prompt_token_ids = _ensure_list(thinker_output.prompt_token_ids)
- output_ids = _ensure_list(output.cumulative_token_ids)
- is_streaming_session = bool(getattr(streaming_context, "enabled", False))
- if is_streaming_session:
- prompt_token_ids, output_ids, thinker_sequences, thinker_input_ids = _get_streaming_talker_tokens(
- req_id,
- prompt_token_ids,
- output_ids,
- getattr(streaming_context, "new_prompt_len_snapshot", None),
- streaming_context,
- clear_state=bool(getattr(thinker_output, "finished", False)),
- )
- else:
- thinker_sequences = prompt_token_ids + output_ids
- thinker_input_ids = prompt_token_ids
- new_seq_length = len(prompt_token_ids + output_ids) - 1
- thinker_mm: OmniPayload = output.multimodal_output
- mm_hs = thinker_mm.get("hidden_states", {})
- mm_layers = mm_hs.get("layers", {})
- thinker_emb = mm_layers[int(_EMBED_LAYER_KEY)].detach().to(device=device, dtype=torch.float)[-new_seq_length:]
- thinker_hid = mm_layers[int(_HIDDEN_LAYER_KEY)].detach().to(device=device, dtype=torch.float)[-new_seq_length:]
-
- prefill_mm: dict[str, Any] | None = None
- if prefill_stage is not None:
- prefill_mm = _get_prefill_multimodal_output(prefill_stage, i)
-
- if prefill_mm is not None:
- expected_total = len(prompt_token_ids) + len(output_ids)
- try:
- thinker_emb, thinker_hid = _merge_pd_embeddings(
- thinker_emb, thinker_hid, prefill_mm, device, expected_total=expected_total
- )
- except Exception as exc:
- logger.warning("[PD] Could not merge prefill embeddings: %s", exc)
-
- payload: OmniPayload = {
- "embed": {
- "prefill": thinker_emb,
- "tts_bos": _resolve_tts_token_embedding(
- "tts_bos", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device
- ),
- "tts_eos": _resolve_tts_token_embedding(
- "tts_eos", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device
- ),
- "tts_pad": _resolve_tts_token_embedding(
- "tts_pad", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device
- ),
- },
- "hidden_states": {
- "output": thinker_hid,
- },
- "ids": {
- "all": thinker_sequences,
- "prompt": thinker_input_ids,
- },
+
+ info = {
+ "thinker_prefill_embeddings": output.multimodal_output["0"].detach().to(device=device, dtype=torch.float),
+ "thinker_hidden_states": output.multimodal_output["24"].detach().to(device=device, dtype=torch.float),
+ "thinker_sequences": (
+ thinker_output.prompt_token_ids + output.token_ids
+ ), # the thinker_sequences is the whole ids
+ "thinker_input_ids": thinker_output.prompt_token_ids,
+ # Provide thinker-side TTS token embeddings for talker projection
+ "tts_bos_embed": output.multimodal_output["tts_bos_embed"].detach().to(device=device, dtype=torch.float),
+ "tts_eos_embed": output.multimodal_output["tts_eos_embed"].detach().to(device=device, dtype=torch.float),
+ "tts_pad_embed": output.multimodal_output["tts_pad_embed"].detach().to(device=device, dtype=torch.float),
}
- info = payload
speaker = extract_speaker_from_prompt(prompt, index=i)
if speaker is not None:
info["speaker"] = speaker
@@ -479,7 +222,7 @@ def thinker2talker(
if language is not None:
info["language"] = language
- prompt_len = _compute_talker_prompt_ids_length(payload, device=device)
+ prompt_len = _compute_talker_prompt_ids_length(info, device=device)
talker_inputs.append(
OmniTokensPrompt(
@@ -507,9 +250,7 @@ def talker2code2wav_async_chunk(
"""
Pooling version.
"""
- talker_codes = pooling_output.get("codes", {})
- code_predictor_codes = talker_codes.get("audio")
- if code_predictor_codes is None:
+ if "code_predictor_codes" not in pooling_output:
return None
connector = getattr(transfer_manager, "connector", None)
@@ -518,6 +259,8 @@ def talker2code2wav_async_chunk(
chunk_size_config = int(cfg.get("codec_chunk_frames", 25))
left_context_size_config = int(cfg.get("codec_left_context_frames", 25))
+ code_predictor_codes = pooling_output["code_predictor_codes"]
+
if code_predictor_codes is None:
return None
if isinstance(code_predictor_codes, torch.Tensor):
@@ -542,7 +285,6 @@ def talker2code2wav_async_chunk(
request_id = request.external_req_id
transfer_manager.code_prompt_token_ids[request_id].append(codec_codes)
length = len(transfer_manager.code_prompt_token_ids[request_id])
-
chunk_length = length % chunk_size_config
if chunk_length != 0 and not is_finished:
return None
@@ -559,10 +301,12 @@ def talker2code2wav_async_chunk(
.tolist()
)
- return {
- "codes": {"audio": codes},
- "meta": {"left_context_size": left_context_size, "finished": torch.tensor(is_finished, dtype=torch.bool)},
+ info = {
+ "code_predictor_codes": codes,
+ "left_context_size": left_context_size,
+ "finished": torch.tensor(is_finished, dtype=torch.bool),
}
+ return info
def talker2code2wav(
@@ -570,7 +314,6 @@ def talker2code2wav(
engine_input_source: list[int],
prompt: OmniTokensPrompt | TextPrompt | None = None,
requires_multimodal_data: bool = False,
- streaming_context: Any | None = None,
) -> list[OmniTokensPrompt]:
"""
Process talker outputs to create code2wav inputs.
@@ -592,19 +335,19 @@ def talker2code2wav(
talker_outputs = _validate_stage_inputs(stage_list, engine_input_source)
code2wav_inputs: list[OmniTokensPrompt] = []
# Process each talker output
- for i, talker_output in enumerate(talker_outputs):
+ for talker_output in talker_outputs:
output = talker_output.outputs[0]
- req_id = str(getattr(talker_output, "request_id", f"idx-{i}"))
- cur_seq_len = len(output.cumulative_token_ids) - 1
- seq_len = cur_seq_len
- is_streaming_session = bool(getattr(streaming_context, "enabled", False))
- if is_streaming_session:
- seq_len = _get_streaming_codec_delta_len(cur_seq_len, req_id, talker_output, streaming_context)
- mm: OmniPayload = output.multimodal_output
+ seq_len = len(output.token_ids) - 1
# Extract codec codes from talker output
# Expected shape: [8, seq_len] (8-layer RVQ codes)
codec_codes = (
- mm["codes"]["audio"][-seq_len:].to(torch.long).transpose(0, 1).cpu().to(torch.long).reshape(-1).tolist()
+ output.multimodal_output["code_predictor_codes"][-seq_len:]
+ .to(torch.long)
+ .transpose(0, 1)
+ .cpu()
+ .to(torch.long)
+ .reshape(-1)
+ .tolist()
) # 16, seq_len
code2wav_inputs.append(
OmniTokensPrompt(
diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py
index 95a771534a5..ade01693216 100644
--- a/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py
+++ b/vllm_omni/model_executor/stage_input_processors/qwen3_tts.py
@@ -37,13 +37,9 @@ def talker2code2wav(
# accumulated the final code sequence.
continue
output = talker_output.outputs[0]
- mm = output.multimodal_output
- mm_codes = mm.get("codes", {})
-
# audio_codes shape: [num_frames, Q] where Q=num_quantizers (16)
- audio_codes = mm_codes["audio"].to(torch.long)
- token_ids = output.cumulative_token_ids
-
+ audio_codes = output.multimodal_output["audio_codes"].to(torch.long)
+ token_ids = output.token_ids
# token_ids provides an upper bound on the newly generated codec span.
# audio_codes may still contain zero-padded / invalid rows, so trim only
# after filtering valid frames instead of trying to align EOS indices.
@@ -55,8 +51,8 @@ def talker2code2wav(
audio_codes = audio_codes[valid_mask]
if seq_len > 0 and audio_codes.ndim == 2 and int(audio_codes.shape[0]) > seq_len:
audio_codes = audio_codes[-seq_len:]
- ref_code = mm_codes.get("ref")
- ref_code_len = mm.get("meta", {}).get("ref_code_len")
+ ref_code = output.multimodal_output.get("ref_code")
+ ref_code_len = output.multimodal_output.get("ref_code_len")
if isinstance(ref_code_len, torch.Tensor):
ref_code_len = int(ref_code_len.reshape(-1)[-1].item()) if ref_code_len.numel() > 0 else 0
elif ref_code_len is None:
@@ -99,7 +95,7 @@ def talker2code2wav(
codec_codes = audio_codes.transpose(0, 1).cpu().reshape(-1).tolist()
additional_information: dict[str, Any] = {}
if ref_code_len > 0:
- additional_information["meta"] = {"left_context_size": [ref_code_len]}
+ additional_information["left_context_size"] = [ref_code_len]
# Propagate speaker and language from the original prompt so they are
# available as runtime_additional_information in later pipeline stages,
# consistent with qwen3-omni and qwen2.5-omni stage input processors.
@@ -121,7 +117,7 @@ def talker2code2wav(
def _extract_last_frame(pooling_output: dict[str, Any]) -> torch.Tensor | None:
- audio_codes = pooling_output.get("codes", {}).get("audio")
+ audio_codes = pooling_output.get("audio_codes")
if not isinstance(audio_codes, torch.Tensor) or audio_codes.numel() == 0:
return None
if audio_codes.ndim == 2:
@@ -152,7 +148,7 @@ def talker2code2wav_async_chunk(
if frame is not None:
codec_codes = frame.cpu().tolist()
transfer_manager.code_prompt_token_ids[request_id].append(codec_codes)
- ref_code = pooling_output.get("codes", {}).get("ref")
+ ref_code = pooling_output.get("ref_code")
if isinstance(ref_code, torch.Tensor) and ref_code.numel() > 0 and request_payload.get(request_id) is None:
request_payload[request_id] = ref_code.to(torch.long).cpu().contiguous()
elif not finished:
@@ -211,8 +207,8 @@ def talker2code2wav_async_chunk(
if length <= 0:
if finished:
return {
- "codes": {"audio": []},
- "meta": {"finished": torch.tensor(True, dtype=torch.bool)},
+ "code_predictor_codes": [],
+ "finished": True,
}
return None
@@ -258,8 +254,9 @@ def talker2code2wav_async_chunk(
code_predictor_codes = [window_frames[f][q] for q in range(num_quantizers) for f in range(num_frames)]
info: dict[str, Any] = {
- "codes": {"audio": code_predictor_codes},
- "meta": {"left_context_size": left_context_size, "finished": torch.tensor(finished, dtype=torch.bool)},
+ "code_predictor_codes": code_predictor_codes,
+ "left_context_size": left_context_size,
+ "finished": finished,
}
# Propagate speaker and language from the request so they are available
# as runtime_additional_information in subsequent pipeline stages, consistent
diff --git a/vllm_omni/model_executor/stage_input_processors/voxcpm.py b/vllm_omni/model_executor/stage_input_processors/voxcpm.py
deleted file mode 100644
index c2fcf521bf4..00000000000
--- a/vllm_omni/model_executor/stage_input_processors/voxcpm.py
+++ /dev/null
@@ -1,128 +0,0 @@
-from __future__ import annotations
-
-from typing import Any
-
-import torch
-from vllm.inputs import TextPrompt
-
-from vllm_omni.inputs.data import OmniTokensPrompt
-
-_VOXCPM_LATENT_MAGIC = 131071
-
-
-def _serialize_latent_to_codes(latent: Any) -> list[int]:
- latent_tensor = latent if isinstance(latent, torch.Tensor) else torch.as_tensor(latent)
- latent_tensor = latent_tensor.detach().cpu().contiguous()
- if latent_tensor.ndim == 3:
- if latent_tensor.shape[0] != 1:
- raise ValueError(f"Expected batch=1 latent tensor, got shape={tuple(latent_tensor.shape)}")
- latent_tensor = latent_tensor.squeeze(0)
- if latent_tensor.ndim != 2:
- raise ValueError(f"Unsupported latent_audio_feat shape for async chunk: {tuple(latent_tensor.shape)}")
- latent_dim, time_dim = int(latent_tensor.shape[0]), int(latent_tensor.shape[1])
- packed = latent_tensor.to(torch.bfloat16).contiguous().view(torch.uint16).reshape(-1).to(torch.int32)
- return [_VOXCPM_LATENT_MAGIC, latent_dim, time_dim, *packed.tolist()]
-
-
-def _coerce_finished_flag(value: Any) -> bool:
- """Normalize VoxCPM async-chunk finished markers to a Python bool."""
- if value is None:
- return False
- if isinstance(value, torch.Tensor):
- if value.numel() != 1:
- raise ValueError(f"finished tensor must be scalar, got shape={tuple(value.shape)}")
- return bool(value.detach().cpu().item())
- if isinstance(value, (list, tuple)):
- if not value:
- return False
- if len(value) != 1:
- raise ValueError(f"finished container must have one element, got len={len(value)}")
- return _coerce_finished_flag(value[0])
- return bool(value)
-
-
-def latent2vae(
- stage_list: list[Any],
- engine_input_source: list[int],
- prompt: OmniTokensPrompt | TextPrompt | None = None,
- requires_multimodal_data: bool = False,
-) -> list[OmniTokensPrompt]:
- del prompt, requires_multimodal_data
-
- if not engine_input_source:
- raise ValueError("engine_input_source cannot be empty")
-
- source_stage_id = engine_input_source[0]
- if source_stage_id >= len(stage_list):
- raise IndexError(f"Invalid stage_id: {source_stage_id}")
-
- source_outputs = stage_list[source_stage_id].engine_outputs
- if source_outputs is None:
- raise RuntimeError(f"Stage {source_stage_id} has no outputs yet")
-
- vae_inputs: list[OmniTokensPrompt] = []
- for source_output in source_outputs:
- output = source_output.outputs[0]
- multimodal_output = getattr(output, "multimodal_output", None)
- if not isinstance(multimodal_output, dict) or "latent_audio_feat" not in multimodal_output:
- raise ValueError(
- "VoxCPM latent stage output missing 'latent_audio_feat'. "
- f"request_id={getattr(source_output, 'request_id', None)}"
- )
-
- additional_information = {
- "latent_audio_feat": multimodal_output["latent_audio_feat"],
- }
- if "sr" in multimodal_output:
- additional_information["sample_rate"] = [int(multimodal_output["sr"])]
-
- vae_inputs.append(
- OmniTokensPrompt(
- prompt_token_ids=[0],
- additional_information=additional_information,
- multi_modal_data=None,
- mm_processor_kwargs=None,
- )
- )
-
- return vae_inputs
-
-
-def latent2vae_async_chunk(
- transfer_manager: Any,
- pooling_output: dict[str, Any] | None,
- request: Any,
- is_finished: bool = False,
-) -> dict[str, Any] | None:
- """Stage-0 latent → stage-1 VAE under ``async_chunk`` (connector payload)."""
- # Kept for callback signature compatibility with OmniChunkTransferAdapter.
- _ = transfer_manager
- finished_request = _coerce_finished_flag(is_finished)
- if callable(getattr(request, "is_finished", None)):
- finished_request = finished_request or _coerce_finished_flag(request.is_finished())
- if not isinstance(pooling_output, dict):
- if finished_request:
- return {
- "code_predictor_codes": [],
- "finished": torch.tensor(True, dtype=torch.bool),
- }
- return None
-
- latent = pooling_output.get("latent_audio_feat")
- if isinstance(latent, torch.Tensor) and latent.numel() == 0:
- latent = None
-
- if latent is None:
- if finished_request:
- return {
- "code_predictor_codes": [],
- "finished": torch.tensor(True, dtype=torch.bool),
- }
- return None
-
- serialized_codes = _serialize_latent_to_codes(latent)
- out: dict[str, Any] = {
- "code_predictor_codes": serialized_codes,
- "finished": torch.tensor(finished_request, dtype=torch.bool),
- }
- return out
diff --git a/vllm_omni/model_executor/stage_input_processors/voxtral_tts.py b/vllm_omni/model_executor/stage_input_processors/voxtral_tts.py
index e1a58b6f16d..2e71a080b6c 100644
--- a/vllm_omni/model_executor/stage_input_processors/voxtral_tts.py
+++ b/vllm_omni/model_executor/stage_input_processors/voxtral_tts.py
@@ -82,8 +82,8 @@ def generator2tokenizer_async_chunk(
if length <= 0:
if finished:
return {
- "codes": {"audio": []},
- "meta": {"finished": torch.tensor(True, dtype=torch.bool)},
+ "code_predictor_codes": [],
+ "finished": torch.tensor(True, dtype=torch.bool),
}
return None
@@ -105,6 +105,6 @@ def generator2tokenizer_async_chunk(
code_predictor_codes = torch.tensor(window_frames).reshape(-1).tolist()
return {
- "codes": {"audio": [int(ctx_frames)] + [int(context_length)] + code_predictor_codes},
- "meta": {"finished": torch.tensor(finished, dtype=torch.bool)},
+ "code_predictor_codes": [int(ctx_frames)] + [int(context_length)] + code_predictor_codes,
+ "finished": torch.tensor(finished, dtype=torch.bool),
}
diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py
index 23930f358bb..ca3ba271a19 100644
--- a/vllm_omni/outputs.py
+++ b/vllm_omni/outputs.py
@@ -9,33 +9,6 @@
from vllm_omni.inputs.data import OmniPromptType
-@dataclass
-class OmniConnectorOutput:
- """Communication results from Model Runner to Scheduler.
-
- Carries transfer readiness signals so the Scheduler can make scheduling
- decisions without ever calling connector.put()/get() directly.
-
- Attributes:
- chunk_ready_req_ids: Request IDs with newly arrived chunks this cycle.
- chunk_finished_req_ids: Request IDs whose final chunk has arrived.
- request_metadata: Lightweight scheduling metadata keyed by request ID
- (e.g. next_stage_prompt_len, code_predictor_codes, left_context_size).
- Full payloads are owned by the Model Runner's local cache.
- kv_sent_req_ids: Request IDs whose KV cache was successfully sent.
- stage_recv_req_ids: Request IDs that received batch stage inputs.
- has_pending_kv_work: True if the mixin has pending, active, or
- completed KV transfers that the scheduler should account for.
- """
-
- chunk_ready_req_ids: set[str] = field(default_factory=set)
- chunk_finished_req_ids: set[str] = field(default_factory=set)
- request_metadata: dict[str, dict[str, Any]] = field(default_factory=dict)
- kv_sent_req_ids: list[str] = field(default_factory=list)
- stage_recv_req_ids: set[str] = field(default_factory=set)
- has_pending_kv_work: bool = False
-
-
class OmniModelRunnerOutput(ModelRunnerOutput):
"""Model runner output for omni models.
@@ -51,7 +24,6 @@ class OmniModelRunnerOutput(ModelRunnerOutput):
# IDs of requests whose KV cache has been extracted from GPU/NPU to CPU.
# The Scheduler can safely free the block tables for these requests.
kv_extracted_req_ids: list[str] | None = None
- omni_connector_output: OmniConnectorOutput | None = None
@dataclass
@@ -86,10 +58,6 @@ class OmniRequestOutput:
images: list[Image.Image] = field(default_factory=list)
prompt: OmniPromptType | None = None
latents: torch.Tensor | None = None
- trajectory_latents: torch.Tensor | None = None
- trajectory_timesteps: torch.Tensor | None = None
- trajectory_log_probs: torch.Tensor | None = None
- trajectory_decoded: list | None = None
metrics: dict[str, Any] = field(default_factory=dict)
_multimodal_output: dict[str, Any] = field(default_factory=dict)
_custom_output: dict[str, Any] = field(default_factory=dict)
@@ -100,22 +68,6 @@ class OmniRequestOutput:
# memory usage info
peak_memory_mb: float = 0.0
- # Error information -- set when the output represents a failed request.
- error: str | None = None
-
- @classmethod
- def from_error(
- cls,
- request_id: str,
- error: str,
- ) -> "OmniRequestOutput":
- """Create an error output for a request that failed during generation."""
- return cls(
- request_id=request_id,
- finished=True,
- error=error,
- )
-
@classmethod
def from_pipeline(
cls,
@@ -149,10 +101,6 @@ def from_diffusion(
prompt: OmniPromptType | None = None,
metrics: dict[str, Any] | None = None,
latents: torch.Tensor | None = None,
- trajectory_latents: torch.Tensor | None = None,
- trajectory_timesteps: torch.Tensor | None = None,
- trajectory_log_probs: torch.Tensor | None = None,
- trajectory_decoded: list | None = None,
multimodal_output: dict[str, Any] | None = None,
custom_output: dict[str, Any] | None = None,
final_output_type: str = "image",
@@ -167,12 +115,8 @@ def from_diffusion(
prompt: The prompt used
metrics: Generation metrics
latents: Optional latent tensors
- trajectory_latents: Optional stacked trajectory latent tensors
- trajectory_timesteps: Optional stacked trajectory timestep tensors
- trajectory_log_probs: Optional stacked trajectory log-probability tensors
- trajectory_decoded: Optional list of decoded trajectory images
multimodal_output: Optional multimodal output dict
- custom_output: Optional custom output dict (e.g. prompt embeds)
+ custom_output: Optional custom output dict (e.g. latent trajectories, prompt embeds)
stage_durations: Optional stage durations (execution time of each stage) dict
peak_memory_mb: Peak memory usage in MB
@@ -185,10 +129,6 @@ def from_diffusion(
images=images,
prompt=prompt,
latents=latents,
- trajectory_latents=trajectory_latents,
- trajectory_timesteps=trajectory_timesteps,
- trajectory_log_probs=trajectory_log_probs,
- trajectory_decoded=trajectory_decoded,
metrics=metrics or {},
_multimodal_output=multimodal_output or {},
_custom_output=custom_output or {},
diff --git a/vllm_omni/patch.py b/vllm_omni/patch.py
index f6c483a92f0..eafff821a21 100644
--- a/vllm_omni/patch.py
+++ b/vllm_omni/patch.py
@@ -1,8 +1,6 @@
import sys
-from functools import cached_property
from aenum import extend_enum
-from vllm.config import ModelConfig as _OriginalModelConfig
from vllm.inputs import TokensPrompt as _OriginalTokensPrompt
from vllm.model_executor.layers.rotary_embedding import (
MRotaryEmbedding as _OriginalMRotaryEmbedding,
@@ -12,63 +10,12 @@
from vllm.v1.engine import EngineCoreRequest as _OriginalEngineCoreRequest
from vllm.v1.request import Request as _OriginalRequest
from vllm.v1.request import RequestStatus
-from vllm.v1.request import StreamingUpdate as _OriginalStreamingUpdate
import vllm_omni.logger # noqa: F401
from vllm_omni.engine import OmniEngineCoreOutput, OmniEngineCoreOutputs, OmniEngineCoreRequest
from vllm_omni.inputs.data import OmniTokensPrompt
from vllm_omni.model_executor.layers.rotary_embedding import OmniMRotaryEmbedding
-from vllm_omni.request import OmniRequest, OmniStreamingUpdate
-
-# =============================================================================
-# Patch ModelConfig.is_mm_prefix_lm to support omni-specific models
-# =============================================================================
-# WHY: HunyuanImage-3.0 requires bidirectional attention for image tokens
-# (cond_token_attn_type: "joint_full" in config.json). vLLM gates this on
-# is_mm_prefix_lm, which checks an internal MM_PREFIX_LM_MODELS list that
-# does not include "hunyuan_image_3_moe" (the upstream HF model_type).
-#
-# WHY NOT model-level: is_mm_prefix_lm is checked in vLLM core (scheduler,
-# attention backend selection) before model code runs — no model-level hook.
-#
-# SCOPE: Only affects model_type in _OMNI_MM_PREFIX_LM_MODELS (currently
-# just "hunyuan_image_3_moe"). All other models fall through to the
-# original vLLM implementation unchanged.
-#
-# FRAGILITY: Relies on is_mm_prefix_lm being a cached_property on
-# ModelConfig. The __dict__ access + __set_name__ dance works around a
-# pydantic dataclass issue in vllm 0.19.0+. If vLLM changes
-# is_mm_prefix_lm to a regular method or removes it, this will break.
-#
-# TODO: Upstream a configurable MM_PREFIX_LM_MODELS or a model_config flag
-# so this patch can be removed.
-_OMNI_MM_PREFIX_LM_MODELS = ("hunyuan_image_3_moe",)
-# Access via __dict__ to avoid triggering cached_property.__get__ which fails
-# with "Cannot use cached_property instance without calling __set_name__" in
-# pydantic dataclasses (vllm 0.19.0+).
-_cp = _OriginalModelConfig.__dict__["is_mm_prefix_lm"]
-_original_is_mm_prefix_lm = _cp.func if hasattr(_cp, "func") else _cp.fget
-
-
-def _patched_is_mm_prefix_lm(self):
- if _original_is_mm_prefix_lm(self):
- return True
- model_type = getattr(self.hf_config, "model_type", "")
- return model_type in _OMNI_MM_PREFIX_LM_MODELS
-
-
-_patched_cp = cached_property(_patched_is_mm_prefix_lm)
-_patched_cp.__set_name__(_OriginalModelConfig, "is_mm_prefix_lm")
-_OriginalModelConfig.is_mm_prefix_lm = _patched_cp
-
-# Sanity check: verify the patch is active. If vLLM changes the descriptor
-# type or __set_name__ semantics, this will fail loudly at import time
-# rather than silently falling back to unpatched behavior.
-_installed = _OriginalModelConfig.__dict__.get("is_mm_prefix_lm")
-assert _installed is _patched_cp, (
- "is_mm_prefix_lm patch failed to install — bidirectional attention "
- "for HunyuanImage3 will not work. Check vLLM ModelConfig changes."
-)
+from vllm_omni.request import OmniRequest
# =============================================================================
# Patch GlmImageTextConfig to expose mrope_section in rope_parameters
@@ -116,7 +63,5 @@ def _patched_glm_image_text_config_init(self, *args, **kwargs):
module.MRotaryEmbedding = OmniMRotaryEmbedding
if hasattr(module, "Request") and module.Request == _OriginalRequest:
module.Request = OmniRequest
- if hasattr(module, "StreamingUpdate") and module.StreamingUpdate == _OriginalStreamingUpdate:
- module.StreamingUpdate = OmniStreamingUpdate
if hasattr(module, "EngineCoreRequest") and module.EngineCoreRequest == _OriginalEngineCoreRequest:
module.EngineCoreRequest = OmniEngineCoreRequest
diff --git a/vllm_omni/platforms/interface.py b/vllm_omni/platforms/interface.py
index 11eec76acdf..4325851e5fb 100644
--- a/vllm_omni/platforms/interface.py
+++ b/vllm_omni/platforms/interface.py
@@ -1,17 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from contextlib import nullcontext
from enum import Enum
from typing import Any
import torch
-import torch.nn as nn
-from vllm.logger import init_logger
from vllm.platforms import Platform
-logger = init_logger(__name__)
-
class OmniPlatformEnum(Enum):
"""Enum for supported Omni platforms."""
@@ -65,20 +60,13 @@ def get_default_stage_config_path(cls) -> str:
@classmethod
def get_diffusion_model_impl_qualname(cls, op_name: str) -> str:
if op_name == "hunyuan_fused_moe":
- return "vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
+ return "vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
raise NotImplementedError(f"Unsupported diffusion model op: {op_name}")
@classmethod
def prepare_diffusion_op_runtime(cls, op_name: str, **kwargs: Any) -> None:
return None
- @classmethod
- def get_diffusion_packed_modules_mapping(
- cls,
- model_class: type[nn.Module],
- ) -> dict[str, list[str]] | None:
- return None
-
@classmethod
def get_diffusion_attn_backend_cls(
cls,
@@ -125,31 +113,10 @@ def synchronize(cls) -> None:
def get_free_memory(cls, device: torch.device | None = None) -> int:
raise NotImplementedError
- @classmethod
- def create_autocast_context(
- cls,
- *,
- device_type: str,
- dtype: torch.dtype,
- enabled: bool = True,
- ):
- if not enabled:
- return nullcontext()
-
- try:
- return torch.autocast(device_type=device_type, dtype=dtype, enabled=True)
- except (RuntimeError, TypeError, ValueError) as exc:
- logger.warning("autocast unavailable for device_type=%s dtype=%s: %s", device_type, dtype, exc)
- return nullcontext()
-
@classmethod
def supports_cpu_offload(cls) -> bool:
return True
- @classmethod
- def supports_float64(cls) -> bool:
- return True
-
@classmethod
def set_device_control_env_var(cls, devices: str | int | None) -> None:
import os
diff --git a/vllm_omni/platforms/musa/platform.py b/vllm_omni/platforms/musa/platform.py
index 77ef225100f..932ce62d27e 100644
--- a/vllm_omni/platforms/musa/platform.py
+++ b/vllm_omni/platforms/musa/platform.py
@@ -8,7 +8,6 @@
from vllm_musa.platform import MUSAPlatformBase
from vllm_omni.diffusion.attention.backends.registry import DiffusionAttentionBackendEnum
-from vllm_omni.diffusion.attention.backends.utils.fa import is_mate_available
from vllm_omni.platforms.interface import OmniPlatform, OmniPlatformEnum
logger = init_logger(__name__)
@@ -25,11 +24,11 @@ class MUSAOmniPlatform(OmniPlatform, MUSAPlatformBase):
@classmethod
def get_omni_ar_worker_cls(cls) -> str:
- return "vllm_omni.worker.gpu_ar_worker.GPUARWorker"
+ return "vllm_omni.platforms.musa.worker.musa_ar_worker.MUSAARWorker"
@classmethod
def get_omni_generation_worker_cls(cls) -> str:
- return "vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker"
+ return "vllm_omni.platforms.musa.worker.musa_generation_worker.MUSAGenerationWorker"
@classmethod
def get_default_stage_config_path(cls) -> str:
@@ -39,7 +38,7 @@ def get_default_stage_config_path(cls) -> str:
def get_diffusion_model_impl_qualname(cls, op_name: str) -> str:
# MUSA uses default implementations for diffusion ops
if op_name == "hunyuan_fused_moe":
- return "vllm_omni.diffusion.models.hunyuan_image3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
+ return "vllm_omni.diffusion.models.hunyuan_image_3.hunyuan_fused_moe.HunyuanFusedMoEDefault"
return super().get_diffusion_model_impl_qualname(op_name)
@classmethod
@@ -55,7 +54,9 @@ def get_diffusion_attn_backend_cls(
) -> str:
"""Get the diffusion attention backend class path for MUSA platform.
- MUSA supports FLASH_ATTN via the mate package, and SDPA as fallback.
+ MUSA currently supports SDPA (Scaled Dot Product Attention) as the
+ primary backend. Flash Attention support may be added in future
+ when MUSA-specific implementations are available.
Args:
selected_backend: User-selected backend name (e.g., "FLASH_ATTN",
@@ -65,24 +66,13 @@ def get_diffusion_attn_backend_cls(
Returns:
Fully qualified class path of the selected backend.
"""
-
- flash_attn_available = is_mate_available()
-
if selected_backend is not None:
backend_upper = selected_backend.upper()
- if backend_upper == "FLASH_ATTN" and not flash_attn_available:
- logger.warning("Flash Attention (mate package) not available. Falling back to TORCH_SDPA backend.")
- logger.info("Defaulting to diffusion attention backend SDPA")
- return DiffusionAttentionBackendEnum.TORCH_SDPA.get_path()
backend = DiffusionAttentionBackendEnum[backend_upper]
logger.info("Using diffusion attention backend '%s'", backend_upper)
return backend.get_path()
- # Default to FLASH_ATTN if mate is available, otherwise SDPA
- if flash_attn_available:
- logger.info("Defaulting to diffusion attention backend FLASH_ATTN")
- return DiffusionAttentionBackendEnum.FLASH_ATTN.get_path()
-
+ # Default to SDPA for MUSA as it's the most compatible backend
logger.info("Defaulting to diffusion attention backend SDPA")
return DiffusionAttentionBackendEnum.TORCH_SDPA.get_path()
@@ -91,11 +81,6 @@ def supports_torch_inductor(cls) -> bool:
"""MUSA supports torch.compile with inductor backend."""
return True
- @classmethod
- def supports_float64(cls) -> bool:
- """MUSA does not support float64 yet."""
- return False
-
@classmethod
def get_torch_device(cls, local_rank: int | None = None) -> torch.device:
"""Get the torch device for MUSA platform.
diff --git a/vllm_omni/platforms/musa/worker/__init__.py b/vllm_omni/platforms/musa/worker/__init__.py
new file mode 100644
index 00000000000..bd0054870eb
--- /dev/null
+++ b/vllm_omni/platforms/musa/worker/__init__.py
@@ -0,0 +1,9 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from vllm_omni.platforms.musa.worker.musa_ar_worker import MUSAARWorker
+from vllm_omni.platforms.musa.worker.musa_generation_worker import (
+ MUSAGenerationWorker,
+)
+
+__all__ = ["MUSAARWorker", "MUSAGenerationWorker"]
diff --git a/vllm_omni/platforms/musa/worker/musa_ar_worker.py b/vllm_omni/platforms/musa/worker/musa_ar_worker.py
new file mode 100644
index 00000000000..258e911df18
--- /dev/null
+++ b/vllm_omni/platforms/musa/worker/musa_ar_worker.py
@@ -0,0 +1,103 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""MUSA AR (Autoregressive) worker for vLLM-Omni.
+
+This worker handles autoregressive model stages (thinker/talker) on MUSA devices.
+"""
+
+import gc
+import os
+
+import torch
+from vllm.logger import init_logger
+from vllm.platforms import current_platform
+from vllm.utils.mem_utils import MemorySnapshot, format_gib
+from vllm.utils.torch_utils import set_random_seed
+from vllm.v1.utils import report_usage_stats
+from vllm.v1.worker.gpu_worker import init_worker_distributed_environment
+from vllm.v1.worker.utils import request_memory
+from vllm.v1.worker.workspace import init_workspace_manager
+
+from vllm_omni.worker.base import OmniGPUWorkerBase
+from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner
+from vllm_omni.worker.mixins import OmniWorkerMixin
+
+logger = init_logger(__name__)
+
+
+class MUSAARWorker(OmniWorkerMixin, OmniGPUWorkerBase):
+ """MUSA AR worker for thinker/talker stages in Omni model."""
+
+ def init_device(self):
+ """Initialize the MUSA device for this worker."""
+ # This env var set by Ray causes exceptions with graph building.
+ os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
+ parallel_config = self.parallel_config
+ if (
+ parallel_config.distributed_executor_backend not in ("ray", "external_launcher")
+ and parallel_config.data_parallel_backend != "ray"
+ and parallel_config.nnodes_within_dp == 1
+ ):
+ # Use local DP rank if available, otherwise use global DP rank.
+ dp_local_rank = self.parallel_config.data_parallel_rank_local
+ if dp_local_rank is None:
+ dp_local_rank = self.parallel_config.data_parallel_index
+
+ tp_pp_world_size = self.parallel_config.pipeline_parallel_size * self.parallel_config.tensor_parallel_size
+
+ # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
+ self.local_rank += dp_local_rank * tp_pp_world_size
+ assert self.local_rank < torch.musa.device_count(), (
+ f"DP adjusted local rank {self.local_rank} is out of bounds. "
+ )
+ visible_device_count = torch.musa.device_count()
+ assert self.parallel_config.local_world_size <= visible_device_count, (
+ f"local_world_size ({self.parallel_config.local_world_size}) must "
+ f"be less than or equal to the number of visible devices "
+ f"({visible_device_count})."
+ )
+
+ self.device = torch.device(f"musa:{self.local_rank}")
+ torch.musa.set_device(self.device)
+
+ current_platform.check_if_supports_dtype(self.model_config.dtype)
+
+ # Initialize the distributed environment BEFORE taking memory snapshot
+ # This ensures NCCL buffers are allocated before we measure available memory
+ init_worker_distributed_environment(
+ self.vllm_config,
+ self.rank,
+ self.distributed_init_method,
+ self.local_rank,
+ current_platform.dist_backend,
+ )
+
+ # Set random seed.
+ set_random_seed(self.model_config.seed)
+
+ # Now take memory snapshot after distributed environment is initialized
+ gc.collect()
+ torch.musa.empty_cache()
+
+ # Take current memory snapshot
+ self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
+ self.requested_memory = request_memory(init_snapshot, self.cache_config)
+ logger.debug("worker init memory snapshot: %r", self.init_snapshot)
+ logger.debug("worker requested memory: %sGiB", format_gib(self.requested_memory))
+
+ # Initialize workspace manager
+ num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
+ init_workspace_manager(self.device, num_ubatches)
+
+ if self.use_v2_model_runner:
+ # OMNI: v2 model runner does not yet include omni hooks.
+ logger.warning("OMNI MUSAARWorker forces v1 model runner for omni hooks.")
+ self.use_v2_model_runner = False
+
+ # Construct the model runner
+ self.model_runner = GPUARModelRunner(self.vllm_config, self.device)
+
+ if self.rank == 0:
+ # If usage stat is enabled, collect relevant info.
+ report_usage_stats(self.vllm_config)
diff --git a/vllm_omni/platforms/musa/worker/musa_generation_worker.py b/vllm_omni/platforms/musa/worker/musa_generation_worker.py
new file mode 100644
index 00000000000..f433f8897ee
--- /dev/null
+++ b/vllm_omni/platforms/musa/worker/musa_generation_worker.py
@@ -0,0 +1,106 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""MUSA generation worker for vLLM-Omni.
+
+This worker handles non-autoregressive generation stages (e.g., code2wav waveform
+generation) on MUSA devices.
+"""
+
+import gc
+import os
+
+import torch
+from vllm.logger import init_logger
+from vllm.platforms import current_platform
+from vllm.tracing import instrument
+from vllm.utils.mem_utils import MemorySnapshot, format_gib
+from vllm.utils.torch_utils import set_random_seed
+from vllm.v1.utils import report_usage_stats
+from vllm.v1.worker.gpu_worker import init_worker_distributed_environment
+from vllm.v1.worker.utils import request_memory
+from vllm.v1.worker.workspace import init_workspace_manager
+
+from vllm_omni.worker.base import OmniGPUWorkerBase
+from vllm_omni.worker.gpu_generation_model_runner import GPUGenerationModelRunner
+from vllm_omni.worker.mixins import OmniWorkerMixin
+
+logger = init_logger(__name__)
+
+
+class MUSAGenerationWorker(OmniWorkerMixin, OmniGPUWorkerBase):
+ """MUSA generation worker for non-AR waveform generation stage."""
+
+ @instrument(span_name="Init device")
+ def init_device(self):
+ """Initialize the MUSA device for this worker."""
+ # This env var set by Ray causes exceptions with graph building.
+ os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
+ parallel_config = self.parallel_config
+ if (
+ parallel_config.distributed_executor_backend not in ("ray", "external_launcher")
+ and parallel_config.data_parallel_backend != "ray"
+ and parallel_config.nnodes_within_dp == 1
+ ):
+ # Use local DP rank if available, otherwise use global DP rank.
+ dp_local_rank = self.parallel_config.data_parallel_rank_local
+ if dp_local_rank is None:
+ dp_local_rank = self.parallel_config.data_parallel_index
+
+ tp_pp_world_size = self.parallel_config.pipeline_parallel_size * self.parallel_config.tensor_parallel_size
+
+ # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
+ self.local_rank += dp_local_rank * tp_pp_world_size
+ assert self.local_rank < torch.musa.device_count(), (
+ f"DP adjusted local rank {self.local_rank} is out of bounds. "
+ )
+ visible_device_count = torch.musa.device_count()
+ assert self.parallel_config.local_world_size <= visible_device_count, (
+ f"local_world_size ({self.parallel_config.local_world_size}) must "
+ f"be less than or equal to the number of visible devices "
+ f"({visible_device_count})."
+ )
+
+ self.device = torch.device(f"musa:{self.local_rank}")
+ torch.musa.set_device(self.device)
+
+ current_platform.check_if_supports_dtype(self.model_config.dtype)
+
+ # Initialize the distributed environment BEFORE taking memory snapshot
+ # This ensures NCCL buffers are allocated before we measure available memory
+ init_worker_distributed_environment(
+ self.vllm_config,
+ self.rank,
+ self.distributed_init_method,
+ self.local_rank,
+ current_platform.dist_backend,
+ )
+
+ # Set random seed.
+ set_random_seed(self.model_config.seed)
+
+ # Now take memory snapshot after distributed environment is initialized
+ gc.collect()
+ torch.musa.empty_cache()
+
+ # Take current memory snapshot
+ self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
+ self.requested_memory = request_memory(init_snapshot, self.cache_config)
+ logger.debug("worker init memory snapshot: %r", self.init_snapshot)
+ logger.debug("worker requested memory: %sGiB", format_gib(self.requested_memory))
+
+ # Initialize workspace manager
+ num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
+ init_workspace_manager(self.device, num_ubatches)
+
+ if self.use_v2_model_runner:
+ # OMNI: v2 model runner does not yet include omni hooks.
+ logger.warning("OMNI MUSAGenerationWorker forces v1 model runner for omni hooks.")
+ self.use_v2_model_runner = False
+
+ # Construct the model runner
+ self.model_runner = GPUGenerationModelRunner(self.vllm_config, self.device)
+
+ if self.rank == 0:
+ # If usage stat is enabled, collect relevant info.
+ report_usage_stats(self.vllm_config)
diff --git a/vllm_omni/platforms/npu/models/hunyuan_fused_moe.py b/vllm_omni/platforms/npu/models/hunyuan_fused_moe.py
index 05079a7e4ae..fad4c0edfc3 100644
--- a/vllm_omni/platforms/npu/models/hunyuan_fused_moe.py
+++ b/vllm_omni/platforms/npu/models/hunyuan_fused_moe.py
@@ -107,6 +107,12 @@ class AscendHunyuanFusedMoE(AscendSharedFusedMoE):
def __init__(self, *, prefix: str = "", **kwargs: Any) -> None:
super().__init__(prefix=prefix, **kwargs)
self._prefix = prefix
+ self._init_hook_handle = self.register_forward_pre_hook(self._initialize_kernel_hook, with_kwargs=True)
+
+ def _initialize_kernel_hook(self, module: Any, args: Any, kwargs: Any) -> None:
+ if self.quant_method:
+ self.quant_method.process_weights_after_loading(self)
+ self._init_hook_handle.remove()
def forward(self, hidden_states: Any, router_logits: Any) -> Any:
_set_hunyuan_fused_moe_forward_context(hidden_states.shape[0])
diff --git a/vllm_omni/platforms/npu/platform.py b/vllm_omni/platforms/npu/platform.py
index 1d3e221ffe7..1d6bea7cb5d 100644
--- a/vllm_omni/platforms/npu/platform.py
+++ b/vllm_omni/platforms/npu/platform.py
@@ -1,11 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from contextlib import nullcontext
from typing import Any
import torch
-import torch.nn as nn
from vllm.logger import init_logger
from vllm_ascend.platform import NPUPlatform
@@ -14,12 +12,6 @@
logger = init_logger(__name__)
-_DIFFUSION_PACKED_MODULES_MAPPING = {
- "HunyuanImage3Pipeline": {
- "experts": ["experts.0.gate_up_proj", "experts.0.down_proj"],
- },
-}
-
class NPUOmniPlatform(OmniPlatform, NPUPlatform):
"""NPU/Ascend implementation of OmniPlatform.
@@ -60,13 +52,6 @@ def prepare_diffusion_op_runtime(cls, op_name: str, **kwargs: Any) -> None:
prepare_hunyuan_fused_moe_runtime()
- @classmethod
- def get_diffusion_packed_modules_mapping(
- cls,
- model_class: type[nn.Module],
- ) -> dict[str, list[str]] | None:
- return _DIFFUSION_PACKED_MODULES_MAPPING.get(model_class.__name__, None)
-
@classmethod
def get_diffusion_attn_backend_cls(
cls,
@@ -83,9 +68,6 @@ def get_diffusion_attn_backend_cls(
# Try FLASH_ATTN if mindiesd is available, otherwise fall back to SDPA
if find_spec("mindiesd"):
- # Configure ASCEND_CUSTOM_OPP_PATH for mindiesd custom ops upon import
- import mindiesd # noqa: F401
-
logger.info("Defaulting to diffusion attention backend FLASH_ATTN")
return DiffusionAttentionBackendEnum.FLASH_ATTN.get_path()
@@ -124,24 +106,6 @@ def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.npu.get_device_properties(device_id)
return device_props.total_memory
- @classmethod
- def create_autocast_context(cls, *, device_type, dtype, enabled=True):
- if device_type != "npu":
- return super().create_autocast_context(
- device_type=device_type,
- dtype=dtype,
- enabled=enabled,
- )
- if not enabled:
- return nullcontext()
-
- # NPU-specific fallback
- try:
- return torch.npu.amp.autocast(dtype=dtype)
- except (RuntimeError, TypeError, ValueError) as exc:
- logger.warning("autocast unavailable for device_type=%s dtype=%s: %s", device_type, dtype, exc)
- return nullcontext()
-
@classmethod
def get_profiler_cls(cls) -> str:
return "vllm_omni.platforms.npu.profiler.NPUTorchProfilerWrapper"
diff --git a/vllm_omni/platforms/npu/stage_configs/hunyuan_image3_t2i.yaml b/vllm_omni/platforms/npu/stage_configs/hunyuan_image3_moe_dit.yaml
similarity index 94%
rename from vllm_omni/platforms/npu/stage_configs/hunyuan_image3_t2i.yaml
rename to vllm_omni/platforms/npu/stage_configs/hunyuan_image3_moe_dit.yaml
index 0fd03949d11..053e8a8cca0 100644
--- a/vllm_omni/platforms/npu/stage_configs/hunyuan_image3_t2i.yaml
+++ b/vllm_omni/platforms/npu/stage_configs/hunyuan_image3_moe_dit.yaml
@@ -33,3 +33,6 @@ stage_args:
# Runtime defaults
runtime:
enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
diff --git a/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml
new file mode 100644
index 00000000000..8f7af161d65
--- /dev/null
+++ b/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml
@@ -0,0 +1,97 @@
+# stage config for running qwen2.5-omni for multi-stage omni runtime.
+stage_args:
+ - stage_id: 0
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "0" # Visible devices for this stage
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.8
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent
+ enable_prefix_caching: false
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+ - stage_id: 1
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ process: true
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.8
+ enforce_eager: true # haven't supported talker ACL graph on NPU
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: latent
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
+ default_sampling_params:
+ temperature: 0.9
+ top_p: 0.8
+ top_k: 40
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+ stop_token_ids: [8294]
+ - stage_id: 2
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ process: true
+ devices: "2" # Example: use a different NPU than the previous stage; use "0" if single NPU
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ gpu_memory_utilization: 0.15
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: audio
+ engine_input_source: [1]
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+
+# Top-level runtime config (concise): default windows and stage edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1 # Simplified: trigger downstream only after full upstream completion
+ max_inflight: 1 # Simplified: process serially within each stage
+ edges:
+ - from: 0 # thinker → talker: trigger only after receiving full input (-1)
+ to: 1
+ window_size: -1
+ - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
+ to: 2
+ window_size: -1
diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml
new file mode 100644
index 00000000000..2638c99cd4b
--- /dev/null
+++ b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml
@@ -0,0 +1,99 @@
+# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
+# Stage 0: Thinker (multimodal understanding + text generation)
+# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
+# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
+
+# The following config has been verified on 5x A2/A3-64G NPUs.
+stage_args:
+ - stage_id: 0
+ runtime:
+ devices: "0,1"
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 1
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.6
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent # Output hidden states for talker
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ hf_config_name: thinker_config
+ tensor_parallel_size: 2
+ # profiler_config:
+ # profiler: torch
+ # torch_profiler_dir: ./perf
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ runtime:
+ devices: "2"
+ engine_args:
+ model_stage: talker
+ max_num_seqs: 1
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.6
+ enforce_eager: true # haven't supported talker ACL graph on NPU
+ trust_remote_code: true
+ engine_output_type: latent # Output codec codes for code2wav
+ # tensor_parallel_size: 2
+ enable_prefix_caching: false
+ distributed_executor_backend: "mp"
+ hf_config_name: talker_config
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
+ # final_output: true
+ # final_output_type: text
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ detokenize: False
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+
+ - stage_id: 2
+ runtime:
+ devices: "2"
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ async_scheduling: false
+ enable_prefix_caching: false
+ engine_output_type: audio # Final output: audio waveform
+ gpu_memory_utilization: 0.3
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 1000000
+ hf_config_name: thinker_config
+ engine_input_source: [1]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml
new file mode 100644
index 00000000000..9aa20baecfb
--- /dev/null
+++ b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml
@@ -0,0 +1,101 @@
+# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
+# Stage 0: Thinker (multimodal understanding + text generation)
+# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
+# Stage 2: Code2Wav (16-layer RVQ codes → audio waveform)
+
+# The following config has been verified on 2x H100-80G GPUs.
+async_chunk: true
+stage_args:
+ - stage_id: 0
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "0,1"
+ engine_args:
+ max_num_seqs: 10
+ model_stage: thinker
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent # Output hidden states for talker
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ hf_config_name: thinker_config
+ tensor_parallel_size: 2
+ custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "2"
+ engine_args:
+ max_num_seqs: 10
+ model_stage: talker
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.6
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: latent # Output codec codes for code2wav
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ distributed_executor_backend: "mp"
+ hf_config_name: talker_config
+ custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk
+ engine_input_source: [0]
+ # final_output: true
+ # final_output_type: text
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ detokenize: False
+ repetition_penalty: 1.0
+ stop_token_ids: [2150]
+
+ - stage_id: 2
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "2"
+ engine_args:
+ max_num_seqs: 10
+ model_stage: code2wav
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ async_scheduling: false
+ enable_prefix_caching: false
+ engine_output_type: audio # Final output: audio waveform
+ gpu_memory_utilization: 0.3
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 51200 # [TODO] if max_num_batched_tokens < max_num_seqs * 800, there will be precision problem.
+ hf_config_name: thinker_config
+ engine_input_source: [1]
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
diff --git a/vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml
similarity index 60%
rename from vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml
rename to vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml
index c6fd177a359..a741f819a2b 100644
--- a/vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml
+++ b/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml
@@ -1,5 +1,3 @@
-# VoxCPM two-stage streaming (align with qwen3_tts.yaml async_chunk pattern).
-# Stage0 (latent_generator) emits latent in time chunks; Stage1 (VAE) decodes as chunks arrive.
async_chunk: true
stage_args:
- stage_id: 0
@@ -7,48 +5,42 @@ stage_args:
is_comprehension: true
runtime:
devices: "0"
- max_batch_size: 1
engine_args:
- dtype: bfloat16
- model_stage: latent_generator
- model_arch: VoxCPMForConditionalGeneration
- # Optional persistent HF-compatible config dir for native VoxCPM models.
- hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,}
+ model_stage: qwen3_tts
+ max_num_seqs: 1
+ model_arch: Qwen3TTSTalkerForConditionalGeneration
worker_type: ar
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
enforce_eager: true
trust_remote_code: true
- async_scheduling: true
+ async_scheduling: false
enable_prefix_caching: false
engine_output_type: latent
- gpu_memory_utilization: 0.7
+ gpu_memory_utilization: 0.3
distributed_executor_backend: "mp"
- max_num_batched_tokens: 4096
+ max_num_batched_tokens: 512
max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.voxcpm.latent2vae_async_chunk
+ custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
+ # Use named connector to apply runtime.connectors.extra.
+ output_connectors:
+ to_stage_1: connector_of_shared_memory
default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
+ temperature: 0.9
+ top_k: 50
max_tokens: 4096
- stop_token_ids: [2]
seed: 42
detokenize: false
- repetition_penalty: 1.0
- final_output: false
- output_connectors:
- to_stage_1: voxcpm_shm
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
- stage_id: 1
stage_type: llm
runtime:
devices: "0"
- max_batch_size: 1
engine_args:
- dtype: float32
- model_stage: vae
- model_arch: VoxCPMForConditionalGeneration
- hf_config_path: ${oc.env:VLLM_OMNI_VOXCPM_HF_CONFIG_PATH,}
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: Qwen3TTSCode2Wav
worker_type: generation
scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager: true
@@ -56,30 +48,35 @@ stage_args:
async_scheduling: false
enable_prefix_caching: false
engine_output_type: audio
- gpu_memory_utilization: 0.15
+ gpu_memory_utilization: 0.2
distributed_executor_backend: "mp"
- max_num_batched_tokens: 8192
- max_model_len: 4096
+ max_num_batched_tokens: 65536
+ max_model_len: 65536
engine_input_source: [0]
final_output: true
final_output_type: audio
+ # Distributed connector configuration
input_connectors:
- from_stage_0: voxcpm_shm
+ from_stage_0: connector_of_shared_memory
+ tts_args:
+ max_instructions_length: 500
default_sampling_params:
temperature: 0.0
top_p: 1.0
top_k: -1
- max_tokens: 128
+ max_tokens: 65536
seed: 42
detokenize: true
repetition_penalty: 1.0
-
runtime:
enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
connectors:
- voxcpm_shm:
+ connector_of_shared_memory:
name: SharedMemoryConnector
extra:
shm_threshold_bytes: 65536
@@ -90,9 +87,10 @@ runtime:
connector_get_max_wait_first_chunk: 3000
connector_get_max_wait: 300
# Align with Omni: small chunks with sufficient context overlap.
- codec_chunk_frames: 1
- codec_left_context_frames: 1
+ codec_chunk_frames: 25
+ codec_left_context_frames: 25
edges:
- from: 0
to: 1
+ window_size: -1
diff --git a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
index ffb997048bd..138948064ba 100644
--- a/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
+++ b/vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
@@ -149,15 +149,7 @@ def execute_model(
encoder_cache=self.encoder_cache,
) as ec_connector_output:
self._execute_mm_encoder(scheduler_output)
-
- kv_ids = self.kv_extracted_req_ids
- self.kv_extracted_req_ids = None
-
- output = make_empty_encoder_model_runner_output(scheduler_output)
- if kv_ids:
- output = copy(output)
- output.kv_extracted_req_ids = kv_ids
- return output
+ return make_empty_encoder_model_runner_output(scheduler_output)
if not num_scheduled_tokens:
if (
@@ -171,20 +163,10 @@ def execute_model(
# dummy run to ensure coordinate_batch_across_dp
# is called into to avoid out of sync issues.
self._dummy_run(1)
-
- kv_ids = self.kv_extracted_req_ids
- self.kv_extracted_req_ids = None
-
if not has_kv_transfer_group():
- output = EMPTY_MODEL_RUNNER_OUTPUT
- else:
- output = self.kv_connector_no_forward(scheduler_output, self.vllm_config)
-
- if kv_ids:
- output = copy(output)
- output.kv_extracted_req_ids = kv_ids
-
- return output
+ # Return empty ModelRunnerOutput if no work to do.
+ return EMPTY_MODEL_RUNNER_OUTPUT
+ return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
if self.cache_config.kv_sharing_fast_prefill:
assert not self.num_prompt_logprobs, (
"--kv-sharing-fast-prefill produces incorrect "
diff --git a/vllm_omni/platforms/npu/worker/npu_model_runner.py b/vllm_omni/platforms/npu/worker/npu_model_runner.py
index 77b6e9ef05f..8ef39adfa67 100644
--- a/vllm_omni/platforms/npu/worker/npu_model_runner.py
+++ b/vllm_omni/platforms/npu/worker/npu_model_runner.py
@@ -417,30 +417,18 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te
req_embeds = self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded]
last_talker_hidden = self.last_talker_hidden.gpu[:num_tokens_padded]
text_step = self.text_step.gpu[:num_tokens_padded]
- subtalker_params = getattr(self.vllm_config.model_config, "subtalker_sampling_params", None)
- if not isinstance(subtalker_params, dict):
- subtalker_params = {}
with set_ascend_forward_context(
None, self.vllm_config, aclgraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc
):
- req_embeds, code_predictor_codes = self.talker_mtp(
- req_input_ids,
- req_embeds,
- last_talker_hidden,
- text_step,
- do_sample=subtalker_params.get("do_sample"),
- temperature=subtalker_params.get("temperature"),
- top_k=subtalker_params.get("top_k"),
- top_p=subtalker_params.get("top_p"),
- )
- # update the inputs_embeds and code_predictor_codes
- code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous()
- out_key = getattr(self.model, "talker_mtp_output_key", ("codes", "audio"))
- if not isinstance(out_key, tuple) or len(out_key) != 2:
- raise TypeError(f"talker_mtp_output_key must be a 2-tuple, got {type(out_key).__name__}: {out_key!r}")
+ req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step)
+ # code_predictor_codes stays on GPU here; _update_intermediate_buffer
+ # keeps it device-resident when the key is in gpu_resident_buffer_keys.
+ # D2H is deferred to sample_tokens where hidden_states.to("cpu") already
+ # syncs the stream, avoiding a per-step cudaStreamSynchronize.
+ out_key = getattr(self.model, "talker_mtp_output_key", "code_predictor_codes")
for idx, req_id in enumerate(decode_req_ids):
req_index = self.input_batch.req_ids.index(req_id)
start_offset = int(self.query_start_loc.cpu[req_index])
inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1]
- update_dict = {out_key[0]: {out_key[1]: code_predictor_codes_cpu[idx : idx + 1]}}
+ update_dict = {out_key: code_predictor_codes[idx : idx + 1]}
self._merge_additional_information_update(req_id, update_dict)
diff --git a/vllm_omni/platforms/rocm/platform.py b/vllm_omni/platforms/rocm/platform.py
index 7b0e09c128e..4479e54f2a2 100644
--- a/vllm_omni/platforms/rocm/platform.py
+++ b/vllm_omni/platforms/rocm/platform.py
@@ -16,34 +16,6 @@ class RocmOmniPlatform(OmniPlatform, RocmPlatform):
Inherits all ROCm-specific implementations from vLLM's RocmPlatform,
and adds Omni-specific interfaces from OmniPlatform.
-
-
- NOTE: AR Attention Backend Overriding Logic:
- ------------------------------------------
- Since vLLM v0.19.0, the default attention backend is ROCM_ATTN for ROCm.
- However, the compatibility of ROCM_ATTN with Omni is not guaranteed.
- Therefore, we still use TRITON_ATTN as the default attention backend,
- when the selected_backend is not specified.
-
- So the behaviour of the attention backend overriding logic currently lives in
- extract_stage_metadata in `vllm_omni/engine/stage_init_utils.py`
-
- ```
- if current_omni_platform.is_rocm():
- print(f"engine_args: {str(engine_args)}")
- if engine_args.get("attention_backend") is None:
- from vllm._aiter_ops import rocm_aiter_ops
-
- if rocm_aiter_ops.is_enabled():
- engine_args["attention_backend"] = "ROCM_AITER_FA"
- # Before vLLM v0.19.0, the default attention backend is TRITON_ATTN for ROCm.
- # Since vLLM v0.19.0, the default attention backend is ROCM_ATTN for ROCm.
- # However, the compatibility of ROCM_ATTN with Omni is not guaranteed.
- # Therefore, we still use TRITON_ATTN as the default attention backend,
- # when the selected_backend is not specified.
- engine_args["attention_backend"] = "TRITON_ATTN"
- ```
-
"""
_omni_enum = OmniPlatformEnum.ROCM
diff --git a/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml
new file mode 100644
index 00000000000..35e81935457
--- /dev/null
+++ b/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml
@@ -0,0 +1,102 @@
+# stage config for running qwen2.5-omni for multi-stage omni runtime.
+
+# The following config has been verified on 2x H100-80G GPU.
+stage_args:
+ - stage_id: 0
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.8
+ enforce_eager: true # Now we only support eager mode
+ trust_remote_code: true
+ engine_output_type: latent
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+
+ - stage_id: 1
+ runtime:
+ process: true
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.8
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ engine_output_type: latent
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
+ default_sampling_params:
+ temperature: 0.9
+ top_p: 0.8
+ top_k: 40
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+ stop_token_ids: [8294]
+
+ - stage_id: 2
+ runtime:
+ process: true
+ devices: "2" # Example: use a different GPU than the previous stage; use "0" if single GPU
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ gpu_memory_utilization: 0.15
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ engine_output_type: audio
+ engine_input_source: [1]
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+
+# Top-level runtime config (concise): default windows and stage edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1 # Simplified: trigger downstream only after full upstream completion
+ max_inflight: 1 # Simplified: process serially within each stage
+
+ edges:
+ - from: 0 # thinker → talker: trigger only after receiving full input (-1)
+ to: 1
+ window_size: -1
+ - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
+ to: 2
+ window_size: -1
diff --git a/vllm_omni/platforms/rocm/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/rocm/stage_configs/qwen3_omni_moe.yaml
new file mode 100644
index 00000000000..0ca150bee6c
--- /dev/null
+++ b/vllm_omni/platforms/rocm/stage_configs/qwen3_omni_moe.yaml
@@ -0,0 +1,97 @@
+# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
+# Stage 0: Thinker (multimodal understanding + text generation)
+# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
+# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
+
+# The following config has been verified on 2x H100-80G GPUs.
+stage_args:
+ - stage_id: 0
+ runtime:
+ devices: "0"
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 1
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: latent # Output hidden states for talker
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ hf_config_name: thinker_config
+ tensor_parallel_size: 1
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ runtime:
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ max_num_seqs: 1
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.6
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: latent # Output codec codes for code2wav
+ # tensor_parallel_size: 2
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ distributed_executor_backend: "mp"
+ hf_config_name: talker_config
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
+ # final_output: true
+ # final_output_type: text
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ detokenize: False
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+
+ - stage_id: 2
+ runtime:
+ devices: "1"
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: audio # Final output: audio waveform
+ gpu_memory_utilization: 0.1
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 1000000
+ hf_config_name: thinker_config
+ engine_input_source: [1]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
diff --git a/vllm_omni/platforms/xpu/platform.py b/vllm_omni/platforms/xpu/platform.py
index cc4218ae253..a77aa83a4fe 100644
--- a/vllm_omni/platforms/xpu/platform.py
+++ b/vllm_omni/platforms/xpu/platform.py
@@ -45,7 +45,8 @@ def get_diffusion_attn_backend_cls(
@classmethod
def supports_torch_inductor(cls) -> bool:
- return True
+ # TODO: Enable this when torch compile bugs are resolved
+ return False
@classmethod
def get_default_stage_config_path(cls) -> str:
diff --git a/vllm_omni/platforms/xpu/stage_configs/bagel.yaml b/vllm_omni/platforms/xpu/stage_configs/bagel.yaml
new file mode 100644
index 00000000000..0fc8a25ea5c
--- /dev/null
+++ b/vllm_omni/platforms/xpu/stage_configs/bagel.yaml
@@ -0,0 +1,86 @@
+# stage config for running bagel-7b-mot with architecture of OmniLLM.
+
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ prompt_expand_func: vllm_omni.model_executor.stage_input_processors.bagel.expand_cfg_prompts
+ runtime:
+ devices: "0"
+ engine_args:
+ # 3 = 1 user prompt + 2 CFG companions (text-unconditional + image-unconditional).
+ max_num_seqs: 3
+ model_stage: thinker
+ model_arch: OmniBagelForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: text
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 16384
+ tensor_parallel_size: 1
+ quantization: fp8
+ omni_kv_config:
+ need_send_cache: true
+ kv_transfer_criteria:
+ type: prefill_finished #or special token generated
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 52
+ detokenize: True
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ stage_type: diffusion
+ cfg_kv_collect_func: vllm_omni.model_executor.stage_input_processors.bagel.collect_cfg_kv_caches
+ runtime:
+ devices: "1"
+ engine_args:
+ max_num_seqs: 1
+ model_stage: dit
+ gpu_memory_utilization: 0.9
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: image
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 1
+ omni_kv_config:
+ need_recv_cache: true
+ engine_input_source: [0]
+
+ final_output: true
+ final_output_type: image
+ is_comprehension: false
+ default_sampling_params:
+ seed: 52
+
+# Runtime edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
+
+ # Distributed connectors configuration (optional)
+ # More connectors will be supported in the future.
+ connectors:
+ shared_memory_connector:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536 # 64KB threshold
+
+
+ edges:
+ - from: 0
+ to: 1
+ window_size: -1
diff --git a/vllm_omni/platforms/xpu/stage_configs/hunyuan_image3_t2i.yaml b/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml
similarity index 93%
rename from vllm_omni/platforms/xpu/stage_configs/hunyuan_image3_t2i.yaml
rename to vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml
index 4e0005f82a1..8f969ced5f4 100644
--- a/vllm_omni/platforms/xpu/stage_configs/hunyuan_image3_t2i.yaml
+++ b/vllm_omni/platforms/xpu/stage_configs/hunyuan_image_3_moe.yaml
@@ -78,3 +78,6 @@ stage_args:
# Top-level runtime config (concise): default windows and stage edges
runtime:
enabled: true
+ defaults:
+ window_size: -1 # Simplified: trigger downstream only after full upstream completion
+ max_inflight: 1 # Simplified: process serially within each stage
diff --git a/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml
new file mode 100644
index 00000000000..7dbedb29a5e
--- /dev/null
+++ b/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml
@@ -0,0 +1,101 @@
+# stage config for running qwen2.5-omni for multi-stage omni runtime.
+
+# The following config is verified with 2 * Intel Arc Pro B60 XPU.
+stage_args:
+ - stage_id: 0
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "0" # Visible devices for this stage
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9 # thinker weight is around 16.74GB for Qwen2.5-Omni-7B
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent
+ enable_prefix_caching: false
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+ - stage_id: 1
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ process: true
+ devices: "1"
+ engine_args:
+ model_stage: talker
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.5 # talker weight is 6.03GB for Qwen2.5-Omni-7B
+ enforce_eager: false
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: latent
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
+ default_sampling_params:
+ temperature: 0.9
+ top_p: 0.8
+ top_k: 40
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+ stop_token_ids: [8294]
+
+ - stage_id: 2
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ process: true
+ devices: "1"
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: Qwen2_5OmniForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ gpu_memory_utilization: 0.3 # code2wav weight is around 1.46GB for Qwen2.5-Omni-7B
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: audio
+ engine_input_source: [1]
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+
+# Top-level runtime config (concise): default windows and stage edges
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1 # Simplified: trigger downstream only after full upstream completion
+ max_inflight: 1 # Simplified: process serially within each stage
+
+ edges:
+ - from: 0 # thinker → talker: trigger only after receiving full input (-1)
+ to: 1
+ window_size: -1
+ - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
+ to: 2
+ window_size: -1
diff --git a/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml
new file mode 100644
index 00000000000..49914bebc43
--- /dev/null
+++ b/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml
@@ -0,0 +1,102 @@
+# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
+# Stage 0: Thinker (multimodal understanding + text generation)
+# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
+# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
+
+# The following config is verified with 8 * Intel Arc Pro B60 XPU.
+stage_args:
+ - stage_id: 0
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "0,1,2,3"
+ engine_args:
+ model_stage: thinker
+ max_num_seqs: 1
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9 # thinker weight is around 61.08GB for Qwen3-Omni-30B-A3B-Instruct
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: latent # Output hidden states for talker
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ hf_config_name: thinker_config
+ tensor_parallel_size: 4
+ max_cudagraph_capture_size: 0
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "4"
+ engine_args:
+ model_stage: talker
+ max_num_seqs: 1
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.6 # talker weight is around 8.5GB for Qwen3-Omni-30B-A3B-Instruct
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: latent # Output codec codes for code2wav
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ distributed_executor_backend: "mp"
+ hf_config_name: talker_config
+ max_cudagraph_capture_size: 0
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
+ # final_output: true
+ # final_output_type: text
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ detokenize: False
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+
+ - stage_id: 2
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ devices: "4"
+ engine_args:
+ model_stage: code2wav
+ max_num_seqs: 1
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_type: generation
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: audio # Final output: audio waveform
+ gpu_memory_utilization: 0.3 # code2wav weight is around 0.4GB for Qwen3-Omni-30B-A3B-Instruct
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 1000000
+ hf_config_name: thinker_config
+ max_cudagraph_capture_size: 0
+ engine_input_source: [1]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
diff --git a/vllm_omni/platforms/xpu/stage_configs/voxtral_tts.yaml b/vllm_omni/platforms/xpu/stage_configs/voxtral_tts.yaml
index 0820ab63203..10051c1eda4 100644
--- a/vllm_omni/platforms/xpu/stage_configs/voxtral_tts.yaml
+++ b/vllm_omni/platforms/xpu/stage_configs/voxtral_tts.yaml
@@ -88,6 +88,9 @@ stage_args:
runtime:
enabled: true
+ defaults:
+ window_size: -1
+ max_inflight: 1
connectors:
connector_of_shared_memory:
@@ -105,3 +108,4 @@ runtime:
edges:
- from: 0
to: 1
+ window_size: -1
diff --git a/vllm_omni/profiler/omni_torch_profiler.py b/vllm_omni/profiler/omni_torch_profiler.py
index 69f893b05fa..2257a212838 100644
--- a/vllm_omni/profiler/omni_torch_profiler.py
+++ b/vllm_omni/profiler/omni_torch_profiler.py
@@ -5,8 +5,8 @@
import os
import subprocess
-from datetime import datetime
-from typing import Any, Literal
+import time
+from typing import Literal
import torch
from typing_extensions import override
@@ -62,13 +62,6 @@ def __init__(
self._trace_path: str | None = None
self._table_path: str | None = None
- self._activities = activities
- self._session_dir: str | None = None
- self._artifact_paths: dict[str, str | None] = {}
- self._memory_history_enabled = False
- self._memory_history_backend: str | None = None
- self._memory_history_module = None
-
if local_rank in (None, 0):
logger.info_once(
"Omni torch profiling enabled. Traces will be saved to: %s",
@@ -79,9 +72,6 @@ def __init__(
self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1
self.profiler = self._create_profiler(profiler_config, activities)
- def _rank(self) -> int:
- return 0 if self.local_rank is None else self.local_rank
-
def _get_default_activities(self) -> list[TorchProfilerActivity]:
"""Get default activities for this platform.
@@ -116,58 +106,19 @@ def set_trace_filename(self, filename: str) -> None:
Can also be a full path (e.g. from diffusion engine).
"""
self._trace_filename = filename
- self._session_dir = None
- self._ensure_session_dir()
-
- def _ensure_session_dir(self) -> str:
- """Create one timestamped directory for this profiling run."""
- if self._session_dir is not None:
- os.makedirs(self._session_dir, exist_ok=True)
- return self._session_dir
-
- ts = datetime.now().strftime("%Y%m%d-%H%M%S")
- base_name = self._trace_filename or self._worker_name
-
- if os.path.dirname(base_name):
- parent_dir = os.path.dirname(base_name)
- leaf_name = os.path.basename(base_name)
- session_name = f"{ts}_{leaf_name}"
- self._session_dir = os.path.join(parent_dir, session_name)
- else:
- session_name = f"{ts}_{base_name}"
- self._session_dir = os.path.join(self._trace_dir, session_name)
-
- os.makedirs(self._session_dir, exist_ok=True)
- self._artifact_paths["session_dir"] = self._session_dir
- return self._session_dir
-
- def _artifact_path(self, stem: str, suffix: str) -> str:
- """Build artifact path under the session directory."""
- return os.path.join(
- self._ensure_session_dir(),
- f"{stem}_rank{self._rank()}{suffix}",
- )
-
- def _write_text_artifact(self, name: str, content: str) -> str:
- path = self._artifact_path(name, ".txt")
- with open(path, "w") as f:
- f.write(content)
- self._artifact_paths[name] = path
- return path
-
- def _has_cuda_like_activity(self) -> bool:
- return any(a in self._activities for a in ("CUDA", "MUSA"))
-
- def _get_time_sort_key(self) -> str:
- if self._has_cuda_like_activity():
- return "self_cuda_time_total"
- return "self_cpu_time_total"
def _on_trace_ready(self, prof) -> None:
"""Custom trace handler: export chrome trace with omni naming."""
- rank = self._rank()
+ rank = self.local_rank
+ filename = self._trace_filename or f"{self._worker_name}_{int(time.time())}"
+ # If filename already contains a directory, use as-is (e.g. from
+ # diffusion engine which builds full path). Otherwise join with trace_dir.
+ if os.path.dirname(filename):
+ json_file = f"{filename}_rank{rank}.json"
+ else:
+ json_file = os.path.join(self._trace_dir, f"{filename}_rank{rank}.json")
- json_file = self._artifact_path("trace", ".json")
+ os.makedirs(os.path.dirname(json_file), exist_ok=True)
try:
prof.export_chrome_trace(json_file)
@@ -192,211 +143,18 @@ def _on_trace_ready(self, prof) -> None:
else:
self._trace_path = json_file
- self._artifact_paths["trace"] = self._trace_path
-
except Exception as e:
logger.warning("[Rank %s] Failed to export trace: %s", rank, e)
- def _try_enable_memory_history(self) -> None:
- """Enable backend-specific memory history for snapshot analysis."""
- if not self.profiler_config.torch_profiler_with_memory:
- return
-
- backend_name, memory_module = self._resolve_memory_history_backend()
- if backend_name is None or memory_module is None:
- return
-
- record_memory_history = getattr(memory_module, "_record_memory_history", None)
- if record_memory_history is None:
- logger.info(
- "[Rank %s] %s memory history is not supported on this platform",
- self._rank(),
- backend_name,
- )
- return
-
- try:
- record_memory_history(
- enabled="all",
- context="all",
- stacks="python",
- max_entries=100000,
- clear_history=True,
- )
- self._memory_history_enabled = True
- self._memory_history_backend = backend_name
- self._memory_history_module = memory_module
- logger.info("[Rank %s] %s memory history enabled", self._rank(), backend_name)
- except Exception as e:
- logger.warning(
- "[Rank %s] Failed to enable %s memory history: %s",
- self._rank(),
- backend_name,
- e,
- )
-
- def _try_dump_memory_snapshot(self) -> None:
- """Dump a backend-specific memory snapshot into the session directory."""
- if not self._memory_history_enabled:
- return
-
- try:
- if self._memory_history_module is None or self._memory_history_backend is None:
- return
-
- dump_snapshot = getattr(self._memory_history_module, "_dump_snapshot", None)
- if dump_snapshot is None:
- logger.info(
- "[Rank %s] %s memory snapshot is not supported on this platform",
- self._rank(),
- self._memory_history_backend,
- )
- return
-
- snapshot_file = self._artifact_path("memory_snapshot", ".pickle")
- dump_snapshot(snapshot_file)
- self._artifact_paths["memory_snapshot"] = snapshot_file
- logger.info(
- "[Rank %s] %s memory snapshot dumped to %s",
- self._rank(),
- self._memory_history_backend,
- snapshot_file,
- )
- except Exception as e:
- logger.warning(
- "[Rank %s] Failed to dump %s memory snapshot: %s",
- self._rank(),
- self._memory_history_backend,
- e,
- )
- finally:
- try:
- if self._memory_history_module is not None:
- disable_memory_history = getattr(
- self._memory_history_module,
- "_record_memory_history",
- None,
- )
- if disable_memory_history is not None:
- disable_memory_history(enabled=None)
- except Exception:
- pass
- self._memory_history_enabled = False
- self._memory_history_backend = None
- self._memory_history_module = None
-
- def _resolve_memory_history_backend(self) -> tuple[str | None, Any]:
- """Resolve the memory backend that supports history/snapshot APIs."""
- backend_specs = [
- ("CUDA", self._has_cuda_like_activity(), getattr(torch, "cuda", None)),
- ("NPU", "NPU" in self._activities, getattr(torch, "npu", None)),
- ("XPU", "XPU" in self._activities, getattr(torch, "xpu", None)),
- ("MUSA", "MUSA" in self._activities, getattr(torch, "musa", None)),
- ]
-
- for backend_name, enabled, device_module in backend_specs:
- if not enabled or device_module is None:
- continue
-
- is_available = getattr(device_module, "is_available", None)
- if callable(is_available) and not is_available():
- continue
-
- memory_module = getattr(device_module, "memory", None)
- if memory_module is not None:
- return backend_name, memory_module
-
- return None, None
-
- def _safe_get(self, obj, name: str, default=None):
- return getattr(obj, name, default)
-
- def _event_list_to_rows(self, event_list) -> list[dict]:
- rows = []
- for evt in event_list:
- row = {
- "name": self._safe_get(evt, "key", None) or self._safe_get(evt, "name", None),
- "count": self._safe_get(evt, "count", None),
- "device_type": self._safe_get(evt, "device_type", None),
- "node_id": self._safe_get(evt, "node_id", None),
- "self_cpu_time_total_us": self._safe_get(evt, "self_cpu_time_total", None),
- "cpu_time_total_us": self._safe_get(evt, "cpu_time_total", None),
- "self_cuda_time_total_us": self._safe_get(evt, "self_cuda_time_total", None),
- "cuda_time_total_us": self._safe_get(evt, "cuda_time_total", None),
- "self_xpu_time_total_us": self._safe_get(evt, "self_xpu_time_total", None),
- "xpu_time_total_us": self._safe_get(evt, "xpu_time_total", None),
- "self_cpu_memory_usage_bytes": self._safe_get(evt, "self_cpu_memory_usage", None),
- "cpu_memory_usage_bytes": self._safe_get(evt, "cpu_memory_usage", None),
- "self_cuda_memory_usage_bytes": self._safe_get(evt, "self_cuda_memory_usage", None),
- "cuda_memory_usage_bytes": self._safe_get(evt, "cuda_memory_usage", None),
- "self_xpu_memory_usage_bytes": self._safe_get(evt, "self_xpu_memory_usage", None),
- "xpu_memory_usage_bytes": self._safe_get(evt, "xpu_memory_usage", None),
- "input_shapes": str(self._safe_get(evt, "input_shapes", None)),
- "stack": "\n".join(self._safe_get(evt, "stack", []) or []),
- "overload_name": self._safe_get(evt, "overload_name", None),
- "is_async": self._safe_get(evt, "is_async", None),
- "is_legacy": self._safe_get(evt, "is_legacy", None),
- }
- rows.append(row)
- return rows
-
- def _write_excel_artifact(self, name: str, sheets: dict[str, list[dict]]) -> str:
- path = self._artifact_path(name, ".xlsx")
-
- try:
- import pandas as pd
- except Exception as e:
- logger.warning(
- "[Rank %s] pandas not available, skip Excel export: %s",
- self._rank(),
- e,
- )
- return path
-
- with pd.ExcelWriter(path, engine="openpyxl") as writer:
- for sheet_name, rows in sheets.items():
- df = pd.DataFrame(rows)
-
- safe_sheet_name = sheet_name if sheet_name else "Sheet1"
-
- df.to_excel(
- writer,
- sheet_name=safe_sheet_name,
- index=False,
- freeze_panes=(1, 0),
- )
-
- ws = writer.sheets[safe_sheet_name]
- ws.auto_filter.ref = ws.dimensions
-
- for col_cells in ws.columns:
- max_len = 0
- col_letter = col_cells[0].column_letter
- for cell in col_cells[:200]:
- try:
- val = "" if cell.value is None else str(cell.value)
- max_len = max(max_len, len(val))
- except Exception:
- pass
- ws.column_dimensions[col_letter].width = min(max(max_len + 2, 12), 80)
-
- self._artifact_paths[name] = path
- return path
-
@override
def _start(self) -> None:
- self._ensure_session_dir()
- self._try_enable_memory_history()
self.profiler.start()
@override
def _stop(self) -> None:
"""Stop profiler, export trace via on_trace_ready, and dump table."""
self.profiler.stop()
- try:
- self._on_stop_hook()
- finally:
- self._try_dump_memory_snapshot()
+ self._on_stop_hook()
def _on_stop_hook(self) -> None:
"""Hook called after profiler.stop().
@@ -405,68 +163,6 @@ def _on_stop_hook(self) -> None:
Base implementation handles CUDA time total dump.
"""
rank = self.local_rank
- sort_key = self._get_time_sort_key()
-
- excel_sheets: dict[str, list[dict]] = {}
-
- # 1) Summary op table
- summary_events = self.profiler.key_averages()
- excel_sheets["summary"] = self._event_list_to_rows(summary_events)
-
- # 2) Shape-grouped op table
- if self.profiler_config.torch_profiler_record_shapes:
- try:
- shape_events = self.profiler.key_averages(
- group_by_input_shape=True,
- )
- excel_sheets["by_shape"] = self._event_list_to_rows(shape_events)
- except Exception as e:
- logger.warning(
- "[Rank %s] Failed to export shape-grouped op table: %s",
- rank,
- e,
- )
-
- # 3) Stack-grouped op table
- if self.profiler_config.torch_profiler_with_stack:
- try:
- stack_events = self.profiler.key_averages(
- group_by_stack_n=8,
- )
- excel_sheets["by_stack"] = self._event_list_to_rows(stack_events)
- except Exception as e:
- logger.warning(
- "[Rank %s] Failed to export stack-grouped op table: %s",
- rank,
- e,
- )
-
- # 4) Export stack files
- try:
- cpu_stack_file = self._artifact_path("stacks_cpu", ".txt")
- self.profiler.export_stacks(
- cpu_stack_file,
- metric="self_cpu_time_total",
- )
- self._artifact_paths["stacks_cpu"] = cpu_stack_file
- except Exception as e:
- logger.warning("[Rank %s] export_stacks(cpu) failed: %s", rank, e)
-
- if self._has_cuda_like_activity():
- try:
- cuda_stack_file = self._artifact_path("stacks_cuda", ".txt")
- self.profiler.export_stacks(
- cuda_stack_file,
- metric="self_cuda_time_total",
- )
- self._artifact_paths["stacks_cuda"] = cuda_stack_file
- except Exception as e:
- logger.warning("[Rank %s] export_stacks(cuda) failed: %s", rank, e)
-
- try:
- self._table_path = self._write_excel_artifact("ops", excel_sheets)
- except Exception as e:
- logger.warning("[Rank %s] Failed to export Excel workbook: %s", rank, e)
if self.profiler_config.torch_profiler_dump_cuda_time_total:
profiler_dir = self.profiler_config.torch_profiler_dir
@@ -474,13 +170,10 @@ def _on_stop_hook(self) -> None:
table = self.profiler.key_averages().table(sort_by=sort_key)
if not _is_uri_path(profiler_dir):
- table_file = os.path.join(
- self._ensure_session_dir(),
- f"profiler_out_{rank}.txt",
- )
+ table_file = os.path.join(profiler_dir, f"profiler_out_{rank}.txt")
with open(table_file, "w") as f:
print(table, file=f)
- self._artifact_paths["profiler_out"] = table_file
+ self._table_path = table_file
if rank == 0:
print(table)
@@ -497,7 +190,6 @@ def get_results(self) -> dict:
return {
"trace": self._trace_path,
"table": self._table_path,
- **self._artifact_paths,
}
diff --git a/vllm_omni/quantization/__init__.py b/vllm_omni/quantization/__init__.py
index a709dc74d85..b50fe8607ff 100644
--- a/vllm_omni/quantization/__init__.py
+++ b/vllm_omni/quantization/__init__.py
@@ -13,7 +13,6 @@
from .component_config import ComponentQuantizationConfig
from .factory import SUPPORTED_QUANTIZATION_METHODS, build_quant_config
-from .inc_config import OmniINCConfig
# DiffusionGGUFConfig is NOT imported here to avoid pulling in
# GGUF -> fused_moe -> pynvml at module load time.
@@ -21,6 +20,5 @@
__all__ = [
"build_quant_config",
"ComponentQuantizationConfig",
- "OmniINCConfig",
"SUPPORTED_QUANTIZATION_METHODS",
]
diff --git a/vllm_omni/quantization/component_config.py b/vllm_omni/quantization/component_config.py
index f9286079be1..7986da8850b 100644
--- a/vllm_omni/quantization/component_config.py
+++ b/vllm_omni/quantization/component_config.py
@@ -23,31 +23,6 @@
)
-# Pre-quantized checkpoints (modelopt FP8/FP4/MXFP8) only quantize the
-# Thinker LM. Vision and audio encoder weights remain in BF16 with no
-# corresponding scale tensors in the checkpoint.
-PRE_QUANTIZED_METHODS: frozenset[str] = frozenset({"modelopt", "modelopt_fp4", "modelopt_mxfp8"})
-
-
-def resolve_encoder_quant_config(
- quant_config: QuantizationConfig | None,
-) -> QuantizationConfig | None:
- """Resolve quantization config for vision / audio encoders.
-
- Returns *None* for pre-quantized methods so that FP8 kernels are never
- applied to BF16 encoder weights (which lack scale tensors). All other
- configs — including ``ComponentQuantizationConfig`` and ``None`` — are
- returned as-is so the caller can handle them.
- """
- if (
- quant_config is not None
- and not isinstance(quant_config, ComponentQuantizationConfig)
- and quant_config.get_name() in PRE_QUANTIZED_METHODS
- ):
- return None
- return quant_config
-
-
class ComponentQuantizationConfig(QuantizationConfig):
"""Routes quantization to different configs by layer prefix."""
diff --git a/vllm_omni/quantization/factory.py b/vllm_omni/quantization/factory.py
index 5af47a1f7db..f85589d69bb 100644
--- a/vllm_omni/quantization/factory.py
+++ b/vllm_omni/quantization/factory.py
@@ -43,16 +43,16 @@ def _build_int8(**kw: Any) -> QuantizationConfig:
def _build_inc(**kw: Any) -> QuantizationConfig:
"""Lazy import for INC/AutoRound config with checkpoint kwarg normalization."""
- from .inc_config import OmniINCConfig
+ from vllm.model_executor.layers.quantization.inc import INCConfig
# Map checkpoint key 'bits' to INCConfig's 'weight_bits'
if "bits" in kw and "weight_bits" not in kw:
kw["weight_bits"] = kw.pop("bits")
# Filter to only valid INCConfig params
- valid = set(inspect.signature(OmniINCConfig.__init__).parameters) - {"self"}
+ valid = set(inspect.signature(INCConfig.__init__).parameters) - {"self"}
filtered = {k: v for k, v in kw.items() if k in valid}
- return OmniINCConfig(**filtered)
+ return INCConfig(**filtered)
_OVERRIDES: dict[str, Callable[..., QuantizationConfig]] = {
diff --git a/vllm_omni/quantization/inc_config.py b/vllm_omni/quantization/inc_config.py
deleted file mode 100644
index fb4851ccabd..00000000000
--- a/vllm_omni/quantization/inc_config.py
+++ /dev/null
@@ -1,140 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Extended INC/AutoRound config for multi-stage omni models."""
-
-from __future__ import annotations
-
-from os.path import commonprefix
-from typing import TYPE_CHECKING, Any
-
-from vllm.model_executor.layers.quantization.inc import INCConfig
-from vllm.model_executor.models.utils import WeightsMapper
-
-if TYPE_CHECKING:
- from vllm.model_executor.layers.quantization.base_config import (
- QuantizationConfig,
- )
-
-_REGEX_SPECIAL_CHARS = frozenset(r"*+?^$()[]{}|\\")
-
-
-def _stage_prefix(prefix_map: dict[str, str | None]) -> str:
- """Derive the container/stage prefix from mapper source keys."""
- cp = commonprefix(list(prefix_map.keys()))
- dot = cp.rfind(".")
- return cp[: dot + 1] if dot >= 0 else ""
-
-
-def _map_with_stage_prefix(
- items: list[str],
- prefix_map: dict[str, str | None],
- stage: str,
-) -> list[str]:
- """Apply *prefix_map* to each item and prepend *stage* to mapped items."""
- sorted_keys = sorted(prefix_map, key=len, reverse=True)
- result: list[str] = []
- for item in items:
- new_item = item
- for orig in sorted_keys:
- if item.startswith(orig):
- new_val = prefix_map[orig] or ""
- new_item = stage + new_val + item[len(orig) :]
- break
- result.append(new_item)
- return result
-
-
-class OmniINCConfig(INCConfig):
- """INCConfig extended with multi-stage prefix remapping."""
-
- # ------------------------------------------------------------------
- # Core integration: called by vLLM's configure_quant_config()
- # ------------------------------------------------------------------
-
- def apply_vllm_mapper(self, hf_to_vllm_mapper: WeightsMapper) -> None:
- """Remap HF checkpoint names to vLLM runtime prefixes."""
- prefix_map = getattr(hf_to_vllm_mapper, "orig_to_new_prefix", None) or {}
- stage = _stage_prefix(prefix_map) if prefix_map else ""
-
- # -- Normalize CSV string -----------------------------------------
- if isinstance(self.block_name_to_quantize, str):
- self.block_name_to_quantize = [b.strip() for b in self.block_name_to_quantize.split(",") if b.strip()]
-
- # -- block_name_to_quantize ----------------------------------------
- if self.block_name_to_quantize is not None:
- if prefix_map and stage:
- self.block_name_to_quantize = _map_with_stage_prefix(
- self.block_name_to_quantize,
- prefix_map,
- stage,
- )
- else:
- self.block_name_to_quantize = hf_to_vllm_mapper.apply_list(self.block_name_to_quantize)
-
- # -- extra_config --------------------------------------------------
- if self.extra_config is not None and prefix_map:
- new_extra: dict[str, Any] = {}
- sorted_keys = sorted(prefix_map, key=len, reverse=True)
-
- # Build escaped-dot map for regex pattern keys
- escaped_map: dict[str, str] = {}
- for orig, new in prefix_map.items():
- escaped_map[orig.replace(".", r"\.")] = (new or "").replace(".", r"\.")
- escaped_sorted = sorted(escaped_map, key=len, reverse=True)
-
- for key, val in self.extra_config.items():
- is_regex = any(c in _REGEX_SPECIAL_CHARS for c in key)
- if is_regex:
- # Regex keys: escaped-dot substring replacement.
- # re.search matches anywhere so no stage prefix needed.
- new_key = key
- for esc_orig in escaped_sorted:
- if esc_orig in new_key:
- new_key = new_key.replace(
- esc_orig,
- escaped_map[esc_orig],
- 1,
- )
- break
- else:
- # Plain keys: prefix replacement + stage prefix.
- new_key = key
- for orig in sorted_keys:
- if key.startswith(orig):
- new_val = prefix_map[orig] or ""
- new_key = stage + new_val + key[len(orig) :]
- break
- new_extra[new_key] = val
- self.extra_config = new_extra
- elif self.extra_config is not None:
- self.extra_config = hf_to_vllm_mapper.apply_dict(self.extra_config)
-
- # ------------------------------------------------------------------
- # Upgrading a vanilla INCConfig created by vLLM
- # ------------------------------------------------------------------
-
- @classmethod
- def from_inc_config(cls, inc: INCConfig) -> OmniINCConfig:
- """Promote a vanilla :class:`INCConfig` to :class:`OmniINCConfig`.
-
- Copies all attributes so that the new instance is a drop-in
- replacement.
- """
- omni = object.__new__(cls)
- omni.__dict__.update(inc.__dict__)
- return omni
-
- @classmethod
- def maybe_upgrade(cls, quant_config: QuantizationConfig | None) -> QuantizationConfig | None:
- """Upgrade *quant_config* to :class:`OmniINCConfig` if applicable.
-
- Returns the original config unchanged when it is not an INC
- config or is already an :class:`OmniINCConfig`.
- """
- if quant_config is None:
- return None
- if isinstance(quant_config, cls):
- return quant_config
- if isinstance(quant_config, INCConfig):
- return cls.from_inc_config(quant_config)
- return quant_config
diff --git a/vllm_omni/request.py b/vllm_omni/request.py
index 48cbf9b31d7..3ec325316fd 100644
--- a/vllm_omni/request.py
+++ b/vllm_omni/request.py
@@ -1,11 +1,8 @@
from collections.abc import Callable
-from dataclasses import dataclass
from typing import TYPE_CHECKING
import numpy as np
import torch
-from vllm.multimodal.inputs import MultiModalFeatureSpec
-from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
if TYPE_CHECKING:
@@ -95,34 +92,3 @@ def from_engine_core_request(
resumable=request.resumable,
reasoning_ended=request.reasoning_ended,
)
-
-
-@dataclass
-class OmniStreamingUpdate:
- """
- Override: add additional information
- Lightweight data for streaming session continuation.
-
- Contains only the fields needed to update an existing streaming session
- with new input data.
- """
-
- mm_features: list[MultiModalFeatureSpec] | None
- prompt_token_ids: list[int] | None
- max_tokens: int
- arrival_time: float
- sampling_params: SamplingParams | None
- additional_information: AdditionalInformationPayload | None = None
-
- @classmethod
- def from_request(cls, request: "Request") -> "OmniStreamingUpdate | None":
- if not request.resumable:
- return None
- return cls(
- mm_features=request.mm_features,
- prompt_token_ids=request.prompt_token_ids,
- max_tokens=request.max_tokens,
- arrival_time=request.arrival_time,
- sampling_params=request.sampling_params,
- additional_information=request.additional_information,
- )
diff --git a/vllm_omni/transformers_utils/configs/__init__.py b/vllm_omni/transformers_utils/configs/__init__.py
index cb9b8418a50..59b23f91490 100644
--- a/vllm_omni/transformers_utils/configs/__init__.py
+++ b/vllm_omni/transformers_utils/configs/__init__.py
@@ -17,14 +17,6 @@
"FishSpeechConfig": "vllm_omni.transformers_utils.configs.fish_speech",
"FishSpeechSlowARConfig": "vllm_omni.transformers_utils.configs.fish_speech",
"FishSpeechFastARConfig": "vllm_omni.transformers_utils.configs.fish_speech",
- "VoxCPMConfig": "vllm_omni.transformers_utils.configs.voxcpm",
- "VoxCPM2Config": "vllm_omni.transformers_utils.configs.voxcpm2",
- "VoxtralTTSConfig": "vllm_omni.transformers_utils.configs.voxtral_tts",
- "BailingMoeV2Config": "vllm_omni.transformers_utils.configs.ming_flash_omni",
- "BailingMM2Config": "vllm_omni.transformers_utils.configs.ming_flash_omni",
- "MingFlashOmniConfig": "vllm_omni.transformers_utils.configs.ming_flash_omni",
- "Qwen3VLMoeVisionConfig": "vllm_omni.transformers_utils.configs.ming_flash_omni",
- "WhisperEncoderConfig": "vllm_omni.transformers_utils.configs.ming_flash_omni",
}
__all__ = [
@@ -35,14 +27,6 @@
"FishSpeechConfig",
"FishSpeechSlowARConfig",
"FishSpeechFastARConfig",
- "VoxCPMConfig",
- "VoxCPM2Config",
- "VoxtralTTSConfig",
- "BailingMoeV2Config",
- "BailingMM2Config",
- "MingFlashOmniConfig",
- "Qwen3VLMoeVisionConfig",
- "WhisperEncoderConfig",
]
@@ -63,7 +47,3 @@ def __dir__():
# run as soon as `vllm_omni.transformers_utils.configs` is imported.
from vllm_omni.transformers_utils.configs import fish_speech as _fish_speech # noqa: F401, E402
from vllm_omni.transformers_utils.configs import mammoth_moda2 as _mammoth_moda2 # noqa: F401, E402
-from vllm_omni.transformers_utils.configs import ming_flash_omni as _ming_flash_omni # noqa: F401, E402
-from vllm_omni.transformers_utils.configs import voxcpm as _voxcpm # noqa: F401, E402
-from vllm_omni.transformers_utils.configs import voxcpm2 as _voxcpm2 # noqa: F401, E402
-from vllm_omni.transformers_utils.configs import voxtral_tts as _voxtral_tts # noqa: F401, E402
diff --git a/vllm_omni/transformers_utils/configs/ming_flash_omni.py b/vllm_omni/transformers_utils/configs/ming_flash_omni.py
deleted file mode 100644
index 408b208682f..00000000000
--- a/vllm_omni/transformers_utils/configs/ming_flash_omni.py
+++ /dev/null
@@ -1,352 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-# Copyright 2024 ANT Group and the HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Configuration for Ming-flash-omni-2.0 model"""
-
-import os
-from typing import Any, ClassVar
-
-from transformers import AutoConfig, AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast
-from transformers.utils import logging
-
-logger = logging.get_logger(__name__)
-
-
-class BailingMoeV2Config(PretrainedConfig):
- model_type = "bailing_moe_v2"
-
- def __init__(
- self,
- vocab_size=30592,
- hidden_size=1024,
- intermediate_size=None,
- num_hidden_layers=24,
- num_attention_heads=16,
- num_key_value_heads=0,
- hidden_act="silu",
- use_qkv_bias=False,
- use_qk_norm=False,
- use_bias=True,
- rms_norm_eps=1e-05,
- norm_head=False,
- tie_word_embeddings=False,
- embedding_dropout=0.0,
- attention_dropout=0.0,
- output_dropout=0.0,
- initializer_range=0.02,
- max_position_embeddings=16384,
- rope_theta=10000.0,
- use_cache=True,
- use_sliding_window=False,
- sliding_window=81920,
- max_window_layers=28,
- rope_scaling=None,
- mrope_section=None,
- pad_token_id=126081,
- num_experts=16,
- num_shared_experts=1,
- num_experts_per_tok=2,
- n_group=8,
- topk_group=4,
- routed_scaling_factor=2.5,
- moe_intermediate_size=None,
- first_k_dense_replace=0,
- head_dim=None,
- output_router_logits=False,
- partial_rotary_factor=0.5,
- router_type="topN",
- _attn_implementation="flash_attention_2",
- use_interleaved_frame_timestamp=True,
- # Multimodal token IDs
- image_patch_token=157157,
- video_patch_token=157175,
- audio_patch_token=157168,
- image_start_token=157158,
- video_start_token=157160,
- audio_start_token=157169,
- image_end_token=157159,
- video_end_token=157161,
- audio_end_token=157170,
- # Position encoding parameters
- spatial_merge_size=2,
- tokens_per_second=2,
- **kwargs,
- ):
- self.num_hidden_layers = num_hidden_layers
- self.vocab_size = vocab_size
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.num_attention_heads = num_attention_heads
- self.num_key_value_heads = num_key_value_heads
- self.hidden_act = hidden_act
- self.use_qkv_bias = use_qkv_bias
- self.use_bias = use_bias
- self.norm_head = norm_head
- self.rms_norm_eps = rms_norm_eps
- self.embedding_dropout = embedding_dropout
- self.attention_dropout = attention_dropout
- self.output_dropout = output_dropout
- self.initializer_range = initializer_range
- self.max_position_embeddings = max_position_embeddings
- self.rope_theta = rope_theta
- self.use_cache = use_cache
- self.use_sliding_window = use_sliding_window
- self.sliding_window = sliding_window
- self.max_window_layers = max_window_layers
- self.head_dim = head_dim or self.hidden_size // self.num_attention_heads
- self.use_qk_norm = use_qk_norm # arg unused; QK norm is always applied
-
- # By default, match the value of `mrope_section`
- # to `apply_3d_rotary_pos_emb` in Ming's repo:
- # https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/modeling_bailing_moe_v2.py
- if mrope_section is None:
- mrope_section = (rope_scaling or {}).get("mrope_section", [8, 12, 12])
- # Ensure mrope_section is stored inside rope_scaling
- if rope_scaling is not None and isinstance(rope_scaling, dict):
- rope_scaling = dict(rope_scaling)
- rope_scaling.setdefault("mrope_section", mrope_section)
- self.rope_scaling = rope_scaling
-
- # NOTE: Expose rope_parameters["mrope_section"]
- # This refers to the pattern used for GLM-Image in vllm_omni/patch.py
- rope_type = (rope_scaling or {}).get("type", (rope_scaling or {}).get("rope_type", ""))
- if rope_type in ("video_rope", "3D", "mrope"):
- self.rope_parameters = {"mrope_section": mrope_section}
- else:
- self.rope_parameters = None
-
- # MoE configs
- self.num_experts = num_experts
- self.num_shared_experts = num_shared_experts
- self.num_experts_per_tok = num_experts_per_tok
- self.n_group = n_group
- self.topk_group = topk_group
- self.moe_intermediate_size = moe_intermediate_size
- self.first_k_dense_replace = first_k_dense_replace
- self.output_router_logits = output_router_logits
- self.routed_scaling_factor = routed_scaling_factor
- self.partial_rotary_factor = partial_rotary_factor
- self.router_type = router_type
- self.use_interleaved_frame_timestamp = use_interleaved_frame_timestamp
- self._attn_implementation = _attn_implementation
-
- # Multimodal token IDs and position encoding
- self.image_patch_token = image_patch_token
- self.video_patch_token = video_patch_token
- self.audio_patch_token = audio_patch_token
- self.image_start_token = image_start_token
- self.video_start_token = video_start_token
- self.audio_start_token = audio_start_token
- self.image_end_token = image_end_token
- self.video_end_token = video_end_token
- self.audio_end_token = audio_end_token
- self.spatial_merge_size = spatial_merge_size
- self.tokens_per_second = tokens_per_second
-
- super().__init__(pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
-
-
-class Qwen3VLMoeVisionConfig(PretrainedConfig):
- """Configuration class for Qwen3 MoE Vision Transformer"""
-
- model_type = "qwen3_moe_vit"
-
- def __init__(
- self,
- depth=27,
- hidden_size=1152,
- hidden_act="gelu_pytorch_tanh",
- intermediate_size=4304,
- num_heads=16,
- in_channels=3,
- patch_size=16,
- spatial_merge_size=2,
- temporal_patch_size=2,
- out_hidden_size=3584,
- num_position_embeddings=2304,
- deepstack_visual_indexes=[8, 16, 24],
- initializer_range=0.02,
- **kwargs,
- ):
- super().__init__(**kwargs)
-
- self.depth = depth
- self.hidden_size = hidden_size
- self.hidden_act = hidden_act
- self.intermediate_size = intermediate_size
- self.num_heads = num_heads
- self.in_channels = in_channels
- self.patch_size = patch_size
- self.spatial_merge_size = spatial_merge_size
- self.temporal_patch_size = temporal_patch_size
- self.out_hidden_size = out_hidden_size
- self.num_position_embeddings = num_position_embeddings
- self.initializer_range = initializer_range
- self.deepstack_visual_indexes = deepstack_visual_indexes
-
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs) -> "PretrainedConfig":
- cls._set_token_in_kwargs(kwargs)
-
- config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
-
- if "vision_config" in config_dict:
- config_dict = config_dict["vision_config"]
-
- if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
- logger.warning(
- f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
- f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
- )
-
- return cls.from_dict(config_dict, **kwargs)
-
-
-class WhisperEncoderConfig(PretrainedConfig):
- """Configuration class for Whisper audio encoder"""
-
- model_type = "whisper_encoder"
-
- def __init__(
- self,
- whisper_encoder_config: dict[str, Any] | None = None,
- ds_kernel_size=3,
- ds_stride=2,
- norm_query_embeds=True,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.whisper_encoder_config = whisper_encoder_config or {}
- self.ds_kernel_size = ds_kernel_size
- self.ds_stride = ds_stride
- self.norm_query_embeds = norm_query_embeds
-
-
-class BailingMM2Config(PretrainedConfig):
- model_type = "bailingmm_moe_v2_lite"
- is_composition = True
- sub_configs: ClassVar = {"llm_config": AutoConfig}
-
- def __init__(
- self,
- mlp_depth=1,
- llm_config: BailingMoeV2Config | None = None,
- vision_config: Qwen3VLMoeVisionConfig | None = None,
- audio_config: WhisperEncoderConfig | None = None,
- **kwargs,
- ):
- self.audio_config = WhisperEncoderConfig(**audio_config) if isinstance(audio_config, dict) else audio_config
- self.vision_config = (
- Qwen3VLMoeVisionConfig(**vision_config) if isinstance(vision_config, dict) else vision_config
- )
- self.llm_config = BailingMoeV2Config(**llm_config) if isinstance(llm_config, dict) else llm_config
- self.mlp_depth = mlp_depth
- super().__init__(**kwargs)
-
- def get_text_config(self, decoder: bool = False) -> PretrainedConfig: # noqa: ARG002
- return self.llm_config
-
-
-class MingFlashOmniTalkerConfig(PretrainedConfig):
- """Configuration class for Ming-flash-omni-2.0 talker (TTS) stage.
-
- The talker uses a Qwen2 LLM backbone with CFM (Conditional Flow Matching)
- via a DiT diffusion transformer, plus an Aggregator that maps generated
- audio latents back to the LLM embedding space for autoregressive generation.
- """
-
- model_type = "ming_flash_omni_talker"
-
- def __init__(
- self,
- llm_config: dict[str, Any] | None = None,
- flowmodel: dict[str, Any] | None = None,
- aggregator: dict[str, Any] | None = None,
- steps: int = 10,
- patch_size: int = 4,
- history_patch_size: int = 32,
- latent_dim: int = 64,
- cfg_strength: float = 2.0,
- audio_vae_path: str | None = None,
- campplus_model: str | None = None,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.llm_config = llm_config
- self.flowmodel = flowmodel or {}
- self.aggregator = aggregator or {}
- self.steps = steps
- self.patch_size = patch_size
- self.history_patch_size = history_patch_size
- self.latent_dim = latent_dim
- self.cfg_strength = cfg_strength
- self.audio_vae_path = audio_vae_path
- self.campplus_model = campplus_model
-
- def get_text_config(self, decoder: bool = False) -> PretrainedConfig: # noqa: ARG002
- if isinstance(self.llm_config, dict):
- return PretrainedConfig.from_dict(self.llm_config)
- return self.llm_config
-
-
-class MingFlashOmniConfig(PretrainedConfig):
- """Configuration class for unified Ming-flash-omni-2.0 model"""
-
- model_type = "ming_flash_omni"
- is_composition = True
- sub_configs: ClassVar = {
- "thinker_config": BailingMM2Config,
- "talker_config": MingFlashOmniTalkerConfig,
- }
-
- def __init__(
- self,
- thinker_config: BailingMM2Config | dict[str, Any] | None = None,
- image_gen_config: dict[str, Any] | None = None,
- talker_config: MingFlashOmniTalkerConfig | dict[str, Any] | None = None,
- **kwargs,
- ):
- super().__init__(**kwargs)
-
- if isinstance(thinker_config, dict):
- self.thinker_config = BailingMM2Config(**thinker_config)
- else:
- self.thinker_config = thinker_config or BailingMM2Config()
-
- # Image generation config (for future implementation)
- self.image_gen_config = image_gen_config
-
- # Talker config
- if isinstance(talker_config, dict):
- self.talker_config = MingFlashOmniTalkerConfig(**talker_config)
- else:
- self.talker_config = talker_config
-
- def get_text_config(self, decoder: bool = False) -> PretrainedConfig: # noqa: ARG002
- return self.thinker_config.get_text_config()
-
-
-# Register model_type -> config class for AutoConfig
-AutoConfig.register(BailingMoeV2Config.model_type, BailingMoeV2Config)
-AutoConfig.register(BailingMM2Config.model_type, BailingMM2Config)
-AutoConfig.register(MingFlashOmniTalkerConfig.model_type, MingFlashOmniTalkerConfig)
-AutoConfig.register(MingFlashOmniConfig.model_type, MingFlashOmniConfig)
-
-# Register tokenizer mapping for composition configs so that
-# AutoTokenizer.from_pretrained can resolve the tokenizer class
-AutoTokenizer.register(BailingMM2Config, fast_tokenizer_class=PreTrainedTokenizerFast)
-AutoTokenizer.register(MingFlashOmniTalkerConfig, fast_tokenizer_class=PreTrainedTokenizerFast)
-AutoTokenizer.register(MingFlashOmniConfig, fast_tokenizer_class=PreTrainedTokenizerFast)
diff --git a/vllm_omni/transformers_utils/configs/voxcpm.py b/vllm_omni/transformers_utils/configs/voxcpm.py
deleted file mode 100644
index 02678389150..00000000000
--- a/vllm_omni/transformers_utils/configs/voxcpm.py
+++ /dev/null
@@ -1,68 +0,0 @@
-from transformers import AutoConfig
-from transformers.configuration_utils import PretrainedConfig
-
-
-class VoxCPMConfig(PretrainedConfig):
- model_type = "voxcpm"
- keys_to_ignore_at_inference = ["past_key_values"]
-
- def __init__(
- self,
- bos_token_id: int = 1,
- eos_token_id: int = 2,
- vocab_size: int = 32000,
- hidden_size: int = 1024,
- intermediate_size: int = 4096,
- max_position_embeddings: int = 4096,
- num_attention_heads: int = 16,
- num_hidden_layers: int = 24,
- num_key_value_heads: int = 16,
- rms_norm_eps: float = 1e-6,
- rope_theta: float = 10000.0,
- rope_scaling: dict | None = None,
- lm_config: dict | None = None,
- encoder_config: dict | None = None,
- dit_config: dict | None = None,
- audio_vae_config: dict | None = None,
- patch_size: int = 2,
- feat_dim: int = 64,
- residual_lm_num_layers: int = 6,
- scalar_quantization_latent_dim: int = 256,
- scalar_quantization_scale: int = 9,
- max_length: int = 4096,
- device: str = "cuda",
- dtype: str = "bfloat16",
- dit_mean_mode: bool = False,
- **kwargs,
- ):
- super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
- self.vocab_size = vocab_size
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.max_position_embeddings = max_position_embeddings
- self.num_attention_heads = num_attention_heads
- self.num_hidden_layers = num_hidden_layers
- self.num_key_value_heads = num_key_value_heads
- self.rms_norm_eps = rms_norm_eps
- self.rope_theta = rope_theta
- self.rope_scaling = rope_scaling
-
- self.lm_config = lm_config or {}
- self.encoder_config = encoder_config or {}
- self.dit_config = dit_config or {}
- self.audio_vae_config = audio_vae_config
-
- self.patch_size = patch_size
- self.feat_dim = feat_dim
- self.residual_lm_num_layers = residual_lm_num_layers
- self.scalar_quantization_latent_dim = scalar_quantization_latent_dim
- self.scalar_quantization_scale = scalar_quantization_scale
- self.max_length = max_length
- self.device = device
- self.dtype = dtype
- self.dit_mean_mode = dit_mean_mode
-
-
-AutoConfig.register("voxcpm", VoxCPMConfig)
-
-__all__ = ["VoxCPMConfig"]
diff --git a/vllm_omni/transformers_utils/configs/voxcpm2.py b/vllm_omni/transformers_utils/configs/voxcpm2.py
deleted file mode 100644
index c625284bd67..00000000000
--- a/vllm_omni/transformers_utils/configs/voxcpm2.py
+++ /dev/null
@@ -1,153 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import math
-
-from transformers import AutoConfig
-from transformers.configuration_utils import PretrainedConfig
-from transformers.modeling_rope_utils import rope_config_validation
-
-
-class VoxCPM2Config(PretrainedConfig):
- """Configuration for VoxCPM2 native AR integration.
-
- The HuggingFace checkpoint stores LM parameters inside a nested
- ``lm_config`` dict. This class hoists them to top-level attributes
- so that vllm's ``MiniCPMModel`` can consume them directly.
-
- vllm's MiniCPM **always** applies muP scaling (scale_emb, scale_depth,
- dim_model_base). VoxCPM2 was trained with ``use_mup=false``, so we
- neutralise the scalings:
- * ``scale_emb = 1.0``
- * ``scale_depth = sqrt(num_hidden_layers)`` (cancels the division)
- * ``dim_model_base = hidden_size`` (makes scale_width = 1.0)
- """
-
- model_type = "voxcpm2"
- keys_to_ignore_at_inference = ["past_key_values"]
-
- def __init__(
- self,
- # -- top-level VoxCPM2 params --
- architecture: str = "voxcpm2",
- lm_config: dict | None = None,
- encoder_config: dict | None = None,
- dit_config: dict | None = None,
- audio_vae_config: dict | None = None,
- patch_size: int = 4,
- feat_dim: int = 64,
- residual_lm_num_layers: int = 8,
- residual_lm_no_rope: bool = True,
- scalar_quantization_latent_dim: int = 512,
- scalar_quantization_scale: int = 9,
- max_length: int = 8192,
- device: str = "cuda",
- dtype: str = "bfloat16",
- # -- LM defaults (overridden by lm_config if present) --
- bos_token_id: int = 1,
- eos_token_id: int = 2,
- vocab_size: int = 73448,
- hidden_size: int = 2048,
- intermediate_size: int = 6144,
- max_position_embeddings: int = 32768,
- num_attention_heads: int = 16,
- num_hidden_layers: int = 28,
- num_key_value_heads: int = 2,
- rms_norm_eps: float = 1e-5,
- rope_theta: float = 10000.0,
- rope_scaling: dict | None = None,
- **kwargs,
- ):
- super().__init__(
- bos_token_id=bos_token_id,
- eos_token_id=eos_token_id,
- **kwargs,
- )
- self.architecture = architecture
-
- # -- VoxCPM2-specific fields --
- self.lm_config = lm_config or {}
- self.encoder_config = encoder_config or {}
- self.dit_config = dit_config or {}
- self.audio_vae_config = audio_vae_config or {}
- self.patch_size = patch_size
- self.feat_dim = feat_dim
- self.residual_lm_num_layers = residual_lm_num_layers
- self.residual_lm_no_rope = residual_lm_no_rope
- self.scalar_quantization_latent_dim = scalar_quantization_latent_dim
- self.scalar_quantization_scale = scalar_quantization_scale
- self.max_length = max_length
- self.device = device
- self.dtype = dtype
-
- # -- Hoist LM parameters to top-level for MiniCPMModel --
- lm = self.lm_config
- self.vocab_size = lm.get("vocab_size", vocab_size)
- self.hidden_size = lm.get("hidden_size", hidden_size)
- self.intermediate_size = lm.get("intermediate_size", intermediate_size)
- self.max_position_embeddings = lm.get("max_position_embeddings", max_position_embeddings)
- self.num_attention_heads = lm.get("num_attention_heads", num_attention_heads)
- self.num_hidden_layers = lm.get("num_hidden_layers", num_hidden_layers)
- self.num_key_value_heads = lm.get("num_key_value_heads", num_key_value_heads)
- self.rms_norm_eps = lm.get("rms_norm_eps", rms_norm_eps)
- self.rope_theta = lm.get("rope_theta", rope_theta)
-
- # MiniCPM-specific: kv_channels overrides head_dim when set.
- kv_channels = lm.get("kv_channels")
- if kv_channels is not None:
- self.head_dim = kv_channels
- else:
- self.head_dim = self.hidden_size // self.num_attention_heads
-
- # MiniCPM requires hidden_act; VoxCPM2 uses SiLU.
- self.hidden_act = "silu"
- self.hidden_act_param = 0.0
- self.tie_word_embeddings = False
- self.num_experts = 0
-
- # -- muP scaling --
- # Native VoxCPM2 MiniCPM gates scale_depth behind use_mup:
- # use_mup=True → residual += h * (scale_depth / sqrt(N))
- # use_mup=False → residual += h (plain add, no scaling)
- # But vllm's MiniCPMModel ALWAYS applies scale_depth / sqrt(N).
- # Native applies scale_emb externally; vllm applies it in embed_input_ids.
- use_mup = lm.get("use_mup", False)
- self.scale_emb = lm.get("scale_emb", 1.0)
- if use_mup:
- self.scale_depth = lm.get("scale_depth", 1.0)
- self.dim_model_base = lm.get("dim_model_base", self.hidden_size)
- else:
- # Neutralize: scale_depth/sqrt(N) = 1.0, scale_width = 1.0
- self.scale_depth = math.sqrt(self.num_hidden_layers)
- self.dim_model_base = self.hidden_size
-
- # -- RoPE scaling (longrope) --
- raw_rope = lm.get("rope_scaling", rope_scaling)
- if raw_rope is not None:
- self.rope_scaling = dict(raw_rope)
- # HF expects "rope_type" not "type"
- if "type" in self.rope_scaling:
- self.rope_scaling["rope_type"] = self.rope_scaling.pop("type")
- # longrope requires "factor" (used by HF validation)
- if "factor" not in self.rope_scaling:
- self.rope_scaling["factor"] = 1.0
- rope_config_validation(self)
-
- # vllm's MiniCPMAttention reads config.rope_parameters (a dict
- # with rope_type, theta, scaling factors, etc.). HF transformers
- # only auto-computes this for known model_types; for custom
- # types we must build it manually.
- if not getattr(self, "rope_parameters", None):
- rp = dict(self.rope_scaling)
- rp["rope_theta"] = self.rope_theta
- self.rope_parameters = rp
- else:
- self.rope_scaling = None
-
- def get_text_config(self, **kwargs):
- """Return self as the text config — LM attributes are top-level."""
- return self
-
-
-AutoConfig.register("voxcpm2", VoxCPM2Config)
-
-__all__ = ["VoxCPM2Config"]
diff --git a/vllm_omni/transformers_utils/configs/voxtral_tts.py b/vllm_omni/transformers_utils/configs/voxtral_tts.py
deleted file mode 100644
index 19af7b02011..00000000000
--- a/vllm_omni/transformers_utils/configs/voxtral_tts.py
+++ /dev/null
@@ -1,36 +0,0 @@
-from __future__ import annotations
-
-from typing import Any
-
-from transformers import AutoConfig, PretrainedConfig
-
-
-class VoxtralTTSConfig(PretrainedConfig):
- """HuggingFace-style config for Voxtral TTS models."""
-
- model_type = "voxtral_tts"
-
- def __init__(
- self,
- text_config: PretrainedConfig | dict | None = None,
- audio_config: dict[str, Any] | None = None,
- **kwargs: Any,
- ) -> None:
- super().__init__(**kwargs)
-
- if isinstance(text_config, PretrainedConfig):
- self.text_config = text_config
- elif isinstance(text_config, dict):
- self.text_config = PretrainedConfig.from_dict(text_config)
- else:
- self.text_config = PretrainedConfig()
-
- self.audio_config = audio_config or {}
-
- def get_text_config(self, **kwargs: Any) -> PretrainedConfig:
- return self.text_config
-
-
-AutoConfig.register("voxtral_tts", VoxtralTTSConfig)
-
-__all__ = ["VoxtralTTSConfig"]
diff --git a/vllm_omni/transformers_utils/parsers/__init__.py b/vllm_omni/transformers_utils/parsers/__init__.py
deleted file mode 100644
index eed3d3f7de9..00000000000
--- a/vllm_omni/transformers_utils/parsers/__init__.py
+++ /dev/null
@@ -1,29 +0,0 @@
-"""Custom vLLM config parsers for vllm-omni."""
-
-from __future__ import annotations
-
-import importlib
-
-_CLASS_TO_MODULE: dict[str, str] = {
- "VoxtralTTSConfigParser": "vllm_omni.transformers_utils.parsers.voxtral_tts",
-}
-
-__all__ = ["VoxtralTTSConfigParser"]
-
-
-def __getattr__(name: str):
- if name in _CLASS_TO_MODULE:
- module_name = _CLASS_TO_MODULE[name]
- module = importlib.import_module(module_name)
- return getattr(module, name)
-
- raise AttributeError(f"module 'vllm_omni.transformers_utils.parsers' has no attribute {name!r}")
-
-
-def __dir__():
- return sorted(list(__all__))
-
-
-# Eagerly import parser modules so their registry side-effects run as soon as
-# `vllm_omni.transformers_utils.parsers` is imported.
-from vllm_omni.transformers_utils.parsers import voxtral_tts as _voxtral_tts # noqa: F401, E402
diff --git a/vllm_omni/transformers_utils/parsers/voxtral_tts.py b/vllm_omni/transformers_utils/parsers/voxtral_tts.py
deleted file mode 100644
index d4669258ad4..00000000000
--- a/vllm_omni/transformers_utils/parsers/voxtral_tts.py
+++ /dev/null
@@ -1,106 +0,0 @@
-from __future__ import annotations
-
-from pathlib import Path
-from typing import Any
-
-from transformers import PretrainedConfig
-from vllm.logger import init_logger
-from vllm.transformers_utils.config import (
- _CONFIG_FORMAT_TO_CONFIG_PARSER,
- MistralConfigParser,
- _download_mistral_config_file,
-)
-
-from vllm_omni.transformers_utils.configs.voxtral_tts import VoxtralTTSConfig
-
-logger = init_logger(__name__)
-
-_VOXTRAL_TTS_ARCHS = frozenset({"VoxtralTTSForConditionalGeneration"})
-_VOXTRAL_TTS_MODEL_TYPE = "voxtral_tts"
-
-
-def _is_voxtral_tts_params(config_dict: dict) -> bool:
- """Return True if the Mistral params.json describes a Voxtral-TTS model"""
- if config_dict.get("model_type") == _VOXTRAL_TTS_MODEL_TYPE:
- return True
- architectures = set(config_dict.get("architectures") or [])
- return bool(architectures & _VOXTRAL_TTS_ARCHS)
-
-
-def _remap_voxtral_tts_audio_args(config_dict: dict) -> dict:
- encoder_args = config_dict["multimodal"].pop("audio_model_args")
- audio_tokenizer_args = config_dict["multimodal"].pop("audio_tokenizer_args", None)
- if encoder_args is None:
- return {}
-
- acoustic_args = encoder_args.get("acoustic_transformer_args", {})
- if acoustic_args.get("n_decoding_steps") is None:
- logger.warning(
- "n_decoding_steps not provided in acoustic_transformer_args, defaulting to 7. "
- "Please add 'n_decoding_steps' to params.json under acoustic_transformer_args."
- )
- acoustic_args["n_decoding_steps"] = 7
-
- return {
- "sampling_rate": encoder_args["audio_encoding_args"]["sampling_rate"],
- "codec_args": audio_tokenizer_args,
- "audio_model_args": encoder_args,
- "speaker_id": (audio_tokenizer_args or {}).get("voice", {}),
- }
-
-
-def _parse_voxtral_tts(config_dict: dict) -> tuple[dict, PretrainedConfig]:
- from vllm.transformers_utils.configs.mistral import (
- _remap_general_mistral_args,
- _remap_mistral_quantization_args,
- )
-
- audio_config: dict[str, Any] = {}
- if (config_dict.get("multimodal") or {}).get("audio_model_args"):
- audio_config = _remap_voxtral_tts_audio_args(config_dict)
-
- text_config = {k: v for k, v in config_dict.items() if k != "multimodal"}
- text_config = _remap_general_mistral_args(text_config)
- if text_config.get("quantization"):
- text_config = _remap_mistral_quantization_args(text_config)
- text_config.setdefault("architectures", ["MistralForCausalLM"])
-
- config = VoxtralTTSConfig(
- text_config=PretrainedConfig.from_dict(text_config),
- audio_config=audio_config,
- architectures=config_dict.get("architectures", ["VoxtralTTSForConditionalGeneration"]),
- )
- return config_dict, config
-
-
-class VoxtralTTSConfigParser(MistralConfigParser):
- """Mistral parser that also recognizes Voxtral-TTS checkpoints."""
-
- def parse(
- self,
- model: str | Path,
- trust_remote_code: bool,
- revision: str | None = None,
- code_revision: str | None = None,
- **kwargs: Any,
- ) -> tuple[dict, PretrainedConfig]:
- config_dict = _download_mistral_config_file(model, revision)
-
- if _is_voxtral_tts_params(config_dict):
- return _parse_voxtral_tts(config_dict)
-
- return super().parse(
- model,
- trust_remote_code,
- revision=revision,
- code_revision=code_revision,
- **kwargs,
- )
-
-
-# Replace the default "mistral" slot directly.
-# Any non-Voxtral-TTS Mistral ckpt still goes through
-# the upstream code path via super().parse().
-_CONFIG_FORMAT_TO_CONFIG_PARSER["mistral"] = VoxtralTTSConfigParser
-
-__all__ = ["VoxtralTTSConfigParser"]
diff --git a/vllm_omni/transformers_utils/processors/__init__.py b/vllm_omni/transformers_utils/processors/__init__.py
deleted file mode 100644
index 52ca6575397..00000000000
--- a/vllm_omni/transformers_utils/processors/__init__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-
-from vllm_omni.transformers_utils.processors.ming import (
- MingFlashOmniProcessor,
- MingWhisperFeatureExtractor,
-)
-
-__all__ = [
- "MingFlashOmniProcessor",
- "MingWhisperFeatureExtractor",
-]
diff --git a/vllm_omni/transformers_utils/processors/ming.py b/vllm_omni/transformers_utils/processors/ming.py
deleted file mode 100644
index 7f414b7268c..00000000000
--- a/vllm_omni/transformers_utils/processors/ming.py
+++ /dev/null
@@ -1,430 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# Copyright 2025 The vLLM-Omni team.
-# Copyright 2024 ANT Group and the HuggingFace Inc. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Any
-
-import numpy as np
-import torch
-from transformers import AutoFeatureExtractor, AutoProcessor
-from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from transformers.processing_utils import ProcessorMixin
-from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
-
-DEFAULT_IMAGE_PATCH_TOKEN = ""
-DEFAULT_IM_START_TOKEN = ""
-DEFAULT_IM_END_TOKEN = " "
-DEFAULT_VID_START_TOKEN = ""
-DEFAULT_VID_END_TOKEN = " "
-DEFAULT_FRAME_PATCH_TOKEN = ""
-
-DEFAULT_AUDIO_PATCH_TOKEN = ""
-DEFAULT_AU_START_TOKEN = ""
-DEFAULT_AU_END_TOKEN = " "
-
-# High-level placeholders used in user prompts
-PLACEHOLDER_IMAGE_TOKEN_IN_TEXT = ""
-PLACEHOLDER_VIDEO_TOKEN_IN_TEXT = ""
-PLACEHOLDER_AUDIO_TOKEN_IN_TEXT = ""
-
-# Chat template constants
-USER_PREFIX = "HUMAN "
-ASSISTANT_PREFIX = "ASSISTANT "
-SYSTEM_PROMPT_NOTHINK = "SYSTEM 你是一个友好的AI助手。\n\ndetailed thinking off"
-SYSTEM_PROMPT_THINK = "SYSTEM 你是一个友好的AI助手。\n\ndetailed thinking on"
-
-
-_NORM_FACTOR_FOR_DTYPE = {
- torch.int8: 2**7,
- torch.int16: 2**15,
- torch.int32: 2**31,
- torch.int64: 2**63,
- torch.float32: 1,
- torch.float64: 1,
-}
-
-
-def _normalize_audio_tensor(
- waveform: torch.Tensor,
- sample_rate: int,
- target_sample_rate: int = 16000,
-) -> torch.Tensor:
- """Normalize waveform to float32, mono, and optionally resample."""
- norm_factor = _NORM_FACTOR_FOR_DTYPE.get(waveform.dtype, 1)
- waveform = waveform.to(torch.float32) / norm_factor
-
- # Remove channel dimension
- while len(waveform.shape) > 1:
- waveform = waveform[0]
-
- # Resample if needed
- if sample_rate != target_sample_rate:
- import torchaudio
-
- resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
- waveform = resampler(waveform.unsqueeze(0)).squeeze(0)
-
- return waveform
-
-
-class MingWhisperFeatureExtractor(FeatureExtractionMixin):
- """Whisper log-mel feature extractor for Ming-flash-omni-2.0.
-
- Produces audio_feats in the time-first packed format.
-
- Adapted from Ming's WhisperAudioEncoder
- https://github.com/inclusionAI/Ming/blob/070dc3c13f95d97952ab7d22030df0c9e28a5122/modeling_whisper_encoder.py
- and HF transformers WhisperFeatureExtractor
- https://github.com/huggingface/transformers/blob/f842abaca95a7dbf3fc6e16122e7409109bc1431/src/transformers/models/whisper/feature_extraction_whisper.py#L33
- """
-
- model_input_names = ["audio_feats", "audio_feats_lengths"]
-
- def __init__(self, feature_size: int = 128, sampling_rate: int = 16000, **kwargs):
- # feature_size == n_mels; stored so to_dict() serialises it correctly.
- self.feature_size = feature_size
- self.sampling_rate = sampling_rate
- super().__init__(**kwargs)
-
- @property
- def n_mels(self) -> int:
- return self.feature_size
-
- def __call__(
- self,
- audios: tuple | list,
- return_tensors: str | None = None,
- **kwargs,
- ) -> BatchFeature:
- """Preprocess audio(s) into Whisper log-mel spectrograms"""
- import whisper
-
- if not isinstance(audios, list):
- audios = [audios]
-
- audio_feat_list = []
- for waveform, sr in audios:
- if isinstance(waveform, np.ndarray):
- waveform = torch.from_numpy(waveform)
- waveform = _normalize_audio_tensor(waveform, sr, target_sample_rate=self.sampling_rate)
- mel = whisper.log_mel_spectrogram(waveform, n_mels=self.n_mels)
- audio_feat_list.append(mel.transpose(0, 1)) # [T, n_mels]
-
- audio_feats_lengths = torch.tensor([[feat.shape[0] for feat in audio_feat_list]], dtype=torch.long)
- # Two stride-2 convolutions in series:
- # 1. WhisperAudioEncoder conv2: kernel=3, stride=2, padding=1
- # (conv1 has stride=1 and does not change T)
- # 2. AudioProjector Conv1d: kernel=3, stride=2, padding=1
- # Combined: T → ((T-1)//2 + 1 - 1)//2 + 1
- # See also: AudioProjector.compute_output_length()
- encoder_feats_lengths = ((audio_feats_lengths - 3 + 2 * 1) // 2 + 1 - 3 + 2 * 1) // 2 + 1
- audio_feats = torch.cat(audio_feat_list, dim=0).unsqueeze(0) # [1, T_total, n_mels]
-
- data = {
- # [1, T_total, n_mels], all audio clips concatenated
- "audio_feats": audio_feats.numpy(),
- # [1, n_audios], actual frame count
- "audio_feats_lengths": audio_feats_lengths.numpy(),
- # [1, n_audios]
- "encoder_feats_lengths": encoder_feats_lengths,
- }
- return BatchFeature(data=data, tensor_type=return_tensors)
-
-
-class MingFlashOmniProcessor(ProcessorMixin):
- """Top-level multimodal processor for Ming-flash-omni 2.0.
-
- Adapted from Ming's BailingMM2Processor
- https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/processing_bailingmm2.py
-
- Subprocessors include:
- - Qwen2VLImageProcessor (image/video)
- - MingWhisperFeatureExtractor (modified audio processor using Whisper's log-mel spectrogram)
- """
-
- attributes = ["image_processor", "audio_processor", "tokenizer"]
- image_processor_class = "AutoImageProcessor"
- audio_processor_class = "AutoFeatureExtractor"
- tokenizer_class = "AutoTokenizer"
-
- def __init__(
- self,
- image_processor=None,
- audio_processor=None,
- tokenizer=None,
- merge_size: int = 2,
- **kwargs,
- ):
- # Enforce that all sub-processors exist
- # Keep None defaults in the signature for HF ProcessorMixin compatibility
- if image_processor is None:
- raise ValueError("MingFlashOmniProcessor requires `image_processor`.")
- if audio_processor is None:
- raise ValueError("MingFlashOmniProcessor requires `audio_processor`.")
- if tokenizer is None:
- raise ValueError("MingFlashOmniProcessor requires `tokenizer`.")
-
- self.spatial_merge_size = merge_size
- self.image_token = PLACEHOLDER_IMAGE_TOKEN_IN_TEXT
- self.video_token = PLACEHOLDER_VIDEO_TOKEN_IN_TEXT
- self.audio_token = PLACEHOLDER_AUDIO_TOKEN_IN_TEXT
- super().__init__(
- image_processor=image_processor,
- audio_processor=audio_processor,
- tokenizer=tokenizer,
- )
-
- # Fall back to the tokenizer's own chat_template.
- if self.chat_template is None:
- self.chat_template = getattr(tokenizer, "chat_template", None)
-
- def __call__(
- self,
- text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput],
- images: Any | None = None,
- videos: Any | None = None,
- audios: tuple[np.ndarray, int] | list[tuple[np.ndarray, int]] | None = None,
- **kwargs,
- ) -> BatchFeature:
- # This should always be parallel implementations that mirror
- # `_get_prompt_updates` logic in Ming processor, and vice versa.
- # Ensure text is a list
- if isinstance(text, str):
- text = [text]
- elif not isinstance(text, list):
- raise ValueError("text must be a string or list of strings")
-
- data: dict[str, Any] = {}
-
- if images is not None:
- image_outputs = self.image_processor(
- images=images,
- videos=None,
- return_tensors="pt",
- **kwargs.get("images_kwargs", {}),
- )
- data.update(image_outputs)
- if "image_grid_thw" in image_outputs:
- text = self._expand_image_tokens(text, image_outputs["image_grid_thw"])
-
- if videos is not None:
- video_outputs = self.image_processor(
- images=None,
- videos=videos,
- return_tensors="pt",
- **kwargs.get("videos_kwargs", {}),
- )
- if "pixel_values" in video_outputs:
- video_outputs["pixel_values_videos"] = video_outputs.pop("pixel_values")
- if "image_grid_thw" in video_outputs:
- video_outputs["video_grid_thw"] = video_outputs.pop("image_grid_thw")
- data.update(video_outputs)
- if "video_grid_thw" in video_outputs:
- text = self._expand_video_tokens(text, video_outputs["video_grid_thw"])
-
- if audios is not None:
- audio_outputs = self.audio_processor(
- audios,
- return_tensors="pt",
- **kwargs.get("audio_kwargs", {}),
- )
- data.update(audio_outputs)
- if "encoder_feats_lengths" in audio_outputs:
- text = self._expand_audio_tokens(text, audio_outputs["encoder_feats_lengths"])
-
- text_outputs = self.tokenizer(
- text,
- return_tensors="pt",
- **kwargs.get("text_kwargs", {}),
- )
- data.update(text_outputs)
- return BatchFeature(data=data)
-
- def _expand_image_tokens(
- self,
- text: list[str],
- image_grid_thw: torch.Tensor,
- special_token: str = PLACEHOLDER_IMAGE_TOKEN_IN_TEXT,
- ) -> list[str]:
- merge_size = self.spatial_merge_size
- num_patches_per_image = torch.prod(image_grid_thw, dim=1) // (merge_size**2)
- prompt_strings = []
- image_index = 0
- for sample in text:
- num_images = sample.count(special_token)
- if num_images > 0:
- for i in range(image_index, num_images + image_index):
- num_patches = int(num_patches_per_image[i].item())
- img_text = (
- DEFAULT_IM_START_TOKEN + (DEFAULT_IMAGE_PATCH_TOKEN * num_patches) + DEFAULT_IM_END_TOKEN + "\n"
- )
- sample = sample.replace(special_token, img_text, 1)
- image_index += num_images
- prompt_strings.append(sample)
- return prompt_strings
-
- def _expand_video_tokens(
- self,
- text: list[str],
- video_grid_thw: torch.Tensor,
- special_token: str = PLACEHOLDER_VIDEO_TOKEN_IN_TEXT,
- ) -> list[str]:
- merge_size = self.spatial_merge_size
- num_patches_per_video = torch.prod(video_grid_thw, dim=1) // (merge_size**2)
- prompt_strings = []
- video_index = 0
- for sample in text:
- num_videos = sample.count(special_token)
- if num_videos > 0:
- for i in range(video_index, num_videos + video_index):
- num_patches = int(num_patches_per_video[i].item())
- video_text = (
- DEFAULT_VID_START_TOKEN
- + (DEFAULT_FRAME_PATCH_TOKEN * num_patches)
- + DEFAULT_VID_END_TOKEN
- + "\n"
- )
- sample = sample.replace(special_token, video_text, 1)
- video_index += num_videos
- prompt_strings.append(sample)
- return prompt_strings
-
- def _expand_audio_tokens(
- self,
- text: list[str],
- encoder_feats_lengths: torch.Tensor,
- special_token: str = PLACEHOLDER_AUDIO_TOKEN_IN_TEXT,
- ) -> list[str]:
- prompt_strings = []
- for sample, lengths_tensor in zip(text, encoder_feats_lengths):
- for length in lengths_tensor:
- num_patches = int(length.item())
- if num_patches == 0:
- continue
- audio_text = DEFAULT_AU_START_TOKEN + (DEFAULT_AUDIO_PATCH_TOKEN * num_patches) + DEFAULT_AU_END_TOKEN
- if special_token in sample:
- sample = sample.replace(special_token, audio_text, 1)
- else:
- sample = sample + audio_text + "\n"
- prompt_strings.append(sample)
- return prompt_strings
-
- def apply_system_template(
- self,
- sys_prompt_exp: str | None = None,
- use_cot_system_prompt: bool = False,
- ) -> str:
- sys_prompt = SYSTEM_PROMPT_THINK if use_cot_system_prompt else SYSTEM_PROMPT_NOTHINK
- if sys_prompt_exp is not None:
- sys_prompt = sys_prompt.replace("你是一个友好的AI助手。", sys_prompt_exp)
- return sys_prompt
-
- def apply_chat_template(
- self,
- conversation: list[dict[str, Any]],
- sys_prompt_exp: str | None = None,
- use_cot_system_prompt: bool = False,
- **kwargs,
- ) -> str:
- eos = self.tokenizer.eos_token
- text = self.apply_system_template(sys_prompt_exp, use_cot_system_prompt) + eos
-
- for idx, message in enumerate(conversation):
- assert message["role"] in ["HUMAN", "ASSISTANT"], (
- f"Invalid role: {message['role']}. Must be 'HUMAN' or 'ASSISTANT'"
- )
- if idx == len(conversation) - 1:
- assert message["role"] == "HUMAN", "Last message must be from HUMAN"
-
- text += USER_PREFIX if message["role"] == "HUMAN" else ASSISTANT_PREFIX
-
- content = message["content"]
- if isinstance(content, str):
- # text-only
- text += content
- elif isinstance(content, list):
- # structured content with multimodal elements
- # Count existing placeholders from text items only
- image_placeholders = 0
- video_placeholders = 0
- audio_placeholders = 0
- for content_item in content:
- if content_item.get("type", "text") == "text":
- t = content_item.get("text", "")
- image_placeholders += t.count(PLACEHOLDER_IMAGE_TOKEN_IN_TEXT)
- video_placeholders += t.count(PLACEHOLDER_VIDEO_TOKEN_IN_TEXT)
- audio_placeholders += t.count(PLACEHOLDER_AUDIO_TOKEN_IN_TEXT)
-
- if video_placeholders > 1:
- raise ValueError("Video count must be at most 1 per message!")
-
- # Insert placeholders only for media items not already covered
- for content_item in content:
- content_type = content_item.get("type", "text")
-
- if content_type == "image":
- image_data = content_item.get("image")
- if image_data is not None:
- from PIL import Image as PILImage
-
- num_images = 1 if isinstance(image_data, (str, PILImage.Image)) else len(image_data)
- for _ in range(num_images):
- if image_placeholders > 0:
- image_placeholders -= 1
- else:
- text += PLACEHOLDER_IMAGE_TOKEN_IN_TEXT
-
- elif content_type == "video":
- if video_placeholders > 0:
- video_placeholders -= 1
- else:
- text += PLACEHOLDER_VIDEO_TOKEN_IN_TEXT
- elif content_type == "audio":
- audio_data = content_item.get("audio")
- if audio_data is not None:
- num_audios = 1 if isinstance(audio_data, str) else len(audio_data)
- for _ in range(num_audios):
- if audio_placeholders > 0:
- audio_placeholders -= 1
- else:
- text += PLACEHOLDER_AUDIO_TOKEN_IN_TEXT
-
- elif content_type == "text":
- text += content_item.get("text", "")
-
- # Add EOS token after each message except the last one
- text += eos
-
- text += ASSISTANT_PREFIX
- return text
-
- def batch_decode(self, *args, **kwargs):
- return self.tokenizer.batch_decode(*args, **kwargs)
-
- def decode(self, *args, **kwargs):
- return self.tokenizer.decode(*args, **kwargs)
-
- @property
- def model_input_names(self):
- names = (
- self.tokenizer.model_input_names
- + self.image_processor.model_input_names
- + self.audio_processor.model_input_names
- )
- return list(dict.fromkeys(names))
-
-
-AutoFeatureExtractor.register("MingWhisperFeatureExtractor", MingWhisperFeatureExtractor)
-AutoProcessor.register("MingFlashOmniProcessor", MingFlashOmniProcessor)
diff --git a/vllm_omni/utils/audio.py b/vllm_omni/utils/audio.py
deleted file mode 100644
index cc25c179471..00000000000
--- a/vllm_omni/utils/audio.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-"""Audio utility functions shared across models and entrypoints."""
-
-import numpy as np
-import torch
-from torchaudio.functional import melscale_fbanks
-
-
-def mel_filter_bank(
- sr: int,
- n_fft: int,
- n_mels: int,
- fmin: float = 0.0,
- fmax: float | None = None,
-) -> torch.Tensor:
- """Compute a mel filterbank matrix.
-
- Drop-in replacement for ``librosa.filters.mel`` using
- ``torchaudio.functional.melscale_fbanks``.
-
- Args:
- sr: Sample rate of the audio.
- n_fft: FFT window size.
- n_mels: Number of mel bands.
- fmin: Minimum frequency (Hz).
- fmax: Maximum frequency (Hz). Defaults to ``sr / 2``.
-
- Returns:
- Tensor of shape ``(n_mels, n_fft // 2 + 1)``.
- """
- if fmax is None:
- fmax = float(sr) / 2.0
- # Use mel_scale='slaney' and norm='slaney' to match librosa's
- # default behaviour (Slaney 1998 frequency mapping with area
- # normalization).
- return melscale_fbanks(
- n_freqs=n_fft // 2 + 1,
- f_min=float(fmin),
- f_max=float(fmax),
- n_mels=n_mels,
- sample_rate=sr,
- mel_scale="slaney",
- norm="slaney",
- ).T
-
-
-def peak_normalize(
- audio: np.ndarray,
- db_level: float = -6.0,
-) -> np.ndarray:
- """Normalize audio so peak amplitude reaches a target dB level.
-
- Drop-in replacement for ``sox.Transformer().norm(db_level=...)``.
-
- Args:
- audio: Input waveform as a 1-D numpy array.
- db_level: Target peak amplitude in dBFS.
-
- Returns:
- Normalized waveform with the same dtype as *audio*.
- """
- peak = np.abs(audio).max()
- if peak == 0:
- return audio
- target = 10.0 ** (db_level / 20.0)
- return audio * (target / peak)
diff --git a/vllm_omni/utils/mm_outputs.py b/vllm_omni/utils/mm_outputs.py
deleted file mode 100644
index 66d4e6ffe04..00000000000
--- a/vllm_omni/utils/mm_outputs.py
+++ /dev/null
@@ -1,93 +0,0 @@
-"""Utilities for handling multimodal outputs / building multimodal output
-payloads, most of which are shared by the prefix cache / no prefix cache path.
-"""
-
-import torch
-from vllm.logger import init_logger
-
-logger = init_logger(__name__)
-
-
-def build_mm_cpu(multimodal_outputs: dict) -> dict[str, object]:
- """Pre-copies multimodal tensor to CPU once (not per-request) to avoid
- redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU.
-
- In the case of prefix caching, the multimodal outputs provided will
- only contain the passthrough data.
-
- Args:
- multimodal_outputs: Multimodal dict mapping strings to objects.
- """
- # Pre-copy multimodal tensors to CPU once (not per-request) to avoid
- # redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU.
- mm_cpu: dict[str, object] = {}
- # Currently there are some cases where this is true at the
- # moment, which should be fixed.
- if not isinstance(multimodal_outputs, dict):
- logger.warning("Multimodal outputs are not a dict and will not be passed")
-
- if multimodal_outputs:
- for k, v in multimodal_outputs.items():
- if isinstance(v, torch.Tensor):
- mm_cpu[k] = v.detach().to("cpu").contiguous()
- elif isinstance(v, dict):
- sub_dict: dict[str, torch.Tensor] = {}
- for sk, sv in v.items():
- if isinstance(sv, torch.Tensor):
- sub_dict[str(sk)] = sv.detach().to("cpu").contiguous()
- if sub_dict:
- mm_cpu[k] = sub_dict
- elif isinstance(v, list) and len(v) > 0:
- cpu_list = []
- for elem in v:
- if isinstance(elem, torch.Tensor):
- cpu_list.append(elem.detach().to("cpu").contiguous())
- else:
- cpu_list.append(elem)
- mm_cpu[k] = cpu_list
- elif v is not None:
- mm_cpu[k] = v
- return mm_cpu
-
-
-def to_payload_element(
- element: object, idx: int, start: int, end: int, pass_lists_through: bool = False, seq_len: int | None = None
-):
- """Build an mm payload element corresponding to one request index
- from an element containing 0 or more CPU tensors.
-
- Args:
- element: The object to be added to the payload.
- idx: The index of the request.
- start: The start index corresponding to the request idx.
- end: The end index corresponding to the request idx.
- pass_lists_through: bool Whether or not lists should be treated as
- passthrough data; this should be False in normal cases, but True
- if we need to avoid splitting nonempty lists prior to calling
- postprocess, which is the case for prefix cache.
- seq_len: Optional sequence length (i.e., dim 0 of hidden states).
- This should be set to None in the prefix caching case, because
- the condition that would be executed here is the same as the
- criteria for being added to the multimodal outputs cache.
- """
- # Prefix cache won't hit this case because this is the condition
- # for being a mm_cache_key in the multimodal outputs tensor.
- if seq_len is not None and isinstance(element, torch.Tensor) and element.shape[0] == seq_len:
- return element[start:end].contiguous()
- # Every other case is shared between prefix cache (passthrough data)
- # and running a model without prefix caching.
- elif isinstance(element, dict):
- return {sk: sv[start:end].contiguous() for sk, sv in element.items()}
- elif isinstance(element, list):
- # For lists, clone tensors to avoid cross-request aliasing
- if pass_lists_through:
- return [elem.clone() if isinstance(elem, torch.Tensor) else elem for elem in element]
- element = element[idx] if idx < len(element) else element[0]
- if isinstance(element, torch.Tensor):
- element = element.clone()
- return element
- elif isinstance(element, torch.Tensor):
- # List-derived tensor payloads are request-invariant; clone to
- # avoid accidental cross-request aliasing on downstream mutation.
- return element.clone()
- return element
diff --git a/vllm_omni/version.py b/vllm_omni/version.py
index 296bebc8e20..e5f0b6b661d 100644
--- a/vllm_omni/version.py
+++ b/vllm_omni/version.py
@@ -5,12 +5,12 @@
and written to _version.py during package build.
"""
-import warnings
-
try:
# Import auto-generated version from _version.py (created by setuptools_scm)
from ._version import __version__, __version_tuple__
except ImportError as e:
+ import warnings
+
warnings.warn(
f"Failed to import version from _version.py: {e}\n"
"This typically happens in development mode before building.\n"
@@ -22,37 +22,4 @@
__version__ = "dev"
__version_tuple__ = (0, 0, "dev")
-
-def warn_if_misaligned_vllm_version():
- """Warn if vLLM and vllm-omni versions don't match (major.minor)."""
- # Import vllm lazily since import order may be sensitive with current monkeypatching,
- # but we want to check this before potentially breaking imports run.
- from vllm import __version__ as vllm_version
- from vllm import __version_tuple__ as vllm_version_tuple
-
- omni_ver: tuple[str | int, ...] = __version_tuple__[:2]
- vllm_ver: tuple[str | int, ...] = vllm_version_tuple[:2]
- # Skip if either version is dev (0, 0)
- if omni_ver == (0, 0) or vllm_ver == (0, 0):
- return
-
- # Compare major.minor
- if omni_ver != vllm_ver:
- warnings.warn(
- "vLLM and vLLM-Omni appear to have mismatched major/minor versions:\n"
- f" --> vLLM-Omni version {__version__}\n"
- f" --> vLLM version {vllm_version}\n"
- "This will likely cause compatibility issues.",
- RuntimeWarning,
- stacklevel=2,
- )
-
-
__all__ = ["__version__", "__version_tuple__"]
-
-# Run version check automatically when this module is imported
-try:
- warn_if_misaligned_vllm_version()
-except ModuleNotFoundError:
- # vLLM not installed (e.g., documentation builds)
- pass
diff --git a/vllm_omni/worker/base.py b/vllm_omni/worker/base.py
index 8bd9efc89c4..f7f5dbd1d8b 100644
--- a/vllm_omni/worker/base.py
+++ b/vllm_omni/worker/base.py
@@ -2,23 +2,15 @@
from __future__ import annotations
-import gc
import os
import time
-from contextlib import AbstractContextManager, nullcontext
import torch
from vllm.logger import init_logger
from vllm.utils.mem_utils import format_gib, memory_profiling
from vllm.v1.worker.gpu_worker import Worker as GPUWorker
-from vllm_omni.diffusion.data import (
- OmniACK,
- OmniSleepTask,
- OmniWakeTask,
-)
from vllm_omni.entrypoints.utils import detect_pid_host
-from vllm_omni.platforms import current_omni_platform
from vllm_omni.worker.gpu_memory_utils import (
get_process_gpu_memory,
is_process_scoped_memory_available,
@@ -38,13 +30,6 @@ class OmniGPUWorkerBase(GPUWorker):
for custom trace naming, background gzip, and trace path collection.
"""
- def load_model(self, *args, **kwargs):
- with self._maybe_get_memory_pool_context("weights"):
- res = super().load_model(*args, **kwargs)
- current_omni_platform.synchronize()
- gc.collect()
- return res
-
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -109,8 +94,10 @@ def determine_available_memory(self) -> int:
"""
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
self.model_runner.profile_run()
- if current_omni_platform.is_rocm():
- torch.cuda.synchronize()
+ logger.info(
+ "Using explicit kv_cache_memory_bytes: %s GiB",
+ format_gib(kv_cache_memory_bytes),
+ )
return kv_cache_memory_bytes
with memory_profiling(
@@ -167,141 +154,3 @@ def determine_available_memory(self) -> int:
)
return int(self.available_kv_cache_memory_bytes)
-
- # Provide memory pool context
- def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
- v1_config_enabled = False
- if hasattr(self, "vllm_config"):
- model_cfg = getattr(self.vllm_config, "model_config", None)
- v1_config_enabled = getattr(model_cfg, "enable_sleep_mode", False)
-
- is_sleep_enabled = v1_config_enabled or getattr(self.cache_config, "enable_sleep_mode", False)
- if is_sleep_enabled:
- current_omni_platform.synchronize()
- gc.collect()
- from vllm.device_allocator.cumem import CuMemAllocator
-
- allocator = CuMemAllocator.get_instance()
- logger.info(f"[LLM Worker {self.rank}] Sleep Mode ENABLED. Activating CuMem pool for tag: {tag}")
- return allocator.use_memory_pool(tag=tag)
- else:
- logger.warning(f"[LLM Worker {self.rank}] Sleep Mode DISABLED.")
- return nullcontext()
-
- def sleep(self, level: int = 1) -> bool:
- """
- Put the worker to sleep.
- Args:
- level: 1 (Offload weights to CPU), level: 2 (Total Discard).
- """
- from vllm.device_allocator.cumem import CuMemAllocator
-
- mem_before = current_omni_platform.get_current_memory_usage(self.device)
- offload_tags = ("weights",) if level == 1 else tuple()
- allocator = CuMemAllocator.get_instance()
- allocator.sleep(offload_tags=offload_tags)
- current_omni_platform.empty_cache()
- current_omni_platform.synchronize()
- mem_after = current_omni_platform.get_current_memory_usage(self.device)
- freed = max(0, mem_before - mem_after)
- remaining_gb = mem_after / 1024**3
- logger.info(
- f"[LLM Worker {self.rank}] Level {level} Sleep: Freed "
- f"{freed / 1024**3:.2f} GiB. {remaining_gb:.2f}GiB memory "
- "is still in use."
- )
- return True
-
- def wake_up(self, tags: list[str] | None = None) -> bool:
- "Physical video memory reloading logic"
- from vllm.device_allocator.cumem import CuMemAllocator
-
- allocator = CuMemAllocator.get_instance()
- allocator.wake_up(tags)
- current_omni_platform.synchronize()
- logger.info(f"[LLM Worker {self.rank}] Wake-up complete.")
- return True
-
- def handle_sleep_task(self, task: OmniSleepTask) -> OmniACK:
- "Handle deterministic Sleep command from the main process"
- try:
- if isinstance(task, dict):
- task = OmniSleepTask(**task)
- logger.info(f"[LLM Worker {self.rank}] Handshake Received: Task {task.task_id}, Level {task.level}")
- if task.level == 2:
- if hasattr(self.model_runner, "graph_runners"):
- self.model_runner.graph_runners.clear()
- logger.info(f"[LLM Worker {self.rank}] LLM CUDA Graphs cleared.")
- mem_before = current_omni_platform.get_current_memory_usage(self.device)
- self.sleep(level=task.level)
- mem_after = current_omni_platform.get_current_memory_usage(self.device)
- rank_freed = max(0, mem_before - mem_after)
- if torch.distributed.is_initialized():
- t_freed = torch.tensor([float(rank_freed)], device=self.device)
- torch.distributed.all_reduce(t_freed)
- total_freed = int(t_freed.item())
- torch.distributed.barrier()
- else:
- total_freed = rank_freed
- if self.rank != 0:
- return None
- current_stage_id = getattr(self.vllm_config.model_config, "stage_id", 0)
- ack = OmniACK(
- task_id=task.task_id,
- status="SUCCESS",
- stage_id=current_stage_id,
- rank=self.rank,
- freed_bytes=total_freed,
- metadata={
- "source": "omni_platform_audit",
- "total_freed_gib": f"{total_freed / 1024**3:.2f}",
- "rank_residual_gib": f"{mem_after / 1024**3:.2f}",
- },
- )
- if hasattr(self, "result_mq") and self.result_mq:
- self.result_mq.put(ack)
- logger.info(f"[LLM Worker {self.rank}] ACK emitted for Task {task.task_id}")
- return ack
- except Exception as e:
- logger.error(f"[LLM Worker {self.rank}] Sleep Task Failed: {e}", exc_info=True)
- if torch.distributed.is_initialized():
- try:
- torch.distributed.barrier()
- except Exception:
- pass
- return OmniACK(task_id=task.task_id, status="ERROR", error_msg=str(e))
-
- def handle_wake_task(self, task: OmniWakeTask) -> OmniACK:
- "Handle deterministic Wakeup command from the main process"
- try:
- if isinstance(task, dict):
- task = OmniWakeTask(**task)
- self.wake_up(tags=task.tags)
- if torch.distributed.is_initialized():
- torch.distributed.barrier()
- gc.collect()
- current_omni_platform.synchronize()
- usage_now = current_omni_platform.get_current_memory_usage(self.device)
- if self.rank != 0:
- return None
- current_stage_id = getattr(self.vllm_config.model_config, "stage_id", 0)
- ack = OmniACK(
- task_id=task.task_id,
- status="SUCCESS",
- stage_id=current_stage_id,
- rank=self.rank,
- metadata={"state": "WARM", "current_vram_gib": f"{usage_now / 1024**3:.2f}"},
- )
- if hasattr(self, "result_mq") and self.result_mq:
- self.result_mq.put(ack)
- logger.info(f"[LLM Worker {self.rank}] Wake-up ACK emitted.")
- return ack
- except Exception as e:
- logger.error(f"[LLM Worker {self.rank}] Wake-up Failed: {e}", exc_info=True)
- if torch.distributed.is_initialized():
- try:
- torch.distributed.barrier()
- except Exception:
- pass
- tid = task.task_id if hasattr(task, "task_id") else "unknown"
- return OmniACK(task_id=tid, status="ERROR", error_msg=str(e))
diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py
index 947b3164f3e..f1115ab4c62 100644
--- a/vllm_omni/worker/gpu_ar_model_runner.py
+++ b/vllm_omni/worker/gpu_ar_model_runner.py
@@ -6,9 +6,7 @@
from __future__ import annotations
-from contextlib import nullcontext
from copy import copy
-from dataclasses import replace
from typing import Any, NamedTuple
import numpy as np
@@ -37,12 +35,9 @@
from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
from vllm.v1.worker.utils import is_residual_scattered_for_sp
-from vllm_omni.data_entry_keys import flatten_payload
from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager
from vllm_omni.outputs import OmniModelRunnerOutput
-from vllm_omni.utils.mm_outputs import build_mm_cpu, to_payload_element
from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner
-from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin
logger = init_logger(__name__)
@@ -63,7 +58,7 @@ class ExecuteModelState(NamedTuple):
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None
-class GPUARModelRunner(OmniGPUModelRunner, OmniConnectorModelRunnerMixin):
+class GPUARModelRunner(OmniGPUModelRunner):
"""Autoregressive GPU model runner that returns hidden states per request.
Follows the v0.12 two-phase execute/sample flow from GPUModelRunner, and
@@ -94,172 +89,6 @@ def _make_buffer(self, *size, dtype, numpy=True):
with maybe_disable_pin_memory_for_ray(self, total_bytes):
return super()._make_buffer(*size, dtype=dtype, numpy=numpy)
- def _build_model_sampler_output_token_ids(self) -> list[list[int]]:
- """Build decoded-token history for custom model samplers.
-
- vLLM only populates sampling_metadata.output_token_ids when penalties or
- logits processors require it. CosyVoice3's custom RAS sampler also
- depends on this history, so we reconstruct it directly from the input
- batch for prefer_model_sampler models.
- """
- req_output_token_ids = getattr(self.input_batch, "req_output_token_ids", [])
- req_ids = list(getattr(self.input_batch, "req_ids", []))
- output_token_ids = [list(req_output_token_ids[idx] or []) for idx in range(len(req_ids))]
-
- sampled_token_ids_cpu = getattr(self.input_batch, "sampled_token_ids_cpu", None)
- async_copy_ready_event = getattr(self.input_batch, "async_copy_ready_event", None)
- prev_req_id_to_index = getattr(self.input_batch, "prev_req_id_to_index", None)
- if sampled_token_ids_cpu is None or not output_token_ids or prev_req_id_to_index is None:
- return output_token_ids
-
- sampled_token_ids: list[list[int]] | None = None
- for index, req_id in enumerate(req_ids):
- prev_index = prev_req_id_to_index.get(req_id)
- if prev_index is None:
- continue
- req_history = output_token_ids[index]
- if not req_history or req_history[-1] != -1:
- continue
- if sampled_token_ids is None:
- assert async_copy_ready_event is not None
- async_copy_ready_event.synchronize()
- sampled_token_ids = sampled_token_ids_cpu.tolist()
- new_ids = list(sampled_token_ids[prev_index])
- if not new_ids:
- continue
- num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
- first_placeholder = req_history.index(-1)
- num_placeholders = len(req_history) - first_placeholder
- num_to_replace = min(num_sampled_ids, num_placeholders)
- req_history[first_placeholder : first_placeholder + num_to_replace] = new_ids[:num_to_replace]
-
- return output_token_ids
-
- def _sampling_metadata_for_model_sampler(self, sampling_metadata):
- output_token_ids = self._build_model_sampler_output_token_ids()
- if output_token_ids == sampling_metadata.output_token_ids:
- return sampling_metadata
- return replace(sampling_metadata, output_token_ids=output_token_ids)
-
- def capture_model(self) -> int:
- result = super().capture_model()
- self._capture_talker_mtp_graphs()
- return result
-
- def _capture_talker_mtp_graphs(self) -> None:
- from vllm_omni.worker.gpu_model_runner import CUDAGraphWrapper
-
- if not self.has_talker_mtp or not isinstance(self.talker_mtp, CUDAGraphWrapper):
- return
-
- from vllm.compilation.monitor import set_cudagraph_capturing_enabled
- from vllm.distributed.parallel_state import graph_capture
-
- capture_sizes = self.compilation_config.cudagraph_capture_sizes
- num_warmups = self.compilation_config.cudagraph_num_of_warmups
- capture_sizes = sorted(capture_sizes, reverse=True)
- logger.info("Capturing talker_mtp graphs for sizes %s", capture_sizes)
-
- set_cudagraph_capturing_enabled(True)
- try:
- with torch.inference_mode(), graph_capture(device=self.device):
- for bsz in capture_sizes:
- _, batch_desc, _, _, _ = self._determine_batch_execution_and_padding(
- num_tokens=bsz,
- num_reqs=bsz,
- num_scheduled_tokens_np=np.ones(bsz, dtype=np.int32),
- max_num_scheduled_tokens=1,
- use_cascade_attn=False,
- )
- n = batch_desc.num_tokens
- ids = self.talker_mtp_input_ids.gpu[:n]
- emb = self.talker_mtp_inputs_embeds.gpu[:n]
- hid = self.last_talker_hidden.gpu[:n]
- ts = self.text_step.gpu[:n]
-
- for _ in range(num_warmups):
- with set_forward_context(
- None,
- self.vllm_config,
- cudagraph_runtime_mode=CUDAGraphMode.NONE,
- batch_descriptor=batch_desc,
- ):
- self.talker_mtp(ids, emb, hid, ts)
-
- with set_forward_context(
- None,
- self.vllm_config,
- cudagraph_runtime_mode=CUDAGraphMode.FULL,
- batch_descriptor=batch_desc,
- ):
- self.talker_mtp(ids, emb, hid, ts)
- torch.cuda.synchronize()
-
- logger.info("Captured talker_mtp graphs for %d sizes", len(capture_sizes))
- except RuntimeError as e:
- raise RuntimeError(
- f"talker_mtp graph capture failed for a model that declared talker_mtp_graph_safe=True: {e}"
- ) from e
- finally:
- set_cudagraph_capturing_enabled(False)
-
- def _maybe_update_prefix_cache(
- self,
- hidden_states: torch.Tensor,
- multimodal_outputs: dict,
- num_tokens_unpadded: int,
- num_tokens_padded: int,
- ):
- """If prefix caching is enabled and it's the last pipeline parallelism rank,
- retrieve the hidden states & multimodal outputs from the prefix cache based
- on our batch slot mappings.
- """
- # Cache hidden states if we've enabled hidden state prefix caching
- # unless this isn't the last pipeline parallelism rank.
- if self.omni_prefix_cache is not None and get_pp_group().is_last_rank:
- # If this happens, it generally means the model is not following the correct
- # interface yet and is therefore currently not compatible with prefix cache.
- if multimodal_outputs is not None and not isinstance(multimodal_outputs, dict):
- logger.warning_once(
- "prefix caching expects mm outputs to be a dict, but got %s",
- type(multimodal_outputs),
- )
-
- self.omni_prefix_cache.update_omni_tensor_prefix_cache(
- hidden_states=hidden_states,
- multimodal_outputs=multimodal_outputs,
- num_tokens_unpadded=num_tokens_unpadded,
- slot_mapping=self.input_batch.block_table[0].slot_mapping.cpu,
- num_tokens_padded=num_tokens_padded,
- )
-
- def _maybe_get_combined_prefix_cache_tensors(
- self,
- hidden_states: torch.Tensor,
- multimodal_outputs: dict,
- num_scheduled_tokens: dict[str, int],
- ) -> tuple[dict[str, torch.Tensor] | None, dict | None]:
- """If prefix caching is enabled, extract the merged hidden states and multimodal outputs for
- all requests in the batch (including those that aren't a hit on Prefix cache).
- """
- # Prior to applying the post-processing func, extract
- # the prefix cached hidden states and multimodal states.
- combined_hidden_states, combined_multimodal_outputs = None, None
- if self.omni_prefix_cache is not None:
- combined_hidden_states = self.omni_prefix_cache.get_merged_hidden_states(
- query_start_loc=self.query_start_loc.cpu,
- input_batch=self.input_batch,
- hidden_states=hidden_states,
- num_scheduled_tokens=num_scheduled_tokens,
- )
- combined_multimodal_outputs = self.omni_prefix_cache.get_merged_multimodal_states(
- query_start_loc=self.query_start_loc.cpu,
- input_batch=self.input_batch,
- multimodal_outputs=multimodal_outputs,
- num_scheduled_tokens=num_scheduled_tokens,
- )
- return combined_hidden_states, combined_multimodal_outputs
-
@torch.inference_mode()
def execute_model(
self,
@@ -321,49 +150,30 @@ def execute_model(
# Update persistent batch states.
deferred_state_corrections_fn = self._update_states(scheduler_output)
- # Notify model of finished requests for state cleanup
- if scheduler_output.finished_req_ids and hasattr(self.model, "on_requests_finished"):
- self.model.on_requests_finished(scheduler_output.finished_req_ids)
-
if has_ec_transfer() and not get_ec_transfer().is_consumer:
with self.maybe_get_ec_connector_output(
scheduler_output,
encoder_cache=self.encoder_cache,
) as ec_connector_output:
self._execute_mm_encoder(scheduler_output)
-
- kv_ids = self.kv_extracted_req_ids
- self.kv_extracted_req_ids = None
-
- output = make_empty_encoder_model_runner_output(scheduler_output)
- if kv_ids:
- output = copy(output)
- output.kv_extracted_req_ids = kv_ids
- return output
+ return make_empty_encoder_model_runner_output(scheduler_output)
if not num_scheduled_tokens:
if (
self.parallel_config.distributed_executor_backend == "external_launcher"
and self.parallel_config.data_parallel_size > 1
):
+ # this is a corner case when both external launcher
+ # and DP are enabled, num_scheduled_tokens could be
+ # 0, and has_unfinished_requests in the outer loop
+ # returns True. before returning early here we call
+ # dummy run to ensure coordinate_batch_across_dp
+ # is called into to avoid out of sync issues.
self._dummy_run(1)
-
- # Capture KV extraction results before early return;
- # sample_tokens() is skipped on this path so the IDs
- # would otherwise be silently overwritten next step.
- kv_ids = self.kv_extracted_req_ids
- self.kv_extracted_req_ids = None
-
if not has_kv_transfer_group():
- output = EMPTY_MODEL_RUNNER_OUTPUT
- else:
- output = self.kv_connector_no_forward(scheduler_output, self.vllm_config)
-
- if kv_ids:
- output = copy(output)
- output.kv_extracted_req_ids = kv_ids
-
- return output
+ # Return empty ModelRunnerOutput if no work to do.
+ return EMPTY_MODEL_RUNNER_OUTPUT
+ return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
if self.cache_config.kv_sharing_fast_prefill:
assert not self.num_prompt_logprobs, (
@@ -492,7 +302,6 @@ def execute_model(
# (wait_for_save + clear metadata) until after draft model runs.
defer_kv_connector_finalize = self.speculative_config is not None
with (
- nullcontext(),
set_forward_context(
attn_metadata,
self.vllm_config,
@@ -535,15 +344,6 @@ def execute_model(
hidden_states, multimodal_outputs = self.extract_multimodal_outputs(model_output)
- # Cache hidden states & multimodal outputs if we've enabled hidden state
- # prefix caching unless this isn't the last pipeline parallelism rank.
- self._maybe_update_prefix_cache(
- hidden_states=hidden_states,
- multimodal_outputs=multimodal_outputs,
- num_tokens_unpadded=num_tokens_unpadded,
- num_tokens_padded=num_tokens_padded,
- )
-
if not self.broadcast_pp_output:
# Common case.
if not get_pp_group().is_last_rank:
@@ -624,56 +424,6 @@ def execute_model(
return None
- def _sample(
- self,
- logits: torch.Tensor | None,
- spec_decode_metadata: Any,
- ):
- sampling_metadata = self.input_batch.sampling_metadata
- if spec_decode_metadata is None:
- model_sample = getattr(self.model, "sample", None)
- if logits is not None and callable(model_sample) and getattr(self.model, "prefer_model_sampler", False):
- # Apply logit bias (min_tokens, allowed_token_ids) before
- # the custom model sampler — the standard GPU sampler does
- # this internally, but prefer_model_sampler bypasses it.
- if hasattr(self.sampler, "logit_bias_state"):
- self.sampler.logit_bias_state.apply_logit_bias(
- logits,
- self.input_batch.expanded_idx_mapping,
- self.input_batch.idx_mapping_np,
- self.input_batch.positions[self.input_batch.logits_indices],
- )
- sampler_output = model_sample(
- logits,
- self._sampling_metadata_for_model_sampler(sampling_metadata),
- )
- if sampler_output is not None:
- return sampler_output
- self.input_batch.update_async_output_token_ids()
- return self.sampler(
- logits=logits,
- sampling_metadata=sampling_metadata,
- )
-
- return super()._sample(logits, spec_decode_metadata)
-
- @staticmethod
- def _resolve_req_hidden_states(
- hidden_states_cpu: torch.Tensor,
- combined_hidden_states: dict[str, torch.Tensor] | None,
- rid: str,
- start: int,
- end: int,
- ):
- if combined_hidden_states is not None:
- # We always have all request IDs for prefix cache, even for
- # partial cache misses, so this should never happen.
- if rid not in combined_hidden_states:
- raise RuntimeError("Request IDs in the batch are missing from the merged states!")
- return combined_hidden_states[rid]
- # Prefix caching is disabled
- return hidden_states_cpu[start:end]
-
@torch.inference_mode()
def sample_tokens(
self,
@@ -682,13 +432,6 @@ def sample_tokens(
kv_extracted_req_ids = getattr(self, "kv_extracted_req_ids", None)
self.kv_extracted_req_ids = None
- # Used for prefix cache
- combined_hidden_states = None
- combined_multimodal_outputs = None
- # Used when we don't use prefix cache; prefix cache builds the payloads
- # internally since it already needs to do this for the cached tensors
- mm_cpu = {}
-
if self.execute_model_state is None:
kv_connector_output = self.kv_connector_output
self.kv_connector_output = None
@@ -720,7 +463,6 @@ def sample_tokens(
slot_mappings, # OMNI: unpack slot_mappings for drafter
) = self.execute_model_state
self.execute_model_state = None
- seq_len = hidden_states.shape[0]
# Apply structured output bitmasks if present.
if grammar_output is not None:
@@ -842,77 +584,67 @@ def propose_draft_token_ids(sampled_token_ids):
dtype=np.int32,
)
- # Prior to applying the post-processing func, extract
- # the prefix cached hidden states and multimodal states.
- if self.omni_prefix_cache is not None:
- (
- combined_hidden_states,
- combined_multimodal_outputs,
- ) = self._maybe_get_combined_prefix_cache_tensors(
- hidden_states,
- multimodal_outputs,
- scheduler_output.num_scheduled_tokens,
- )
- # Otherwise we don't have the mm CPU data yet, so we still need to build it
- if self.omni_prefix_cache is None:
- mm_cpu = build_mm_cpu(flatten_payload(multimodal_outputs))
-
self._process_additional_information_updates(
- hidden_states,
- multimodal_outputs,
- num_scheduled_tokens_np,
- scheduler_output,
- combined_hidden_states,
- combined_multimodal_outputs,
+ hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output
)
+ # Pre-copy multimodal tensors to CPU once (not per-request) to avoid
+ # redundant D2H transfers when gpu_resident_buffer_keys keeps them on GPU.
+ mm_cpu: dict[str, object] = {}
+ if isinstance(multimodal_outputs, dict) and multimodal_outputs:
+ for k, v in multimodal_outputs.items():
+ try:
+ if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]:
+ mm_cpu[k] = v.detach().to("cpu").contiguous()
+ elif isinstance(v, dict):
+ sub_dict: dict[str, torch.Tensor] = {}
+ for sk, sv in v.items():
+ if isinstance(sv, torch.Tensor) and sv.shape[0] == hidden_states_cpu.shape[0]:
+ sub_dict[str(sk)] = sv.detach().to("cpu").contiguous()
+ if sub_dict:
+ mm_cpu[k] = sub_dict
+ elif isinstance(v, list):
+ if len(v) == 0:
+ continue
+ cpu_list = []
+ for elem in v:
+ if isinstance(elem, torch.Tensor):
+ cpu_list.append(elem.detach().to("cpu").contiguous())
+ else:
+ cpu_list.append(elem)
+ mm_cpu[k] = cpu_list
+ except Exception as e:
+ logger.error(f"Error in merge multimodal outputs: {e}")
+
pooler_output: list[dict[str, object]] = []
for rid in req_ids_output_copy:
idx = req_id_to_index_output_copy[rid]
start = int(self.query_start_loc.cpu[idx])
sched = int(num_scheduled_tokens_np[idx])
end = start + sched
- # If prefix cache is enabled, we have already split everything
- # by request and converted the states to CPU tensors
- req_hidden_states = self._resolve_req_hidden_states(
- hidden_states_cpu,
- combined_hidden_states,
- rid,
- start,
- end,
- )
- payload: dict[str, object] = {"hidden": req_hidden_states}
-
- mm_payload: dict[str, object] = {}
- if combined_multimodal_outputs or mm_cpu:
- if combined_multimodal_outputs:
- # Prefix cache enabled; all items have already been processed
- # and split apart for each request as needed, and all tensors
- # have already been detached to the CPU. The only exception is
- # lists, which we keep as passthrough data for consistent behavior
- # in postprocess.
- for mm_key in combined_multimodal_outputs.keys():
- value = combined_multimodal_outputs[mm_key][rid]
- if isinstance(value, list):
- mm_payload[mm_key] = value[idx] if idx < len(value) else value[0]
- else:
- mm_payload[mm_key] = value
-
- else:
- # Prefix cache disabled; we still need to process the data
- for mm_key, mm_val in mm_cpu.items():
- mm_payload[mm_key] = to_payload_element(
- element=mm_val,
- idx=idx,
- start=start,
- end=end,
- pass_lists_through=False,
- seq_len=seq_len,
- )
+ hidden_slice = hidden_states_cpu[start:end]
+ payload: dict[str, object] = {"hidden": hidden_slice}
+ if mm_cpu:
+ mm_payload: dict[str, object] = {}
+ for k, v in mm_cpu.items():
+ if isinstance(v, torch.Tensor) and v.shape[0] == hidden_states_cpu.shape[0]:
+ mm_payload[k] = v[start:end].contiguous()
+ elif isinstance(v, dict):
+ mm_payload[k] = {sk: sv[start:end].contiguous() for sk, sv in v.items()}
+ elif isinstance(v, list):
+ element = v[idx] if idx < len(v) else v[0]
+ # Clone tensors to avoid cross-request aliasing
+ if isinstance(element, torch.Tensor):
+ element = element.clone()
+ mm_payload[k] = element
+ elif isinstance(v, torch.Tensor):
+ # List-derived tensor payloads are request-invariant; clone to
+ # avoid accidental cross-request aliasing on downstream mutation.
+ mm_payload[k] = v.clone()
+ else:
+ mm_payload[k] = v
payload.update(mm_payload)
- # Flatten nested dicts to dotted keys so pooling_output
- # stays dict[str, torch.Tensor] for msgspec serialization.
- pooler_output.append(flatten_payload(payload))
+ pooler_output.append(payload)
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
if self.routed_experts_initialized:
capturer = RoutedExpertsCapturer.get_instance()
diff --git a/vllm_omni/worker/gpu_ar_worker.py b/vllm_omni/worker/gpu_ar_worker.py
index 50ad70c57fe..4abe21964b3 100644
--- a/vllm_omni/worker/gpu_ar_worker.py
+++ b/vllm_omni/worker/gpu_ar_worker.py
@@ -12,7 +12,6 @@
from vllm.v1.worker.utils import request_memory
from vllm.v1.worker.workspace import init_workspace_manager
-from vllm_omni.diffusion.data import OmniACK, OmniSleepTask, OmniWakeTask
from vllm_omni.worker.base import OmniGPUWorkerBase
from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner
from vllm_omni.worker.mixins import OmniWorkerMixin
@@ -29,7 +28,7 @@ class GPUARWorker(OmniWorkerMixin, OmniGPUWorkerBase):
@instrument(span_name="Init device")
def init_device(self):
- if self.device_config.device_type in ("cuda", "musa"):
+ if self.device_config.device_type == "cuda":
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
parallel_config = self.parallel_config
@@ -52,7 +51,7 @@ def init_device(self):
assert self.local_rank < torch.accelerator.device_count(), (
f"DP adjusted local rank {self.local_rank} is out of bounds. "
)
- visible_device_count = torch.accelerator.device_count()
+ visible_device_count = torch.accelerator.device_count() if torch.cuda.is_available() else 0
assert self.parallel_config.local_world_size <= visible_device_count, (
f"local_world_size ({self.parallel_config.local_world_size}) must "
f"be less than or equal to the number of visible devices "
@@ -105,23 +104,3 @@ def init_device(self):
if self.rank == 0:
# If usage stat is enabled, collect relevant info.
report_usage_stats(self.vllm_config)
-
- def handle_sleep_task(self, task: OmniSleepTask | dict) -> OmniACK:
- """
- Explicitly handle sleep commands.
- Calls the implementation in the base class OmniGPUWorkerBase.
- """
- logger.debug(f"[AR Worker {self.rank}] Resolving handle_sleep_task dispatch")
- if isinstance(task, dict):
- task = OmniSleepTask(**task)
- return super().handle_sleep_task(task)
-
- def handle_wake_task(self, task: OmniWakeTask | dict) -> OmniACK:
- """
- Explicitly handle wake-up commands.
- Calls the implementation in the base class OmniGPUWorkerBase.
- """
- logger.debug(f"[AR Worker {self.rank}] Resolving handle_wake_task dispatch")
- if isinstance(task, dict):
- task = OmniWakeTask(**task)
- return super().handle_wake_task(task)
diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py
index f10115c8e90..d95b676f6d6 100644
--- a/vllm_omni/worker/gpu_generation_model_runner.py
+++ b/vllm_omni/worker/gpu_generation_model_runner.py
@@ -39,12 +39,11 @@
from vllm_omni.outputs import OmniModelRunnerOutput
from vllm_omni.worker.gpu_ar_model_runner import ExecuteModelState
from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner
-from vllm_omni.worker.omni_connector_model_runner_mixin import OmniConnectorModelRunnerMixin
logger = logging.getLogger(__name__)
-class GPUGenerationModelRunner(OmniGPUModelRunner, OmniConnectorModelRunnerMixin):
+class GPUGenerationModelRunner(OmniGPUModelRunner):
"""Generation model runner for vLLM-Omni (non-autoregressive).
- Reuses GPUModelRunner preparation, multimodal handling, and TP/PP/DP glue.
diff --git a/vllm_omni/worker/gpu_generation_worker.py b/vllm_omni/worker/gpu_generation_worker.py
index a356f03ad0c..267ed61c0a4 100644
--- a/vllm_omni/worker/gpu_generation_worker.py
+++ b/vllm_omni/worker/gpu_generation_worker.py
@@ -28,7 +28,7 @@ class GPUGenerationWorker(OmniWorkerMixin, OmniGPUWorkerBase):
@instrument(span_name="Init device")
def init_device(self):
- if self.device_config.device_type in ("cuda", "musa"):
+ if self.device_config.device_type == "cuda":
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
parallel_config = self.parallel_config
@@ -51,7 +51,7 @@ def init_device(self):
assert self.local_rank < torch.accelerator.device_count(), (
f"DP adjusted local rank {self.local_rank} is out of bounds. "
)
- visible_device_count = torch.accelerator.device_count()
+ visible_device_count = torch.accelerator.device_count() if torch.cuda.is_available() else 0
assert self.parallel_config.local_world_size <= visible_device_count, (
f"local_world_size ({self.parallel_config.local_world_size}) must "
f"be less than or equal to the number of visible devices "
diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py
index d914f1b39df..a7abaf7b62a 100644
--- a/vllm_omni/worker/gpu_model_runner.py
+++ b/vllm_omni/worker/gpu_model_runner.py
@@ -1,8 +1,9 @@
+import sys
from typing import TYPE_CHECKING, Any, cast
import numpy as np
import torch
-from vllm.compilation.cuda_graph import CUDAGraphWrapper
+from vllm.compilation.cuda_graph import CUDAGraphWrapper as _OriginalCUDAGraphWrapper
from vllm.config import CUDAGraphMode
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
@@ -20,7 +21,6 @@
from vllm.v1.worker.gpu_model_runner import GPUModelRunner, IntermediateTensors, PerLayerAttnMetadata
from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
-from vllm_omni.core.prefix_cache import OmniTensorPrefixCache
from vllm_omni.engine.serialization import deserialize_additional_information
from vllm_omni.model_executor.layers.rotary_embedding.mrope import OmniMRotaryEmbedding as MRotaryEmbedding
from vllm_omni.model_executor.models.output_templates import OmniOutput
@@ -38,15 +38,28 @@
logger = init_logger(__name__)
+class CUDAGraphWrapper(_OriginalCUDAGraphWrapper):
+ def __getattr__(self, key: str) -> Any:
+ # allow accessing the attributes of the runnable.
+ if hasattr(self.runnable, key):
+ return getattr(self.runnable, key)
+ raise AttributeError(f"Attribute {key} not exists in the runnable of cudagraph wrapper")
+
+
+# Patch vLLM's CUDAGraphWrapper with our optimized version
+for _module_name, _module in sys.modules.items():
+ if "vllm" not in _module_name:
+ continue
+ if hasattr(_module, "CUDAGraphWrapper") and _module.CUDAGraphWrapper is _OriginalCUDAGraphWrapper:
+ _module.CUDAGraphWrapper = CUDAGraphWrapper
+
+
class OmniGPUModelRunner(GPUModelRunner):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_intermediate_buffer: dict[str, dict[str, Any]] = {}
self._omni_num_scheduled_tokens_np: np.ndarray | None = None
self._omni_last_model_output: object | None = None
- # The Omni tensor prefix cache will be allocated
- # when we initialize the metadata builders if enabled
- self.omni_prefix_cache = None
def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes):
"""Override to fix scheduler_metadata buffer size for FA3 + CUDA graph.
@@ -74,16 +87,6 @@ def initialize_metadata_builders(self, kv_cache_config, kernel_block_sizes):
device=sm.device,
)
- # Initialize the wrapper for both multimodal output tensors
- # and for hidden states to be passed between stages
- if self.cache_config.enable_prefix_caching:
- self.omni_prefix_cache = OmniTensorPrefixCache(
- num_blocks=kv_cache_config.num_blocks,
- block_size=self.cache_config.block_size,
- hidden_size=self.model_config.get_hidden_size(),
- hs_dtype=self.dtype,
- )
-
@instrument(span_name="Loading (GPU)")
def load_model(self, *args, **kwargs) -> None:
super().load_model(*args, **kwargs)
@@ -97,9 +100,11 @@ def load_model(self, *args, **kwargs) -> None:
self.has_talker_mtp = True
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
+ # Only wrap talker_mtp in CUDAGraphWrapper for Omni models that
+ # have a separate .talker sub-module. TTS models' code predictor
+ # has internal AR loops / torch.multinomial — not graph-safe.
has_separate_talker = getattr(self.model, "talker", None) is not None
- talker_mtp_graph_safe = getattr(self.model, "talker_mtp_graph_safe", False)
- if cudagraph_mode.has_full_cudagraphs() and (has_separate_talker or talker_mtp_graph_safe):
+ if cudagraph_mode.has_full_cudagraphs() and has_separate_talker:
self.talker_mtp = CUDAGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL)
# TTS exposes mtp_hidden_size; Omni uses hf_text_config.hidden_size.
hidden_size = int(
@@ -152,10 +157,8 @@ def _init_mrope_positions(self, req_state: CachedRequestState):
if supports_mrope(self.get_model()):
# Model implements SupportsMRoPE interface
# Pass all extracted metadata; models use what they need via **kwargs
- sp_extra_args = getattr(req_state.sampling_params, "extra_args", {}) if req_state.sampling_params else {}
- target_h = sp_extra_args.get("target_h") if isinstance(sp_extra_args, dict) else None
- target_w = sp_extra_args.get("target_w") if isinstance(sp_extra_args, dict) else None
- kwargs = dict(
+ req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions(
+ req_state.prompt_token_ids,
mm_features=req_state.mm_features,
hf_config=self.model_config.hf_config,
image_grid_thw=image_grid_thw,
@@ -164,14 +167,6 @@ def _init_mrope_positions(self, req_state: CachedRequestState):
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
- if target_h is not None:
- kwargs["target_h"] = target_h
- if target_w is not None:
- kwargs["target_w"] = target_w
- req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions(
- req_state.prompt_token_ids,
- **kwargs,
- )
else:
req_state.mrope_positions, req_state.mrope_position_delta = MRotaryEmbedding.get_input_positions_tensor(
req_state.prompt_token_ids,
@@ -258,10 +253,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput"):
The SamplingMetadata is updated and copied to the GPU if there is a
new/resumed/paused/finished request in the batch.
"""
- # Used for prefix cache
- if self.omni_prefix_cache is not None:
- self.omni_prefix_cache.reset_prefix_cached_new_req_ids()
-
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
@@ -318,18 +309,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput"):
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
if req_id in self.requests:
- self._update_streaming_input_additional_info(new_req_data, req_id)
req_state = self._update_streaming_request(req_id, new_req_data)
reqs_to_add.append(req_state)
continue
- # Since this is the first time the request has been scheduled,
- # num_computed_tokens > 0 means that we have a hit in prefix
- # caching; mark it so that we can manage the hidden states
- # later on as needed.
- if self.omni_prefix_cache is not None and new_req_data.num_computed_tokens > 0:
- self.omni_prefix_cache.add_prefix_cached_new_req_id(req_id)
-
sampling_params = new_req_data.sampling_params
pooling_params = new_req_data.pooling_params
@@ -386,7 +369,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput"):
logger.warning_once(
"additional_information on request data is deprecated, use model_intermediate_buffer"
)
- info_dict = deserialize_additional_information(new_req_data.additional_information)
+ payload_info = new_req_data.additional_information
+ info_dict = deserialize_additional_information(payload_info)
if info_dict:
self.model_intermediate_buffer[req_id] = info_dict
setattr(
@@ -1037,15 +1021,6 @@ def _build_model_kwargs_extra(self) -> dict:
import traceback
traceback.print_exc()
-
- if getattr(self.model_config, "has_sampling_extra_args", False):
- extra_args_list: list[dict] = []
- for req_id in self.input_batch.req_ids:
- req = self.requests[req_id]
- sp = req.sampling_params if req else None
- extra_args_list.append(sp.extra_args if sp and sp.extra_args else {})
- model_kwargs_extra["sampling_extra_args"] = extra_args_list
-
return model_kwargs_extra
def _process_additional_information_updates(
@@ -1054,8 +1029,6 @@ def _process_additional_information_updates(
multimodal_outputs: object,
num_scheduled_tokens_np: np.ndarray,
scheduler_output: "SchedulerOutput",
- combined_hidden_states: dict[str, torch.Tensor] | None = None,
- combined_multimodal_outputs: dict[str, object] | None = None,
) -> None:
"""Process model-provided per-request updates and merge into model_intermediate_buffer."""
try:
@@ -1064,37 +1037,21 @@ def _process_additional_information_updates(
if hasattr(self.model, "has_postprocess") and self.model.has_postprocess:
for req_index, req_id in enumerate(self.input_batch.req_ids):
req_infos = self.model_intermediate_buffer.get(req_id, {})
- if combined_hidden_states:
- # Combined hidden states contains all hidden states for every request
- hidden_states_slice = combined_hidden_states[req_id]
- else:
- start_offset = int(self.query_start_loc.cpu[req_index])
- sched_tokens = int(num_scheduled_tokens_np[req_index])
- s, e = start_offset, start_offset + sched_tokens
- # only consider to store data into update dict.
- hidden_states_slice = hidden_states[s:e]
-
- if combined_multimodal_outputs:
- # NOTE this is a bit ugly, but the mm data is structured as a list of
- # keys mapping to request IDs, and if enabled, we will always have all
- # request IDs in every subdict, including for cache misses.
- mm_out = {k: v[req_id] for k, v in combined_multimodal_outputs.items()}
- else:
- mm_out = multimodal_outputs
- # Exclude 'hidden_states' from kwargs to avoid clash with
- # the positional arg. The buffer entry must be preserved
- # because preprocess reads hidden_states['last'] from it.
- # TODO: pass req_infos as a single payload arg instead of **unpacking
- # to avoid key collisions with positional args.
- postprocess_kwargs = {k: v for k, v in req_infos.items() if k != "hidden_states"}
+ start_offset = int(self.query_start_loc.cpu[req_index])
+ sched_tokens = int(num_scheduled_tokens_np[req_index])
+ s, e = start_offset, start_offset + sched_tokens
+ # only consider to store data into update dict.
+ hidden_states_slice = hidden_states[s:e]
update_dict = self.model.postprocess(
- hidden_states_slice,
- multimodal_outputs=mm_out,
- **postprocess_kwargs,
+ hidden_states_slice, multimodal_outputs=multimodal_outputs, **req_infos
)
self._update_intermediate_buffer(req_id, update_dict)
except Exception as e:
- logger.error(f"Error merging for requests:{self.input_batch.req_ids} additional information update: {e}")
+ logger.error(
+ f"Error merging for requests:{self.input_batch.req_ids} "
+ f"additional information update: {e}, with the multimodal_outputs "
+ f"as {multimodal_outputs}"
+ )
import traceback
traceback.print_exc()
@@ -1303,7 +1260,6 @@ def _preprocess(
span_len = int(e) - int(s)
# call the custom process function
- req_infos["request_id"] = req_id
embed_slice = inputs_embeds[s:e] if inputs_embeds is not None else None
req_input_ids, req_embeds, update_dict = self.model.preprocess(
input_ids=input_ids[s:e], input_embeds=embed_slice, **req_infos
@@ -1369,31 +1325,20 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te
req_embeds = self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded]
last_talker_hidden = self.last_talker_hidden.gpu[:num_tokens_padded]
text_step = self.text_step.gpu[:num_tokens_padded]
- subtalker_params = getattr(self.vllm_config.model_config, "subtalker_sampling_params", None)
- if not isinstance(subtalker_params, dict):
- subtalker_params = {}
with set_forward_context(
None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc
):
- req_embeds, code_predictor_codes = self.talker_mtp(
- req_input_ids,
- req_embeds,
- last_talker_hidden,
- text_step,
- do_sample=subtalker_params.get("do_sample"),
- temperature=subtalker_params.get("temperature"),
- top_k=subtalker_params.get("top_k"),
- top_p=subtalker_params.get("top_p"),
- )
- # update the inputs_embeds and code_predictor_codes
- out_key = getattr(self.model, "talker_mtp_output_key", ("codes", "audio"))
- if not isinstance(out_key, tuple) or len(out_key) != 2:
- raise TypeError(f"talker_mtp_output_key must be a 2-tuple, got {type(out_key).__name__}: {out_key!r}")
+ req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step)
+ # code_predictor_codes stays on GPU here; _update_intermediate_buffer
+ # keeps it device-resident when the key is in gpu_resident_buffer_keys.
+ # D2H is deferred to sample_tokens where hidden_states.to("cpu") already
+ # syncs the stream, avoiding a per-step cudaStreamSynchronize.
+ out_key = getattr(self.model, "talker_mtp_output_key", "code_predictor_codes")
for idx, req_id in enumerate(decode_req_ids):
req_index = self.input_batch.req_ids.index(req_id)
start_offset = int(self.query_start_loc.cpu[req_index])
inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1]
- update_dict = {out_key[0]: {out_key[1]: code_predictor_codes[idx : idx + 1]}}
+ update_dict = {out_key: code_predictor_codes[idx : idx + 1]}
self._merge_additional_information_update(req_id, update_dict)
def _model_forward(
@@ -1421,72 +1366,32 @@ def _model_forward(
self._omni_last_model_output = model_output
return model_output
- def _store_value(self, dest: dict, key: str, value: Any, gpu_keys: set) -> None:
- if isinstance(value, torch.Tensor):
- if key in gpu_keys:
- dest[key] = value.detach().clone()
- else:
- dest[key] = value.detach().to("cpu").contiguous()
- elif isinstance(value, list):
- dest[key] = [
- (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in value
- ]
- else:
- dest[key] = value
-
def _update_intermediate_buffer(self, req_id: str, upd: dict) -> None:
if not isinstance(upd, dict) or not upd:
return
req_state = self.requests.get(req_id)
if req_state is None:
return
- # Check if the model declares keys that should stay on GPU (tuples of (type_key, qualifier))
- gpu_keys: set[tuple[str, str]] = set()
+ # Check if the model declares keys that should stay on GPU
+ gpu_keys: set[str] = set()
if hasattr(self, "model") and hasattr(self.model, "gpu_resident_buffer_keys"):
gpu_keys = self.model.gpu_resident_buffer_keys
existing = self.model_intermediate_buffer.setdefault(req_id, {})
for k, v in upd.items():
- if isinstance(v, dict):
- existing_sub = existing.setdefault(k, {})
- for qual, val in v.items():
- self._store_value(existing_sub, qual, val, {q for tk, q in gpu_keys if tk == k})
+ if isinstance(v, torch.Tensor):
+ if k in gpu_keys:
+ existing[k] = v.detach().clone()
+ else:
+ existing[k] = v.detach().to("cpu").contiguous()
+ elif isinstance(v, list):
+ existing[k] = [
+ (item.detach().to("cpu").contiguous() if isinstance(item, torch.Tensor) else item) for item in v
+ ]
else:
- self._store_value(existing, k, v, set())
+ existing[k] = v
# Backward compatible: mirror to old setattr location
setattr(req_state, "additional_information_cpu", existing)
def _merge_additional_information_update(self, req_id, upd):
logger.warning_once("_merge_additional_information_update is deprecated, use _update_intermediate_buffer")
return self._update_intermediate_buffer(req_id, upd)
-
- def _update_streaming_input_additional_info(self, new_req_data, req_id):
- # For streaming input prefill case only. Update buffer from last segment input
- cached_additional_info = self.model_intermediate_buffer.get(req_id, {})
- if cached_additional_info:
- payload_info = getattr(new_req_data, "additional_information", None)
- inc_info = deserialize_additional_information(payload_info)
- if isinstance(inc_info, dict) and inc_info:
- accumulated_keys: set[tuple[str, str]] = set()
- if hasattr(self, "model") and hasattr(self.model, "streaming_accumulated_keys"):
- accumulated_keys = self.model.streaming_accumulated_keys
- merged_info = dict(cached_additional_info)
- for key, value in inc_info.items():
- if isinstance(value, dict):
- existing_sub = merged_info.get(key)
- merged_sub = dict(existing_sub) if isinstance(existing_sub, dict) else {}
- for sk, sv in value.items():
- if (key, sk) in accumulated_keys and isinstance(sv, torch.Tensor):
- inc_tensor = sv.detach().to("cpu").contiguous()
- old_tensor = merged_sub.get(sk)
- if old_tensor is None:
- merged_sub[sk] = inc_tensor
- else:
- merged_sub[sk] = torch.cat((old_tensor, inc_tensor), dim=0)
- else:
- merged_sub[sk] = sv
- merged_info[key] = merged_sub
- else:
- merged_info[key] = value
- merged_info.setdefault("meta", {})["num_processed_tokens"] = 0
- self.model_intermediate_buffer[req_id] = merged_info
- setattr(self.requests[req_id], "additional_information_cpu", merged_info)
diff --git a/vllm_omni/worker/omni_connector_model_runner_mixin.py b/vllm_omni/worker/omni_connector_model_runner_mixin.py
deleted file mode 100644
index 8e8f5741fa6..00000000000
--- a/vllm_omni/worker/omni_connector_model_runner_mixin.py
+++ /dev/null
@@ -1,2155 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Unified data-plane communication mixin for Model Runners.
-
-All connector.put()/get() calls are consolidated here. Background I/O
-threads handle async_chunk and full_payload_mode transfers; KV cache is delegated to
-the existing OmniKVTransferManager (to be absorbed later).
-
-The mixin reports transfer results via OmniConnectorOutput so that the
-Scheduler can make scheduling decisions without ever touching a connector.
-"""
-
-from __future__ import annotations
-
-import importlib
-import inspect
-import os
-import threading
-from collections import defaultdict, deque
-from types import SimpleNamespace
-from typing import TYPE_CHECKING, Any
-
-import torch
-from vllm.distributed.parallel_state import get_tp_group
-from vllm.logger import init_logger
-
-from vllm_omni.distributed.omni_connectors.factory import OmniConnectorFactory
-from vllm_omni.distributed.omni_connectors.utils.config import ConnectorSpec
-from vllm_omni.outputs import OmniConnectorOutput
-from vllm_omni.worker.payload_span import (
- THINKER_DECODE_EMBEDDINGS_KEY,
- THINKER_DECODE_TOKEN_END_KEY,
- THINKER_DECODE_TOKEN_START_KEY,
- THINKER_OUTPUT_TOKEN_IDS_KEY,
- get_tensor_span,
- merge_tensor_spans,
-)
-
-if TYPE_CHECKING:
- from vllm_omni.distributed.omni_connectors.connectors.base import (
- OmniConnectorBase,
- )
- from vllm_omni.distributed.omni_connectors.kv_transfer_manager import (
- OmniKVTransferManager,
- )
-
-logger = init_logger(__name__)
-
-
-class OmniConnectorModelRunnerMixin:
- """Unified data-plane communication mixin for Model Runners.
-
- Provides three transfer modes through a single pair of bg I/O threads:
- - **full_payload_mode**: ``recv_full_payload_inputs`` / ``send_full_payload_outputs``
- - **Streaming (async_chunk)**: ``recv_chunk`` / ``send_chunk``
- - **KV cache**: ``send_kv_cache`` / ``recv_kv_cache`` (delegates to
- the existing ``OmniKVTransferManager``)
-
- The mixin owns connector instances and background threads. It never
- touches scheduling queues -- readiness is communicated to the Scheduler
- via ``OmniConnectorOutput``.
- """
-
- # ------------------------------------------------------------------ #
- # Init / Shutdown
- # ------------------------------------------------------------------ #
-
- def init_omni_connectors(
- self,
- vllm_config: Any,
- model_config: Any,
- kv_transfer_manager: OmniKVTransferManager | None = None,
- ) -> None:
- """Initialize connectors and background threads.
-
- Args:
- vllm_config: Full vLLM config object.
- model_config: Stage-level model config with connector settings.
- kv_transfer_manager: Existing KV transfer manager to delegate to.
- """
- self._omni_connector: OmniConnectorBase | None = self._create_connector(model_config)
- self._kv_transfer_manager = kv_transfer_manager
-
- self._async_chunk: bool = getattr(model_config, "async_chunk", False)
- self._model_mode: str = getattr(model_config, "worker_type", "ar")
- stage_id = getattr(model_config, "stage_id", 0)
- if isinstance(stage_id, str):
- stage_id = int(stage_id)
- self._stage_id: int = stage_id if isinstance(stage_id, int) else 0
-
- self._custom_process_func_path, self._custom_process_func = self._load_custom_func(model_config)
- self._custom_process_supports_is_finished = self._custom_process_supports_is_finished_kwarg()
- logger.info(
- "[Stage-%s] init_omni_connectors: async_chunk=%s, custom_process_func=%s, connector=%s, func_path=%s",
- self._stage_id,
- self._async_chunk,
- self._custom_process_func,
- type(self._omni_connector).__name__ if self._omni_connector else None,
- self._custom_process_func_path,
- )
-
- # -- next stage ID (from connector config or default stage_id + 1) --
- self._next_stage_id: int = self._resolve_next_stage_id(model_config)
-
- # -- heterogeneous TP rank support --
- rank_cfg = self._parse_rank_mapping(model_config)
- self._from_tp: int = rank_cfg["from_tp"]
- self._to_tp: int = rank_cfg["to_tp"]
- self._local_rank: int = rank_cfg["local_rank"]
- if self._kv_transfer_manager is not None:
- self._kv_transfer_manager.kv_send_key_builder = self.get_rank_aware_kv_send_keys
- self._kv_transfer_manager.kv_recv_key_builder = self.get_rank_aware_kv_keys
- self._kv_transfer_manager.kv_payload_merger = self._merge_rank_sharded_kv_payloads
- self._kv_transfer_manager.kv_payload_slicer = self._slice_rank_sharded_kv_payload
-
- # -- chunk index tracking (ported from OmniChunkTransferAdapter) --
- self._put_req_chunk: dict[str, int] = defaultdict(int)
- self._get_req_chunk: dict[str, int] = defaultdict(int)
- # Send-side async accumulation / staging buffer. Receive-side payload
- # ownership lives in ``_local_stage_payload_cache``.
- self._send_side_request_payload: dict[str, dict[str, Any]] = {}
- self._code_prompt_token_ids: dict[str, list[list[int]]] = defaultdict(list)
- self._request_ids_mapping: dict[str, str] = {}
-
- # -- async I/O state (shared by chunk + full_payload_mode) --
- self._pending_load_reqs: dict[str, Any] = {}
- self._finished_load_reqs: set[str] = set()
- self._pending_save_reqs: dict[str, deque] = {}
- self._pending_save_counts: dict[str, int] = defaultdict(int)
- self._deferred_send_cleanup: set[str] = set()
- # -- per-cycle output accumulator --
- self._chunk_ready_req_ids: set[str] = set()
- self._chunk_finished_req_ids: set[str] = set()
- self._stage_recv_req_ids: set[str] = set()
- self._full_payload_pending_broadcast_req_ids: set[str] = set()
- self._async_chunk_updated_req_ids: set[str] = set()
-
- # -- Model Runner local payload cache (RFC §2.4) --
- # Full stage payloads land here first on the recv side. We
- # intentionally do not write connector recv results straight into
- # `model_intermediate_buffer`: runner-owned runtime state is
- # materialized later by `_sync_local_stage_payloads()` on the
- # model thread. This keeps recv timing separate from execute-step
- # visibility and avoids mixing connector I/O with model runtime
- # ownership.
- self._local_stage_payload_cache: dict[str, dict[str, Any]] = {}
- # Lightweight scheduling metadata pending delivery to the Scheduler.
- self._local_request_metadata: dict[str, dict[str, Any]] = {}
-
- # -- persistent set of request IDs whose chunk stream is complete --
- # Prevents re-registration after the finish sentinel has been received.
- self._chunk_stream_completed: set[str] = set()
-
- # -- full_payload_mode: accumulate latest pooler_output per request,
- # send only when the request finishes (next-cycle flush) --
- self._pending_full_payload_send: dict[str, tuple[Any, Any]] = {}
-
- # -- KV sent accumulator --
- self._kv_sent_req_ids: list[str] = []
-
- # -- KV transfer lifecycle (absorbed from scheduler) --
- # Requests marked for KV transfer: {req_id: {seq_len, block_ids}}
- self._kv_pending_transfers: dict[str, dict[str, Any]] = {}
- # Requests whose KV transfer has been submitted but not yet acked
- self._kv_active_transfers: set[str] = set()
- # Requests whose KV transfer is complete (acked by kv_extracted_req_ids)
- self._kv_completed_transfers: set[str] = set()
- # Dedup guard: requests that have already triggered KV transfer
- self._kv_triggered_requests: set[str] = set()
-
- self._lock = threading.Lock()
- self._stop_event = threading.Event()
- self._work_available = threading.Event()
-
- # Start background threads only when there's a connector
- self._recv_thread: threading.Thread | None = None
- self._save_thread: threading.Thread | None = None
- if self._omni_connector is not None:
- self._recv_thread = threading.Thread(
- target=self._recv_loop,
- daemon=True,
- name="omni-mixin-recv",
- )
- self._recv_thread.start()
- self._save_thread = threading.Thread(
- target=self._save_loop,
- daemon=True,
- name="omni-mixin-save",
- )
- self._save_thread.start()
-
- def shutdown_omni_connectors(self) -> None:
- """Stop background threads and release connector resources."""
- self._stop_event.set()
- if self._recv_thread is not None:
- self._recv_thread.join(timeout=5)
- if self._save_thread is not None:
- self._save_thread.join(timeout=5)
- if self._omni_connector is not None:
- try:
- self._omni_connector.close()
- except Exception:
- pass
-
- def cleanup_finished_request(self, req_id: str) -> None:
- """Clean up per-request state after a request is fully finished.
-
- Call this when a request is freed from the model runner to prevent
- memory leaks in the mixin's tracking dicts/sets. The external
- request ID is resolved before cleaning up ``_put_req_chunk`` which
- is keyed by external ID.
- """
- ext_id = self._request_ids_mapping.pop(req_id, None)
- send_req_id = ext_id if ext_id is not None else req_id
-
- with self._lock:
- if self._pending_save_counts.get(send_req_id, 0):
- self._deferred_send_cleanup.add(send_req_id)
- else:
- self._put_req_chunk.pop(send_req_id, None)
- self._send_side_request_payload.pop(send_req_id, None)
- self._code_prompt_token_ids.pop(send_req_id, None)
- self._kv_pending_transfers.pop(req_id, None)
- self._kv_active_transfers.discard(req_id)
- self._kv_completed_transfers.discard(req_id)
- self._kv_triggered_requests.discard(req_id)
- self._cleanup_recv_delivery_state(req_id)
-
- def drop_inactive_request_delivery_state(self, req_id: str) -> None:
- """Clear recv-side state for inactive requests."""
- ext_id = self._request_ids_mapping.pop(req_id, None)
- if hasattr(self, "_lock"):
- with self._lock:
- self._drop_send_side_payload_state(req_id, ext_id)
- else:
- self._drop_send_side_payload_state(req_id, ext_id)
- self._cleanup_recv_delivery_state(req_id)
-
- def _drop_send_side_payload_state(self, req_id: str, ext_id: str | None) -> None:
- if ext_id is not None:
- self._send_side_request_payload.pop(ext_id, None)
- self._send_side_request_payload.pop(req_id, None)
-
- def _cleanup_recv_delivery_state(self, req_id: str) -> None:
- """Clear recv-side delivery-cycle state."""
- if hasattr(self, "_lock"):
- with self._lock:
- self._clear_recv_delivery_state(req_id)
- else:
- self._clear_recv_delivery_state(req_id)
-
- def _clear_recv_delivery_state(self, req_id: str) -> None:
- self._get_req_chunk.pop(req_id, None)
- self._pending_load_reqs.pop(req_id, None)
- self._finished_load_reqs.discard(req_id)
- self._chunk_ready_req_ids.discard(req_id)
- self._chunk_finished_req_ids.discard(req_id)
- self._chunk_stream_completed.discard(req_id)
- self._stage_recv_req_ids.discard(req_id)
- self._full_payload_pending_broadcast_req_ids.discard(req_id)
- self._async_chunk_updated_req_ids.discard(req_id)
- self._local_stage_payload_cache.pop(req_id, None)
- self._local_request_metadata.pop(req_id, None)
-
- def prune_inactive_requests(self, active_req_ids: Any) -> set[str]:
- """Drop connector state for requests that no longer exist locally.
-
- Preempted / unscheduled requests are expected to stay in
- ``self.requests`` and therefore remain untouched. This only prunes
- stale request IDs that have already fallen out of the active request
- map, preventing background recv/send bookkeeping from outliving the
- request lifecycle.
- """
- if active_req_ids is None:
- return set()
-
- active_req_ids = set(active_req_ids)
- pending_req_ids = set(getattr(self, "_pending_load_reqs", {}).keys())
- received_req_ids = set(getattr(self, "_stage_recv_req_ids", set()))
- received_req_ids.update(getattr(self, "_full_payload_pending_broadcast_req_ids", set()))
- received_req_ids.update(getattr(self, "_local_request_metadata", {}).keys())
- # Pending recv requests may not yet be in the caller's active set
- # (e.g. WAITING_FOR_CHUNK requests live in the coordinator's internal
- # queues, not in model runner self.requests). Protect them so that
- # legitimate waiting requests are not pruned.
- #
- # Likewise, a full payload can arrive on the background recv thread
- # after the scheduler_output snapshot for the current execute_model()
- # cycle was already materialized. Those requests may briefly live only
- # in recv-side buffers/local cache until the next scheduler cycle wakes
- # them up; pruning them here drops the payload before stage_recv can be
- # published.
- active_req_ids.update(pending_req_ids)
- active_req_ids.update(received_req_ids)
- stale_req_ids: set[str] = set()
-
- # NOTE: _pending_load_reqs is excluded from the scan list because
- # all its entries are unconditionally protected above. The mixin
- # cannot distinguish a legitimately-waiting pending recv from an
- # orphaned one (only the coordinator/scheduler knows).
- #
- # Requests with freshly received full payloads / local stage payloads
- # are also protected above. Their scheduler wake-up may lag the recv
- # thread by one execute_model() cycle, especially when the request was
- # added after the current scheduler_output snapshot.
- #
- # Orphaned pending recv entries (e.g. from upstream stage crash)
- # are handled by OmniSchedulingCoordinator.collect_timed_out_request_ids()
- # which detects wait-time violations. The scheduler then removes the
- # request from its queues, sets FINISHED_ERROR, and calls _free_request()
- # which ultimately triggers cleanup_finished_request() here.
- for attr_name in (
- "_request_ids_mapping",
- "_get_req_chunk",
- "_finished_load_reqs",
- "_chunk_ready_req_ids",
- "_chunk_finished_req_ids",
- "_chunk_stream_completed",
- "_stage_recv_req_ids",
- "_full_payload_pending_broadcast_req_ids",
- "_async_chunk_updated_req_ids",
- "_local_stage_payload_cache",
- "_local_request_metadata",
- "_kv_pending_transfers",
- "_kv_active_transfers",
- "_kv_completed_transfers",
- "_kv_triggered_requests",
- ):
- state = getattr(self, attr_name, None)
- if isinstance(state, dict):
- stale_req_ids.update(req_id for req_id in state if req_id not in active_req_ids)
- elif isinstance(state, set):
- stale_req_ids.update(req_id for req_id in state if req_id not in active_req_ids)
-
- for req_id in stale_req_ids:
- self.cleanup_finished_request(req_id)
-
- return stale_req_ids
-
- # ------------------------------------------------------------------ #
- # Local payload cache (RFC §2.4 – Model Runner ownership)
- # ------------------------------------------------------------------ #
-
- def put_local_stage_payload(self, req_id: str, payload: dict[str, Any]) -> None:
- """Store a full stage payload in the local cache."""
- self._local_stage_payload_cache[req_id] = payload
-
- def get_local_stage_payload(self, req_id: str) -> dict[str, Any] | None:
- """Read a stage payload without removing it."""
- return self._local_stage_payload_cache.get(req_id)
-
- def pop_local_stage_payload(self, req_id: str) -> dict[str, Any] | None:
- """Remove and return a stage payload (consume after use)."""
- return self._local_stage_payload_cache.pop(req_id, None)
-
- def put_local_request_metadata(self, req_id: str, metadata: dict[str, Any]) -> None:
- """Store lightweight scheduling metadata for a request."""
- self._local_request_metadata[req_id] = metadata
-
- def get_local_request_metadata(self, req_id: str) -> dict[str, Any] | None:
- """Retrieve scheduling metadata for a request."""
- return self._local_request_metadata.get(req_id)
-
- # ------------------------------------------------------------------ #
- # Scheduling metadata extraction
- # ------------------------------------------------------------------ #
-
- @classmethod
- def _extract_scheduling_metadata(cls, payload: dict[str, Any]) -> dict[str, Any]:
- """Extract only the fields the scheduler needs from a full payload."""
- extracted: dict[str, Any] = {}
- if "next_stage_prompt_len" in payload:
- extracted["next_stage_prompt_len"] = payload["next_stage_prompt_len"]
- audio_codes = cls._payload_audio_codes(payload)
- if audio_codes is not None:
- extracted["code_predictor_codes"] = audio_codes
- meta = payload.get("meta")
- if isinstance(meta, dict) and "left_context_size" in meta:
- extracted["left_context_size"] = meta["left_context_size"]
- elif "left_context_size" in payload:
- logger.warning_once("legacy flat 'left_context_size' key in payload; expected 'meta.left_context_size'")
- return extracted
-
- _NON_CONSUMABLE_PAYLOAD_KEYS = {
- "finished",
- "override_keys",
- "next_stage_prompt_len",
- "left_context_size",
- THINKER_OUTPUT_TOKEN_IDS_KEY,
- THINKER_DECODE_TOKEN_START_KEY,
- THINKER_DECODE_TOKEN_END_KEY,
- }
-
- @staticmethod
- def _payload_value_has_content(value: Any) -> bool:
- if value is None:
- return False
- if isinstance(value, torch.Tensor):
- return value.numel() > 0
- if isinstance(value, (list, tuple, dict, set)):
- return len(value) > 0
- return True
-
- @staticmethod
- def _payload_finished(payload: Any) -> bool:
- if not isinstance(payload, dict):
- return False
- if "finished" in payload:
- logger.warning_once("legacy flat 'finished' key in payload; expected 'meta.finished'")
- meta = payload.get("meta")
- if not isinstance(meta, dict) or "finished" not in meta:
- return False
- flag = meta["finished"]
- if isinstance(flag, torch.Tensor):
- return flag.numel() == 1 and bool(flag.item())
- return bool(flag)
-
- @staticmethod
- def _payload_audio_codes(payload: Any) -> Any:
- if not isinstance(payload, dict):
- return None
- if "code_predictor_codes" in payload:
- logger.warning_once("legacy flat 'code_predictor_codes' key in payload; expected 'codes.audio'")
- codes = payload.get("codes")
- if isinstance(codes, dict):
- return codes.get("audio")
- return None
-
- @classmethod
- def _payload_is_consumable(cls, payload: dict[str, Any] | None) -> bool:
- """Return True when an async payload can drive a real forward step.
-
- Metadata-only wake-ups should not transition WAITING_FOR_CHUNK requests
- back to schedulable state. In particular, a widened token horizon without
- any newly visible thinker decode embeds should not force a placeholder-only
- talker decode step.
- """
- if not isinstance(payload, dict) or not payload:
- return False
-
- decode_embeddings = payload.get(THINKER_DECODE_EMBEDDINGS_KEY)
- if isinstance(decode_embeddings, torch.Tensor):
- if decode_embeddings.ndim == 0:
- return True
- return decode_embeddings.numel() > 0 and decode_embeddings.shape[0] > 0
-
- audio_codes = cls._payload_audio_codes(payload)
- if audio_codes is not None:
- if isinstance(audio_codes, torch.Tensor):
- return audio_codes.numel() > 0
- # Codec code 0 is valid; non-empty code payloads are consumable.
- if hasattr(audio_codes, "__len__"):
- return len(audio_codes) > 0
- return True
-
- for key, value in payload.items():
- if key in cls._NON_CONSUMABLE_PAYLOAD_KEYS:
- continue
- if cls._payload_value_has_content(value):
- return True
- return False
-
- @staticmethod
- def _get_local_tp_group() -> Any | None:
- """Return the local TP group when tensor parallelism is initialized."""
- try:
- return get_tp_group()
- except Exception:
- return None
-
- def _recv_ordinary_stage_result(
- self,
- connector: OmniConnectorBase,
- from_stage: str,
- to_stage: str,
- connector_get_key: str,
- ) -> Any:
- """Receive one ordinary non-KV stage payload on the local leader rank only."""
- tp_group = self._get_local_tp_group()
- if tp_group is None or getattr(tp_group, "world_size", 1) <= 1:
- return connector.get(from_stage, to_stage, connector_get_key)
- if not self.is_data_transfer_rank():
- return None
- return connector.get(from_stage, to_stage, connector_get_key)
-
- def _recv_full_payload_result(
- self,
- connector: OmniConnectorBase,
- from_stage: str,
- to_stage: str,
- connector_get_key: str,
- ) -> Any:
- """Receive one full-payload transfer on the local leader rank only."""
- return self._recv_ordinary_stage_result(
- connector,
- from_stage,
- to_stage,
- connector_get_key,
- )
-
- def _recv_async_chunk_result(
- self,
- connector: OmniConnectorBase,
- from_stage: str,
- to_stage: str,
- connector_get_key: str,
- ) -> Any:
- """Receive one ordinary async chunk on the local leader rank only."""
- return self._recv_ordinary_stage_result(
- connector,
- from_stage,
- to_stage,
- connector_get_key,
- )
-
- @staticmethod
- def _snapshot_payload(payload: Any) -> Any:
- if isinstance(payload, dict):
- return dict(payload)
- return payload
-
- def _broadcast_tp_payload_packet(self, packet: Any) -> Any:
- """Broadcast one ordinary payload packet from TP rank 0 when TP is active."""
- tp_group = self._get_local_tp_group()
- if tp_group is None or getattr(tp_group, "world_size", 1) <= 1:
- return packet
- leader_packet = packet if self.is_data_transfer_rank() else None
- return tp_group.broadcast_object(leader_packet, src=0)
-
- def _apply_staged_payloads_locked(self, staged_payloads: dict[str, Any]) -> None:
- for req_id, payload in staged_payloads.items():
- self._local_stage_payload_cache[req_id] = self._snapshot_payload(payload)
-
- def _collect_full_payload_results_locked(self) -> dict[str, Any] | None:
- if not self._full_payload_pending_broadcast_req_ids:
- return None
- results: dict[str, Any] = {}
- missing_req_ids: list[str] = []
- for req_id in tuple(self._full_payload_pending_broadcast_req_ids):
- payload = self._local_stage_payload_cache.get(req_id)
- if payload is None:
- missing_req_ids.append(req_id)
- continue
- results[req_id] = self._snapshot_payload(payload)
- self._full_payload_pending_broadcast_req_ids.discard(req_id)
- if missing_req_ids:
- logger.warning(
- "[Stage-%s] _collect_full_payload_results_locked: "
- "pending full-payload reqs missing from local cache: %s",
- self._stage_id,
- missing_req_ids,
- )
- return results or None
-
- def _collect_async_chunk_fanout_packet_locked(self) -> dict[str, Any] | None:
- payload_req_ids = set(self._async_chunk_updated_req_ids)
- payload_req_ids.update(self._finished_load_reqs)
- payload_req_ids.update(self._chunk_finished_req_ids)
- payload_req_ids.update(self._local_request_metadata)
- if not (
- payload_req_ids or self._finished_load_reqs or self._chunk_finished_req_ids or self._local_request_metadata
- ):
- return None
-
- staged_payloads = {
- req_id: self._snapshot_payload(self._local_stage_payload_cache[req_id])
- for req_id in payload_req_ids
- if req_id in self._local_stage_payload_cache
- }
- packet = {
- "staged_payloads": staged_payloads,
- "request_metadata": dict(self._local_request_metadata),
- "newly_finished": set(self._finished_load_reqs),
- "chunk_finished": set(self._chunk_finished_req_ids),
- }
-
- self._async_chunk_updated_req_ids.clear()
- self._finished_load_reqs.clear()
- self._chunk_finished_req_ids.clear()
- self._local_request_metadata.clear()
-
- for req_id in packet["chunk_finished"]:
- if req_id not in self._local_stage_payload_cache:
- continue
- ext_req_id = self._request_ids_mapping.get(req_id, req_id)
- self._send_side_request_payload.pop(ext_req_id, None)
- if ext_req_id != req_id:
- self._send_side_request_payload.pop(req_id, None)
-
- return packet
-
- def _apply_async_chunk_fanout_packet(self, packet: dict[str, Any]) -> None:
- staged_payloads = packet.get("staged_payloads", {})
- chunk_finished = set(packet.get("chunk_finished", ()))
- with self._lock:
- self._apply_staged_payloads_locked(staged_payloads)
- for req_id in chunk_finished:
- self._pending_load_reqs.pop(req_id, None)
- self._chunk_stream_completed.add(req_id)
-
- # ------------------------------------------------------------------ #
- # full_payload_mode (recv_full_payload_inputs / send_full_payload_outputs)
- # ------------------------------------------------------------------ #
-
- def recv_full_payload_inputs(self, scheduler_output: Any) -> dict[str, Any] | None:
- """Check for incoming full_payload_mode stage inputs (non-blocking).
-
- Returns a dict mapping ``request_id -> engine_inputs`` for data
- that has arrived, or ``None`` if nothing is ready. Stores full
- payloads in the local cache and extracts scheduling metadata.
- """
- with self._lock:
- results = self._collect_full_payload_results_locked() if self.is_data_transfer_rank() else None
- results = self._broadcast_tp_payload_packet(results)
- if not results:
- return None
- with self._lock:
- self._stage_recv_req_ids.update(results.keys())
- for req_id in results:
- self._pending_load_reqs.pop(req_id, None)
- self._apply_staged_payloads_locked(results)
- for req_id, payload in results.items():
- self._local_request_metadata[req_id] = self._extract_scheduling_metadata(payload)
- logger.info(
- "[Stage-%s] recv_full_payload_inputs: consumed %s reqs: %s, stage_recv_req_ids now=%s",
- self._stage_id,
- len(results),
- list(results.keys()),
- self._stage_recv_req_ids,
- )
- return results
-
- @staticmethod
- def _is_all_zero_tensor(t: Any) -> bool:
- """Return True if *t* is a torch.Tensor whose elements are all zero."""
- return isinstance(t, torch.Tensor) and t.numel() > 0 and not t.any()
-
- def accumulate_full_payload_output(
- self,
- req_id: str,
- pooler_output: Any,
- request: Any,
- ) -> None:
- """Accumulate pooler_output for a request across steps (full_payload_mode).
-
- Per-token tensors (2-D+, matching trailing dims) are concatenated
- along dim-0. Scalar / global tensors (1-D or 0-D) are replaced
- with the latest value.
-
- All-zero tensors (e.g. ``code_predictor_codes`` emitted during
- prefill) are dropped so that they do not pollute downstream stages
- with garbage / noise frames.
-
- The data is actually sent when ``flush_full_payload_outputs`` is called
- with the finished request IDs from the next scheduler cycle.
- """
- # ---- Filter out all-zero tensors from the incoming pooler_output ----
- filtered: dict[str, Any] = {}
- dropped_zero_keys: list[tuple[str, tuple[int, ...]]] = []
- for k, v in pooler_output.items():
- if self._is_all_zero_tensor(v):
- dropped_zero_keys.append((k, tuple(v.shape)))
- continue # skip prefill zero-filled placeholders
- filtered[k] = v
- if dropped_zero_keys:
- logger.info(
- "[Stage-%s] accumulate_full_payload_output: req=%s dropped_zero_keys=%s",
- self._stage_id,
- req_id,
- dropped_zero_keys,
- )
- pooler_output = filtered
-
- existing = self._pending_full_payload_send.get(req_id)
- if existing is None:
- self._pending_full_payload_send[req_id] = (pooler_output, request)
- return
-
- prev_output, _ = existing
- merged: dict[str, Any] = {}
- for k in set(prev_output) | set(pooler_output):
- v_new = pooler_output.get(k)
- v_old = prev_output.get(k)
- if v_new is None:
- merged[k] = v_old
- elif v_old is None:
- merged[k] = v_new
- elif (
- isinstance(v_new, torch.Tensor)
- and isinstance(v_old, torch.Tensor)
- and v_new.dim() >= 2
- and v_old.dim() >= 2
- and v_new.shape[1:] == v_old.shape[1:]
- ):
- merged[k] = torch.cat([v_old, v_new], dim=0)
- else:
- merged[k] = v_new
- self._pending_full_payload_send[req_id] = (merged, request)
-
- def flush_full_payload_outputs(self, finished_req_ids: set[str]) -> None:
- """Send accumulated full_payload outputs for requests that just finished."""
- logger.info(
- "[Stage-%s] flush_full_payload_outputs: finished_req_ids=%s, pending=%s",
- self._stage_id,
- finished_req_ids,
- list(self._pending_full_payload_send.keys()),
- )
- to_send: dict[str, tuple[Any, Any]] = {}
- for req_id in finished_req_ids:
- entry = self._pending_full_payload_send.pop(req_id, None)
- if entry is not None:
- to_send[req_id] = entry
- logger.info("[Stage-%s] flush_full_payload_outputs: to_send=%s", self._stage_id, list(to_send.keys()))
- if to_send:
- self.send_full_payload_outputs(scheduler_output=None, outputs=to_send)
-
- def send_full_payload_outputs(
- self,
- scheduler_output: Any,
- outputs: dict[str, tuple[Any, Any] | Any],
- ) -> list[str]:
- """Send full_payload stage outputs to the next stage via connector.
-
- Args:
- outputs: Mapping of ``req_id`` to either a
- ``(pooling_output, request)`` tuple (preferred) or a raw
- payload dict. When a tuple is supplied the request object
- is forwarded to ``custom_process_stage_input_func``.
-
- Returns list of request IDs successfully enqueued.
- """
- if self._omni_connector is None:
- logger.info("[Stage-%s] send_full_payload_outputs: connector is None, skip", self._stage_id)
- return []
- if not self.is_data_transfer_rank():
- logger.info(
- "[Stage-%s] send_full_payload_outputs: not data_transfer_rank (rank=%s), skip",
- self._stage_id,
- self._local_rank,
- )
- return list(outputs.keys())
- sent_ids: list[str] = []
- next_stage_id = self._next_stage_id
- for req_id, value in outputs.items():
- if isinstance(value, tuple) and len(value) == 2:
- raw_output, request = value
- else:
- raw_output, request = value, None
-
- payload = raw_output
- if self._custom_process_func is not None:
- payload = self._build_custom_process_payload(
- request_id=req_id,
- request=request,
- pooling_output=raw_output,
- )
- if payload is None:
- continue
- if payload is None:
- logger.info("[Stage-%s] send_full_payload_outputs: payload is None for %s", self._stage_id, req_id)
- continue
- if isinstance(payload, dict):
- audio_codes = self._payload_audio_codes(payload)
- if isinstance(audio_codes, torch.Tensor):
- code_len = int(audio_codes.numel())
- elif hasattr(audio_codes, "__len__"):
- code_len = len(audio_codes)
- else:
- code_len = None
- meta = payload.get("meta") if isinstance(payload.get("meta"), dict) else {}
- logger.info(
- "[Stage-%s] send_full_payload_outputs: req=%s payload_keys=%s code_len=%s left_context_size=%s",
- self._stage_id,
- req_id,
- sorted(payload.keys()),
- code_len,
- meta.get("left_context_size"),
- )
-
- external_req_id = self._resolve_external_req_id(request, req_id)
- chunk_id = self._put_req_chunk[req_id]
- self._put_req_chunk[req_id] += 1
- connector_put_key = f"{external_req_id}_{self._stage_id}_{chunk_id}"
-
- logger.info(
- "[Stage-%s] send_full_payload_outputs: enqueue req=%s put_key=%s next_stage=%s",
- self._stage_id,
- req_id,
- connector_put_key,
- next_stage_id,
- )
- task = {
- "stage_id": self._stage_id,
- "next_stage_id": next_stage_id,
- "put_key": connector_put_key,
- "data": payload,
- "request_id": req_id,
- }
- with self._lock:
- self._pending_save_reqs.setdefault(req_id, deque()).append(task)
- self._pending_save_counts[req_id] += 1
- sent_ids.append(req_id)
- if sent_ids:
- self._work_available.set()
- return sent_ids
-
- def recv_stage_inputs(self, scheduler_output: Any) -> dict[str, Any] | None:
- """Compatibility wrapper for ``recv_full_payload_inputs``."""
- return self.recv_full_payload_inputs(scheduler_output)
-
- def accumulate_batch_output(
- self,
- req_id: str,
- pooler_output: Any,
- request: Any,
- ) -> None:
- """Compatibility wrapper for ``accumulate_full_payload_output``."""
- self.accumulate_full_payload_output(req_id, pooler_output, request)
-
- def flush_batch_outputs(self, finished_req_ids: set[str]) -> None:
- """Compatibility wrapper for ``flush_full_payload_outputs``."""
- self.flush_full_payload_outputs(finished_req_ids)
-
- def send_stage_outputs(
- self,
- scheduler_output: Any,
- outputs: dict[str, tuple[Any, Any] | Any],
- ) -> list[str]:
- """Compatibility wrapper for ``send_full_payload_outputs``."""
- return self.send_full_payload_outputs(scheduler_output, outputs)
-
- # ------------------------------------------------------------------ #
- # Streaming chunk mode (recv_chunk / send_chunk)
- # ------------------------------------------------------------------ #
-
- def register_chunk_recv(self, request: Any) -> None:
- """Register a request for async chunk retrieval by the bg thread.
-
- Stage-0 has no upstream producer so this is a no-op there.
- Skips requests whose batch data has already been received to
- prevent the bg thread from polling for non-existent chunks.
- """
- if self._stage_id == 0:
- return
- request_id = request.request_id
- self._request_ids_mapping[request_id] = getattr(
- request,
- "external_req_id",
- request_id,
- )
- with self._lock:
- if request_id in self._stage_recv_req_ids:
- return
- # Don't re-register if the finish sentinel was already received
- if request_id in self._chunk_stream_completed:
- return
- self._pending_load_reqs[request_id] = request
- self._work_available.set()
-
- def recv_chunk(self) -> dict[str, Any]:
- """Collect chunks received by the bg thread since last call.
-
- Returns a dict ``{request_id: chunk_payload}`` for newly arrived
- chunks. Empty dict when nothing is ready.
-
- This method reads from ``_finished_load_reqs`` without clearing
- it -- ``get_omni_connector_output()`` is the sole consumer that
- drains and resets ``_finished_load_reqs`` at the end of each
- ``execute_model`` cycle.
-
- Returns **shallow copies** of the cached payloads so that the
- caller can read them without racing against the background recv
- thread, which may concurrently mutate the live cache entries via
- ``dict.update()``.
- """
- with self._lock:
- finished = set(self._finished_load_reqs)
- if not finished:
- return {}
- # Snapshot the payloads under the lock to avoid racing with
- # _poll_single_request which does existing.update(payload_data)
- # on the same dict objects.
- result = {}
- for rid in finished:
- payload = self._local_stage_payload_cache.get(rid)
- result[rid] = dict(payload) if isinstance(payload, dict) else payload
-
- self._chunk_ready_req_ids.update(finished)
- return result
-
- def send_chunk(
- self,
- request: Any,
- pooling_output: Any | None = None,
- ) -> bool:
- """Derive and enqueue one chunk for async sending.
-
- Payload extraction runs in the caller thread (via
- ``custom_process_stage_input_func``); the actual
- ``connector.put()`` is done by the background save thread.
- Non-KV data is identical across TP ranks; only rank 0 sends.
- """
- if self._omni_connector is None:
- logger.warning("[Stage-%s] send_chunk: connector is None", self._stage_id)
- return False
- if not self.is_data_transfer_rank():
- return True
- raw_req_id = getattr(request, "request_id", None) or getattr(request, "req_id", None)
- request_id = self._resolve_external_req_id(request, raw_req_id)
- # Cache the internal→external mapping so that finish sentinels can
- # resolve the external ID even after the request is freed.
- if raw_req_id and raw_req_id != request_id:
- self._request_ids_mapping.setdefault(raw_req_id, request_id)
- chunk_id = self._put_req_chunk[request_id]
-
- payload_data = self._build_custom_process_payload(
- request_id=request_id,
- request=request,
- pooling_output=pooling_output,
- )
- if payload_data is None:
- if chunk_id == 0:
- logger.warning(
- "[Stage-%s] send_chunk: payload is None for req=%s chunk=%s (process_func=%s)",
- self._stage_id,
- request_id,
- chunk_id,
- self._custom_process_func,
- )
- return False
-
- self._put_req_chunk[request_id] += 1
- next_stage_id = self._next_stage_id
- connector_put_key = f"{request_id}_{self._stage_id}_{chunk_id}"
-
- if chunk_id == 0:
- logger.info(
- "[Stage-%s] send_chunk: first chunk enqueued, req=%s key=%s",
- self._stage_id,
- request_id,
- connector_put_key,
- )
-
- task = {
- "stage_id": self._stage_id,
- "next_stage_id": next_stage_id,
- "put_key": connector_put_key,
- "data": payload_data,
- "request_id": request_id,
- }
- with self._lock:
- self._pending_save_reqs.setdefault(request_id, deque()).append(task)
- self._pending_save_counts[request_id] += 1
- self._work_available.set()
- return True
-
- # ------------------------------------------------------------------ #
- # KV cache (delegates to OmniKVTransferManager)
- # ------------------------------------------------------------------ #
-
- def send_kv_cache(
- self,
- finished_reqs: dict[str, dict[str, Any]],
- kv_caches: list[torch.Tensor],
- block_size: int,
- cache_dtype: str,
- request_id_resolver: Any | None = None,
- ) -> list[str]:
- """Send KV cache for finished requests.
-
- Delegates to the existing ``OmniKVTransferManager``.
- """
- if self._kv_transfer_manager is None:
- return list(finished_reqs.keys()) if finished_reqs else []
- result = self._kv_transfer_manager.handle_finished_requests_kv_transfer(
- finished_reqs=finished_reqs,
- kv_caches=kv_caches,
- block_size=block_size,
- cache_dtype=cache_dtype,
- request_id_resolver=request_id_resolver,
- )
- if result:
- self._kv_sent_req_ids.extend(result)
- return result
-
- def recv_kv_cache(
- self,
- request_id: str,
- target_device: torch.device | None = None,
- ) -> tuple[dict[str, Any] | None, int]:
- """Receive KV cache for a request.
-
- Delegates to the existing ``OmniKVTransferManager``.
- """
- if self._kv_transfer_manager is None:
- return None, 0
- return self._kv_transfer_manager.receive_kv_cache_for_request(
- request_id=request_id,
- target_device=target_device,
- )
-
- def receive_cfg_companion_kv_payloads(
- self,
- cfg_request_ids: dict[str, str],
- target_device: torch.device | None = None,
- ) -> dict[str, tuple[dict[str, Any] | None, int]]:
- """Receive raw CFG companion KV payloads keyed by role."""
- return {
- role: self.recv_kv_cache(companion_rid, target_device=target_device)
- for role, companion_rid in cfg_request_ids.items()
- }
-
- def receive_multi_kv_cache(
- self,
- req: Any,
- cfg_kv_collect_func: Any | None = None,
- target_device: torch.device | None = None,
- ) -> bool:
- """Receive primary and optional companion KV caches for a request.
-
- The mixin owns the runner-facing orchestration: primary KV receive,
- companion payload fetch, and applying any model-specific CFG fields back
- onto ``req.sampling_params``.
- """
- if self._kv_transfer_manager is None:
- return False
-
- request_id = getattr(req, "request_id", None) or (
- req.request_ids[0] if hasattr(req, "request_ids") and req.request_ids else None
- )
- if not request_id:
- logger.warning("Request has no ID, cannot receive KV cache")
- return False
-
- active_requests = getattr(self, "requests", None)
- if active_requests is not None and request_id not in active_requests:
- logger.info("Skip receiving KV cache for inactive request %s", request_id)
- return False
-
- primary_ok = False
- data, _size = self.recv_kv_cache(request_id, target_device=target_device)
- if data:
- self._kv_transfer_manager.apply_kv_cache_to_request(req, data)
- primary_ok = True
-
- cfg_ids = getattr(getattr(req, "sampling_params", None), "cfg_kv_request_ids", None)
- if cfg_ids and cfg_kv_collect_func:
- try:
- cfg_role_payloads = self.receive_cfg_companion_kv_payloads(
- cfg_ids,
- target_device=target_device,
- )
- cfg_kvs = cfg_kv_collect_func(request_id, cfg_role_payloads)
- if cfg_kvs and hasattr(req, "sampling_params") and req.sampling_params is not None:
- for key, value in cfg_kvs.items():
- setattr(req.sampling_params, key, value)
- logger.info("Applied CFG KV caches: %s", list(cfg_kvs.keys()))
- except Exception:
- logger.exception("Failed to collect CFG KV caches for %s", request_id)
-
- return primary_ok
-
- # ------------------------------------------------------------------ #
- # Rank-aware KV transfer routing
- # ------------------------------------------------------------------ #
-
- def get_rank_aware_kv_keys(
- self,
- req_id: str,
- from_stage: int,
- to_stage: int | None = None,
- chunk_id: int = 0,
- ) -> list[str]:
- """Build recv-side connector keys for all remote ranks this rank needs.
-
- For heterogeneous TP receive, the local rank is the target rank and must
- fetch one or more source-rank shards keyed as ``from_rank -> to_rank``.
- """
- remote_ranks = self.get_kv_remote_ranks()
- return [
- self.get_kv_connector_key(
- req_id=req_id,
- from_stage=from_stage,
- chunk_id=chunk_id,
- from_rank=remote_rank,
- to_rank=self._local_rank,
- )
- for remote_rank in remote_ranks
- ]
-
- def get_kv_target_ranks_for_send(self) -> list[int]:
- """Determine which target ranks this local rank should send KV shards to."""
- self._validate_kv_tp_topology()
- if self._from_tp == self._to_tp:
- return [self._local_rank]
- if self._from_tp > self._to_tp:
- tp_ratio = self._from_tp // self._to_tp
- return [self._local_rank // tp_ratio]
- tp_ratio = self._to_tp // self._from_tp
- base_rank = self._local_rank * tp_ratio
- return [base_rank + i for i in range(tp_ratio)]
-
- def get_rank_aware_kv_send_keys(
- self,
- req_id: str,
- from_stage: int,
- to_stage: int | None = None,
- chunk_id: int = 0,
- ) -> list[str]:
- """Build send-side connector keys for this rank's KV shard(s)."""
- target_ranks = self.get_kv_target_ranks_for_send()
- return [
- self.get_kv_connector_key(
- req_id=req_id,
- from_stage=from_stage,
- chunk_id=chunk_id,
- from_rank=self._local_rank,
- to_rank=target_rank,
- )
- for target_rank in target_ranks
- ]
-
- @staticmethod
- def _merge_rank_sharded_kv_payloads(payloads: list[dict[str, Any]]) -> dict[str, Any] | None:
- """Merge multiple source-rank KV shards for one target rank."""
- payloads = [payload for payload in payloads if isinstance(payload, dict)]
- if not payloads:
- return None
- if len(payloads) == 1:
- return payloads[0]
-
- merged = dict(payloads[0])
- layer_blocks = merged.get("layer_blocks")
- if not isinstance(layer_blocks, dict):
- return merged
-
- def _merge_tensor_lists(name: str) -> list[torch.Tensor | None]:
- merged_list: list[torch.Tensor | None] = []
- cache_lists = [payload.get("layer_blocks", {}).get(name, []) for payload in payloads]
- max_len = max((len(cache_list) for cache_list in cache_lists), default=0)
- for idx in range(max_len):
- tensors = [cache_list[idx] for cache_list in cache_lists if idx < len(cache_list)]
- tensors = [tensor for tensor in tensors if isinstance(tensor, torch.Tensor)]
- if not tensors:
- merged_list.append(None)
- elif len(tensors) == 1:
- merged_list.append(tensors[0])
- else:
- merged_list.append(torch.cat(tensors, dim=-2).contiguous())
- return merged_list
-
- merged["layer_blocks"] = {
- "key_cache": _merge_tensor_lists("key_cache"),
- "value_cache": _merge_tensor_lists("value_cache"),
- }
- metadata = dict(merged.get("metadata", {}))
- metadata["merged_remote_rank_count"] = len(payloads)
- merged["metadata"] = metadata
- return merged
-
- def _slice_rank_sharded_kv_payload(self, payload: dict[str, Any] | None) -> dict[str, Any] | None:
- """Slice a duplicated source-rank KV shard for ``from_tp < to_tp`` cases."""
- if payload is None or self._from_tp >= self._to_tp:
- return payload
-
- tp_ratio = self._to_tp // self._from_tp
- shard_index = self._local_rank % tp_ratio
- layer_blocks = payload.get("layer_blocks") if isinstance(payload, dict) else None
- if not isinstance(layer_blocks, dict):
- return payload
-
- def _slice_tensor_list(name: str) -> list[torch.Tensor | None]:
- sliced: list[torch.Tensor | None] = []
- for tensor in layer_blocks.get(name, []):
- if not isinstance(tensor, torch.Tensor) or tensor.ndim < 2:
- sliced.append(tensor)
- continue
- head_dim = tensor.shape[-2]
- if head_dim % tp_ratio != 0:
- sliced.append(tensor)
- continue
- per_rank = head_dim // tp_ratio
- start = shard_index * per_rank
- sliced.append(tensor.narrow(-2, start, per_rank).contiguous())
- return sliced
-
- payload = dict(payload)
- payload["layer_blocks"] = {
- "key_cache": _slice_tensor_list("key_cache"),
- "value_cache": _slice_tensor_list("value_cache"),
- }
- metadata = dict(payload.get("metadata", {}))
- metadata["sliced_for_local_rank"] = self._local_rank
- payload["metadata"] = metadata
- return payload
-
- def should_replicate_payload(self) -> bool:
- """Whether non-KV payloads should be replicated across ranks.
-
- Data payloads (stage inputs, chunks) are identical after all-gather,
- so only rank 0 transfers them. KV payloads are rank-specific and
- all ranks participate.
- """
- return self._local_rank != 0
-
- def get_kv_rank_mapping(self) -> dict[str, Any]:
- """Return the current rank mapping configuration.
-
- Useful for debugging and for downstream code that needs to know
- the TP topology without re-parsing model config.
- """
- return {
- "from_tp": self._from_tp,
- "to_tp": self._to_tp,
- "local_rank": self._local_rank,
- "remote_ranks": self.get_kv_remote_ranks(),
- "is_data_transfer_rank": self.is_data_transfer_rank(),
- }
-
- # ------------------------------------------------------------------ #
- # KV transfer lifecycle (RFC – mixin-owned)
- # ------------------------------------------------------------------ #
-
- def mark_kv_transfer(
- self,
- req_id: str,
- seq_len: int,
- block_ids: list[int],
- custom_metadata: dict[str, Any] | None = None,
- ) -> None:
- """Mark a request as needing KV cache transfer.
-
- Called by the scheduler when a transfer trigger fires. The mixin
- owns the lifecycle from this point: pending → active → completed.
- """
- if req_id in self._kv_pending_transfers:
- return
- self._kv_triggered_requests.add(req_id)
- transfer = {
- "seq_len": seq_len,
- "block_ids": block_ids,
- }
- if custom_metadata is not None:
- transfer["custom_metadata"] = custom_metadata
- self._kv_pending_transfers[req_id] = transfer
-
- def drain_pending_kv_transfers(self) -> dict[str, dict[str, Any]]:
- """Drain pending KV transfers and move them to active.
-
- Returns ``{req_id: {seq_len, block_ids}}`` for the model runner
- to submit to ``send_kv_cache``.
- """
- if not self._kv_pending_transfers:
- return {}
- pending = dict(self._kv_pending_transfers)
- self._kv_active_transfers.update(pending.keys())
- self._kv_pending_transfers.clear()
- return pending
-
- def ack_kv_transfers(self, req_ids: list[str] | set[str]) -> None:
- """Acknowledge completed KV transfers (from kv_extracted_req_ids).
-
- Moves requests from active to completed so the scheduler can
- safely free their blocks.
- """
- for req_id in req_ids:
- self._kv_active_transfers.discard(req_id)
- self._kv_completed_transfers.add(req_id)
-
- def drain_completed_kv_transfers(self) -> set[str]:
- """Drain and return completed KV transfer request IDs.
-
- The scheduler calls this to know which requests' blocks can be freed.
- """
- completed = set(self._kv_completed_transfers)
- self._kv_completed_transfers.clear()
- return completed
-
- def is_kv_transfer_triggered(self, req_id: str) -> bool:
- """Check if a request has already triggered KV transfer."""
- return req_id in self._kv_triggered_requests
-
- def has_pending_kv_work(self) -> bool:
- """True if any KV transfers are pending, active, or awaiting ack."""
- return bool(self._kv_pending_transfers or self._kv_active_transfers or self._kv_completed_transfers)
-
- # Output aggregation
- # ------------------------------------------------------------------ #
-
- def _empty_output_with_connector_signals(self) -> Any:
- """Return a minimal ModelRunnerOutput carrying pending connector signals.
-
- Used by early-return paths (e.g. ``num_scheduled_tokens == 0``)
- that still need to deliver ``omni_connector_output`` to the
- Scheduler so that WAITING_FOR_INPUT / WAITING_FOR_CHUNK
- transitions are not lost.
- """
- from vllm_omni.outputs import OmniModelRunnerOutput
-
- output = OmniModelRunnerOutput(req_ids=[], req_id_to_index={})
- output.omni_connector_output = self.get_omni_connector_output()
- return output
-
- def get_omni_connector_output(self) -> OmniConnectorOutput:
- """Collect and reset transfer results for this execute_model cycle.
-
- ``request_metadata`` carries only lightweight scheduling metadata.
- Full payloads remain owned by the Model Runner local cache for all
- paths.
- """
- if not hasattr(self, "_lock"):
- return OmniConnectorOutput()
-
- tp_group = self._get_local_tp_group()
- if self._async_chunk and tp_group is not None and getattr(tp_group, "world_size", 1) > 1:
- if self.is_data_transfer_rank():
- with self._lock:
- fanout_packet = self._collect_async_chunk_fanout_packet_locked()
- else:
- fanout_packet = None
- fanout_packet = self._broadcast_tp_payload_packet(fanout_packet)
- if fanout_packet is None:
- newly_finished = set()
- chunk_finished = set()
- request_metadata = {}
- else:
- if not self.is_data_transfer_rank():
- self._apply_async_chunk_fanout_packet(fanout_packet)
- newly_finished = set(fanout_packet["newly_finished"])
- chunk_finished = set(fanout_packet["chunk_finished"])
- request_metadata = dict(fanout_packet["request_metadata"])
- else:
- with self._lock:
- newly_finished = set(self._finished_load_reqs)
- self._finished_load_reqs.clear()
- chunk_finished = set(self._chunk_finished_req_ids)
- self._chunk_finished_req_ids.clear()
- request_metadata = dict(self._local_request_metadata)
- self._local_request_metadata.clear()
- # _send_side_request_payload is the async accumulation buffer for
- # future recv chunks. Clearing it on every consumable wake-up drops
- # intermediate
- # thinker decode spans before the model side can consume them.
- # Only terminal chunk_finished requests may release that buffer.
- for req_id in chunk_finished:
- if req_id not in self._local_stage_payload_cache:
- continue
- ext_req_id = self._request_ids_mapping.get(req_id, req_id)
- self._send_side_request_payload.pop(ext_req_id, None)
- if ext_req_id != req_id:
- self._send_side_request_payload.pop(req_id, None)
- self._chunk_ready_req_ids.update(newly_finished)
-
- output = OmniConnectorOutput(
- chunk_ready_req_ids=set(self._chunk_ready_req_ids),
- chunk_finished_req_ids=chunk_finished,
- request_metadata=request_metadata,
- kv_sent_req_ids=list(self._kv_sent_req_ids),
- stage_recv_req_ids=set(self._stage_recv_req_ids),
- has_pending_kv_work=self.has_pending_kv_work(),
- )
- if output.stage_recv_req_ids or chunk_finished or newly_finished:
- logger.info(
- "[Stage-%s] get_omni_connector_output: stage_recv=%s, chunk_finished=%s, chunk_ready=%s",
- self._stage_id,
- output.stage_recv_req_ids,
- chunk_finished,
- output.chunk_ready_req_ids,
- )
- self._chunk_ready_req_ids.clear()
- self._kv_sent_req_ids.clear()
- self._stage_recv_req_ids.clear()
- return output
-
- @staticmethod
- def _connector_output_has_signals(output: OmniConnectorOutput) -> bool:
- return bool(
- output.chunk_ready_req_ids
- or output.chunk_finished_req_ids
- or output.request_metadata
- or output.kv_sent_req_ids
- or output.stage_recv_req_ids
- or output.has_pending_kv_work
- )
-
- def attach_omni_connector_output(self, result: Any | None) -> Any:
- omni_output = self.get_omni_connector_output()
- if not self._connector_output_has_signals(omni_output):
- return result
-
- from copy import copy
-
- from vllm.v1.worker.gpu_model_runner import EMPTY_MODEL_RUNNER_OUTPUT
-
- wrapped = copy(result if result is not None else EMPTY_MODEL_RUNNER_OUTPUT)
- wrapped.omni_connector_output = omni_output
- return wrapped
-
- # ------------------------------------------------------------------ #
- # Properties for compatibility with custom_process funcs that access
- # transfer_manager.put_req_chunk / request_payload / code_prompt_token_ids
- # ------------------------------------------------------------------ #
-
- @property
- def put_req_chunk(self) -> dict[str, int]:
- return self._put_req_chunk
-
- @property
- def request_payload(self) -> dict[str, dict[str, Any]]:
- return self._send_side_request_payload
-
- @request_payload.setter
- def request_payload(self, value: dict[str, dict[str, Any]]) -> None:
- self._send_side_request_payload = value
-
- @property
- def code_prompt_token_ids(self) -> dict[str, list[list[int]]]:
- return self._code_prompt_token_ids
-
- @property
- def connector(self) -> Any | None:
- return self._omni_connector
-
- # ------------------------------------------------------------------ #
- # Background I/O threads
- # ------------------------------------------------------------------ #
-
- def _recv_loop(self) -> None:
- """Background thread: poll connector for incoming data."""
- _recv_poll_count = 0
- while not self._stop_event.is_set():
- with self._lock:
- pending_ids = list(self._pending_load_reqs.keys())
-
- if not pending_ids:
- self._work_available.wait(timeout=0.01)
- self._work_available.clear()
- continue
-
- _recv_poll_count += 1
- if _recv_poll_count % 5000 == 1:
- logger.info(
- "[Stage-%s] _recv_loop: polling %s pending reqs: %s (poll#%s)",
- self._stage_id,
- len(pending_ids),
- pending_ids[:5],
- _recv_poll_count,
- )
-
- made_progress = False
- for req_id in pending_ids:
- if self._stop_event.is_set():
- break
- try:
- made_progress = self._poll_single_request(req_id) or made_progress
- except Exception:
- logger.warning("Error receiving data for %s", req_id, exc_info=True)
-
- if not made_progress and not self._stop_event.is_set():
- self._work_available.wait(timeout=0.001)
- self._work_available.clear()
-
- _MAX_SEND_RETRIES = 3
-
- def _save_loop(self) -> None:
- """Background thread: send outgoing data via connector."""
- while not self._stop_event.is_set():
- task = None
- with self._lock:
- for req_id in list(self._pending_save_reqs.keys()):
- dq = self._pending_save_reqs[req_id]
- if dq:
- task = dq.popleft()
- if not dq:
- del self._pending_save_reqs[req_id]
- break
- del self._pending_save_reqs[req_id]
-
- if task is not None:
- success = False
- try:
- success = self._send_single_request(task)
- except Exception:
- logger.error(
- "Error saving data for %s",
- task.get("request_id"),
- exc_info=True,
- )
- if not success:
- self._requeue_or_drop_failed_send(task)
- continue
-
- self._work_available.wait(timeout=0.01)
- self._work_available.clear()
-
- def _requeue_or_drop_failed_send(self, task: dict) -> None:
- """Re-enqueue a failed send task or drop it after max retries."""
- retry_count = task.get("_retry_count", 0) + 1
- req_id = task.get("request_id")
- if retry_count <= self._MAX_SEND_RETRIES:
- task["_retry_count"] = retry_count
- logger.warning(
- "[Stage-%s] Re-enqueuing failed send for %s (retry %d/%d)",
- getattr(self, "_stage_id", "?"),
- req_id,
- retry_count,
- self._MAX_SEND_RETRIES,
- )
- with self._lock:
- dq = self._pending_save_reqs.setdefault(req_id, deque())
- dq.appendleft(task)
- else:
- logger.error(
- "[Stage-%s] Giving up on send for %s after %d retries",
- getattr(self, "_stage_id", "?"),
- req_id,
- self._MAX_SEND_RETRIES,
- )
- self._decrement_pending_save_count(req_id)
-
- # ------------------------------------------------------------------ #
- # Chunk-level poll / send (ported from OmniChunkTransferAdapter)
- # ------------------------------------------------------------------ #
-
- def _poll_single_request(self, req_id: str) -> bool:
- """Poll connector for one chunk of a request (non-blocking)."""
- connector = self._omni_connector
- if connector is None:
- return False
-
- if self._async_chunk and self._model_mode != "ar":
- with self._lock:
- staged_payload = self._local_stage_payload_cache.get(req_id)
- metadata_in_flight = req_id in self._local_request_metadata
- scheduler_wakeup_pending = req_id in self._finished_load_reqs
- if self._payload_is_consumable(staged_payload) or metadata_in_flight or scheduler_wakeup_pending:
- logger.debug(
- "[Stage-%s] delaying recv for req=%s until staged async payload is handed to scheduler",
- self._stage_id,
- req_id,
- )
- return False
-
- target_stage_id = self._stage_id - 1
- chunk_id = self._get_req_chunk[req_id]
- external_req_id = self._request_ids_mapping.get(req_id, req_id)
- connector_get_key = f"{external_req_id}_{target_stage_id}_{chunk_id}"
-
- if self._async_chunk:
- result = self._recv_async_chunk_result(
- connector,
- str(target_stage_id),
- str(self._stage_id),
- connector_get_key,
- )
- else:
- result = self._recv_full_payload_result(
- connector,
- str(target_stage_id),
- str(self._stage_id),
- connector_get_key,
- )
-
- if result is None:
- return False
-
- payload_data, _size = result
- if not payload_data:
- return False
- if isinstance(payload_data, dict):
- logger.info(
- "[Stage-%s] recv_chunk_result: req=%s ext=%s key=%s keys=%s finished=%s",
- self._stage_id,
- req_id,
- external_req_id,
- connector_get_key,
- sorted(payload_data.keys()),
- self._payload_finished(payload_data),
- )
-
- self._get_req_chunk[req_id] += 1
-
- if self._async_chunk:
- is_finished = self._payload_finished(payload_data)
- incoming_payload_consumable = self._payload_is_consumable(payload_data)
-
- if self._model_mode == "ar":
- payload_data = self._accumulate_payload(external_req_id, payload_data)
- payload_consumable = incoming_payload_consumable
- else:
- new_ids = self._payload_audio_codes(payload_data) or []
- if not new_ids and not is_finished:
- return False
- payload_consumable = self._payload_is_consumable(payload_data)
-
- with self._lock:
- if is_finished:
- self._chunk_finished_req_ids.add(req_id)
- self._chunk_stream_completed.add(req_id)
- # Local cache (RFC §2.4) — merge, don't replace, so that
- # earlier chunk keys (e.g. thinker_prefill_embeddings from
- # chunk 0) are not overwritten by later chunks.
- existing = self._local_stage_payload_cache.get(req_id)
- if existing is not None and isinstance(existing, dict) and isinstance(payload_data, dict):
- existing.update(payload_data)
- else:
- self._local_stage_payload_cache[req_id] = payload_data
- staged_payload = self._local_stage_payload_cache[req_id]
- self._async_chunk_updated_req_ids.add(req_id)
- self.put_local_request_metadata(req_id, self._extract_scheduling_metadata(staged_payload))
- # A finish-only sentinel still needs one terminal wake-up so
- # the downstream stage can sync the merged local payload and
- # flush/finish even when the last recv carries no new
- # consumable chunk bytes.
- if payload_consumable or is_finished:
- self._finished_load_reqs.add(req_id)
- if is_finished and not payload_consumable:
- logger.debug(
- "[Stage-%s] finish sentinel arrived for req=%s without new consumable payload",
- self._stage_id,
- req_id,
- )
- elif not payload_consumable:
- logger.debug(
- "[Stage-%s] req=%s received metadata-only / non-consumable async payload; delaying wake-up",
- self._stage_id,
- req_id,
- )
- if is_finished:
- self._pending_load_reqs.pop(req_id, None)
- else:
- # full_payload_mode: the complete payload arrives in a single get(),
- # so always unregister immediately.
- if isinstance(payload_data, dict):
- engine_inputs = payload_data.get("engine_inputs", payload_data)
- else:
- engine_inputs = payload_data
- with self._lock:
- self._local_stage_payload_cache[req_id] = self._snapshot_payload(engine_inputs)
- # Publish full-payload readiness only after the aligned TP broadcast
- # path in recv_full_payload_inputs() has materialized the payload on all
- # local ranks. Publishing metadata / stage_recv from the background recv
- # thread can let the scheduler observe a request before the payload is
- # actually visible to the model thread.
- self._full_payload_pending_broadcast_req_ids.add(req_id)
- self._pending_load_reqs.pop(req_id, None)
- logger.info(
- "[Stage-%s] full_payload recv complete: req=%s key=%s payload_type=%s",
- self._stage_id,
- req_id,
- connector_get_key,
- type(engine_inputs).__name__,
- )
-
- logger.debug("[Stage-%s] Received data for key %s", self._stage_id, connector_get_key)
- return True
-
- def _build_custom_process_payload(
- self,
- request_id: str | None,
- request: Any | None,
- pooling_output: Any | None,
- ) -> Any | None:
- """Run the custom process hook with a best-effort finished kwarg."""
- if self._custom_process_func is None:
- return None
-
- kwargs = {
- "transfer_manager": self,
- "pooling_output": pooling_output,
- "request": request,
- }
- supports_is_finished = getattr(
- self,
- "_custom_process_supports_is_finished",
- self._custom_process_supports_is_finished_kwarg(),
- )
- is_finished_fn = getattr(request, "is_finished", None)
- if callable(is_finished_fn):
- try:
- if supports_is_finished is not False:
- kwargs["is_finished"] = bool(is_finished_fn())
- except Exception:
- logger.debug("request.is_finished() failed for %s", request_id, exc_info=True)
-
- try:
- return self._custom_process_func(**kwargs)
- except TypeError as exc:
- if "is_finished" not in kwargs or not self._is_unexpected_is_finished_kwarg_error(exc):
- logger.exception("custom_process_stage_input_func failed for chunk %s", request_id)
- return None
- kwargs.pop("is_finished", None)
- try:
- return self._custom_process_func(**kwargs)
- except Exception:
- logger.exception("custom_process_stage_input_func failed for chunk %s", request_id)
- return None
- except Exception:
- logger.exception("custom_process_stage_input_func failed for chunk %s", request_id)
- return None
-
- def _custom_process_supports_is_finished_kwarg(self) -> bool | None:
- """Return whether the custom process hook accepts `is_finished`."""
- if self._custom_process_func is None:
- return None
- try:
- signature = inspect.signature(self._custom_process_func)
- except (TypeError, ValueError):
- return None
-
- for param in signature.parameters.values():
- if param.kind == inspect.Parameter.VAR_KEYWORD:
- return True
-
- is_finished_param = signature.parameters.get("is_finished")
- if is_finished_param is None:
- return False
- return is_finished_param.kind in (
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
- inspect.Parameter.KEYWORD_ONLY,
- )
-
- @staticmethod
- def _is_unexpected_is_finished_kwarg_error(exc: TypeError) -> bool:
- message = str(exc)
- return (
- "unexpected keyword argument 'is_finished'" in message
- or 'unexpected keyword argument "is_finished"' in message
- or "positional-only arguments passed as keyword arguments: 'is_finished'" in message
- )
-
- def _send_single_request(self, task: dict) -> bool:
- """Send one queued task via connector.put().
-
- Returns True on success. On failure (put() raises or returns
- ``success=False``), returns False **without** decrementing
- ``_pending_save_counts`` so the caller can retry or clean up.
- """
- connector = self._omni_connector
- if connector is None:
- return True
-
- request_id = task.get("request_id")
- payload_data = task.get("data")
- if payload_data is None and task.get("request") is not None:
- payload_data = self._build_custom_process_payload(
- request_id=request_id,
- request=task.get("request"),
- pooling_output=task.get("pooling_output"),
- )
- put_key = task.get("put_key")
-
- success, _size, _metadata = connector.put(
- from_stage=str(task["stage_id"]),
- to_stage=str(task["next_stage_id"]),
- put_key=put_key,
- data=payload_data,
- )
- logger.info(
- "[Stage-%s] _send_single_request: put_key=%s success=%s size=%s",
- task["stage_id"],
- put_key,
- success,
- _size,
- )
-
- if not success:
- return False
-
- self._decrement_pending_save_count(request_id)
- return True
-
- def _decrement_pending_save_count(self, request_id: str) -> None:
- """Decrement pending save count and run deferred cleanup if zero."""
- cleanup_req_id = None
- with self._lock:
- remaining = self._pending_save_counts.get(request_id, 0)
- if remaining > 1:
- self._pending_save_counts[request_id] = remaining - 1
- elif remaining == 1:
- self._pending_save_counts.pop(request_id, None)
- if request_id in self._deferred_send_cleanup:
- self._deferred_send_cleanup.remove(request_id)
- cleanup_req_id = request_id
- if cleanup_req_id is not None:
- self._put_req_chunk.pop(cleanup_req_id, None)
- self._send_side_request_payload.pop(cleanup_req_id, None)
- self._code_prompt_token_ids.pop(cleanup_req_id, None)
-
- # ------------------------------------------------------------------ #
- # Payload accumulation (ported from OmniChunkTransferAdapter)
- # ------------------------------------------------------------------ #
-
- def _accumulate_payload(self, req_id: str, payload_data: dict[str, Any]) -> dict[str, Any]:
- """Accumulate chunk payloads (concat tensors, extend lists).
-
- Returns a **shallow copy** of the accumulated state so callers
- (e.g. ``_poll_single_request``) can store it in
- ``_local_stage_payload_cache`` without aliasing the authoritative
- ``_send_side_request_payload`` dict.
- """
- if req_id not in self._send_side_request_payload:
- self._send_side_request_payload[req_id] = dict(payload_data)
- return dict(self._send_side_request_payload[req_id])
-
- origin = self._send_side_request_payload[req_id]
- merged = dict(origin)
- override_keys = payload_data.get("override_keys", ())
- drop_decode_span = False
- decode_span_handled = False
- for key, value in payload_data.items():
- if key == "finished":
- merged[key] = value
- continue
- if key == THINKER_DECODE_EMBEDDINGS_KEY:
- merged_span = merge_tensor_spans(
- get_tensor_span(
- origin,
- tensor_key=THINKER_DECODE_EMBEDDINGS_KEY,
- start_key=THINKER_DECODE_TOKEN_START_KEY,
- end_key=THINKER_DECODE_TOKEN_END_KEY,
- ),
- get_tensor_span(
- payload_data,
- tensor_key=THINKER_DECODE_EMBEDDINGS_KEY,
- start_key=THINKER_DECODE_TOKEN_START_KEY,
- end_key=THINKER_DECODE_TOKEN_END_KEY,
- ),
- )
- if merged_span is not None:
- merged[key], merged[THINKER_DECODE_TOKEN_START_KEY], merged[THINKER_DECODE_TOKEN_END_KEY] = (
- merged_span
- )
- decode_span_handled = True
- continue
- if isinstance(value, torch.Tensor) and key in origin:
- if (
- THINKER_DECODE_TOKEN_START_KEY in origin
- or THINKER_DECODE_TOKEN_END_KEY in origin
- or THINKER_DECODE_TOKEN_START_KEY in payload_data
- or THINKER_DECODE_TOKEN_END_KEY in payload_data
- ):
- logger.warning(
- "[Stage-%s] req=%s falling back to legacy thinker decode "
- "merge due to missing/invalid/non-contiguous span "
- "metadata",
- self._stage_id,
- req_id,
- )
- drop_decode_span = True
- merged[key] = torch.cat([origin[key], value], dim=0)
- continue
- merged[key] = value
- continue
- if key in {THINKER_DECODE_TOKEN_START_KEY, THINKER_DECODE_TOKEN_END_KEY}:
- if decode_span_handled or drop_decode_span:
- continue
- merged[key] = value
- continue
- if key in override_keys:
- merged[key] = value
- continue
- if isinstance(value, torch.Tensor) and key in origin:
- merged[key] = torch.cat([origin[key], value], dim=0)
- elif isinstance(value, list) and key in origin:
- merged[key] = origin[key] + value
- else:
- merged[key] = value
-
- if drop_decode_span:
- merged.pop(THINKER_DECODE_TOKEN_START_KEY, None)
- merged.pop(THINKER_DECODE_TOKEN_END_KEY, None)
- self._send_side_request_payload[req_id] = merged
- return dict(merged)
-
- def drop_inactive_request_runtime_state(self, req_id: str) -> None:
- """Clear inactive request state used by both the runner and mixin.
-
- This centralizes the model-runner-side cleanup pattern so
- ``OmniGPUModelRunner`` can reuse it instead of open-coding the same
- inactive-request state mutations.
- """
- if hasattr(self, "model_intermediate_buffer"):
- self.model_intermediate_buffer.pop(req_id, None)
- self.drop_inactive_request_delivery_state(req_id)
-
- # ------------------------------------------------------------------ #
- # Helpers
- # ------------------------------------------------------------------ #
-
- @staticmethod
- def _freeze_request_attr(value: Any) -> Any:
- if isinstance(value, list):
- return list(value)
- if isinstance(value, tuple):
- return list(value)
- if isinstance(value, torch.Tensor):
- return value.clone()
- raw_list = getattr(value, "_x", None)
- if raw_list is not None:
- return list(raw_list)
- return value
-
- def _snapshot_request_for_send(self, request: Any, external_req_id: str) -> Any:
- finished = bool(getattr(request, "is_finished", lambda: False)())
- attrs: dict[str, Any] = {}
- try:
- attrs.update(vars(request))
- except TypeError:
- pass
-
- for name in (
- "request_id",
- "req_id",
- "external_req_id",
- "prompt_token_ids",
- "output_token_ids",
- "all_token_ids",
- "additional_information",
- "sampling_params",
- "multi_modal_data",
- "mm_hashes",
- ):
- if hasattr(request, name):
- attrs[name] = self._freeze_request_attr(getattr(request, name))
-
- attrs["external_req_id"] = external_req_id
- attrs["_frozen_is_finished"] = finished
- snapshot = SimpleNamespace(**attrs)
- snapshot.is_finished = lambda: finished
- return snapshot
-
- @staticmethod
- def _create_connector(model_config: Any) -> OmniConnectorBase | None:
- """Create a connector from model_config, or None if unconfigured."""
- connector_config = getattr(model_config, "stage_connector_config", None)
- if connector_config is None:
- return None
-
- if not isinstance(connector_config, dict):
- connector_config = {
- "name": getattr(connector_config, "name", None),
- "extra": getattr(connector_config, "extra", None),
- }
-
- name = connector_config.get("name")
- if not isinstance(name, str) or not name.strip():
- raise RuntimeError("Invalid stage connector config: missing connector name")
- name = name.strip()
-
- extra = connector_config.get("extra")
- if extra is None:
- extra = {}
- elif not isinstance(extra, dict):
- raise RuntimeError(f"Invalid extra config for connector {name}: expected dict, got {type(extra).__name__}")
-
- spec = ConnectorSpec(name=name, extra=extra)
- try:
- return OmniConnectorFactory.create_connector(spec)
- except Exception as exc:
- raise RuntimeError(f"Failed to create connector {name}") from exc
-
- @staticmethod
- def _load_custom_func(model_config: Any) -> tuple[str | None, Any | None]:
- """Load the connector payload builder for the downstream stage.
-
- Preferred source is ``custom_process_next_stage_input_func``. Some
- full_payload_mode configs (async_chunk=false) only expose the next-stage prompt builder via
- ``custom_process_input_func`` (for example ``thinker2talker``), while the
- connector payload builder lives beside it as ``thinker2talker_full_payload``.
- In that case, derive the full_payload_mode builder path automatically.
- """
- candidates: list[str] = []
-
- next_stage_func = getattr(model_config, "custom_process_next_stage_input_func", None)
- if isinstance(next_stage_func, str) and next_stage_func:
- candidates.append(next_stage_func)
-
- if not getattr(model_config, "async_chunk", False):
- input_func = getattr(model_config, "custom_process_input_func", None)
- if isinstance(input_func, str) and input_func:
- try:
- module_path, func_name = input_func.rsplit(".", 1)
- if func_name.endswith("_full_payload") or func_name.endswith("_batch"):
- candidates.append(f"{module_path}.{func_name}")
- else:
- candidates.append(f"{module_path}.{func_name}_full_payload")
- candidates.append(f"{module_path}.{func_name}_batch")
- candidates.append(input_func)
- except ValueError:
- candidates.append(input_func)
-
- tried: set[str] = set()
- for func_path in candidates:
- if func_path in tried:
- continue
- tried.add(func_path)
- try:
- module_path, func_name = func_path.rsplit(".", 1)
- module = importlib.import_module(module_path)
- func = getattr(module, func_name, None)
- if callable(func):
- if not OmniConnectorModelRunnerMixin._is_connector_payload_builder(func):
- logger.debug(
- "Skipping incompatible connector payload hook %s; signature=%s",
- func_path,
- inspect.signature(func),
- )
- continue
- return func_path, func
- except Exception:
- logger.warning("Failed to load custom func: %s", func_path, exc_info=True)
-
- return None, None
-
- @staticmethod
- def _is_connector_payload_builder(func: Any) -> bool:
- """Whether *func* matches the mixin payload-builder contract."""
- try:
- signature = inspect.signature(func)
- except (TypeError, ValueError):
- return False
-
- params = signature.parameters
- if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params.values()):
- return True
-
- required = {"transfer_manager", "pooling_output", "request"}
- supported = {
- name
- for name, param in params.items()
- if param.kind
- in (
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
- inspect.Parameter.KEYWORD_ONLY,
- )
- }
- return required.issubset(supported)
-
- def _resolve_external_req_id(self, request: Any, fallback_req_id: str) -> str:
- """Resolve the external request ID consistently.
-
- Checks ``_request_ids_mapping`` first (populated by
- ``register_chunk_recv``), then falls back to the request's
- ``external_req_id`` attribute, and finally to the given
- ``fallback_req_id``.
- """
- mapped = self._request_ids_mapping.get(fallback_req_id)
- if mapped is not None:
- return mapped
- if request is not None:
- return getattr(request, "external_req_id", fallback_req_id)
- return fallback_req_id
-
- def _resolve_next_stage_id(self, model_config: Any) -> int:
- """Determine the downstream stage ID from connector config.
-
- Falls back to ``stage_id + 1`` when the config does not specify
- a ``to_stage`` explicitly.
- """
- connector_config = getattr(model_config, "stage_connector_config", None)
- if connector_config is not None:
- if isinstance(connector_config, dict):
- to_stage = connector_config.get("to_stage")
- else:
- to_stage = getattr(connector_config, "to_stage", None)
- if isinstance(to_stage, int):
- return to_stage
- if isinstance(to_stage, str) and to_stage.strip():
- return int(to_stage)
- return self._stage_id + 1
-
- @staticmethod
- def _parse_rank_mapping(model_config: Any) -> dict[str, int]:
- """Parse rank_mapping from connector config (optional).
-
- Returns ``{"from_tp": int, "to_tp": int, "local_rank": int}``.
- When ``rank_mapping`` is absent, assumes 1:1 homogeneous mapping.
- """
- connector_config = getattr(model_config, "stage_connector_config", None)
- if connector_config is not None and not isinstance(connector_config, dict):
- connector_config = getattr(connector_config, "__dict__", {})
-
- rank_mapping: dict = {}
- if isinstance(connector_config, dict):
- rank_mapping = connector_config.get("rank_mapping", {})
-
- from_tp = int(rank_mapping.get("from_tp", 1))
- to_tp = int(rank_mapping.get("to_tp", 1))
-
- local_rank = 0
- try:
- local_rank = int(os.environ.get("LOCAL_RANK", "0"))
- except (ValueError, TypeError):
- pass
-
- return {"from_tp": from_tp, "to_tp": to_tp, "local_rank": local_rank}
-
- # ------------------------------------------------------------------ #
- # Heterogeneous TP rank support
- # ------------------------------------------------------------------ #
-
- def _validate_kv_tp_topology(self) -> None:
- """Reject heterogeneous TP mappings that cannot be routed losslessly."""
- if self._from_tp <= 0 or self._to_tp <= 0:
- raise ValueError(f"Invalid KV TP mapping: from_tp={self._from_tp}, to_tp={self._to_tp}")
- larger = max(self._from_tp, self._to_tp)
- smaller = min(self._from_tp, self._to_tp)
- if larger % smaller != 0:
- raise ValueError(
- f"KV TP mapping must be divisible for rank-aware routing: from_tp={self._from_tp}, to_tp={self._to_tp}"
- )
-
- def get_kv_remote_ranks(self) -> list[int]:
- """Determine which remote ranks this local rank exchanges KV with.
-
- Follows vLLM's ``TpKVTopology.get_target_remote_ranks()`` pattern:
- - ``from_tp > to_tp``: each to-rank reads from multiple from-ranks
- - ``from_tp < to_tp``: multiple to-ranks read from the same from-rank
- - ``from_tp == to_tp``: 1:1 mapping
- """
- self._validate_kv_tp_topology()
- if self._from_tp == self._to_tp:
- return [self._local_rank]
-
- if self._from_tp > self._to_tp:
- tp_ratio = self._from_tp // self._to_tp
- return [self._local_rank * tp_ratio + i for i in range(tp_ratio)]
- else:
- tp_ratio = self._to_tp // self._from_tp
- return [self._local_rank // tp_ratio]
-
- def is_data_transfer_rank(self) -> bool:
- """Whether this rank should participate in data (non-KV) transfer.
-
- Ordinary stage payloads are TP-identical, so exactly one TP rank
- should talk to the connector. When TP is initialized, use TP rank 0
- so the connector leader matches TP-local broadcast source rank.
- Otherwise fall back to LOCAL_RANK==0 for the single-rank case.
- """
- tp_group = self._get_local_tp_group()
- if tp_group is not None and getattr(tp_group, "world_size", 1) > 1:
- return getattr(tp_group, "rank_in_group", 0) == 0
- return self._local_rank == 0
-
- def get_kv_connector_key(
- self,
- req_id: str,
- from_stage: int,
- chunk_id: int,
- from_rank: int,
- to_rank: int,
- ) -> str:
- """Build connector key that includes rank info for KV transfers."""
- return f"{req_id}_{from_stage}_{chunk_id}_{from_rank}_{to_rank}"
diff --git a/vllm_omni/worker/payload_span.py b/vllm_omni/worker/payload_span.py
deleted file mode 100644
index 994392343a9..00000000000
--- a/vllm_omni/worker/payload_span.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""Helpers for explicit thinker decode span metadata."""
-
-from collections.abc import Mapping
-from typing import Any
-
-import torch
-
-THINKER_DECODE_EMBEDDINGS_KEY = "thinker_decode_embeddings"
-THINKER_OUTPUT_TOKEN_IDS_KEY = "thinker_output_token_ids"
-THINKER_DECODE_TOKEN_START_KEY = "thinker_decode_embeddings_token_start"
-THINKER_DECODE_TOKEN_END_KEY = "thinker_decode_embeddings_token_end"
-
-CACHED_THINKER_DECODE_EMBEDDINGS_KEY = "cached_thinker_decode_embeddings"
-CACHED_THINKER_DECODE_TOKEN_START_KEY = "cached_thinker_decode_embeddings_token_start"
-CACHED_THINKER_DECODE_TOKEN_END_KEY = "cached_thinker_decode_embeddings_token_end"
-
-TensorSpan = tuple[torch.Tensor, int, int]
-
-
-def get_tensor_span(payload: Mapping[str, Any], *, tensor_key: str, start_key: str, end_key: str) -> TensorSpan | None:
- tensor = payload.get(tensor_key)
- start = payload.get(start_key)
- end = payload.get(end_key)
- if not isinstance(tensor, torch.Tensor):
- return None
- if not isinstance(start, int) or not isinstance(end, int):
- return None
- if start < 0 or end < start or (end - start) != int(tensor.shape[0]):
- return None
- return tensor, start, end
-
-
-def merge_tensor_spans(existing_span: TensorSpan | None, incoming_span: TensorSpan | None) -> TensorSpan | None:
- if existing_span is None or incoming_span is None:
- return None
-
- existing_tensor, existing_start, existing_end = existing_span
- incoming_tensor, incoming_start, incoming_end = incoming_span
- if incoming_tensor.device != existing_tensor.device or incoming_tensor.dtype != existing_tensor.dtype:
- incoming_tensor = incoming_tensor.to(device=existing_tensor.device, dtype=existing_tensor.dtype)
- if incoming_start == existing_end:
- return torch.cat([existing_tensor, incoming_tensor], dim=0), existing_start, incoming_end
- if incoming_start < existing_end:
- overlap = existing_end - incoming_start
- if overlap >= int(incoming_tensor.shape[0]):
- return existing_tensor, existing_start, existing_end
- trimmed_tensor = incoming_tensor[overlap:]
- return (
- torch.cat([existing_tensor, trimmed_tensor], dim=0),
- existing_start,
- existing_end + int(trimmed_tensor.shape[0]),
- )
- return None
-
-
-def get_tensor_span_row(span: TensorSpan | None, index: int) -> torch.Tensor | None:
- if span is None:
- return None
- tensor, start, end = span
- if index < start or index >= end:
- return None
- return tensor[index - start]