diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml
index f18aef61771..7375dd4a2c4 100644
--- a/.buildkite/pipeline.yml
+++ b/.buildkite/pipeline.yml
@@ -94,7 +94,7 @@ steps:
- "/fsx/hf_cache:/fsx/hf_cache"
- label: "Diffusion Parallelism Test"
- timeout_in_minutes: 20
+ timeout_in_minutes: 25
depends_on: image-build
commands:
- pytest -s -v tests/e2e/offline_inference/test_sequence_parallel.py
@@ -116,7 +116,7 @@ steps:
timeout_in_minutes: 20
depends_on: image-build
commands:
- - pytest -s -v tests/diffusion/test_gpu_worker.py
+ - pytest -s -v tests/diffusion/test_gpu_diffusion_worker.py
agents:
queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU
plugins:
diff --git a/.buildkite/scripts/simple_test.sh b/.buildkite/scripts/simple_test.sh
index 33248d99cde..55ac27cec9f 100755
--- a/.buildkite/scripts/simple_test.sh
+++ b/.buildkite/scripts/simple_test.sh
@@ -52,3 +52,4 @@ VENV_PYTHON="${VENV_DIR}/bin/python"
"${VENV_PYTHON}" -m pytest -v -s tests/entrypoints/
"${VENV_PYTHON}" -m pytest -v -s tests/diffusion/cache/
"${VENV_PYTHON}" -m pytest -v -s tests/model_executor/models/qwen2_5_omni/test_audio_length.py
+"${VENV_PYTHON}" -m pytest -v -s tests/worker/
diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml
index bcece8c495d..86d65f15bcf 100644
--- a/.buildkite/test-amd.yaml
+++ b/.buildkite/test-amd.yaml
@@ -54,7 +54,7 @@ steps:
commands:
- export MIOPEN_DEBUG_CONV_DIRECT=0
- export MIOPEN_DEBUG_CONV_GEMM=0
- - pytest -s -v tests/diffusion/test_gpu_worker.py
+ - pytest -s -v tests/diffusion/test_gpu_diffusion_worker.py
- label: "Omni Model Test Qwen2-5-Omni"
timeout_in_minutes: 15
diff --git a/README.md b/README.md
index 829d683c3c3..1a8ad999050 100644
--- a/README.md
+++ b/README.md
@@ -9,13 +9,13 @@ Easy, fast, and cheap omni-modality model serving for everyone
-| Documentation | User Forum | Developer Slack |
+| Documentation | User Forum | Developer Slack | WeChat |
---
*Latest News* π₯
-
+- [2026/01] We released [0.14.0rc1](https://github.com/vllm-project/vllm-omni/releases/tag/v0.14.0rc2).
- [2026/01] We released [0.12.0rc1](https://github.com/vllm-project/vllm-omni/releases/tag/v0.12.0rc1) - a major RC milestone focused on maturing the diffusion stack, strengthening OpenAI-compatible serving, expanding omni-model coverage, and improving stability across platforms (GPU/NPU/ROCm), please check our latest [design](https://docs.google.com/presentation/d/1qv4qMW1rKAqDREMXiUDLIgqqHQe7TDPj/edit?usp=sharing&ouid=110473603432222024453&rtpof=true&sd=true).
- [2025/11] vLLM community officially released [vllm-project/vllm-omni](https://github.com/vllm-project/vllm-omni) in order to support omni-modality models serving.
@@ -70,6 +70,10 @@ Please check out [Contributing to vLLM-Omni](https://vllm-omni.readthedocs.io/en
## Join the Community
Feel free to ask questions, provide feedbacks and discuss with fellow users of vLLM-Omni in `#sig-omni` slack channel at [slack.vllm.ai](https://slack.vllm.ai) or vLLM user forum at [discuss.vllm.ai](https://discuss.vllm.ai).
+## Star History
+
+[](https://www.star-history.com/#vllm-project/vllm-omni&type=date&legend=top-left)
+
## License
Apache License 2.0, as found in the [LICENSE](./LICENSE) file.
diff --git a/collect_env.py b/collect_env.py
index 71cec0c4a87..8b09379e1a3 100644
--- a/collect_env.py
+++ b/collect_env.py
@@ -57,6 +57,7 @@
"cpu_info",
"rocm_version", # vllm specific field
"vllm_version", # vllm specific field
+ "vllm_omni_version", # vllm-omni specific field
"vllm_build_flags", # vllm specific field
"gpu_topo", # vllm specific field
"env_vars",
@@ -289,6 +290,31 @@ def get_vllm_version():
return __version__
+def get_vllm_omni_version(run_lambda):
+ try:
+ import vllm_omni
+ from vllm_omni import __version__, __version_tuple__
+
+ version_str = __version_tuple__[-1]
+ if isinstance(version_str, str) and version_str.startswith("g"):
+ if "." in version_str:
+ git_sha = version_str.split(".")[0][1:]
+ date = version_str.split(".")[-1][1:]
+ return f"{__version__} (git sha: {git_sha}, date: {date})"
+ else:
+ git_sha = version_str[1:]
+ return f"{__version__} (git sha: {git_sha})"
+
+ package_dir = os.path.dirname(os.path.abspath(vllm_omni.__file__))
+ git_sha = run_and_read_all(run_lambda, f"git -C {package_dir} rev-parse --short HEAD")
+ if git_sha:
+ return f"{__version__} (git sha: {git_sha})"
+
+ return __version__
+ except ImportError:
+ return "N/A (vllm_omni not installed)"
+
+
def summarize_vllm_build_flags():
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
return "CUDA Archs: {}; ROCm: {}".format(
@@ -524,6 +550,7 @@ def get_version_or_na(cfg, prefix):
rocm_version = get_rocm_version(run_lambda)
vllm_version = get_vllm_version()
+ vllm_omni_version = get_vllm_omni_version(run_lambda)
vllm_build_flags = summarize_vllm_build_flags()
gpu_topo = get_gpu_topo(run_lambda)
@@ -555,6 +582,7 @@ def get_version_or_na(cfg, prefix):
cpu_info=get_cpu_info(run_lambda),
rocm_version=rocm_version,
vllm_version=vllm_version,
+ vllm_omni_version=vllm_omni_version,
vllm_build_flags=vllm_build_flags,
gpu_topo=gpu_topo,
env_vars=get_env_vars(),
@@ -621,6 +649,7 @@ def get_version_or_na(cfg, prefix):
==============================
ROCM Version : {rocm_version}
vLLM Version : {vllm_version}
+vLLM-Omni Version : {vllm_omni_version}
vLLM Build Flags:
{vllm_build_flags}
GPU Topology:
diff --git a/docker/Dockerfile.ci b/docker/Dockerfile.ci
index 5e1d00a5f88..55599df29ed 100644
--- a/docker/Dockerfile.ci
+++ b/docker/Dockerfile.ci
@@ -1,11 +1,17 @@
ARG VLLM_BASE_IMAGE=vllm/vllm-openai
-ARG VLLM_BASE_TAG=v0.12.0
+ARG VLLM_BASE_TAG=v0.14.0rc2
FROM ${VLLM_BASE_IMAGE}:${VLLM_BASE_TAG}
ARG APP_DIR=/workspace/vllm-omni
WORKDIR ${APP_DIR}
COPY . .
+# Install system dependencies
+RUN apt-get update && \
+ apt-get install -y ffmpeg && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
# Install vllm-omni into the same uv-managed Python environment used by the base image.
RUN uv pip install --python "$(python3 -c 'import sys; print(sys.executable)')" --no-cache-dir ".[dev]"
diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm
index 7fabb9c3c68..80f709deae0 100644
--- a/docker/Dockerfile.rocm
+++ b/docker/Dockerfile.rocm
@@ -2,7 +2,7 @@ ARG BASE_IMAGE=rocm/vllm-dev:nightly_main_20251205
FROM ${BASE_IMAGE}
ARG COMMON_WORKDIR=/app
-ARG VLLM_VERSION=v0.12.0
+ARG VLLM_VERSION=v0.14.0rc2
ARG PYTORCH_ROCM_ARCH="gfx942;gfx950"
WORKDIR ${COMMON_WORKDIR}
diff --git a/docs/.nav.yml b/docs/.nav.yml
index abe4865af33..be930637175 100644
--- a/docs/.nav.yml
+++ b/docs/.nav.yml
@@ -3,7 +3,7 @@ nav:
- User Guide:
- Getting Started:
- getting_started/quickstart.md
- - getting_started/installation
+ - getting_started/installation/*
- Serving:
- OpenAI-Compatible API:
- Image Generation: serving/image_generation_api.md
@@ -26,17 +26,16 @@ nav:
- Configuration:
- configuration/README.md
- configuration/*
- - Diffusion Acceleration:
- - Overview: user_guide/diffusion_acceleration.md
- - Acceleration Methods:
- - TeaCache: user_guide/acceleration/teacache.md
- - Cache-DiT: user_guide/acceleration/cache_dit_acceleration.md
- - Parallelism Acceleration: user_guide/acceleration/parallelism_acceleration.md
- Models:
- models/supported_models.md
- Features:
- Sleep Mode: features/sleep_mode.md
- - CPU Offloading for Diffusion Model: features/cpu_offload_diffusion.md
+ - Diffusion Features:
+ - Overview: user_guide/diffusion_acceleration.md
+ - TeaCache: user_guide/diffusion/teacache.md
+ - Cache-DiT: user_guide/diffusion/cache_dit_acceleration.md
+ - Parallelism Acceleration: user_guide/diffusion/parallelism_acceleration.md
+ - CPU Offloading: user_guide/diffusion/cpu_offload_diffusion.md
- Developer Guide:
- General:
- contributing/README.md
@@ -47,7 +46,6 @@ nav:
- contributing/model/adding_omni_model.md
- contributing/model/adding_diffusion_model.md
- CI: contributing/ci
- - Tests: contributing/tests
- Design Documents:
- design/index.md
- design/architecture_overview.md
diff --git a/docs/api/README.md b/docs/api/README.md
index 4fa85cdc663..a1f07011118 100644
--- a/docs/api/README.md
+++ b/docs/api/README.md
@@ -82,6 +82,7 @@ Model execution components.
- [vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_code_predictor_mtp.Qwen3OmniMoeTalkerCodePredictor][]
- [vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_talker.Qwen3OmniMoeModel][]
- [vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_talker.Qwen3OmniMoeTalkerForConditionalGeneration][]
+- [vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_talker.Qwen3OmniMoeTalkerSharedExpertWrapper][]
- [vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_thinker.Qwen3MoeLLMForCausalLM][]
- [vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_thinker.Qwen3MoeLLMModel][]
- [vllm_omni.model_executor.models.qwen3_omni.qwen3_omni_moe_thinker.Qwen3OmniMoeConditionalGenerationMixin][]
@@ -102,8 +103,9 @@ Configuration classes.
Worker classes and model runners for distributed inference.
-- [vllm_omni.diffusion.worker.gpu_worker.GPUWorker][]
-- [vllm_omni.diffusion.worker.gpu_worker.WorkerProc][]
+- [vllm_omni.diffusion.worker.gpu_diffusion_model_runner.GPUDiffusionModelRunner][]
+- [vllm_omni.diffusion.worker.gpu_diffusion_worker.GPUDiffusionWorker][]
+- [vllm_omni.diffusion.worker.gpu_diffusion_worker.WorkerProc][]
- [vllm_omni.diffusion.worker.npu.npu_worker.NPUWorker][]
- [vllm_omni.diffusion.worker.npu.npu_worker.NPUWorkerProc][]
- [vllm_omni.worker.gpu_ar_model_runner.ExecuteModelState][]
diff --git a/docs/assets/WeChat.jpg b/docs/assets/WeChat.jpg
new file mode 100644
index 00000000000..5a63afde85c
Binary files /dev/null and b/docs/assets/WeChat.jpg differ
diff --git a/docs/configuration/README.md b/docs/configuration/README.md
index 40439d51121..02440b95dce 100644
--- a/docs/configuration/README.md
+++ b/docs/configuration/README.md
@@ -2,7 +2,7 @@
This section lists the most common options for running vLLM-Omni.
-For options within a vLLM Engine. Please refer to [vLLM Configuration](https://docs.vllm.ai/en/v0.12.0/configuration/index.html)
+For options within a vLLM Engine. Please refer to [vLLM Configuration](https://docs.vllm.ai/en/v0.14.0/configuration/index.html)
Currently, the main options are maintained by stage configs for each model.
@@ -16,6 +16,6 @@ For introduction, please check [Introduction for stage config](./stage_configs.m
## Optimization Features
-- **[TeaCache Configuration](../user_guide/acceleration/teacache.md)** - Enable TeaCache adaptive caching for DiT models to achieve 1.5x-2.0x speedup with minimal quality loss
-- **[Cache-DiT Configuration](../user_guide/acceleration/cache_dit_acceleration.md)** - Enable Cache-DiT as cache acceleration backends for DiT models
-- **[Parallelism Configuration](../user_guide/acceleration/parallelism_acceleration.md)** - Enable parallelism (e.g., sequence parallelism) for for DiT models
+- **[TeaCache Configuration](../user_guide/diffusion/teacache.md)** - Enable TeaCache adaptive caching for DiT models to achieve 1.5x-2.0x speedup with minimal quality loss
+- **[Cache-DiT Configuration](../user_guide/diffusion/cache_dit_acceleration.md)** - Enable Cache-DiT as cache acceleration backends for DiT models
+- **[Parallelism Configuration](../user_guide/diffusion/parallelism_acceleration.md)** - Enable parallelism (e.g., sequence parallelism) for for DiT models
diff --git a/docs/contributing/ci/tests_markers.md b/docs/contributing/ci/tests_markers.md
new file mode 100644
index 00000000000..bf56914f8da
--- /dev/null
+++ b/docs/contributing/ci/tests_markers.md
@@ -0,0 +1,160 @@
+# Markers for Tests
+
+By adding markers before test functions, tests can later be executed uniformly by simply declaring the corresponding marker type.
+
+## Current Markers
+Defined in `pyproject.toml`:
+
+| Marker | Description |
+| ------------------ | ------------------------------------------------------- |
+| `core_model` | Core model tests (run in each PR) |
+| `diffusion` | Diffusion model tests |
+| `omni` | Omni model tests |
+| `cache` | Cache backend tests |
+| `parallel` | Parallelism/distributed tests |
+| `cpu` | Tests that run on CPU |
+| `gpu` | Tests that run on GPU (auto-added) |
+| `cuda` | Tests that run on CUDA (auto-added) |
+| `rocm` | Tests that run on AMD/ROCm (auto-added) |
+| `npu` | Tests that run on NPU/Ascend (auto-added) |
+| `H100` | Tests that require H100 GPU |
+| `L4` | Tests that require L4 GPU |
+| `MI325` | Tests that require MI325 GPU (AMD/ROCm) |
+| `A2` | Tests that require A2 NPU |
+| `A3` | Tests that require A3 NPU |
+| `distributed_cuda` | Tests that require multi cards on CUDA platform |
+| `distributed_rocm` | Tests that require multi cards on ROCm platform |
+| `distributed_npu` | Tests that require multi cards on NPU platform |
+| `skipif_cuda` | Skip if the num of CUDA cards is less than the required |
+| `skipif_rocm` | Skip if the num of ROCm cards is less than the required |
+| `skipif_npu` | Skip if the num of NPU cards is less than the required |
+| `slow` | Slow tests (may skip in quick CI) |
+| `benchmark` | Benchmark tests |
+
+For those markers shown as auto-added, they will be added by the `@hardware_test` decorator.
+
+### Example usage for markers
+
+```python
+from tests.utils import hardware_test
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(
+ res={"cuda": "L4", "rocm": "MI325", "npu": "A2"},
+ num_cards=2,
+)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_video_to_audio()
+ ...
+```
+### Decorator: `@hardware_test`
+
+This decorator is intended to make hardware-aware, cross-platform test authoring easier and more robust for CI/CD environments. The `hardware_test` decorator in `vllm-omni/tests/utils.py` performs the following actions:
+
+1. **Applies platform and resource markers**
+ Adds the appropriate pytest markers for each specified hardware platform (e.g., `cuda`, `rocm`, `npu`) and resource type (e.g., `L4`, `H100`, `MI325`, `A2`, `A3`).
+ ```
+ @pytest.mark.cuda
+ @pytest.mark.L4
+ ```
+2. **Handles multi-card (distributed) scenarios**
+ For tests requiring multiple cards, it automatically adds distributed markers such as `distributed_cuda`, `distributed_rocm`, or `distributed_npu`.
+ ```
+ @pytest.mark.distributed_cuda(num_cards=num_cards)
+ ```
+3. **Supports flexible card requirements**
+ Accepts `num_cards` as either a single integer for all platforms or as a dictionary with per-platform values. If not specified, defaults to 1 card per platform.
+
+4. **Integrates resource validation**
+ On CUDA, adds a skip marker (`skipif_cuda`) if the system does not have the required number of devices.
+ Support for `skipif_rocm` and `skipif_npu` will be implemented later.
+
+
+5. **Runs each test in a new process**
+ Automatically wraps the distributed test with a decorator (`@create_new_process_for_each_test`) to ensure isolation and compatibility with multi-process hardware backends.
+
+6. **Works with pytest filtering**
+ Allows tests to be filtered and selected at runtime using standard pytest marker expressions (e.g., `-m "distributed_cuda and L4"`).
+
+#### Example usage for decorator
+- Single call for multiple platforms:
+ ```python
+ @hardware_test(
+ res={"cuda": "L4", "rocm": "MI325", "npu": "A2"},
+ num_cards={"cuda": 2, "rocm": 2, "npu": 2},
+ )
+ ```
+ or
+ ```python
+ @hardware_test(
+ res={"cuda": "L4", "rocm": "MI325", "npu": "A2"},
+ num_cards=2,
+ )
+ ```
+- `res` must be a dict; supported resources: CUDA (L4/H100), ROCm (MI325), NPU (A2/A3)
+- `num_cards` can be int (all platforms) or dict (per platform); defaults to 1 when missing
+- `hardware_test` automatically applies `@create_new_process_for_each_test` for distributed tests.
+- Distributed markers (`distributed_cuda`, `distributed_rocm`, `distributed_npu`) are auto-added for multi-card cases
+- Filtering examples:
+ - CUDA only: `pytest -m "distributed_cuda and L4"`
+ - ROCm only: `pytest -m "distributed_rocm and MI325"`
+ - NPU only: `pytest -m "distributed_npu"`
+
+## Add Support for a New Platform
+
+If you want to add support for a new platform (e.g., "tpu" for a new accelerator), follow these steps:
+
+1. **Extend the marker list in your pytest config** so that platform/resource markers are defined:
+ ```toml
+ # In pyproject.toml or pytest.ini
+ [tool.pytest.ini_options]
+ markers = [
+ # ... existing markers ...
+ "tpu: Tests that require TPU device",
+ "TPU_V3: Tests that require TPU v3 hardware",
+ "distributed_tpu: Tests that require multiple TPU devices",
+ ]
+ ```
+2. **Implement a marker construction function for your platform** in `vllm-omni/tests/utils.py`:
+ ```python
+ # In vllm-omni/tests/utils.py
+
+ def tpu_marks(*, res: str, num_cards: int):
+ test_platform = pytest.mark.tpu
+ if res == "TPU_V3":
+ test_resource = pytest.mark.TPU_V3
+ else:
+ raise ValueError(
+ f"Invalid TPU resource type: {res}. Supported: TPU_V3")
+
+ if num_cards == 1:
+ return [test_platform, test_resource]
+ else:
+ test_distributed = pytest.mark.distributed_tpu(num_cards=num_cards)
+ # Optionally: add skipif_tpu when implemented
+ return [test_platform, test_resource, test_distributed]
+ ```
+3. **Update `hardware_test` to recognize your new platform**:
+ In the relevant place (see the `hardware_test` implementation), add:
+ ```python
+ if platform == "tpu":
+ marks = tpu_marks(res=resource, num_cards=cards)
+ ```
+4. **(Recommended) Add a test using your new markers**:
+ ```python
+ @hardware_test(
+ res={"tpu": "TPU_V3"},
+ num_cards=2,
+ )
+ def test_my_tpu_feature():
+ ...
+ ```
+
+**Summary**:
+- Add pytest markers for your new platform/resources
+- Implement a marker function (`xxx_marks`)
+- Plug into `hardware_test`
+- You're done: tests decorated with `@hardware_test` using your platform now automatically get the correct markers, distribution, and isolation!
+
+See code in `vllm-omni/tests/utils.py` for existing examples (`cuda_marks`, `rocm_marks`, `npu_marks`).
diff --git a/docs/contributing/tests/tests_style.md b/docs/contributing/ci/tests_style.md
similarity index 94%
rename from docs/contributing/tests/tests_style.md
rename to docs/contributing/ci/tests_style.md
index c88e17dee34..65c2b044346 100644
--- a/docs/contributing/tests/tests_style.md
+++ b/docs/contributing/ci/tests_style.md
@@ -139,7 +139,7 @@ vllm_omni/ tests/
4. **Documentation**: Add docstrings to all test functions
5. **Environment variables**: Set uniformly in `conftest.py` or at the top of files
6. **Type annotations**: Add type annotations to all test function parameters
-7. **Resources**, Using pytest tag to specify the computation resources the test required.
+7. **Pytest Markers**: Add necessary markers like `@pytest.mark.core_model` and use `@hardware_test` to declare hardware requirements (check detailed in [Markers for Tests](../ci/tests_markers.md)).
### Template
#### E2E - Online serving
@@ -155,6 +155,7 @@ from pathlib import Path
import pytest
import openai
+from tests.utils import hardware_test
# Optional: set process start method for workers
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@@ -184,6 +185,12 @@ def base64_encoded_video() -> str:
def dummy_messages_from_video_data(video_data_url: str, content_text: str) -> str:
xxx
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(
+ res={"cuda": "L4", "rocm": "MI325", "npu": "A2"},
+ num_cards={"cuda": 2, "rocm": 2, "npu": 4},
+)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_video_to_audio(
client: openai.OpenAI,
@@ -226,6 +233,7 @@ from pathlib import Path
import pytest
from vllm.assets.video import VideoAsset
+from tests.utils import hardware_test
from ..multi_stages.conftest import OmniRunner
# Optional: set process start method for workers
@@ -239,7 +247,12 @@ test_params = [(model, stage_config) for model in models for stage_config in sta
# function name: test_{input_modality}_to_{output_modality}
# modality candidate: text, image, audio, video, mixed_modalities
-@pytest.mark.gpu_mem_high # requires high-memory GPU node
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(
+ res={"cuda": "L4", "rocm": "MI325", "npu": "A2"},
+ num_cards=2,
+)
@pytest.mark.parametrize("test_config", test_params)
def test_video_to_audio(omni_runner: type[OmniRunner], model: str) -> None:
"""Offline inference: video input, audio output."""
@@ -263,4 +276,5 @@ def test_video_to_audio(omni_runner: type[OmniRunner], model: str) -> None:
1. The file is saved in an appropriate place and the file name is clear.
2. The coding style follows the requirements outlined above.
-3. For e2e model test, please ensure the test is configured under the `./buildkite/` folder.
+3. **All test functions have appropriate pytest markers**
+4. For tests that need run in CI, please ensure the test is configured under the `./buildkite/` folder.
diff --git a/docs/contributing/model/adding_diffusion_model.md b/docs/contributing/model/adding_diffusion_model.md
index 70fdc6a0817..7eb56d5f5bc 100644
--- a/docs/contributing/model/adding_diffusion_model.md
+++ b/docs/contributing/model/adding_diffusion_model.md
@@ -140,7 +140,7 @@ Key point for writing the example:
+ Save or display the generated results so users can validate the integration.
# Testing
-For comprehensive testing guidelines, please refer to the [Test File Structure and Style Guide](../tests/tests_style.md).
+For comprehensive testing guidelines, please refer to the [Test File Structure and Style Guide](../ci/tests_style.md).
## Adding a Model Recipe
diff --git a/docs/contributing/model/adding_omni_model.md b/docs/contributing/model/adding_omni_model.md
index 2a91a305091..81499118623 100644
--- a/docs/contributing/model/adding_omni_model.md
+++ b/docs/contributing/model/adding_omni_model.md
@@ -572,7 +572,7 @@ def talker2code2wav(
## Testing
-For comprehensive testing guidelines, please refer to the [Test File Structure and Style Guide](../tests/tests_style.md).
+For comprehensive testing guidelines, please refer to the [Test File Structure and Style Guide](../ci/tests_style.md).
## Adding a Model Recipe
diff --git a/docs/design/architecture_overview.md b/docs/design/architecture_overview.md
index ea7eff4397d..6793895cd46 100644
--- a/docs/design/architecture_overview.md
+++ b/docs/design/architecture_overview.md
@@ -64,13 +64,13 @@ According to analysis for current popular open-source models, most of them have
## Key Components
-| Component | Description |
-|-----------|-------------|
-| **OmniRouter** | provide an intelligent router for Omni-modality requests dispatch |
-| **EntryPoints** | define the APIs for offline/online serving (APIServer, Omni/AsyncOmni) and provide the OmniStage abstraction for different AR/DiT stages |
-| **AR** | adapted for omni-modality models while inheriting efficient features from vLLM, such as cache management |
-| **Diffusion** | natively implemented and optimized using acceleration components |
-| **OmniConnector** | supports fully disaggregation based on E/P/D/G (Encoding/Processing/Decoding/Generation) disaggregation across stages |
+| Component | Description |
+| ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------- |
+| **OmniRouter** | provide an intelligent router for Omni-modality requests dispatch |
+| **EntryPoints** | define the APIs for offline/online serving (APIServer, Omni/AsyncOmni) and provide the OmniStage abstraction for different AR/DiT stages |
+| **AR** | adapted for omni-modality models while inheriting efficient features from vLLM, such as cache management |
+| **Diffusion** | natively implemented and optimized using acceleration components |
+| **OmniConnector** | supports fully disaggregation based on E/P/D/G (Encoding/Processing/Decoding/Generation) disaggregation across stages |
Disaggregated stages are managed through configuration, such as in the Qwen3-Omni example, where stages like Thinker, Talker, and Code2wav are defined as separate OmniStage instances with specific resources and input/output type.
@@ -192,4 +192,4 @@ curl -sS -X POST http://localhost:8091/v1/chat/completions \
}
```
-For more usages, please refer to [examples](../user_guide/examples/).
+For more usages, please refer to [examples](https://github.com/vllm-project/vllm-omni/tree/main/examples).
diff --git a/docs/getting_started/installation/gpu/cuda.inc.md b/docs/getting_started/installation/gpu/cuda.inc.md
index 540b1852947..c073c152ee9 100644
--- a/docs/getting_started/installation/gpu/cuda.inc.md
+++ b/docs/getting_started/installation/gpu/cuda.inc.md
@@ -58,11 +58,11 @@ If you want to check, modify or debug with source code of vLLM, install the libr
```bash
git clone https://github.com/vllm-project/vllm.git
cd vllm
-git checkout v0.12.0
+git checkout v0.14.0rc2
```
Set up environment variables to get pre-built wheels. If there are internet problems, just download the whl file manually. And set `VLLM_PRECOMPILED_WHEEL_LOCATION` as your local absolute path of whl file.
```bash
-export VLLM_PRECOMPILED_WHEEL_LOCATION=https://github.com/vllm-project/vllm/releases/download/v0.12.0/vllm-0.12.0-cp38-abi3-manylinux_2_31_x86_64.whl
+export VLLM_PRECOMPILED_WHEEL_LOCATION=https://github.com/vllm-project/vllm/releases/download/v0.14.0/vllm-0.14.0rc2-cp38-abi3-manylinux_2_31_x86_64.whl
```
Install vllm with command below (If you have no existing PyTorch).
```bash
@@ -93,7 +93,7 @@ docker run --runtime nvidia --gpus 2 \
--env "HF_TOKEN=$HF_TOKEN" \
-p 8091:8091 \
--ipc=host \
- vllm/vllm-omni:v0.12.0rc1 \
+ vllm/vllm-omni:v0.14.0rc2 \
--model Qwen/Qwen3-Omni-30B-A3B-Instruct --port 8091
```
diff --git a/docs/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu/rocm.inc.md
index 1fa751e2508..638c689d4be 100644
--- a/docs/getting_started/installation/gpu/rocm.inc.md
+++ b/docs/getting_started/installation/gpu/rocm.inc.md
@@ -68,7 +68,7 @@ docker run -it \
-v :/app/model \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--env "HF_TOKEN=$HF_TOKEN" \
- vllm/vllm-omni-rocm:v0.12.0rc1 \
+ vllm/vllm-omni-rocm:v0.14.0rc2 \
vllm serve --model Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
```
@@ -86,7 +86,7 @@ docker run -it \
-v :/app/model \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--env "HF_TOKEN=$HF_TOKEN" \
- vllm/vllm-omni-rocm:v0.12.0rc1 \
+ vllm/vllm-omni-rocm:v0.14.0rc2 \
bash
```
diff --git a/docs/getting_started/installation/npu/npu.inc.md b/docs/getting_started/installation/npu/npu.inc.md
index 9c36c2be626..5714449a70a 100644
--- a/docs/getting_started/installation/npu/npu.inc.md
+++ b/docs/getting_started/installation/npu/npu.inc.md
@@ -13,10 +13,10 @@ export DEVICE0=/dev/davinci0
export DEVICE1=/dev/davinci1
# Update the vllm-ascend image
# Atlas A2:
-# export IMAGE=quay.io/ascend/vllm-ascend:v0.12.0rc1
+# export IMAGE=quay.io/ascend/vllm-ascend:v0.14.0rc2
# Atlas A3:
-# export IMAGE=quay.io/ascend/vllm-ascend:v0.12.0rc1-a3
-export IMAGE=quay.io/ascend/vllm-ascend:v0.12.0rc1
+# export IMAGE=quay.io/ascend/vllm-ascend:v0.14.0rc2-a3
+export IMAGE=quay.io/ascend/vllm-ascend:v0.14.0rc2
docker run --rm \
--name vllm-omni-npu \
--shm-size=1g \
@@ -42,7 +42,7 @@ source ~/.bashrc
# Inside the container, install vLLM-Omni from source
cd /vllm-workspace
-git clone -b v0.12.0rc1 https://github.com/vllm-project/vllm-omni.git
+git clone -b v0.14.0rc2 https://github.com/vllm-project/vllm-omni.git
cd vllm-omni
pip install -v -e .
export VLLM_WORKER_MULTIPROC_METHOD=spawn
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 1b993152837..cd70019de32 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -32,6 +32,7 @@ th {
|`LongcatImagePipeline` | LongCat-Image | `meituan-longcat/LongCat-Image` |
|`LongCatImageEditPipeline` | LongCat-Image-Edit | `meituan-longcat/LongCat-Image-Edit` |
|`StableDiffusion3Pipeline` | Stable-Diffusion-3 | `stabilityai/stable-diffusion-3.5-medium` |
+|`Flux2KleinPipeline` | FLUX.2-klein | `black-forest-labs/FLUX.2-klein-4B`, `black-forest-labs/FLUX.2-klein-9B` |
|`StableAudioPipeline` | Stable-Audio-Open | `stabilityai/stable-audio-open-1.0` |
diff --git a/docs/user_guide/acceleration/cache_dit_acceleration.md b/docs/user_guide/diffusion/cache_dit_acceleration.md
similarity index 100%
rename from docs/user_guide/acceleration/cache_dit_acceleration.md
rename to docs/user_guide/diffusion/cache_dit_acceleration.md
diff --git a/docs/features/cpu_offload_diffusion.md b/docs/user_guide/diffusion/cpu_offload_diffusion.md
similarity index 93%
rename from docs/features/cpu_offload_diffusion.md
rename to docs/user_guide/diffusion/cpu_offload_diffusion.md
index aaa4243a3a2..533b6b3b964 100644
--- a/docs/features/cpu_offload_diffusion.md
+++ b/docs/user_guide/diffusion/cpu_offload_diffusion.md
@@ -23,7 +23,7 @@ if __name__ == "__main__":
m = Omni(model="Qwen/Qwen-Image",enable_cpu_offload=True)
```
-- **CLI**: pass `--dit-cpu-offload` to the diffusion service entrypoint.
+- **CLI**: pass `--enable-cpu-offload` to the diffusion service entrypoint.
## Known Limitations
- Cold start latency increases for over one minute for some models(e.g., Qwen-Image)
diff --git a/docs/user_guide/acceleration/parallelism_acceleration.md b/docs/user_guide/diffusion/parallelism_acceleration.md
similarity index 95%
rename from docs/user_guide/acceleration/parallelism_acceleration.md
rename to docs/user_guide/diffusion/parallelism_acceleration.md
index dfacd2183ff..324301158d8 100644
--- a/docs/user_guide/acceleration/parallelism_acceleration.md
+++ b/docs/user_guide/diffusion/parallelism_acceleration.md
@@ -23,13 +23,23 @@ The following table shows which models are currently supported by parallelism me
| **LongCat-Image** | `meituan-longcat/LongCat-Image` | β
| β
| β | β |
| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | β
| β
| β | β |
| **Ovis-Image** | `OvisAI/Ovis-Image` | β | β | β | β |
-| **Qwen-Image** | `Qwen/Qwen-Image` | β
| β
| β
| β |
+| **Qwen-Image** | `Qwen/Qwen-Image` | β
| β
| β
| β
|
| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | β
| β
| β
| β |
| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | β
| β
| β
| β |
| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | β
| β
| β
| β |
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | β | β | β | β
(TP=2 only) |
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | β | β | β | β |
+
+!!! note "TP Limitations for Diffusion Models"
+ We currently implement Tensor Parallelism (TP) only for the DiT (Diffusion Transformer) blocks. This is because the `text_encoder` component in vLLM-Omni uses the original Transformers implementation, which does not yet support TP.
+
+ - Good news: The text_encoder typically has minimal impact on overall inference performance.
+ - Bad news: When TP is enabled, every TP process retains a full copy of the text_encoder weights, leading to significant GPU memory waste.
+
+ We are actively refactoring this design to address this. For details and progress, please refer to [Issue #771](https://github.com/vllm-project/vllm-omni/issues/771).
+
+
!!! note "Why Z-Image is TP=2 only"
Z-Image Turbo is currently limited to `tensor_parallel_size` of **1 or 2** due to model shape divisibility constraints.
For example, the model has `n_heads=30` and a final projection out dimension of `64`, so valid TP sizes must divide both 30 and 64; the only common divisors are **1 and 2**.
diff --git a/docs/user_guide/acceleration/teacache.md b/docs/user_guide/diffusion/teacache.md
similarity index 100%
rename from docs/user_guide/acceleration/teacache.md
rename to docs/user_guide/diffusion/teacache.md
diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md
index 8f78ae32e50..0184a8fcd39 100644
--- a/docs/user_guide/diffusion_acceleration.md
+++ b/docs/user_guide/diffusion_acceleration.md
@@ -6,8 +6,8 @@ vLLM-Omni supports various cache acceleration methods to speed up diffusion mode
vLLM-Omni currently supports two main cache acceleration backends:
-1. **[TeaCache](acceleration/teacache.md)** - Hook-based adaptive caching that caches transformer computations when consecutive timesteps are similar
-2. **[Cache-DiT](acceleration/cache_dit_acceleration.md)** - Library-based acceleration using multiple techniques:
+1. **[TeaCache](diffusion/teacache.md)** - Hook-based adaptive caching that caches transformer computations when consecutive timesteps are similar
+2. **[Cache-DiT](diffusion/cache_dit_acceleration.md)** - Library-based acceleration using multiple techniques:
- **DBCache** (Dual Block Cache): Caches intermediate transformer block outputs based on residual differences
- **TaylorSeer**: Uses Taylor expansion-based forecasting for faster inference
- **SCM** (Step Computation Masking): Selectively computes steps based on adaptive masking
@@ -16,11 +16,11 @@ Both methods can provide significant speedups (typically **1.5x-2.0x**) while ma
vLLM-Omni also supports parallelism methods for diffusion models, including:
-1. [Ulysses-SP](acceleration/parallelism_acceleration.md#ulysses-sp) - splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads.
+1. [Ulysses-SP](diffusion/parallelism_acceleration.md#ulysses-sp) - splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads.
-2. [Ring-Attention](acceleration/parallelism_acceleration.md#ring-attention) - splits the input along the sequence dimension and uses ring-based P2P communication to accumulate attention results, keeping the sequence dimension sharded.
+2. [Ring-Attention](diffusion/parallelism_acceleration.md#ring-attention) - splits the input along the sequence dimension and uses ring-based P2P communication to accumulate attention results, keeping the sequence dimension sharded.
-3. [CFG-Parallel](acceleration/parallelism_acceleration.md#cfg-parallel) - runs the positive/negative prompts of classifier-free guidance (CFG) on different devices, then merges on a single device to perform the scheduler step.
+3. [CFG-Parallel](diffusion/parallelism_acceleration.md#cfg-parallel) - runs the positive/negative prompts of classifier-free guidance (CFG) on different devices, then merges on a single device to perform the scheduler step.
## Quick Comparison
@@ -49,6 +49,7 @@ The following table shows which models are currently supported by each accelerat
| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | β | β
| β
| β
| β
|
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | β | β
|β | β | β |
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | β | β
| β | β | β |
+| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | β | β
| β | β | β |
### VideoGen
@@ -197,7 +198,7 @@ outputs = omni.generate(prompt="turn this cat to a dog",
For detailed information on each acceleration method:
-- **[TeaCache Guide](acceleration/teacache.md)** - Complete TeaCache documentation, configuration options, and best practices
-- **[Cache-DiT Acceleration Guide](acceleration/cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters
-- **[Sequence Parallelism](acceleration/parallelism_acceleration.md#sequence-parallelism)** - Guidance on how to set sequence parallelism with configuration.
-- **[CFG-Parallel](acceleration/parallelism_acceleration.md#cfg-parallel)** - Guidance on how to set CFG-Parallel to run positive/negative branches across ranks.
+- **[TeaCache Guide](diffusion/teacache.md)** - Complete TeaCache documentation, configuration options, and best practices
+- **[Cache-DiT Acceleration Guide](diffusion/cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters
+- **[Sequence Parallelism](diffusion/parallelism_acceleration.md#sequence-parallelism)** - Guidance on how to set sequence parallelism with configuration.
+- **[CFG-Parallel](diffusion/parallelism_acceleration.md#cfg-parallel)** - Guidance on how to set CFG-Parallel to run positive/negative branches across ranks.
diff --git a/docs/user_guide/examples/online_serving/gradio_demo.md b/docs/user_guide/examples/online_serving/gradio_demo.md
deleted file mode 100644
index 38278d9cf5a..00000000000
--- a/docs/user_guide/examples/online_serving/gradio_demo.md
+++ /dev/null
@@ -1,7 +0,0 @@
-# Gradio Demo
-
-Source .
-
-``````py
---8<-- "examples/online_serving/gradio_demo.py"
-``````
diff --git a/docs/user_guide/examples/online_serving/openai_chat_completion_client_for_multimodal_generation.md b/docs/user_guide/examples/online_serving/openai_chat_completion_client_for_multimodal_generation.md
deleted file mode 100644
index ca3fa8306b3..00000000000
--- a/docs/user_guide/examples/online_serving/openai_chat_completion_client_for_multimodal_generation.md
+++ /dev/null
@@ -1,7 +0,0 @@
-# OpenAI Chat Completion Client For Multimodal Generation
-
-Source .
-
-``````py
---8<-- "examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py"
-``````
diff --git a/examples/offline_inference/image_to_image/image_edit.py b/examples/offline_inference/image_to_image/image_edit.py
index 8f4dbeef98e..c31d098252b 100644
--- a/examples/offline_inference/image_to_image/image_edit.py
+++ b/examples/offline_inference/image_to_image/image_edit.py
@@ -55,7 +55,15 @@
--prompt "Edit description" \
--cfg_parallel_size 2 \
--num_inference_steps 50 \
- --cfg_scale 4.0 \
+ --cfg_scale 4.0
+
+Usage (disable torch.compile):
+ python image_edit.py \
+ --image input.png \
+ --prompt "Edit description" \
+ --enforce_eager \
+ --num_inference_steps 50 \
+ --cfg_scale 4.0
For more options, run:
python image_edit.py --help
@@ -173,6 +181,12 @@ def parse_args() -> argparse.Namespace:
default=1,
help="Number of GPUs used for ring sequence parallelism.",
)
+ parser.add_argument(
+ "--tensor_parallel_size",
+ type=int,
+ default=1,
+ help="Number of GPUs used for tensor parallelism (TP) inside the DiT.",
+ )
parser.add_argument("--layers", type=int, default=4, help="Number of layers to decompose the input image into.")
parser.add_argument(
"--resolution",
@@ -260,6 +274,11 @@ def parse_args() -> argparse.Namespace:
choices=[1, 2],
help="Number of GPUs used for classifier free guidance parallel size.",
)
+ parser.add_argument(
+ "--enforce_eager",
+ action="store_true",
+ help="Disable torch.compile and force eager execution.",
+ )
return parser.parse_args()
@@ -288,7 +307,10 @@ def main():
vae_use_slicing = is_npu()
vae_use_tiling = is_npu()
parallel_config = DiffusionParallelConfig(
- ulysses_degree=args.ulysses_degree, ring_degree=args.ring_degree, cfg_parallel_size=args.cfg_parallel_size
+ ulysses_degree=args.ulysses_degree,
+ ring_degree=args.ring_degree,
+ cfg_parallel_size=args.cfg_parallel_size,
+ tensor_parallel_size=args.tensor_parallel_size,
)
# Configure cache based on backend type
@@ -321,6 +343,7 @@ def main():
cache_backend=args.cache_backend,
cache_config=cache_config,
parallel_config=parallel_config,
+ enforce_eager=args.enforce_eager,
)
print("Pipeline loaded")
@@ -337,7 +360,7 @@ def main():
else:
print(f" Input image size: {input_image.size}")
print(
- f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}"
+ f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}, tensor_parallel_size={args.tensor_parallel_size}"
)
print(f"{'=' * 60}\n")
diff --git a/examples/offline_inference/qwen2_5_omni/end2end.py b/examples/offline_inference/qwen2_5_omni/end2end.py
index d1e0264cd7a..22f2161c22b 100644
--- a/examples/offline_inference/qwen2_5_omni/end2end.py
+++ b/examples/offline_inference/qwen2_5_omni/end2end.py
@@ -278,9 +278,9 @@ def get_audio_query(question: str = None, audio_path: str | None = None, samplin
query_map = {
- "mixed_modalities": get_mixed_modalities_query,
+ "use_mixed_modalities": get_mixed_modalities_query,
"use_audio_in_video": get_use_audio_in_video_query,
- "multi_audios": get_multi_audios_query,
+ "use_multi_audios": get_multi_audios_query,
"use_image": get_image_query,
"use_video": get_video_query,
"use_audio": get_audio_query,
@@ -434,7 +434,7 @@ def parse_args():
"--query-type",
"-q",
type=str,
- default="mixed_modalities",
+ default="use_mixed_modalities",
choices=query_map.keys(),
help="Query type.",
)
diff --git a/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh b/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh
index 26cbf0172c5..c8e4cd2cbf3 100644
--- a/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh
+++ b/examples/offline_inference/qwen2_5_omni/run_single_prompt.sh
@@ -1,2 +1,2 @@
python end2end.py --output-wav output_audio \
- --query-type mixed_modalities
+ --query-type use_mixed_modalities
diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py
index 9a1324305cf..3cd8918208e 100644
--- a/examples/offline_inference/qwen3_omni/end2end.py
+++ b/examples/offline_inference/qwen3_omni/end2end.py
@@ -12,8 +12,8 @@
import librosa
import numpy as np
import soundfile as sf
-from PIL import Image
import vllm
+from PIL import Image
from vllm import SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
@@ -226,19 +226,76 @@ def get_multi_audios_query() -> QueryResult:
)
+# def get_use_audio_in_video_query(video_path: str | None = None) -> QueryResult:
+# question = (
+# "Describe the content of the video in details, then convert what the "
+# "baby say into text."
+# )
+# prompt = (
+# f"<|im_start|>system\n{default_system}<|im_end|>\n"
+# "<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>"
+# f"{question}<|im_end|>\n"
+# f"<|im_start|>assistant\n"
+# )
+# if video_path:
+# if not os.path.exists(video_path):
+# raise FileNotFoundError(f"Video file not found: {video_path}")
+# video_frames = video_to_ndarrays(video_path, num_frames=16)
+# else:
+# video_frames = VideoAsset(name="baby_reading", num_frames=16).np_ndarrays
+# audio = extract_video_audio(video_path, sampling_rate=16000)
+# return QueryResult(
+# inputs={
+# "prompt": prompt,
+# "multi_modal_data": {
+# "video": video_frames,
+# "audio": audio,
+# },
+# "mm_processor_kwargs": {
+# "use_audio_in_video": True,
+# },
+# },
+# limit_mm_per_prompt={"audio": 1, "video": 1},
+# )
+def get_use_audio_in_video_query() -> QueryResult:
+ question = "Describe the content of the video in details, then convert what the baby say into text."
+ prompt = (
+ f"<|im_start|>system\n{default_system}<|im_end|>\n"
+ "<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>"
+ f"{question}<|im_end|>\n"
+ f"<|im_start|>assistant\n"
+ )
+ asset = VideoAsset(name="baby_reading", num_frames=16)
+ audio = asset.get_audio(sampling_rate=16000)
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {
+ "video": asset.np_ndarrays,
+ "audio": audio,
+ },
+ "mm_processor_kwargs": {
+ "use_audio_in_video": True,
+ },
+ },
+ limit_mm_per_prompt={"audio": 1, "video": 1},
+ )
+
+
query_map = {
"text": get_text_query,
"use_audio": get_audio_query,
"use_image": get_image_query,
"use_video": get_video_query,
- "multi_audios": get_multi_audios_query,
- "mixed_modalities": get_mixed_modalities_query,
+ "use_multi_audios": get_multi_audios_query,
+ "use_mixed_modalities": get_mixed_modalities_query,
+ "use_audio_in_video": get_use_audio_in_video_query,
}
def main(args):
model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
- print(f"="*20,"\n",f"vllm version: {vllm.__version__}","\n","="*20)
+ print("=" * 20, "\n", f"vllm version: {vllm.__version__}", "\n", "=" * 20)
# Get paths from args
video_path = getattr(args, "video_path", None)
@@ -261,6 +318,10 @@ def main(args):
num_frames=getattr(args, "num_frames", 16),
sampling_rate=getattr(args, "sampling_rate", 16000),
)
+ elif args.query_type == "multi_audios":
+ query_result = query_func()
+ elif args.query_type == "use_audio_in_video":
+ query_result = query_func()
else:
query_result = query_func()
@@ -272,7 +333,7 @@ def main(args):
)
thinker_sampling_params = SamplingParams(
- temperature=0.4,
+ temperature=0.9,
top_p=0.9,
top_k=-1,
max_tokens=1200,
@@ -304,8 +365,8 @@ def main(args):
sampling_params_list = [
thinker_sampling_params,
- # talker_sampling_params, # code predictor is integrated into talker for Qwen3 Omni
- # code2wav_sampling_params,
+ talker_sampling_params, # code predictor is integrated into talker for Qwen3 Omni
+ code2wav_sampling_params,
]
if args.txt_prompts is None:
@@ -333,6 +394,8 @@ def main(args):
total_requests = len(prompts)
processed_count = 0
+ print(f"query type: {args.query_type}")
+
for stage_outputs in omni_generator:
if stage_outputs.final_output_type == "text":
for output in stage_outputs.request_output:
@@ -387,7 +450,7 @@ def parse_args():
"--query-type",
"-q",
type=str,
- default="mixed_modalities",
+ default="use_mixed_modalities",
choices=query_map.keys(),
help="Query type.",
)
diff --git a/examples/online_serving/image_to_image/openai_chat_client.py b/examples/online_serving/image_to_image/openai_chat_client.py
index 14bec8a3be4..0fe4b0edece 100644
--- a/examples/online_serving/image_to_image/openai_chat_client.py
+++ b/examples/online_serving/image_to_image/openai_chat_client.py
@@ -127,7 +127,7 @@ def main():
parser.add_argument("--width", type=int, default=1024, help="Output image width")
parser.add_argument("--steps", type=int, default=50, help="Inference steps")
parser.add_argument("--guidance", type=float, default=7.5, help="Guidance scale")
- parser.add_argument("--seed", type=int, help="Random seed")
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--negative", help="Negative prompt")
args = parser.parse_args()
diff --git a/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py b/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py
index 0d502063576..b9a76858161 100644
--- a/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py
+++ b/examples/online_serving/qwen3_omni/openai_chat_completion_client_for_multimodal_generation.py
@@ -304,6 +304,25 @@ def get_multi_audios_query(custom_prompt: str | None = None):
],
}
+def get_use_audio_in_video_query(
+ video_path: str | None = None,
+ audio_path: str | None = None,
+ custom_prompt: str | None = None,
+):
+ question = custom_prompt or (
+ "Describe the content of the video in details, then convert what the "
+ "baby say into text."
+ )
+ video_url = get_video_url_from_path(video_path)
+ audio_url = get_audio_url_from_path(audio_path)
+ return {
+ "role": "user",
+ "content": [
+ {"type": "video_url", "video_url": {"url": video_url}},
+ {"type": "audio_url", "audio_url": {"url": audio_url}},
+ {"type": "text", "text": question},
+ ],
+ }
query_map = {
"text": get_text_query,
@@ -312,6 +331,7 @@ def get_multi_audios_query(custom_prompt: str | None = None):
"use_video": get_video_query,
"use_mixed_modalities": get_mixed_modalities_query,
"use_multi_audios": get_multi_audios_query,
+ "use_audio_in_video": get_use_audio_in_video_query,
}
@@ -372,6 +392,12 @@ def run_multimodal_generation(args) -> None:
prompt = query_func(audio_path=audio_path, custom_prompt=custom_prompt)
elif args.query_type == "text":
prompt = query_func(custom_prompt=custom_prompt)
+ elif args.query_type == "use_audio_in_video":
+ prompt = query_func(
+ video_path=video_path,
+ audio_path=audio_path,
+ custom_prompt=custom_prompt,
+ )
else:
prompt = query_func()
diff --git a/examples/online_serving/text_to_image/README.md b/examples/online_serving/text_to_image/README.md
index a4f1ad63321..744b7b2921d 100644
--- a/examples/online_serving/text_to_image/README.md
+++ b/examples/online_serving/text_to_image/README.md
@@ -116,6 +116,7 @@ Use `extra_body` to pass generation parameters:
| `seed` | int | None | Random seed (reproducible) |
| `negative_prompt` | str | None | Negative prompt |
| `num_outputs_per_prompt` | int | 1 | Number of images to generate |
+| `--cfg-parallel-size`. | int | 1 | Number of GPUs for CFG parallelism |
## Response Format
diff --git a/examples/online_serving/text_to_image/openai_chat_client.py b/examples/online_serving/text_to_image/openai_chat_client.py
index c529bf203fd..39fa7dc22b7 100644
--- a/examples/online_serving/text_to_image/openai_chat_client.py
+++ b/examples/online_serving/text_to_image/openai_chat_client.py
@@ -100,7 +100,7 @@ def main():
parser.add_argument("--width", type=int, default=1024, help="Image width")
parser.add_argument("--steps", type=int, default=50, help="Inference steps")
parser.add_argument("--cfg-scale", type=float, default=4.0, help="True CFG scale")
- parser.add_argument("--seed", type=int, default=42, help="Random seed")
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument("--negative", help="Negative prompt")
args = parser.parse_args()
diff --git a/mkdocs.yml b/mkdocs.yml
index 71cfe030569..1e8e38f5104 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -60,6 +60,10 @@ hooks:
- docs/mkdocs/hooks/url_schemes.py
- docs/mkdocs/hooks/generate_examples.py
+# Exclude include files from navigation warnings
+exclude_docs: |
+ **/*.inc.md
+
# Plugins
plugins:
- meta
diff --git a/pyproject.toml b/pyproject.toml
index 209a085bf87..2e2cddc5b7e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -50,6 +50,8 @@ dev = [
"pytest-cov>=4.0.0",
"mypy==1.11.1",
"pre-commit==4.0.1",
+ "openai-whisper>=20250625",
+ "psutil>=7.2.0"
]
docs = [
@@ -151,11 +153,34 @@ addopts = [
"--cov-report=xml",
]
markers = [
- "unit: Unit tests",
- "integration: Integration tests",
+ # ci/cd required
+ "core_model: Core model tests (run in each PR)",
+ # function module markers
+ "diffusion: Diffusion model tests",
+ "omni: Omni model tests",
+ "cache: Cache backend tests",
+ "parallel: Parallelism/distributed tests",
+ # platform markers
+ "cpu: Tests that run on CPU",
+ "gpu: Tests that run on GPU (auto-added)",
+ "cuda: Tests that run on CUDA (auto-added)",
+ "rocm: Tests that run on AMD/ROCm (auto-added)",
+ "npu: Tests that run on NPU/Ascend (auto-added)",
+ # specified computation resources marks (auto-added)
+ "H100: Tests that require H100 GPU",
+ "L4: Tests that require L4 GPU",
+ "MI325: Tests that require MI325 GPU (AMD/ROCm)",
+ "A2: Tests that require A2 NPU",
+ "A3: Tests that require A3 NPU",
+ "distributed_cuda: Tests that require multi cards on CUDA platform",
+ "distributed_rocm: Tests that require multi cards on ROCm platform",
+ "distributed_npu: Tests that require multi cards on NPU platform",
+ "skipif_cuda: Skip if the num of CUDA cards is less than the required",
+ "skipif_rocm: Skip if the num of ROCm cards is less than the required",
+ "skipif_npu: Skip if the num of NPU cards is less than the required",
+ # more detailed markers
+ "slow: Slow tests (may skip in quick CI)",
"benchmark: Benchmark tests",
- "slow: Slow tests",
- "core_model: enable this model test in each PR instead of only nightly",
]
[tool.typos.default]
diff --git a/pytest.ini b/pytest.ini
deleted file mode 100644
index 8fb4beb9755..00000000000
--- a/pytest.ini
+++ /dev/null
@@ -1,3 +0,0 @@
-[pytest]
-markers =
- gpu_mem_high: needs high VRAM
diff --git a/tests/conftest.py b/tests/conftest.py
index 82c959f07ca..5b21f671bdb 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,8 +1,19 @@
+import base64
import os
+import socket
+import subprocess
+import sys
+import time
+from pathlib import Path
+from typing import Any
+import psutil
import pytest
import torch
+import whisper
+import yaml
from vllm.logger import init_logger
+from vllm.utils import get_open_port
logger = init_logger(__name__)
@@ -34,3 +45,286 @@ def clean_gpu_memory_between_tests():
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
+
+
+def dummy_messages_from_mix_data(
+ system_prompt: dict[str, Any] = None,
+ video_data_url: Any = None,
+ audio_data_url: Any = None,
+ image_data_url: Any = None,
+ content_text: str = None,
+):
+ """Create messages with videoγimageγaudio data URL for OpenAI API."""
+
+ if content_text is not None:
+ content = [{"type": "text", "text": content_text}]
+ else:
+ content = []
+
+ media_items = []
+ if isinstance(video_data_url, list):
+ for video_url in video_data_url:
+ media_items.append((video_url, "video"))
+ else:
+ media_items.append((video_data_url, "video"))
+
+ if isinstance(image_data_url, list):
+ for url in image_data_url:
+ media_items.append((url, "image"))
+ else:
+ media_items.append((image_data_url, "image"))
+
+ if isinstance(audio_data_url, list):
+ for url in audio_data_url:
+ media_items.append((url, "audio"))
+ else:
+ media_items.append((audio_data_url, "audio"))
+
+ content.extend(
+ {"type": f"{media_type}_url", f"{media_type}_url": {"url": url}}
+ for url, media_type in media_items
+ if url is not None
+ )
+ messages = [{"role": "user", "content": content}]
+ if system_prompt is not None:
+ messages = [system_prompt] + messages
+ return messages
+
+
+def cosine_similarity_text(s1, s2):
+ """
+ Calculate cosine similarity between two text strings.
+ Notes:
+ ------
+ - Higher score means more similar texts
+ - Score of 1.0 means identical word composition (bag-of-words)
+ - Score of 0.0 means completely different vocabulary
+ """
+ from sklearn.feature_extraction.text import CountVectorizer
+ from sklearn.metrics.pairwise import cosine_similarity
+
+ vectorizer = CountVectorizer().fit_transform([s1, s2])
+ vectors = vectorizer.toarray()
+ return cosine_similarity([vectors[0]], [vectors[1]])[0][0]
+
+
+def convert_audio_to_text(audio_data):
+ """
+ Convert base64 encoded audio data to text using speech recognition.
+ """
+
+ audio_data = base64.b64decode(audio_data)
+ output_path = f"./test_{int(time.time())}"
+ with open(output_path, "wb") as audio_file:
+ audio_file.write(audio_data)
+
+ print(f"audio data is saved: {output_path}")
+ model = whisper.load_model("base")
+ text = model.transcribe(output_path)["text"]
+ if text:
+ return text
+ else:
+ return ""
+
+
+def modify_stage_config(
+ yaml_path: str,
+ stage_updates: dict[int, dict[str, Any]],
+) -> str:
+ """
+ Batch modify configurations for multiple stages in a YAML file.
+
+ Args:
+ yaml_path: Path to the YAML configuration file.
+ stage_updates: Dictionary where keys are stage IDs and values are dictionaries of
+ modifications for that stage. Each modification dictionary uses
+ dot-separated paths as keys and new configuration values as values.
+ Example: {
+ 0: {'engine_args.max_model_len': 5800},
+ 1: {'runtime.max_batch_size': 2}
+ }
+
+ Returns:
+ str: Path to the newly created modified YAML file with timestamp suffix.
+
+ Example:
+ >>> output_file = modify_stage_config(
+ ... 'config.yaml',
+ ... {
+ ... 0: {'engine_args.max_model_len': 5800},
+ ... 1: {'runtime.max_batch_size': 2}
+ ... }
+ ... )
+ >>> print(f"Modified configuration saved to: {output_file}")
+ Modified configuration saved to: config_1698765432.yaml
+ """
+ path = Path(yaml_path)
+ if not path.exists():
+ raise FileNotFoundError(f"yaml does not exist: {path}")
+ try:
+ with open(yaml_path, encoding="utf-8") as f:
+ config = yaml.safe_load(f) or {}
+ except Exception as e:
+ raise ValueError(f"Cannot parse YAML file: {e}")
+
+ stage_args = config.get("stage_args", [])
+ if not stage_args:
+ raise ValueError("the stage_args does not exist")
+
+ for stage_id, config_dict in stage_updates.items():
+ target_stage = None
+ for stage in stage_args:
+ if stage.get("stage_id") == stage_id:
+ target_stage = stage
+ break
+
+ if target_stage is None:
+ available_ids = [s.get("stage_id") for s in stage_args if "stage_id" in s]
+ raise KeyError(f"Stage ID {stage_id} is not exist, available IDs: {available_ids}")
+
+ for key_path, value in config_dict.items():
+ current = target_stage
+ keys = key_path.split(".")
+ for i in range(len(keys) - 1):
+ key = keys[i]
+ if key not in current:
+ raise KeyError(f"the {'.'.join(keys[: i + 1])} does not exist")
+
+ elif not isinstance(current[key], dict) and i < len(keys) - 2:
+ raise ValueError(f"{'.'.join(keys[: i + 1])}' cannot continue deeper because it's not a dict")
+ current = current[key]
+ current[keys[-1]] = value
+
+ output_path = f"{yaml_path.split('.')[0]}_{int(time.time())}.yaml"
+ with open(output_path, "w", encoding="utf-8") as f:
+ yaml.dump(config, f, default_flow_style=False, sort_keys=False, allow_unicode=True, indent=2)
+
+ return output_path
+
+
+class OmniServer:
+ """Omniserver for vLLM-Omni tests."""
+
+ def __init__(
+ self,
+ model: str,
+ serve_args: list[str],
+ *,
+ env_dict: dict[str, str] | None = None,
+ ) -> None:
+ self.model = model
+ self.serve_args = serve_args
+ self.env_dict = env_dict
+ self.proc: subprocess.Popen | None = None
+ self.host = "127.0.0.1"
+ self.port = get_open_port()
+
+ def _start_server(self) -> None:
+ """Start the vLLM-Omni server subprocess."""
+ env = os.environ.copy()
+ env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+ if self.env_dict is not None:
+ env.update(self.env_dict)
+
+ cmd = [
+ sys.executable,
+ "-m",
+ "vllm_omni.entrypoints.cli.main",
+ "serve",
+ self.model,
+ "--omni",
+ "--host",
+ self.host,
+ "--port",
+ str(self.port),
+ ] + self.serve_args
+
+ print(f"Launching OmniServer with: {' '.join(cmd)}")
+ self.proc = subprocess.Popen(
+ cmd,
+ env=env,
+ cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # Set working directory to vllm-omni root
+ )
+
+ # Wait for server to be ready
+ max_wait = 600 # 10 minutes
+ start_time = time.time()
+ while time.time() - start_time < max_wait:
+ try:
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ sock.settimeout(1)
+ result = sock.connect_ex((self.host, self.port))
+ if result == 0:
+ print(f"Server ready on {self.host}:{self.port}")
+ return
+ except Exception:
+ pass
+ time.sleep(2)
+
+ raise RuntimeError(f"Server failed to start within {max_wait} seconds")
+
+ def _kill_process_tree(self, pid):
+ """kill process and its children"""
+ try:
+ parent = psutil.Process(pid)
+ children = parent.children(recursive=True)
+ for child in children:
+ try:
+ child.terminate()
+ except psutil.NoSuchProcess:
+ pass
+
+ gone, still_alive = psutil.wait_procs(children, timeout=10)
+
+ for child in still_alive:
+ try:
+ child.kill()
+ except psutil.NoSuchProcess:
+ pass
+
+ try:
+ parent.terminate()
+ parent.wait(timeout=10)
+ except (psutil.NoSuchProcess, psutil.TimeoutExpired):
+ try:
+ parent.kill()
+ except psutil.NoSuchProcess:
+ pass
+
+ except psutil.NoSuchProcess:
+ pass
+
+ def __enter__(self):
+ self._start_server()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self.proc:
+ try:
+ parent = psutil.Process(self.proc.pid)
+ children = parent.children(recursive=True)
+ for child in children:
+ try:
+ child.terminate()
+ except psutil.NoSuchProcess:
+ pass
+
+ gone, still_alive = psutil.wait_procs(children, timeout=10)
+
+ for child in still_alive:
+ try:
+ child.kill()
+ except psutil.NoSuchProcess:
+ pass
+
+ try:
+ parent.terminate()
+ parent.wait(timeout=10)
+ except (psutil.NoSuchProcess, psutil.TimeoutExpired):
+ try:
+ parent.kill()
+ except psutil.NoSuchProcess:
+ pass
+
+ except psutil.NoSuchProcess:
+ pass
diff --git a/tests/diffusion/attention/test_flash_attn.py b/tests/diffusion/attention/test_flash_attn.py
new file mode 100644
index 00000000000..3f3862405ed
--- /dev/null
+++ b/tests/diffusion/attention/test_flash_attn.py
@@ -0,0 +1,290 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""
+Test script for FlashAttention backend with padding handling.
+
+This script tests two main scenarios:
+1. Case 1: Comparing padded vs unpadded inputs for batch_size=1
+2. Case 2: Comparing FlashAttention and SDPA backends for batch_size=2 with padding
+"""
+
+import pytest
+import torch
+
+from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
+from vllm_omni.diffusion.attention.backends.flash_attn import FlashAttentionImpl
+from vllm_omni.diffusion.attention.backends.sdpa import SDPAImpl
+
+
+def create_attention_mask(batch_size: int, seq_len: int, valid_len: int, device: torch.device) -> torch.Tensor:
+ """
+ Create attention mask where first valid_len tokens are valid (1) and rest are padding (0).
+
+ Args:
+ batch_size: Batch size
+ seq_len: Total sequence length (including padding)
+ valid_len: Number of valid (non-padded) tokens
+
+ Returns:
+ Attention mask of shape (batch_size, seq_len)
+ """
+ mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device)
+ mask[:, :valid_len] = True
+ return mask
+
+
+def pad_tensor(tensor: torch.Tensor, target_seq_len: int, pad_value: float = 0.0) -> torch.Tensor:
+ """
+ Pad tensor along sequence dimension (dim=1).
+
+ Args:
+ tensor: Input tensor of shape (batch_size, seq_len, num_heads, head_dim)
+ target_seq_len: Target sequence length after padding
+ pad_value: Value to use for padding
+
+ Returns:
+ Padded tensor of shape (batch_size, target_seq_len, num_heads, head_dim)
+ """
+ batch_size, seq_len, num_heads, head_dim = tensor.shape
+ if target_seq_len <= seq_len:
+ return tensor
+
+ padding = torch.full(
+ (batch_size, target_seq_len - seq_len, num_heads, head_dim), pad_value, dtype=tensor.dtype, device=tensor.device
+ )
+ return torch.cat([tensor, padding], dim=1)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="FlashAttention requires CUDA")
+def test_padding_equivalence():
+ """
+ Case 1: Test that padded and unpadded inputs produce similar outputs.
+
+ - Input A: batch_size=1, hidden_states (1, 48), encoder_hidden_states (1, 16)
+ Concatenated length: 64, NO attention_mask
+ - Input B: Same data but padded: hidden_states (1, 58), encoder_hidden_states (1, 26)
+ Concatenated length: 84, WITH attention_mask
+
+ Expected: Output A and Output B should be very close.
+ """
+ device = torch.device("cuda")
+ dtype = torch.bfloat16
+
+ # Configuration
+ batch_size = 1
+ hidden_seq_len = 48
+ encoder_seq_len = 16
+ pad_length = 10
+ num_heads = 8
+ head_dim = 64
+
+ # Initialize FlashAttention
+ fa_impl = FlashAttentionImpl(
+ num_heads=num_heads, head_size=head_dim, softmax_scale=1.0 / (head_dim**0.5), causal=False
+ )
+
+ # Create base tensors with random values (same for both A and B)
+ torch.manual_seed(42)
+ hidden_states_base = torch.randn(batch_size, hidden_seq_len, num_heads, head_dim, device=device, dtype=dtype)
+ encoder_hidden_states_base = torch.randn(
+ batch_size, encoder_seq_len, num_heads, head_dim, device=device, dtype=dtype
+ )
+
+ # ========== Input A: Unpadded, no attention mask ==========
+ query_a = torch.cat([hidden_states_base, encoder_hidden_states_base], dim=1)
+ key_a = query_a.clone()
+ value_a = query_a.clone()
+
+ attn_metadata_a = AttentionMetadata(attn_mask=None)
+
+ output_a = fa_impl.forward(query=query_a, key=key_a, value=value_a, attn_metadata=attn_metadata_a)
+
+ # ========== Input B: Padded with attention mask ==========
+ hidden_states_padded = pad_tensor(hidden_states_base, hidden_seq_len + pad_length)
+ encoder_hidden_states_padded = pad_tensor(encoder_hidden_states_base, encoder_seq_len + pad_length)
+
+ query_b = torch.cat([hidden_states_padded, encoder_hidden_states_padded], dim=1)
+ key_b = query_b.clone()
+ value_b = query_b.clone()
+
+ # Create attention mask
+ attn_mask_b = torch.cat(
+ [
+ create_attention_mask(batch_size, hidden_seq_len + pad_length, hidden_seq_len, device),
+ create_attention_mask(batch_size, encoder_seq_len + pad_length, encoder_seq_len, device),
+ ],
+ dim=1,
+ )
+
+ attn_metadata_b = AttentionMetadata(attn_mask=attn_mask_b)
+
+ output_b = fa_impl.forward(query=query_b, key=key_b, value=value_b, attn_metadata=attn_metadata_b)
+
+ # Extract non-padded portion from output_b
+ output_b_unpadded = torch.cat(
+ [
+ output_b[:, :hidden_seq_len, :, :],
+ output_b[:, hidden_seq_len + pad_length : hidden_seq_len + pad_length + encoder_seq_len, :, :],
+ ],
+ dim=1,
+ )
+
+ # Compare outputs
+ max_diff = torch.max(torch.abs(output_a - output_b_unpadded)).item()
+ mean_diff = torch.mean(torch.abs(output_a - output_b_unpadded)).item()
+
+ print("\n=== Case 1: Padding Equivalence Test ===")
+ print(f"Output A shape: {output_a.shape}")
+ print(f"Output B shape: {output_b.shape}")
+ print(f"Output B unpadded shape: {output_b_unpadded.shape}")
+ print(f"Max absolute difference: {max_diff:.6f}")
+ print(f"Mean absolute difference: {mean_diff:.6f}")
+
+ # Assert that outputs are close
+ # Using higher tolerance for bfloat16
+ assert max_diff < 0.1, f"Max difference {max_diff} exceeds threshold 0.1"
+ assert mean_diff < 0.01, f"Mean difference {mean_diff} exceeds threshold 0.01"
+
+ print("β Case 1 PASSED: Padded and unpadded outputs are very close!")
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="FlashAttention requires CUDA")
+def test_fa_vs_sdpa():
+ """
+ Case 2: Compare FlashAttention and SDPA backends with padding.
+
+ - batch_size=2
+ - hidden_states: (2, 48) padded to (2, 58)
+ - encoder_hidden_states: (2, 16) padded to (2, 26)
+ - Concatenated length: 84
+ - Compare FA and SDPA outputs
+
+ Expected: FA and SDPA outputs should be very close.
+ """
+ device = torch.device("cuda")
+ dtype = torch.bfloat16
+
+ # Configuration
+ batch_size = 2
+ hidden_seq_len = 48
+ encoder_seq_len = 16
+ pad_length = 10
+ num_heads = 8
+ head_dim = 64
+
+ # Initialize both backends
+ fa_impl = FlashAttentionImpl(
+ num_heads=num_heads, head_size=head_dim, softmax_scale=1.0 / (head_dim**0.5), causal=False
+ )
+
+ sdpa_impl = SDPAImpl(num_heads=num_heads, head_size=head_dim, softmax_scale=1.0 / (head_dim**0.5), causal=False)
+
+ # Create base tensors
+ torch.manual_seed(123)
+ hidden_states_base = torch.randn(batch_size, hidden_seq_len, num_heads, head_dim, device=device, dtype=dtype)
+ encoder_hidden_states_base = torch.randn(
+ batch_size, encoder_seq_len, num_heads, head_dim, device=device, dtype=dtype
+ )
+
+ # Pad tensors
+ hidden_states_padded = pad_tensor(hidden_states_base, hidden_seq_len + pad_length)
+ encoder_hidden_states_padded = pad_tensor(encoder_hidden_states_base, encoder_seq_len + pad_length)
+
+ # Concatenate
+ query = torch.cat([hidden_states_padded, encoder_hidden_states_padded], dim=1)
+ key = query.clone()
+ value = query.clone()
+
+ # Create attention mask
+ attn_mask = torch.cat(
+ [
+ create_attention_mask(batch_size, hidden_seq_len + pad_length, hidden_seq_len, device),
+ create_attention_mask(batch_size, encoder_seq_len + pad_length, encoder_seq_len, device),
+ ],
+ dim=1,
+ )
+
+ attn_metadata = AttentionMetadata(attn_mask=attn_mask)
+
+ # Run FlashAttention
+ output_fa = fa_impl.forward(query=query.clone(), key=key.clone(), value=value.clone(), attn_metadata=attn_metadata)
+
+ # Run SDPA
+ # SDPA expects 4D attention mask: (batch_size, 1, seq_len, seq_len) or (batch_size, seq_len)
+ # For causal=False, we need to convert 2D mask to 4D
+ if attn_mask is not None:
+ # Expand mask for SDPA: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len)
+ attn_mask_4d = attn_mask.unsqueeze(1).unsqueeze(2)
+ # Convert bool to float: True -> 0.0, False -> -inf
+ attn_mask_float = torch.zeros_like(attn_mask_4d, dtype=dtype)
+ attn_mask_float.masked_fill_(~attn_mask_4d, float("-inf"))
+ attn_metadata_sdpa = AttentionMetadata(attn_mask=attn_mask_float)
+ else:
+ attn_metadata_sdpa = AttentionMetadata(attn_mask=None)
+
+ output_sdpa = sdpa_impl.forward(
+ query=query.clone(), key=key.clone(), value=value.clone(), attn_metadata=attn_metadata_sdpa
+ )
+
+ # Compare outputs (only compare valid regions)
+ output_fa_valid = torch.cat(
+ [
+ output_fa[:, :hidden_seq_len, :, :],
+ output_fa[:, hidden_seq_len + pad_length : hidden_seq_len + pad_length + encoder_seq_len, :, :],
+ ],
+ dim=1,
+ )
+ output_sdpa_valid = torch.cat(
+ [
+ output_sdpa[:, :hidden_seq_len, :, :],
+ output_sdpa[:, hidden_seq_len + pad_length : hidden_seq_len + pad_length + encoder_seq_len, :, :],
+ ],
+ dim=1,
+ )
+
+ max_diff = torch.max(torch.abs(output_fa_valid - output_sdpa_valid)).item()
+ mean_diff = torch.mean(torch.abs(output_fa_valid - output_sdpa_valid)).item()
+
+ print("\n=== Case 2: FA vs SDPA Comparison ===")
+ print(f"Batch size: {batch_size}")
+ print(f"FA output shape: {output_fa.shape}")
+ print(f"SDPA output shape: {output_sdpa.shape}")
+ print(f"Max absolute difference (valid region): {max_diff:.6f}")
+ print(f"Mean absolute difference (valid region): {mean_diff:.6f}")
+
+ # Assert that outputs are close
+ # Using higher tolerance for bfloat16 and different implementations
+ assert max_diff < 0.01, f"Max difference {max_diff} exceeds threshold 0.01"
+ assert mean_diff < 0.001, f"Mean difference {mean_diff} exceeds threshold 0.001"
+
+ print("β Case 2 PASSED: FA and SDPA outputs are very close!")
+
+
+if __name__ == "__main__":
+ print("Running FlashAttention Padding Tests...")
+ print("=" * 60)
+
+ # Try to run CUDA tests
+ if torch.cuda.is_available():
+ try:
+ print("\n[Running Case 1: Padding Equivalence for FA]")
+ test_padding_equivalence()
+ except Exception as e:
+ print(f"β Case 1 failed: {e}")
+ import traceback
+
+ traceback.print_exc()
+
+ try:
+ print("\n[Running Case 2: FA vs SDPA]")
+ test_fa_vs_sdpa()
+ except Exception as e:
+ print(f"β Case 2 failed: {e}")
+ import traceback
+
+ traceback.print_exc()
+ else:
+ raise RuntimeError("CUDA is not available")
+ print("\n" + "=" * 60)
+ print("Test suite completed!")
diff --git a/tests/diffusion/test_gpu_worker.py b/tests/diffusion/test_gpu_diffusion_worker.py
similarity index 81%
rename from tests/diffusion/test_gpu_worker.py
rename to tests/diffusion/test_gpu_diffusion_worker.py
index defeffe5b56..7a43710c878 100644
--- a/tests/diffusion/test_gpu_worker.py
+++ b/tests/diffusion/test_gpu_diffusion_worker.py
@@ -2,9 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
-Unit tests for GPUWorker class.
+Unit tests for GPUDiffusionWorker class.
-This module tests the GPUWorker implementation:
+This module tests the GPUDiffusionWorker implementation:
- load_weights: Loading model weights
- sleep: Putting worker into sleep mode (levels 1 and 2)
- wake_up: Waking worker from sleep mode
@@ -15,7 +15,7 @@
import pytest
import torch
-from vllm_omni.diffusion.worker.gpu_worker import GPUWorker
+from vllm_omni.diffusion.worker.gpu_diffusion_worker import GPUDiffusionWorker
@pytest.fixture
@@ -33,20 +33,21 @@ def mock_od_config():
@pytest.fixture
def mock_gpu_worker(mock_od_config):
- """Create a GPUWorker with mocked initialization."""
- with patch.object(GPUWorker, "init_device_and_model"):
- worker = GPUWorker(local_rank=0, rank=0, od_config=mock_od_config)
- # Mock the pipeline
- worker.pipeline = Mock()
- worker.cache_backend = None
+ """Create a GPUDiffusionWorker with mocked initialization."""
+ with patch.object(GPUDiffusionWorker, "init_device"):
+ worker = GPUDiffusionWorker(local_rank=0, rank=0, od_config=mock_od_config)
+ # Mock the model_runner with pipeline
+ worker.model_runner = Mock()
+ worker.model_runner.pipeline = Mock()
+ worker._sleep_saved_buffers = {}
return worker
-class TestGPUWorkerLoadWeights:
- """Test GPUWorker.load_weights method."""
+class TestGPUDiffusionWorkerLoadWeights:
+ """Test GPUDiffusionWorker.load_weights method."""
def test_load_weights_calls_pipeline(self, mock_gpu_worker):
- """Test that load_weights delegates to pipeline.load_weights."""
+ """Test that load_weights delegates to model_runner.load_weights."""
# Setup mock weights
mock_weights = [
("layer1.weight", torch.randn(10, 10)),
@@ -54,30 +55,30 @@ def test_load_weights_calls_pipeline(self, mock_gpu_worker):
]
expected_loaded = {"layer1.weight", "layer2.weight"}
- # Configure pipeline mock
- mock_gpu_worker.pipeline.load_weights = Mock(return_value=expected_loaded)
+ # Configure model_runner mock
+ mock_gpu_worker.model_runner.load_weights = Mock(return_value=expected_loaded)
# Call load_weights
result = mock_gpu_worker.load_weights(mock_weights)
- # Verify pipeline.load_weights was called with the weights
- mock_gpu_worker.pipeline.load_weights.assert_called_once_with(mock_weights)
+ # Verify model_runner.load_weights was called with the weights
+ mock_gpu_worker.model_runner.load_weights.assert_called_once_with(mock_weights)
assert result == expected_loaded
def test_load_weights_empty_iterable(self, mock_gpu_worker):
"""Test load_weights with empty weights iterable."""
- mock_gpu_worker.pipeline.load_weights = Mock(return_value=set())
+ mock_gpu_worker.model_runner.load_weights = Mock(return_value=set())
result = mock_gpu_worker.load_weights([])
- mock_gpu_worker.pipeline.load_weights.assert_called_once_with([])
+ mock_gpu_worker.model_runner.load_weights.assert_called_once_with([])
assert result == set()
-class TestGPUWorkerSleep:
- """Test GPUWorker.sleep method."""
+class TestGPUDiffusionWorkerSleep:
+ """Test GPUDiffusionWorker.sleep method."""
- @patch("vllm_omni.diffusion.worker.gpu_worker.torch.cuda.mem_get_info")
+ @patch("vllm_omni.diffusion.worker.gpu_diffusion_worker.torch.cuda.mem_get_info")
@patch("vllm.device_allocator.cumem.CuMemAllocator")
def test_sleep_level_1(self, mock_allocator_class, mock_mem_info, mock_gpu_worker):
"""Test sleep mode level 1 (offload weights only)."""
@@ -103,7 +104,7 @@ def test_sleep_level_1(self, mock_allocator_class, mock_mem_info, mock_gpu_worke
# Verify buffers were NOT saved (level 1 doesn't save buffers)
assert len(mock_gpu_worker._sleep_saved_buffers) == 0
- @patch("vllm_omni.diffusion.worker.gpu_worker.torch.cuda.mem_get_info")
+ @patch("vllm_omni.diffusion.worker.gpu_diffusion_worker.torch.cuda.mem_get_info")
@patch("vllm.device_allocator.cumem.CuMemAllocator")
def test_sleep_level_2(self, mock_allocator_class, mock_mem_info, mock_gpu_worker):
"""Test sleep mode level 2 (offload all, save buffers)."""
@@ -121,7 +122,7 @@ def test_sleep_level_2(self, mock_allocator_class, mock_mem_info, mock_gpu_worke
# Mock pipeline buffers
mock_buffer1 = torch.randn(10, 10)
mock_buffer2 = torch.randn(20, 20)
- mock_gpu_worker.pipeline.named_buffers = Mock(
+ mock_gpu_worker.model_runner.pipeline.named_buffers = Mock(
return_value=[
("buffer1", mock_buffer1),
("buffer2", mock_buffer2),
@@ -140,7 +141,7 @@ def test_sleep_level_2(self, mock_allocator_class, mock_mem_info, mock_gpu_worke
assert "buffer1" in mock_gpu_worker._sleep_saved_buffers
assert "buffer2" in mock_gpu_worker._sleep_saved_buffers
- @patch("vllm_omni.diffusion.worker.gpu_worker.torch.cuda.mem_get_info")
+ @patch("vllm_omni.diffusion.worker.gpu_diffusion_worker.torch.cuda.mem_get_info")
@patch("vllm.device_allocator.cumem.CuMemAllocator")
def test_sleep_memory_freed_validation(self, mock_allocator_class, mock_mem_info, mock_gpu_worker):
"""Test that sleep validates memory was actually freed."""
@@ -159,8 +160,8 @@ def test_sleep_memory_freed_validation(self, mock_allocator_class, mock_mem_info
mock_gpu_worker.sleep(level=1)
-class TestGPUWorkerWakeUp:
- """Test GPUWorker.wake_up method."""
+class TestGPUDiffusionWorkerWakeUp:
+ """Test GPUDiffusionWorker.wake_up method."""
@patch("vllm.device_allocator.cumem.CuMemAllocator")
def test_wake_up_without_buffers(self, mock_allocator_class, mock_gpu_worker):
@@ -202,7 +203,7 @@ def test_wake_up_with_buffers(self, mock_allocator_class, mock_gpu_worker):
mock_buffer2 = Mock()
mock_buffer2.data = Mock()
- mock_gpu_worker.pipeline.named_buffers = Mock(
+ mock_gpu_worker.model_runner.pipeline.named_buffers = Mock(
return_value=[
("buffer1", mock_buffer1),
("buffer2", mock_buffer2),
@@ -243,7 +244,7 @@ def test_wake_up_partial_buffer_restore(self, mock_allocator_class, mock_gpu_wor
mock_buffer2 = Mock()
mock_buffer2.data = Mock()
- mock_gpu_worker.pipeline.named_buffers = Mock(
+ mock_gpu_worker.model_runner.pipeline.named_buffers = Mock(
return_value=[
("buffer1", mock_buffer1),
("buffer2", mock_buffer2),
diff --git a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py
index 0066d49b161..cefda891571 100644
--- a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py
+++ b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py
@@ -1,11 +1,10 @@
import sys
-import threading
-import time
from pathlib import Path
import pytest
import torch
+from tests.utils import GPUMemoryMonitor
from vllm_omni.utils.platform_utils import is_npu, is_rocm
# ruff: noqa: E402
@@ -15,39 +14,6 @@
from vllm_omni import Omni
-
-class GPUMemoryMonitor:
- """Poll global device memory usage via CUDA APIs."""
-
- def __init__(self, device_index: int, interval: float = 0.05):
- self.device_index = device_index
- self.interval = interval
- self.peak_used_mb = 0.0
- self._stop_event = threading.Event()
- self._thread: threading.Thread | None = None
-
- def start(self) -> None:
- def monitor_loop() -> None:
- while not self._stop_event.is_set():
- try:
- with torch.cuda.device(self.device_index):
- free_bytes, total_bytes = torch.cuda.mem_get_info()
- used_mb = (total_bytes - free_bytes) / (1024**2)
- self.peak_used_mb = max(self.peak_used_mb, used_mb)
- except Exception:
- pass
- time.sleep(self.interval)
-
- self._thread = threading.Thread(target=monitor_loop, daemon=True)
- self._thread.start()
-
- def stop(self) -> None:
- if self._thread is None:
- return
- self._stop_event.set()
- self._thread.join(timeout=2.0)
-
-
models = ["riverclouds/qwen_image_random"]
@@ -73,13 +39,7 @@ def inference(offload: bool = True):
generator=torch.Generator("cuda").manual_seed(42),
)
- monitor.stop()
- torch.cuda.synchronize(device_index)
- fallback_alloc = torch.cuda.max_memory_allocated(device=device_index) / (1024**2)
- fallback_reserved = torch.cuda.max_memory_reserved(device=device_index) / (1024**2)
- peak_memory_mb = max(monitor.peak_used_mb, fallback_alloc, fallback_reserved)
-
- return peak_memory_mb
+ return monitor.peak_used_mb
offload_peak_memory = inference(offload=True)
no_offload_peak_memory = inference(offload=False)
diff --git a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py
index d32bb2b8223..60686992278 100644
--- a/tests/e2e/offline_inference/test_zimage_tensor_parallel.py
+++ b/tests/e2e/offline_inference/test_zimage_tensor_parallel.py
@@ -17,6 +17,7 @@
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
+from tests.utils import GPUMemoryMonitor
from vllm_omni import Omni
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.outputs import OmniRequestOutput
@@ -66,7 +67,12 @@ def _extract_single_image(outputs) -> Image.Image:
def _run_zimage_generate(
*, tp_size: int, height: int, width: int, num_inference_steps: int, seed: int
-) -> tuple[Image.Image, float]:
+) -> tuple[Image.Image, float, float]:
+ torch.cuda.empty_cache()
+ device_index = torch.cuda.current_device()
+ monitor = GPUMemoryMonitor(device_index=device_index, interval=0.02)
+ monitor.start()
+
m = Omni(
model=_get_zimage_model(),
parallel_config=DiffusionParallelConfig(tensor_parallel_size=tp_size),
@@ -107,7 +113,10 @@ def _run_zimage_generate(
pass
median_time_s = float(np.median(per_request_times_s))
- return _extract_single_image([last_output]), median_time_s
+
+ peak_memory_mb = monitor.peak_used_mb
+
+ return _extract_single_image([last_output]), median_time_s, peak_memory_mb
finally:
m.close()
cleanup_dist_env_and_memory()
@@ -125,14 +134,14 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path):
num_inference_steps = 2
seed = 42
- tp1_img, tp1_time_s = _run_zimage_generate(
+ tp1_img, tp1_time_s, tp1_peak_mem = _run_zimage_generate(
tp_size=1,
height=height,
width=width,
num_inference_steps=num_inference_steps,
seed=seed,
)
- tp2_img, tp2_time_s = _run_zimage_generate(
+ tp2_img, tp2_time_s, tp2_peak_mem = _run_zimage_generate(
tp_size=2,
height=height,
width=width,
@@ -164,3 +173,8 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path):
print(f"Z-Image TP perf (lower is better): tp1_time_s={tp1_time_s:.6f}, tp2_time_s={tp2_time_s:.6f}")
assert tp2_time_s < tp1_time_s, f"Expected TP=2 to be faster than TP=1 (tp1={tp1_time_s}, tp2={tp2_time_s})"
+
+ print(f"Z-Image TP peak memory (MB): tp1_peak_mem={tp1_peak_mem:.2f}, tp2_peak_mem={tp2_peak_mem:.2f}")
+ assert tp2_peak_mem < tp1_peak_mem, (
+ f"Expected TP=2 to use less peak memory than TP=1 (tp1={tp1_peak_mem}, tp2={tp2_peak_mem})"
+ )
diff --git a/tests/e2e/online_serving/test_qwen3_omni_expansion.py b/tests/e2e/online_serving/test_qwen3_omni_expansion.py
new file mode 100644
index 00000000000..6a47e96f866
--- /dev/null
+++ b/tests/e2e/online_serving/test_qwen3_omni_expansion.py
@@ -0,0 +1,158 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+E2E Online tests for Qwen3-Omni model.
+"""
+
+import concurrent.futures
+import os
+import time
+from pathlib import Path
+
+import openai
+import pytest
+
+from tests.conftest import (
+ OmniServer,
+ convert_audio_to_text,
+ cosine_similarity_text,
+ dummy_messages_from_mix_data,
+ modify_stage_config,
+)
+
+os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+
+models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
+
+# CI stage config for 2*H100-80G GPUs
+stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml")]
+
+# Create parameter combinations for model and stage config
+test_params = [(model, stage_config) for model in models for stage_config in stage_configs]
+
+
+def client(omni_server):
+ """OpenAI client for the running vLLM-Omni server."""
+ return openai.OpenAI(
+ base_url=f"http://{omni_server.host}:{omni_server.port}/v1",
+ api_key="EMPTY",
+ )
+
+
+def get_system_prompt():
+ return {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": (
+ "You are Qwen, a virtual human developed by the Qwen Team, "
+ "Alibaba Group, capable of perceiving auditory and visual inputs, "
+ "as well as generating text and speech."
+ ),
+ }
+ ],
+ }
+
+
+def get_prompt(prompt_type="text_only"):
+ prompts = {
+ "text_only": "What is the capital of China?",
+ "mix": "What is recited in the audio? What is in this image? Describe the video briefly.",
+ }
+ return prompts.get(prompt_type, prompts["text_only"])
+
+
+def get_max_batch_size(size_type="few"):
+ batch_sizes = {"few": 5, "medium": 100, "large": 256}
+ return batch_sizes.get(size_type, 5)
+
+
+@pytest.mark.parametrize("test_config", test_params)
+def test_text_to_text_001(test_config: tuple[str, str]) -> None:
+ """Test processing text, generating text output via OpenAI API."""
+ model, stage_config_path = test_config
+ with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "90"]) as server:
+ messages = dummy_messages_from_mix_data(system_prompt=get_system_prompt(), content_text=get_prompt())
+
+ # Test single completion
+ api_client = client(server)
+ start_time = time.perf_counter()
+ chat_completion = api_client.chat.completions.create(
+ model=server.model, messages=messages, max_tokens=20, modalities=["text"]
+ )
+ # Verify E2E
+ print(f"the request e2e is: {time.perf_counter() - start_time}")
+ # TODO: Verify the E2E latency after confirmation baseline.
+
+ # Verify only output text
+ assert len(chat_completion.choices) == 1, "The generated content includes more than just text."
+
+ # Verify text output success
+ text_choice = chat_completion.choices[0]
+ assert text_choice.message.content is not None, "No text output is generated"
+ assert chat_completion.usage.completion_tokens <= 20, "The output length more than the requested max_tokens."
+ assert "beijing" in text_choice.message.content.lower(), "The output do not contain keywords."
+
+
+@pytest.mark.parametrize("test_config", test_params)
+def test_text_to_text_audio_001(test_config: tuple[str, str]) -> None:
+ """Test processing text, generating text and audio output via OpenAI API."""
+
+ model, stage_config_path = test_config
+ num_concurrent_requests = get_max_batch_size()
+ stage_config_path = modify_stage_config(
+ stage_config_path,
+ {
+ 0: {"runtime.max_batch_size": num_concurrent_requests},
+ 1: {"runtime.max_batch_size": num_concurrent_requests},
+ },
+ )
+ with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "90"]) as server:
+ messages = dummy_messages_from_mix_data(
+ system_prompt=get_system_prompt(), content_text="What is the capital of China?"
+ )
+
+ # Test single completion
+ api_client = client(server)
+ e2e_list = list()
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_concurrent_requests) as executor:
+ # Submit multiple completion requests concurrently
+ futures = [
+ executor.submit(api_client.chat.completions.create, model=server.model, messages=messages)
+ for _ in range(num_concurrent_requests)
+ ]
+ start_time = time.perf_counter()
+ # Wait for all requests to complete and collect results
+ chat_completions = list()
+ for future in concurrent.futures.as_completed(futures):
+ chat_completions.append(future.result())
+ # Verify E2E
+ current_e2e = time.perf_counter() - start_time
+ print(f"the request e2e is: {current_e2e}")
+ # TODO: Verify the E2E latency after confirmation baseline.
+ e2e_list.append(current_e2e)
+
+ print(f"the avg e2e is: {sum(e2e_list) / len(e2e_list)}")
+ # Verify all completions succeeded
+ assert len(chat_completions) == num_concurrent_requests, "Not all requests succeeded."
+ for chat_completion in chat_completions:
+ # Verify audio output success
+ audio_message = chat_completion.choices[1].message
+ audio_data = audio_message.audio.data
+ assert audio_data is not None, "No audio output is generated"
+ assert audio_message.audio.expires_at > time.time(), "The generated audio has expired."
+
+ # Verify text output success
+ text_choice = chat_completion.choices[0]
+ text_content = text_choice.message.content
+ assert text_choice.message.content is not None, "No text output is generated"
+ assert "beijing" in text_choice.message.content.lower(), "The output do not contain keywords."
+
+ # Verify text output same as audio output
+ audio_content = convert_audio_to_text(audio_data)
+ print(f"text content is: {text_content}")
+ print(f"audio content is: {audio_content}")
+ assert cosine_similarity_text(audio_content.lower(), text_content.lower()) > 0.9, (
+ "The audio content is not same as the text"
+ )
diff --git a/tests/e2e/stage_configs/qwen3_omni_ci.yaml b/tests/e2e/stage_configs/qwen3_omni_ci.yaml
new file mode 100644
index 00000000000..5106b185419
--- /dev/null
+++ b/tests/e2e/stage_configs/qwen3_omni_ci.yaml
@@ -0,0 +1,95 @@
+# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
+# Stage 0: Thinker (multimodal understanding + text generation)
+# Stage 1: Talker (text embeddings β 16-layer RVQ codec codes)
+# Stage 2: Code2Wav (8-layer RVQ codes β audio waveform)
+
+# The following config has been verified on 2x H100-80G GPUs.
+stage_args:
+- stage_id: 0
+ runtime:
+ devices: "0,1"
+ max_batch_size: 1
+ engine_args:
+ model_stage: thinker
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.6
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: latent # Output hidden states for talker
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 32768
+ enable_prefix_caching: false
+ hf_config_name: thinker_config
+ tensor_parallel_size: 2
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 100
+ seed: 42
+ ignore_eos: False
+ detokenize: True
+ repetition_penalty: 1.05
+
+- stage_id: 1
+ runtime:
+ devices: "1"
+ max_batch_size: 1
+ engine_args:
+ model_stage: talker
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.3
+ enforce_eager: true
+ trust_remote_code: true
+ engine_output_type: latent # Output codec codes for code2wav
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ distributed_executor_backend: "mp"
+ hf_config_name: talker_config
+ engine_input_source: [0]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 100
+ seed: 42
+ detokenize: False
+ repetition_penalty: 1.05
+ stop_token_ids: [2150]
+
+- stage_id: 2
+ runtime:
+ devices: "0"
+ max_batch_size: 1
+ engine_args:
+ model_stage: code2wav
+ model_arch: Qwen3OmniMoeForConditionalGeneration
+ worker_cls: vllm_omni.worker.gpu_generation_worker.GPUGenerationWorker
+ scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
+ enforce_eager: true
+ trust_remote_code: true
+ enable_prefix_caching: false
+ engine_output_type: audio # Final output: audio waveform
+ gpu_memory_utilization: 0.1
+ distributed_executor_backend: "mp"
+ max_num_batched_tokens: 1000000
+ hf_config_name: thinker_config
+ engine_input_source: [1]
+ custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
+ final_output: true
+ final_output_type: audio
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 200
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
diff --git a/tests/utils.py b/tests/utils.py
index aba734501eb..8e5593d6501 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -1,11 +1,26 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
+# Some functions are copied from vllm/tests/utils.py
+import functools
import os
+import signal
+import subprocess
+import sys
+import tempfile
+import threading
import time
-from contextlib import contextmanager
+from collections.abc import Callable
+from contextlib import ExitStack, contextmanager, suppress
+from typing import Any, Literal
+import cloudpickle
+import pytest
+import torch
+from typing_extensions import ParamSpec
from vllm.platforms import current_platform
+from vllm.utils.torch_utils import cuda_device_count_stateless
+
+_P = ParamSpec("_P")
if current_platform.is_rocm():
from amdsmi import (
@@ -90,10 +105,16 @@ def wait_for_gpu_memory_to_clear(
print("")
if threshold_bytes is not None:
- is_free = lambda used, total: used <= threshold_bytes / 2**30 # noqa E731
+
+ def is_free(used, total):
+ return used <= threshold_bytes / 2**30 # noqa E731
+
threshold = f"{threshold_bytes / 2**30} GiB"
else:
- is_free = lambda used, total: used / total <= threshold_ratio # noqa E731
+
+ def is_free(used, total):
+ return used / total <= threshold_ratio # noqa E731
+
threshold = f"{threshold_ratio:.2f}"
dur_s = time.time() - start_time
@@ -105,3 +126,394 @@ def wait_for_gpu_memory_to_clear(
raise ValueError(f"Memory of devices {devices=} not free after {dur_s=:.02f} ({threshold=})")
time.sleep(5)
+
+
+def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]:
+ """Decorator to fork a new process for each test function.
+ See https://github.com/vllm-project/vllm/issues/7053 for more details.
+ """
+
+ @functools.wraps(func)
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
+ # Make the process the leader of its own process group
+ # to avoid sending SIGTERM to the parent process
+ os.setpgrp()
+ from _pytest.outcomes import Skipped
+
+ # Create a unique temporary file to store exception info from child
+ # process. Use test function name and process ID to avoid collisions.
+ with (
+ tempfile.NamedTemporaryFile(
+ delete=False, mode="w+b", prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", suffix=".exc"
+ ) as exc_file,
+ ExitStack() as delete_after,
+ ):
+ exc_file_path = exc_file.name
+ delete_after.callback(os.remove, exc_file_path)
+
+ pid = os.fork()
+ print(f"Fork a new process to run a test {pid}")
+ if pid == 0:
+ # Parent process responsible for deleting, don't delete
+ # in child.
+ delete_after.pop_all()
+ try:
+ func(*args, **kwargs)
+ except Skipped as e:
+ # convert Skipped to exit code 0
+ print(str(e))
+ os._exit(0)
+ except Exception as e:
+ import traceback
+
+ tb_string = traceback.format_exc()
+
+ # Try to serialize the exception object first
+ exc_to_serialize: dict[str, Any]
+ try:
+ # First, try to pickle the actual exception with
+ # its traceback.
+ exc_to_serialize = {"pickled_exception": e}
+ # Test if it can be pickled
+ cloudpickle.dumps(exc_to_serialize)
+ except (Exception, KeyboardInterrupt):
+ # Fall back to string-based approach.
+ exc_to_serialize = {
+ "exception_type": type(e).__name__,
+ "exception_msg": str(e),
+ "traceback": tb_string,
+ }
+ try:
+ with open(exc_file_path, "wb") as f:
+ cloudpickle.dump(exc_to_serialize, f)
+ except Exception:
+ # Fallback: just print the traceback.
+ print(tb_string)
+ os._exit(1)
+ else:
+ os._exit(0)
+ else:
+ pgid = os.getpgid(pid)
+ _pid, _exitcode = os.waitpid(pid, 0)
+ # ignore SIGTERM signal itself
+ old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
+ # kill all child processes
+ os.killpg(pgid, signal.SIGTERM)
+ # restore the signal handler
+ signal.signal(signal.SIGTERM, old_signal_handler)
+ if _exitcode != 0:
+ # Try to read the exception from the child process
+ exc_info = {}
+ if os.path.exists(exc_file_path):
+ with suppress(Exception), open(exc_file_path, "rb") as f:
+ exc_info = cloudpickle.load(f)
+
+ if (original_exception := exc_info.get("pickled_exception")) is not None:
+ # Re-raise the actual exception object if it was
+ # successfully pickled.
+ assert isinstance(original_exception, Exception)
+ raise original_exception
+
+ if (original_tb := exc_info.get("traceback")) is not None:
+ # Use string-based traceback for fallback case
+ raise AssertionError(
+ f"Test {func.__name__} failed when called with"
+ f" args {args} and kwargs {kwargs}"
+ f" (exit code: {_exitcode}):\n{original_tb}"
+ ) from None
+
+ # Fallback to the original generic error
+ raise AssertionError(
+ f"function {func.__name__} failed when called with"
+ f" args {args} and kwargs {kwargs}"
+ f" (exit code: {_exitcode})"
+ ) from None
+
+ return wrapper
+
+
+def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]:
+ """Decorator to spawn a new process for each test function."""
+
+ @functools.wraps(f)
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
+ # Check if we're already in a subprocess
+ if os.environ.get("RUNNING_IN_SUBPROCESS") == "1":
+ # If we are, just run the function directly
+ return f(*args, **kwargs)
+
+ import torch.multiprocessing as mp
+
+ with suppress(RuntimeError):
+ mp.set_start_method("spawn")
+
+ # Get the module
+ module_name = f.__module__
+
+ # Create a process with environment variable set
+ env = os.environ.copy()
+ env["RUNNING_IN_SUBPROCESS"] = "1"
+
+ with tempfile.TemporaryDirectory() as tempdir:
+ output_filepath = os.path.join(tempdir, "new_process.tmp")
+
+ # `cloudpickle` allows pickling complex functions directly
+ input_bytes = cloudpickle.dumps((f, output_filepath))
+
+ cmd = [sys.executable, "-m", f"{module_name}"]
+
+ returned = subprocess.run(cmd, input=input_bytes, capture_output=True, env=env)
+
+ # check if the subprocess is successful
+ try:
+ returned.check_returncode()
+ except Exception as e:
+ # wrap raised exception to provide more information
+ raise RuntimeError(f"Error raised in subprocess:\n{returned.stderr.decode()}") from e
+
+ return wrapper
+
+
+def create_new_process_for_each_test(
+ method: Literal["spawn", "fork"] | None = None,
+) -> Callable[[Callable[_P, None]], Callable[_P, None]]:
+ """Creates a decorator that runs each test function in a new process.
+
+ Args:
+ method: The process creation method. Can be either "spawn" or "fork".
+ If not specified, it defaults to "spawn" on ROCm and XPU
+ platforms and "fork" otherwise.
+
+ Returns:
+ A decorator to run test functions in separate processes.
+ """
+ if method is None:
+ # TODO: Spawn is not working correctly on ROCm
+ # The test content will not run and tests passed immediately.
+ # For now, using `fork` for ROCm as it can run with `fork`
+ # and tests are running correctly.
+ use_spawn = current_platform.is_xpu()
+ method = "spawn" if use_spawn else "fork"
+
+ assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'"
+
+ if method == "fork":
+ return fork_new_process_for_each_test
+
+ return spawn_new_process_for_each_test
+
+
+def cuda_marks(*, res: str, num_cards: int):
+ """
+ Get a collection of pytest marks to apply for `@cuda_test`.
+
+ Args:
+ res: Resource type, e.g., "L4" or "H100".
+ num_cards: Number of GPU cards required.
+
+ Returns:
+ List of pytest marks to apply.
+ """
+ test_platform_detail = pytest.mark.cuda
+
+ if res == "L4":
+ test_resource = pytest.mark.L4
+ elif res == "H100":
+ test_resource = pytest.mark.H100
+ else:
+ raise ValueError(f"Invalid CUDA resource type: {res}. Supported: L4, H100")
+
+ marks = [test_resource, test_platform_detail]
+
+ if num_cards == 1:
+ return marks
+ else:
+ test_distributed = pytest.mark.distributed_cuda(num_cards=num_cards)
+ test_skipif = pytest.mark.skipif_cuda(
+ cuda_device_count_stateless() < num_cards,
+ reason=f"Need at least {num_cards} CUDA GPUs to run the test.",
+ )
+ return marks + [test_distributed, test_skipif]
+
+
+def rocm_marks(*, res: str, num_cards: int):
+ """
+ Get a collection of pytest marks to apply for `@rocm_test`.
+
+ Args:
+ res: Resource type, e.g., "MI325".
+ num_cards: Number of GPU cards required.
+
+ Returns:
+ List of pytest marks to apply.
+ """
+ test_platform_detail = pytest.mark.rocm
+
+ if res == "MI325":
+ test_resource = pytest.mark.MI325
+ else:
+ raise ValueError(f"Invalid ROCm resource type: {res}. Supported: MI325")
+
+ marks = [test_resource, test_platform_detail]
+
+ if num_cards == 1:
+ return marks
+ else:
+ test_distributed = pytest.mark.distributed_rocm(num_cards=num_cards)
+ # TODO: add ROCm support for `skipif_rocm` marker
+ return marks + [test_distributed]
+
+
+def gpu_marks(*, res: str, num_cards: int):
+ """
+ Get a collection of pytest marks to apply for `@gpu_test`.
+ Platform is automatically determined based on resource type.
+
+ Args:
+ res: Resource type, e.g., "L4", "H100" for CUDA, or "MI325" for ROCm.
+ num_cards: Number of GPU cards required.
+
+ Returns:
+ List of pytest marks to apply.
+ """
+ test_platform = pytest.mark.gpu
+ if res in ("L4", "H100"):
+ return [test_platform] + cuda_marks(res=res, num_cards=num_cards)
+ if res == "MI325":
+ return [test_platform] + rocm_marks(res=res, num_cards=num_cards)
+ raise ValueError(f"Invalid resource type: {res}. Supported: L4, H100, MI325")
+
+
+def npu_marks(*, res: str, num_cards: int):
+ """Get a collection of pytest marks to apply for `@npu_test`."""
+ test_platform = pytest.mark.npu
+ if res == "A2":
+ test_resource = pytest.mark.A2
+ elif res == "A3":
+ test_resource = pytest.mark.A3
+ else:
+ # TODO: Currently we don't have various NPU card types defined
+ # Use None to skip resource-specific marking for unknown types
+ test_resource = None
+
+ if num_cards == 1:
+ return [mark for mark in [test_platform, test_resource] if mark is not None]
+ else:
+ # Multiple cards scenario needs distributed_npu mark
+ test_distributed = pytest.mark.distributed_npu(num_cards=num_cards)
+ # TODO: add NPU support for `skipif_npu` marker
+ return [mark for mark in [test_platform, test_resource, test_distributed] if mark is not None]
+
+
+def hardware_test(*, res: dict[str, str], num_cards: int | dict[str, int] = 1):
+ """
+ Decorate a test for multiple hardware platforms with a single call.
+ Automatically wraps the test with @create_new_process_for_each_test() for distributed tests.
+
+ Args:
+ res: Mapping from platform to resource type. Supported platforms/resources:
+ - cuda: L4, H100
+ - rocm: MI325
+ - npu: A2, A3
+ num_cards: Number of cards required. Can be:
+ - int: same card count for all platforms (default: 1)
+ - dict: per-platform card count, e.g., {"cuda": 2, "rocm": 2}
+
+ Example:
+ @hardware_test(
+ res={"cuda": "L4", "rocm": "MI325", "npu": "A2"},
+ num_cards={"cuda": 2, "rocm": 2, "npu": 2},
+ )
+ def test_multi_platform():
+ ...
+ """
+ # Validate platforms
+ # Don't validate platform details in this decorator
+ for platform, _ in res.items():
+ if platform not in ("cuda", "rocm", "npu"):
+ raise ValueError(f"Unsupported platform: {platform}")
+
+ # Normalize num_cards
+ if isinstance(num_cards, int):
+ num_cards_dict = {platform: num_cards for platform in res.keys()}
+ else:
+ num_cards_dict = num_cards
+ for platform in num_cards_dict.keys():
+ if platform not in res:
+ raise ValueError(
+ f"Platform '{platform}' in num_cards but not in res. Available platforms: {list(res.keys())}"
+ )
+ for platform in res.keys():
+ if platform not in num_cards_dict:
+ num_cards_dict[platform] = 1
+
+ # Collect marks from all platforms
+ all_marks: list[Callable[[Callable[_P, None]], Callable[_P, None]]] = []
+ for platform, resource in res.items():
+ cards = num_cards_dict[platform]
+ if platform == "cuda" or platform == "rocm":
+ marks = gpu_marks(res=resource, num_cards=cards)
+ elif platform == "npu":
+ marks = npu_marks(res=resource, num_cards=cards)
+ else:
+ raise ValueError(f"Unsupported platform: {platform}")
+ all_marks.extend(marks)
+
+ create_new_process_flag = False
+ for cards in num_cards_dict.values():
+ if cards > 1:
+ create_new_process_flag = True
+ break
+
+ def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
+ if create_new_process_flag:
+ # only for distributed tests
+ func = create_new_process_for_each_test()(f)
+ else:
+ func = f
+ for mark in reversed(all_marks):
+ func = mark(func)
+ return func
+
+ return wrapper
+
+
+class GPUMemoryMonitor:
+ """Poll global device memory usage via CUDA APIs."""
+
+ def __init__(self, device_index: int, interval: float = 0.05):
+ self.device_index = device_index
+ self.interval = interval
+ self._peak_used_mb = 0.0
+ self._stop_event = threading.Event()
+ self._thread: threading.Thread | None = None
+
+ def start(self) -> None:
+ def monitor_loop() -> None:
+ while not self._stop_event.is_set():
+ try:
+ with torch.cuda.device(self.device_index):
+ free_bytes, total_bytes = torch.cuda.mem_get_info()
+ used_mb = (total_bytes - free_bytes) / (1024**2)
+ self._peak_used_mb = max(self._peak_used_mb, used_mb)
+ except Exception:
+ pass
+ time.sleep(self.interval)
+
+ self._thread = threading.Thread(target=monitor_loop, daemon=True)
+ self._thread.start()
+
+ def stop(self) -> None:
+ if self._thread is None:
+ return
+ self._stop_event.set()
+ self._thread.join(timeout=2.0)
+
+ @property
+ def peak_used_mb(self) -> float:
+ fallback_alloc = torch.cuda.max_memory_allocated(device=self.device_index) / (1024**2)
+ fallback_reserved = torch.cuda.max_memory_reserved(device=self.device_index) / (1024**2)
+ return max(self._peak_used_mb, fallback_alloc, fallback_reserved)
+
+ def __del__(self):
+ self.stop()
diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py
new file mode 100644
index 00000000000..b0132306c81
--- /dev/null
+++ b/tests/worker/test_omni_gpu_model_runner.py
@@ -0,0 +1,123 @@
+from contextlib import contextmanager
+
+import torch
+
+from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner
+
+
+class DummyBuffer:
+ """A minimal buffer wrapper that exposes the `.gpu` attribute."""
+
+ def __init__(self, t: torch.Tensor):
+ self.gpu = t
+
+
+class DummyInputBatch:
+ """A minimal input batch that only provides `req_ids`."""
+
+ def __init__(self, req_ids):
+ self.req_ids = req_ids
+
+
+class DummyReqState:
+ """A minimal request state container."""
+
+ pass
+
+
+class DummyTalkerMTP(torch.nn.Module):
+ """A fake talker_mtp module for deterministic CPU testing."""
+
+ def forward(self, req_input_ids, req_embeds, last_talker_hidden, text_step):
+ # Deterministic behavior:
+ # - output embeds = input embeds + 1
+ # - output codes = [[0], [1], ...]
+ bsz = req_embeds.shape[0]
+ new_embeds = req_embeds + 1.0
+ codes = torch.arange(bsz, dtype=torch.int64).view(bsz, 1)
+ return new_embeds, codes
+
+
+@contextmanager
+def _noop_forward_context(*args, **kwargs):
+ """A no-op context manager to replace vLLM forward context in CPU tests."""
+ yield
+
+
+def _make_runner(req_ids=("r1", "r2"), hidden_size=4):
+ # Create an instance without calling OmniGPUModelRunner.__init__
+ runner = object.__new__(OmniGPUModelRunner)
+
+ # Minimal attributes used by OmniGPUModelRunner._talker_mtp_forward
+ runner.input_batch = DummyInputBatch(list(req_ids))
+ runner.requests = {rid: DummyReqState() for rid in req_ids}
+
+ # query_start_loc.cpu[req_index] is used to locate the token position
+ # in the flattened `inputs_embeds`.
+ runner.query_start_loc = type("QSL", (), {})()
+ # Map: r1 -> offset 0, r2 -> offset 3
+ runner.query_start_loc.cpu = torch.tensor([0, 3], dtype=torch.int32)
+
+ bsz = len(req_ids)
+ runner.talker_mtp_input_ids = DummyBuffer(torch.zeros((bsz,), dtype=torch.int64))
+ runner.talker_mtp_inputs_embeds = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32))
+ runner.last_talker_hidden = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32))
+ runner.text_step = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32))
+
+ runner.talker_mtp = DummyTalkerMTP()
+ runner.vllm_config = object()
+
+ # Provide a minimal implementation that returns the expected 4-tuple.
+ def _determine_batch_execution_and_padding(**kwargs):
+ return None, object(), None, None
+
+ runner._determine_batch_execution_and_padding = _determine_batch_execution_and_padding
+
+ # Use the real merge method from OmniGPUModelRunner.
+ return runner
+
+
+def test_talker_mtp_forward_cpu_updates_inputs_and_info(monkeypatch):
+ # Patch the module-level `set_forward_context` symbol used inside
+ # OmniGPUModelRunner._talker_mtp_forward.
+ import vllm_omni.worker.gpu_model_runner as mod # Must be the same module that defines OmniGPUModelRunner
+
+ monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context)
+
+ runner = _make_runner(req_ids=("r1", "r2"), hidden_size=4)
+
+ # Initialize per-request embeds (batch-major inside talker_mtp_inputs_embeds)
+ runner.talker_mtp_inputs_embeds.gpu[0] = torch.tensor([1.0, 2.0, 3.0, 4.0])
+ runner.talker_mtp_inputs_embeds.gpu[1] = torch.tensor([10.0, 20.0, 30.0, 40.0])
+
+ # Flattened `inputs_embeds`: offsets 0 and 3 will be overwritten
+ inputs_embeds = torch.zeros((6, 4), dtype=torch.float32)
+
+ # Call the original implementation from OmniGPUModelRunner (no re-implementation)
+ OmniGPUModelRunner._talker_mtp_forward(runner, ["r1", "r2"], inputs_embeds)
+
+ # Validate embeds were written back (+1)
+ assert torch.allclose(inputs_embeds[0], torch.tensor([2.0, 3.0, 4.0, 5.0]))
+ assert torch.allclose(inputs_embeds[3], torch.tensor([11.0, 21.0, 31.0, 41.0]))
+
+ # Validate per-request additional_information_cpu was updated
+ info_r1 = runner.requests["r1"].additional_information_cpu
+ info_r2 = runner.requests["r2"].additional_information_cpu
+ assert int(info_r1["code_predictor_codes"][0, 0]) == 0
+ assert int(info_r2["code_predictor_codes"][0, 0]) == 1
+
+
+def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch):
+ import vllm_omni.worker.gpu_model_runner as mod
+
+ monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context)
+
+ runner = _make_runner(req_ids=("r1",), hidden_size=4)
+
+ inputs_embeds = torch.randn((2, 4))
+ before = inputs_embeds.clone()
+
+ OmniGPUModelRunner._talker_mtp_forward(runner, [], inputs_embeds)
+
+ # Ensure no changes were made
+ assert torch.allclose(inputs_embeds, before)
diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py
index 562999d7e58..db45f29900d 100644
--- a/tools/pre_commit/check_pickle_imports.py
+++ b/tools/pre_commit/check_pickle_imports.py
@@ -18,6 +18,7 @@
ALLOWED_FILES = {
"vllm_omni/entrypoints/omni_llm.py",
"tests/e2e/offline_inference/utils.py",
+ "tests/utils.py",
"vllm_omni/diffusion/distributed/group_coordinator.py",
"tests/diffusion/attention/test_sequence_parallel.py",
}
diff --git a/vllm_omni/assets/video.py b/vllm_omni/assets/video.py
new file mode 100644
index 00000000000..361e2ac785f
--- /dev/null
+++ b/vllm_omni/assets/video.py
@@ -0,0 +1,14 @@
+import librosa
+import numpy as np
+from vllm.assets.video import VideoAsset
+def extract_video_audio(path: str = None, sampling_rate: int = 16000) -> np.ndarray:
+ """ This function extracts the audio from a video file path and returns the audio as a numpy array.
+ Args:
+ path: The path to the video file.
+ Returns:
+ The audio as a numpy array.
+ """
+ if not path:
+ path = VideoAsset(name="baby_reading").video_path
+ audio_signal, sr = librosa.load(path, sr=sampling_rate)
+ return audio_signal
\ No newline at end of file
diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py
index 2e53a7af2e1..592c501a631 100644
--- a/vllm_omni/config/model.py
+++ b/vllm_omni/config/model.py
@@ -1,12 +1,9 @@
import warnings
-from importlib.util import find_spec
from typing import Any
import torch
-import vllm.envs as envs
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass
-from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.config import ModelConfig, config
from vllm.config.model import (
_RUNNER_CONVERTS,
@@ -24,10 +21,12 @@
get_pooling_config,
)
from vllm.transformers_utils.gguf_utils import (
+ is_gguf,
maybe_patch_hf_config_from_gguf,
)
from vllm.transformers_utils.utils import maybe_model_redirect
-from vllm.transformers_utils.gguf_utils import is_gguf
+from vllm.v1.attention.backends.registry import AttentionBackendEnum
+
import vllm_omni.model_executor.models as me_models
logger = init_logger(__name__)
@@ -108,9 +107,7 @@ def __post_init__(
video_pruning_rate: float | None,
) -> None:
# Keep set served_model_name before maybe_model_redirect(self.model)
- self.served_model_name = get_served_model_name(
- self.model, self.served_model_name
- )
+ self.served_model_name = get_served_model_name(self.model, self.served_model_name)
self.model = maybe_model_redirect(self.model)
# The tokenizer is consistent with the model by default.
if self.tokenizer is None:
@@ -167,9 +164,7 @@ def __post_init__(
if dict_overrides:
self._apply_dict_overrides(hf_config, dict_overrides)
self.hf_text_config = self.draw_hf_text_config()
- self.attention_chunk_size = getattr(
- self.hf_text_config, "attention_chunk_size", None
- )
+ self.attention_chunk_size = getattr(self.hf_text_config, "attention_chunk_size", None)
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=self.hf_token, revision=self.revision
@@ -182,9 +177,7 @@ def __post_init__(
is_pooling_model = registry.is_pooling_model(architectures, self)
self.runner_type = self._get_runner_type(architectures, self.runner)
- self.convert_type = self._get_convert_type(
- architectures, self.runner_type, self.convert
- )
+ self.convert_type = self._get_convert_type(architectures, self.runner_type, self.convert)
if self.runner_type == "generate" and not is_generative_model:
generate_converts = _RUNNER_CONVERTS["generate"]
@@ -244,10 +237,7 @@ def __post_init__(
# Init multimodal config if needed
if self._model_info.supports_multimodal:
- if (
- mm_encoder_tp_mode == "data"
- and not self._model_info.supports_multimodal_encoder_tp_data
- ):
+ if mm_encoder_tp_mode == "data" and not self._model_info.supports_multimodal_encoder_tp_data:
logger.warning_once(
"This model does not support `--mm-encoder-tp-mode data`. "
"Falling back to `--mm-encoder-tp-mode weights`."
@@ -269,9 +259,7 @@ def __post_init__(
video_pruning_rate=video_pruning_rate,
)
- mm_config_kwargs = {
- k: v for k, v in mm_config_kwargs.items() if v is not None
- }
+ mm_config_kwargs = {k: v for k, v in mm_config_kwargs.items() if v is not None}
self.multimodal_config = MultiModalConfig(**mm_config_kwargs)
diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py
index dc3f56ac2db..3c262b7c927 100644
--- a/vllm_omni/core/sched/omni_ar_scheduler.py
+++ b/vllm_omni/core/sched/omni_ar_scheduler.py
@@ -3,6 +3,8 @@
from collections import defaultdict
from time import time
+import numpy as np
+from vllm.compilation.cuda_graph import CUDAGraphStat
from vllm.distributed.kv_events import KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.logger import init_logger
@@ -10,10 +12,13 @@
from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
+from vllm.v1.metrics.perf import PerfStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
+logger = init_logger(__name__)
+
class OmniARScheduler(VLLMScheduler):
"""
@@ -73,7 +78,7 @@ def update_from_output(
pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
kv_connector_output = model_runner_output.kv_connector_output
- cudagraph_stats = model_runner_output.cudagraph_stats
+ cudagraph_stats: CUDAGraphStat | None = model_runner_output.cudagraph_stats
perf_stats: PerfStats | None = None
if self.perf_metrics and self.perf_metrics.is_enabled():
@@ -152,7 +157,7 @@ def update_from_output(
new_token_ids, stopped = self._update_request_with_output(request, new_token_ids)
if pooler_output:
- # Note: As we occupied the pooler output, for multimodal outputs, we do not intermediate stop checking for pooler output
+ # Note: For multimodal outputs, we skip intermediate stop checks.
if request.output_token_ids:
stopped = check_stop(request, self.max_model_len)
routed_experts = None
@@ -172,13 +177,10 @@ def update_from_output(
# compute slot mapping: slot = block_id * block_size + offset
slot_mapping = (
- block_offsets.reshape((1, block_size))
- + block_ids_array.reshape((num_blocks, 1)) * block_size
+ block_offsets.reshape((1, block_size)) + block_ids_array.reshape((num_blocks, 1)) * block_size
).flatten()[:num_tokens]
- routed_experts = self.routed_experts_reader.get_routed_experts(
- indices=slot_mapping
- )
+ routed_experts = self.routed_experts_reader.get_routed_experts(indices=slot_mapping)
kv_transfer_params = self._free_request(request)
if status_before_stop == RequestStatus.RUNNING:
stopped_running_reqs.add(request)
@@ -221,6 +223,7 @@ def update_from_output(
kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
+ routed_experts=routed_experts,
num_nans_in_logits=request.num_nans_in_logits,
)
)
@@ -234,7 +237,7 @@ def update_from_output(
if stopped_preempted_reqs:
# This is a rare case and unlikely to impact performance.
self.waiting.remove_requests(stopped_preempted_reqs)
-
+
if failed_kv_load_req_ids and not self.recompute_kv_load_failures:
requests = [self.requests[req_id] for req_id in failed_kv_load_req_ids]
self.finish_requests(failed_kv_load_req_ids, RequestStatus.FINISHED_ERROR)
@@ -286,7 +289,7 @@ def update_from_output(
engine_core_outputs[client_index] = EngineCoreOutputs(finished_requests=finished_set)
finished_req_ids.clear()
- if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats)) is not None:
+ if (stats := self.make_stats(spec_decoding_stats, kv_connector_stats, cudagraph_stats, perf_stats)) is not None:
# Return stats to only one of the front-ends.
if (eco := next(iter(engine_core_outputs.values()), None)) is None:
# We must return the stats even if there are no request
diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py
index 7acef3bdcb5..2d04faeeea7 100644
--- a/vllm_omni/core/sched/omni_generation_scheduler.py
+++ b/vllm_omni/core/sched/omni_generation_scheduler.py
@@ -90,7 +90,7 @@ def schedule(self) -> SchedulerOutput:
any_request = self.running[0]
num_common_prefix_blocks = self.kv_cache_manager.get_num_common_prefix_blocks(any_request.request_id)
- # Assemble SchedulerOutput (align with v0.12.0)
+ # Assemble SchedulerOutput (align with v0.14.0)
if self.use_v2_model_runner:
# No resumed reqs in fast path; pass prefill_token_ids for new reqs.
new_reqs_data = [
@@ -129,7 +129,7 @@ def schedule(self) -> SchedulerOutput:
preempted_req_ids=set(),
)
- # Record the request ids scheduled in this step (v0.12.0 behavior).
+ # Record the request ids scheduled in this step (v0.14.0 behavior).
self.prev_step_scheduled_req_ids.clear()
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())
@@ -176,7 +176,7 @@ def update_from_output(
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats = kv_connector_output.kv_connector_stats if kv_connector_output else None
- # Merge connector-side stats (align with v0.12.0)
+ # Merge connector-side stats (align with v0.14.0)
if kv_connector_stats and self.connector:
kv_stats = self.connector.get_kv_connector_stats()
if kv_stats:
@@ -294,7 +294,7 @@ def update_from_output(
if kv_connector_output:
self._update_from_kv_xfer_finished(kv_connector_output)
- # Collect and publish KV cache events (align with v0.12.0)
+ # Collect and publish KV cache events (align with v0.14.0)
events = self.kv_cache_manager.take_events()
if self.connector is not None:
connector_events = self.connector.take_events()
diff --git a/vllm_omni/diffusion/attention/backends/flash_attn.py b/vllm_omni/diffusion/attention/backends/flash_attn.py
index 6921c79aee9..3623d49db8f 100644
--- a/vllm_omni/diffusion/attention/backends/flash_attn.py
+++ b/vllm_omni/diffusion/attention/backends/flash_attn.py
@@ -9,13 +9,14 @@
AttentionImpl,
AttentionMetadata,
)
+from vllm_omni.diffusion.attention.backends.utils.fa import _pad_input, _unpad_input, _upad_input
logger = init_logger(__name__)
try:
# only tested with flash_attn v3
# from flash_attn_interface import flash_attn_func as flash_attn_3_func # not available in flash-attn 2.8.1
- from flash_attn import flash_attn_func # can be FA2 or FA3
+ from flash_attn import flash_attn_func, flash_attn_varlen_func # can be FA2 or FA3
except ImportError:
logger.warning(
"FlashAttentionBackend is not available. You may install flash-attn "
@@ -63,12 +64,51 @@ def forward(
value: torch.Tensor,
attn_metadata: AttentionMetadata = None,
) -> torch.Tensor:
- # TODO: flash_attn_func does not support attn_mask.
- out: torch.Tensor = flash_attn_func(
- query,
- key,
- value,
- causal=self.causal,
- softmax_scale=self.softmax_scale,
- )
+ """
+ Flash attention implementation.
+
+ Args:
+ query: (batch_size, seq_len, num_heads, head_dim)
+ key: (batch_size, seq_len, num_heads, head_dim)
+ value: (batch_size, seq_len, num_heads, head_dim)
+ attn_metadata: AttentionMetadata. Attention mask is supported as attn_metadata.attn_mask
+
+ Returns:
+ (batch_size, seq_len, num_heads, head_dim)
+ """
+ query_length = query.size(1)
+ attention_mask = attn_metadata.attn_mask if attn_metadata is not None else None
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None and torch.any(~attention_mask):
+ assert attention_mask.ndim == 2, "attention_mask must be 2D, (batch_size, seq_len)"
+ q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input(
+ query, key, value, attention_mask, query_length, _unpad_input
+ )
+
+ out_unpad = flash_attn_varlen_func(
+ q,
+ k,
+ v,
+ cu_seqlens_q=cu_seq_lens_q,
+ cu_seqlens_k=cu_seq_lens_k,
+ max_seqlen_q=max_length_q,
+ max_seqlen_k=max_length_k,
+ **{
+ "causal": self.causal,
+ "softmax_scale": self.softmax_scale,
+ },
+ )
+ if isinstance(out_unpad, tuple):
+ out_unpad = out_unpad[0]
+
+ out = _pad_input(out_unpad, indices_q, query.size(0), query_length)
+
+ else:
+ out: torch.Tensor = flash_attn_func(
+ query,
+ key,
+ value,
+ causal=self.causal,
+ softmax_scale=self.softmax_scale,
+ )
return out
diff --git a/vllm_omni/diffusion/attention/backends/ring_flash_attn.py b/vllm_omni/diffusion/attention/backends/ring_flash_attn.py
index e163ef1ace1..dd27f88a8c1 100644
--- a/vllm_omni/diffusion/attention/backends/ring_flash_attn.py
+++ b/vllm_omni/diffusion/attention/backends/ring_flash_attn.py
@@ -28,6 +28,20 @@ def ring_flash_attn_forward(
joint_tensor_value=None,
joint_strategy="front",
):
+ # Validate causal + joint_strategy combination
+ # When causal=True and joint_strategy="rear", the causal mask would incorrectly
+ # prevent local query tokens from attending to joint key tokens (which are
+ # concatenated at the end). This breaks the semantics where joint tokens
+ # (e.g., text conditioning) should be visible to all local tokens.
+ if causal and joint_tensor_key is not None and joint_strategy == "rear":
+ raise ValueError(
+ "joint_strategy='rear' is not compatible with causal=True in Ring Attention. "
+ "When using causal attention with joint tokens, use joint_strategy='front' "
+ "to ensure joint tokens act as a visible prefix for all local tokens. "
+ "With 'rear' strategy, the causal mask would incorrectly block local tokens "
+ "from seeing the joint tokens."
+ )
+
comm = RingComm(process_group)
out = None
diff --git a/vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py b/vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py
index 43ee35f7098..9ed2c2076c4 100644
--- a/vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py
+++ b/vllm_omni/diffusion/attention/backends/ring_pytorch_attn.py
@@ -62,6 +62,20 @@ def forward(
joint_tensor_value=None,
joint_strategy="front",
):
+ # Validate causal + joint_strategy combination
+ # When causal=True and joint_strategy="rear", the causal mask would incorrectly
+ # prevent local query tokens from attending to joint key tokens (which are
+ # concatenated at the end). This breaks the semantics where joint tokens
+ # (e.g., text conditioning) should be visible to all local tokens.
+ if is_causal and joint_tensor_key is not None and joint_strategy == "rear":
+ raise ValueError(
+ "joint_strategy='rear' is not compatible with causal=True in Ring Attention. "
+ "When using causal attention with joint tokens, use joint_strategy='front' "
+ "to ensure joint tokens act as a visible prefix for all local tokens. "
+ "With 'rear' strategy, the causal mask would incorrectly block local tokens "
+ "from seeing the joint tokens."
+ )
+
comm = RingComm(group)
# Ensure tensors are contiguous for P2P communication
q = q.contiguous()
diff --git a/vllm_omni/diffusion/attention/backends/utils/__init__.py b/vllm_omni/diffusion/attention/backends/utils/__init__.py
new file mode 100644
index 00000000000..92c7c8027cb
--- /dev/null
+++ b/vllm_omni/diffusion/attention/backends/utils/__init__.py
@@ -0,0 +1,13 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Utils for attention backends.
+"""
+
+from vllm_omni.diffusion.attention.backends.utils.fa import _pad_input, _unpad_input, _upad_input
+
+__all__ = [
+ "_pad_input",
+ "_unpad_input",
+ "_upad_input",
+]
diff --git a/vllm_omni/diffusion/attention/backends/utils/fa.py b/vllm_omni/diffusion/attention/backends/utils/fa.py
new file mode 100644
index 00000000000..d89082e717a
--- /dev/null
+++ b/vllm_omni/diffusion/attention/backends/utils/fa.py
@@ -0,0 +1,209 @@
+# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flash_attention_utils.py
+import torch
+import torch.nn.functional as F
+
+
+def _index_first_axis(tensor, indices):
+ """
+ A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
+ after flattening the first two dimensions of the tensor. This is functionally equivalent to
+ FA2's `index_first_axis` and replaces the need to import it.
+ """
+ # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
+ # two dimensions to get (total_tokens, ...) before indexing.
+ reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
+ return reshaped_tensor[indices]
+
+
+def _unpad_input(hidden_states, attention_mask, unused_mask=None):
+ """
+ unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
+
+ Arguments:
+ hidden_states: (batch, seqlen, ...)
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
+ unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
+
+ Return:
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
+ indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
+ max_seqlen_in_batch: int
+ seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
+ """
+ all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
+ seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
+ used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+
+ return (
+ _index_first_axis(hidden_states, indices),
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ used_seqlens_in_batch,
+ )
+
+
+def _pad_input(hidden_states, indices, batch, seqlen):
+ """
+ pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
+
+ Arguments:
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
+ batch: int, batch size for the padded sequence.
+ seqlen: int, maximum sequence length for the padded sequence.
+
+ Return:
+ hidden_states: (batch, seqlen, ...)
+ """
+ dim = hidden_states.shape[1:]
+ output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
+ output[indices] = hidden_states
+ return output.view(batch, seqlen, *dim)
+
+
+def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
+ """
+ Retrieves indexing data required to repad unpadded (ragged) tensors.
+
+ Arguments:
+ attention_mask (`torch.Tensor`):
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+
+ Return:
+ indices (`torch.Tensor`):
+ The indices of non-masked tokens from the flattened input sequence.
+ cu_seqlens (`torch.Tensor`):
+ The cumulative sequence lengths, used to index into ragged (unpadded) tensors.
+ `cu_seqlens` shape is (batch_size + 1,).
+ max_seqlen_in_batch (`int`):
+ Maximum sequence length in batch.
+ """
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile,
+ # this might cause a graph break
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+def _upad_input(
+ query_layer: torch.Tensor,
+ key_layer: torch.Tensor,
+ value_layer: torch.Tensor,
+ attention_mask: torch.Tensor,
+ query_length: int,
+ unpad_input_func,
+):
+ """
+ Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong
+ to different batches. This function is used instead of `flash_attn.bert_padding.unpad_input` in
+ order to avoid the recomputation of the same intermediary tensors for query, key, value tensors.
+
+ Arguments:
+ query_layer (`torch.Tensor`):
+ Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
+ key_layer (`torch.Tensor`):
+ Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+ value_layer (`torch.Tensor`):
+ Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+ attention_mask (`torch.Tensor`):
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+ query_length (`int`):
+ Target length.
+ unpad_input_func:
+ The function to use for unpadding the input tensors.
+
+ Return:
+ query_layer (`torch.Tensor`):
+ Query state without padding. Shape: (total_target_length, num_heads, head_dim).
+ key_layer (`torch.Tensor`):
+ Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+ value_layer (`torch.Tensor`):
+ Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+ indices_q (`torch.Tensor`):
+ The indices of non-masked tokens from the flattened input target sequence.
+ (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into
+ ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query,
+ `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
+ """
+ if torch.compiler.is_compiling():
+ # allow PyTorch compiler to include operations that return scalar values (like .item()
+ torch._dynamo.config.capture_scalar_outputs = True
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+
+ # With static caches, the k/v states may be larger than the mask ->
+ # we need to slice them to avoid generating garbage
+ # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores
+ if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]):
+ key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :]
+
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = _index_first_axis(key_layer, indices_k)
+ value_layer = _index_first_axis(value_layer, indices_k)
+ if query_length == kv_seq_len:
+ query_layer = _index_first_axis(query_layer, indices_k)
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+def _is_packed_sequence(position_ids, batch_size):
+ """
+ Check the position ids whether packed sequences are indicated or not
+ 1. Position ids exist
+ 2. Flattened sequences only are supported
+ 3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e.
+ we have multiple increasing sequences
+ """
+ if position_ids is None:
+ return False
+
+ increasing_position_sequences = torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min()
+ return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool()
diff --git a/vllm_omni/diffusion/attention/selector.py b/vllm_omni/diffusion/attention/selector.py
index f2f60243f04..b9f62e0681d 100644
--- a/vllm_omni/diffusion/attention/selector.py
+++ b/vllm_omni/diffusion/attention/selector.py
@@ -31,6 +31,8 @@
"ASCEND": {"module": "vllm_omni.diffusion.attention.backends.ascend_attn", "class": "AscendAttentionBackend"},
}
+_BACKENDS_SUPPORT_ATTENTION_MASK = ["SDPA", "ASCEND", "FLASH_ATTN"]
+
def load_backend(backend_name: str) -> type[AttentionBackend]:
config = _BACKEND_CONFIG[backend_name]
diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py
index 3485e2262c9..0c43659c9a6 100644
--- a/vllm_omni/diffusion/cache/cache_dit_backend.py
+++ b/vllm_omni/diffusion/cache/cache_dit_backend.py
@@ -7,11 +7,19 @@
pipelines in vllm-omni, supporting both single and dual-transformer architectures.
"""
+import functools
from collections.abc import Callable
+from contextlib import ExitStack
from typing import Any, Optional
import cache_dit
+import torch
from cache_dit import BlockAdapter, DBCacheConfig, ForwardPattern, ParamsModifier, TaylorSeerCalibratorConfig
+from cache_dit.caching.block_adapters import FakeDiffusionPipeline
+from cache_dit.caching.cache_adapters.cache_adapter import CachedAdapter
+from cache_dit.caching.cache_blocks.pattern_0_1_2 import CachedBlocks_Pattern_0_1_2
+from cache_dit.caching.cache_contexts import BasicCacheConfig
+from cache_dit.caching.cache_contexts.cache_manager import CachedContextManager
from vllm.logger import init_logger
from vllm_omni.diffusion.cache.base import CacheBackend
@@ -401,6 +409,384 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
return refresh_cache_context
+class BagelCachedContextManager(CachedContextManager):
+ """
+ Custom CachedContextManager for Bagel that safely handles NaiveCache objects
+ (mapped to encoder_hidden_states) by skipping tensor operations on them.
+ """
+
+ @torch.compiler.disable
+ def apply_cache(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ prefix: str = "Bn",
+ encoder_prefix: str = "Bn_encoder",
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ # Allow Bn and Fn prefix to be used for residual cache.
+ if "Bn" in prefix:
+ hidden_states_prev = self.get_Bn_buffer(prefix)
+ else:
+ hidden_states_prev = self.get_Fn_buffer(prefix)
+
+ assert hidden_states_prev is not None, f"{prefix}_buffer must be set before"
+
+ if self.is_cache_residual():
+ hidden_states = hidden_states_prev + hidden_states
+ else:
+ # If cache is not residual, we use the hidden states directly
+ hidden_states = hidden_states_prev
+
+ hidden_states = hidden_states.contiguous()
+
+ if encoder_hidden_states is not None:
+ if "Bn" in encoder_prefix:
+ encoder_hidden_states_prev = self.get_Bn_encoder_buffer(encoder_prefix)
+ else:
+ encoder_hidden_states_prev = self.get_Fn_encoder_buffer(encoder_prefix)
+
+ if encoder_hidden_states_prev is not None:
+ if self.is_encoder_cache_residual():
+ # FIX: Check if encoder_hidden_states is a tensor before adding
+ if isinstance(encoder_hidden_states, torch.Tensor) and isinstance(
+ encoder_hidden_states_prev, torch.Tensor
+ ):
+ encoder_hidden_states = encoder_hidden_states_prev + encoder_hidden_states
+ else:
+ # If encoder cache is not residual, we use the encoder hidden states directly
+ encoder_hidden_states = encoder_hidden_states_prev
+
+ # FIX: Check if encoder_hidden_states is a tensor before calling contiguous
+ if isinstance(encoder_hidden_states, torch.Tensor):
+ encoder_hidden_states = encoder_hidden_states.contiguous()
+
+ return hidden_states, encoder_hidden_states
+
+
+class BagelCachedBlocks(CachedBlocks_Pattern_0_1_2):
+ """
+ Custom CachedBlocks for Bagel that safely handles NaiveCache objects
+ by adding isinstance checks in call_Mn_blocks and compute_or_prune.
+ """
+
+ def call_Mn_blocks(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ *args,
+ **kwargs,
+ ):
+ original_hidden_states = hidden_states
+ original_encoder_hidden_states = encoder_hidden_states
+ for block in self._Mn_blocks():
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ *args,
+ **kwargs,
+ )
+ hidden_states, encoder_hidden_states = self._process_block_outputs(hidden_states, encoder_hidden_states)
+
+ # compute hidden_states residual
+ hidden_states = hidden_states.contiguous()
+
+ hidden_states_residual = hidden_states - original_hidden_states
+
+ if (
+ encoder_hidden_states is not None
+ and original_encoder_hidden_states is not None
+ and isinstance(encoder_hidden_states, torch.Tensor) # FIX: Added Check
+ ):
+ encoder_hidden_states = encoder_hidden_states.contiguous()
+ encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
+ else:
+ encoder_hidden_states_residual = None
+
+ return (
+ hidden_states,
+ encoder_hidden_states,
+ hidden_states_residual,
+ encoder_hidden_states_residual,
+ )
+
+ def compute_or_prune(
+ self,
+ block_id: int, # Block index in the transformer blocks
+ # Below are the inputs to the block
+ block, # The transformer block to be executed
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ *args,
+ **kwargs,
+ ):
+ # NOTE: Although Bagel likely won't use pruning, implementing safe version just in case.
+ # Copy-pasted from original but adding checks.
+
+ original_hidden_states = hidden_states
+ original_encoder_hidden_states = encoder_hidden_states
+
+ can_use_prune = self._maybe_prune(
+ block_id,
+ hidden_states,
+ prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
+ )
+
+ torch._dynamo.graph_break()
+ if can_use_prune:
+ self.context_manager.add_pruned_step()
+ hidden_states, encoder_hidden_states = self.context_manager.apply_prune(
+ hidden_states,
+ encoder_hidden_states,
+ prefix=(
+ f"{self.cache_prefix}_{block_id}_Bn_residual"
+ if self.context_manager.is_cache_residual()
+ else f"{self.cache_prefix}_Bn_hidden_states"
+ ),
+ encoder_prefix=(
+ f"{self.cache_prefix}_{block_id}_Bn_encoder_residual"
+ if self.context_manager.is_encoder_cache_residual()
+ else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states"
+ ),
+ )
+ torch._dynamo.graph_break()
+ else:
+ # Normal steps: Compute the block and cache the residuals.
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ *args,
+ **kwargs,
+ )
+ hidden_states, encoder_hidden_states = self._process_block_outputs(hidden_states, encoder_hidden_states)
+ if not self._skip_prune(block_id):
+ hidden_states = hidden_states.contiguous()
+ hidden_states_residual = hidden_states - original_hidden_states
+
+ if (
+ encoder_hidden_states is not None
+ and original_encoder_hidden_states is not None
+ and isinstance(encoder_hidden_states, torch.Tensor) # FIX: Added Check
+ ):
+ encoder_hidden_states = encoder_hidden_states.contiguous()
+ encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
+ else:
+ encoder_hidden_states_residual = None
+
+ self.context_manager.set_Fn_buffer(
+ original_hidden_states,
+ prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
+ )
+ if self.context_manager.is_cache_residual():
+ self.context_manager.set_Bn_buffer(
+ hidden_states_residual,
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_residual",
+ )
+ else:
+ self.context_manager.set_Bn_buffer(
+ hidden_states,
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states",
+ )
+ if encoder_hidden_states_residual is not None:
+ if self.context_manager.is_encoder_cache_residual():
+ self.context_manager.set_Bn_encoder_buffer(
+ encoder_hidden_states_residual,
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual",
+ )
+ else:
+ self.context_manager.set_Bn_encoder_buffer(
+ encoder_hidden_states_residual,
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states",
+ )
+ torch._dynamo.graph_break()
+
+ return hidden_states, encoder_hidden_states
+
+
+class BagelCachedAdapter(CachedAdapter):
+ """
+ Custom CachedAdapter for Bagel that uses BagelCachedContextManager and BagelCachedBlocks.
+ """
+
+ @classmethod
+ def create_context(
+ cls,
+ block_adapter: BlockAdapter,
+ **context_kwargs,
+ ) -> tuple[list[str], list[dict[str, Any]]]:
+ # Override to use BagelCachedContextManager
+
+ BlockAdapter.assert_normalized(block_adapter)
+
+ if BlockAdapter.is_cached(block_adapter.pipe):
+ return block_adapter.pipe
+
+ # Check context_kwargs
+ context_kwargs = cls.check_context_kwargs(block_adapter, **context_kwargs)
+
+ # Each Pipeline should have it's own context manager instance.
+ cache_config: BasicCacheConfig = context_kwargs.get("cache_config", None)
+ assert cache_config is not None, "cache_config can not be None."
+
+ # Apply cache on pipeline: wrap cache context
+ pipe_cls_name = block_adapter.pipe.__class__.__name__
+
+ # USE CUSTOM CONTEXT MANAGER
+ context_manager = BagelCachedContextManager(
+ name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
+ persistent_context=isinstance(block_adapter.pipe, FakeDiffusionPipeline),
+ )
+
+ flatten_contexts, contexts_kwargs = cls.modify_context_params(block_adapter, **context_kwargs)
+
+ block_adapter.pipe._context_manager = context_manager # instance level
+
+ if not context_manager.persistent_context:
+ original_call = block_adapter.pipe.__class__.__call__
+
+ @functools.wraps(original_call)
+ def new_call(self, *args, **kwargs):
+ with ExitStack() as stack:
+ # cache context will be reset for each pipe inference
+ for context_name, context_kwargs in zip(flatten_contexts, contexts_kwargs):
+ stack.enter_context(
+ context_manager.enter_context(
+ context_manager.reset_context(
+ context_name,
+ **context_kwargs,
+ ),
+ )
+ )
+ outputs = original_call(self, *args, **kwargs)
+ cls.apply_stats_hooks(block_adapter)
+ return outputs
+
+ block_adapter.pipe.__class__.__call__ = new_call
+ block_adapter.pipe.__class__._original_call = original_call
+
+ else:
+ # Init persistent cache context for transformer
+ for context_name, context_kwargs in zip(flatten_contexts, contexts_kwargs):
+ context_manager.reset_context(
+ context_name,
+ **context_kwargs,
+ )
+
+ block_adapter.pipe.__class__._is_cached = True
+
+ cls.apply_params_hooks(block_adapter, contexts_kwargs)
+
+ return flatten_contexts, contexts_kwargs
+
+ @classmethod
+ def collect_unified_blocks(
+ cls,
+ block_adapter: BlockAdapter,
+ contexts_kwargs: list[dict],
+ ) -> list[dict[str, torch.nn.ModuleList]]:
+ # Override to use BagelCachedBlocks
+
+ BlockAdapter.assert_normalized(block_adapter)
+
+ total_cached_blocks: list[dict[str, torch.nn.ModuleList]] = []
+ assert hasattr(block_adapter.pipe, "_context_manager")
+ # Skipping isinstance check for ContextManager._supported_managers to avoid import issues
+
+ for i in range(len(block_adapter.transformer)):
+ unified_blocks_bind_context = {}
+ for j in range(len(block_adapter.blocks[i])):
+ cache_config: BasicCacheConfig = contexts_kwargs[i * len(block_adapter.blocks[i]) + j]["cache_config"]
+
+ # Directly instantiate BagelCachedBlocks
+ unified_blocks_bind_context[block_adapter.unique_blocks_name[i][j]] = torch.nn.ModuleList(
+ [
+ BagelCachedBlocks(
+ # 0. Transformer blocks configuration
+ block_adapter.blocks[i][j],
+ transformer=block_adapter.transformer[i],
+ forward_pattern=block_adapter.forward_pattern[i][j],
+ check_forward_pattern=block_adapter.check_forward_pattern,
+ check_num_outputs=block_adapter.check_num_outputs,
+ # 1. Cache/Prune context configuration
+ cache_prefix=block_adapter.blocks_name[i][j],
+ cache_context=block_adapter.unique_blocks_name[i][j],
+ context_manager=block_adapter.pipe._context_manager,
+ cache_type=cache_config.cache_type,
+ )
+ ]
+ )
+
+ total_cached_blocks.append(unified_blocks_bind_context)
+
+ return total_cached_blocks
+
+
+def enable_cache_for_bagel(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
+ """Enable cache-dit for Bagel model (via OmniDiffusion pipeline).
+
+ Args:
+ pipeline: The OmniDiffusion pipeline instance.
+ cache_config: DiffusionCacheConfig instance with cache configuration.
+
+ Returns:
+ A refresh function that can be called to update cache context with new num_inference_steps.
+ """
+ # Build DBCacheConfig
+ db_cache_config = _build_db_cache_config(cache_config)
+
+ # Build calibrator config if TaylorSeer is enabled
+ calibrator_config = None
+ if cache_config.enable_taylorseer:
+ taylorseer_order = cache_config.taylorseer_order
+ calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
+ logger.info(f"TaylorSeer enabled with order={taylorseer_order}")
+
+ # Access the transformer: BagelPipeline -> Qwen2MoTForCausalLM -> Qwen2MoTModel
+ # BagelPipeline has self.language_model which is Qwen2MoTForCausalLM
+ # Qwen2MoTForCausalLM has self.model which is Qwen2MoTModel
+ transformer = pipeline.language_model.model
+
+ logger.info(
+ f"Enabling cache-dit on Bagel transformer: "
+ f"Fn={db_cache_config.Fn_compute_blocks}, "
+ f"Bn={db_cache_config.Bn_compute_blocks}, "
+ f"W={db_cache_config.max_warmup_steps}, "
+ )
+
+ # Enable cache-dit on the transformer
+ # Pattern_0 corresponds to (hidden_states, encoder_hidden_states) input, output
+ # Custom adapter for Bagel to handle NaiveCache correctly
+ # from vllm_omni.diffusion.cache.bagel_cache_adapter import BagelCachedAdapter # No longer needed
+ BagelCachedAdapter.apply(
+ BlockAdapter(
+ transformer=transformer,
+ blocks=transformer.layers,
+ forward_pattern=ForwardPattern.Pattern_0,
+ ),
+ cache_config=db_cache_config,
+ calibrator_config=calibrator_config,
+ )
+
+ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
+ transformer = pipeline.language_model.model
+ if cache_config.scm_steps_mask_policy is None:
+ cache_dit.refresh_context(transformer, num_inference_steps=num_inference_steps, verbose=verbose)
+ else:
+ cache_dit.refresh_context(
+ transformer,
+ cache_config=DBCacheConfig().reset(
+ num_inference_steps=num_inference_steps,
+ steps_computation_mask=cache_dit.steps_mask(
+ mask_policy=cache_config.scm_steps_mask_policy,
+ total_steps=num_inference_steps,
+ ),
+ steps_computation_policy=cache_config.scm_steps_policy,
+ ),
+ verbose=verbose,
+ )
+
+ return refresh_cache_context
+
+
# Register custom cache-dit enablers after function definitions
CUSTOM_DIT_ENABLERS.update(
{
@@ -409,6 +795,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
"LongCatImagePipeline": enable_cache_for_longcat_image,
"LongCatImageEditPipeline": enable_cache_for_longcat_image,
"StableDiffusion3Pipeline": enable_cache_for_sd3,
+ "BagelPipeline": enable_cache_for_bagel,
}
)
diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py
index 64184caa749..1948b40a3b4 100644
--- a/vllm_omni/diffusion/diffusion_engine.py
+++ b/vllm_omni/diffusion/diffusion_engine.py
@@ -293,6 +293,7 @@ def add_req_and_wait_for_response(self, requests: list[OmniDiffusionRequest]):
def _dummy_run(self):
"""A dummy run to warm up the model."""
prompt = "dummy run"
+ # note that num_inference_steps=1 will cause timestep and temb None in the pipeline
num_inference_steps = 1
height = 1024
width = 1024
diff --git a/vllm_omni/diffusion/models/bagel/bagel_transformer.py b/vllm_omni/diffusion/models/bagel/bagel_transformer.py
index 7950389ef1f..256f25e0839 100644
--- a/vllm_omni/diffusion/models/bagel/bagel_transformer.py
+++ b/vllm_omni/diffusion/models/bagel/bagel_transformer.py
@@ -314,10 +314,12 @@ def __init__(
def forward(
self,
- packed_query_sequence: torch.Tensor,
- query_lens: torch.Tensor,
- packed_query_position_embeddings: torch.Tensor,
- packed_query_indexes: torch.Tensor,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor | None = None,
+ packed_query_sequence: torch.Tensor | None = None,
+ query_lens: torch.Tensor = None,
+ packed_query_position_embeddings: torch.Tensor = None,
+ packed_query_indexes: torch.Tensor = None,
past_key_values: NaiveCache | None = None,
key_values_lens: torch.Tensor | None = None,
packed_key_value_indexes: torch.Tensor | None = None,
@@ -327,6 +329,8 @@ def forward(
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
+ if packed_query_sequence is None:
+ packed_query_sequence = hidden_states
residual = packed_query_sequence
if mode == "und":
packed_query_sequence = self.input_layernorm(packed_query_sequence)
@@ -437,7 +441,8 @@ def forward(
for layer_idx, decoder_layer in enumerate(self.layers):
packed_query_sequence, past_key_values = decoder_layer(
- packed_query_sequence=packed_query_sequence,
+ hidden_states=packed_query_sequence,
+ encoder_hidden_states=None,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
diff --git a/vllm_omni/diffusion/models/flux2_klein/__init__.py b/vllm_omni/diffusion/models/flux2_klein/__init__.py
new file mode 100644
index 00000000000..0d477ab0a48
--- /dev/null
+++ b/vllm_omni/diffusion/models/flux2_klein/__init__.py
@@ -0,0 +1,17 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Flux2 klein diffusion model components."""
+
+from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import (
+ Flux2Transformer2DModel,
+)
+from vllm_omni.diffusion.models.flux2_klein.pipeline_flux2_klein import (
+ Flux2KleinPipeline,
+ get_flux2_klein_post_process_func,
+)
+
+__all__ = [
+ "Flux2KleinPipeline",
+ "Flux2Transformer2DModel",
+ "get_flux2_klein_post_process_func",
+]
diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py
new file mode 100644
index 00000000000..86658a01deb
--- /dev/null
+++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py
@@ -0,0 +1,723 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Iterable
+from types import SimpleNamespace
+from typing import Any
+
+import torch
+import torch.nn as nn
+from diffusers.models.embeddings import (
+ TimestepEmbedding,
+ Timesteps,
+ get_1d_rotary_pos_embed,
+)
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.normalization import AdaLayerNormContinuous
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+
+from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
+from vllm_omni.diffusion.attention.layer import Attention
+from vllm_omni.diffusion.layers.rope import RotaryEmbedding
+
+
+class Flux2SwiGLU(nn.Module):
+ """SwiGLU activation used by Flux2."""
+
+ def __init__(self):
+ super().__init__()
+ self.gate_fn = nn.SiLU()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x1, x2 = x.chunk(2, dim=-1)
+ return self.gate_fn(x1) * x2
+
+
+class Flux2FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int | None = None,
+ mult: float = 3.0,
+ inner_dim: int | None = None,
+ bias: bool = False,
+ ):
+ super().__init__()
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
+ dim_out = dim_out or dim
+
+ self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias)
+ self.act_fn = Flux2SwiGLU()
+ self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.linear_in(x)
+ x = self.act_fn(x)
+ return self.linear_out(x)
+
+
+class Flux2Attention(nn.Module):
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ added_kv_proj_dim: int | None = None,
+ added_proj_bias: bool | None = True,
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ out_dim: int = None,
+ elementwise_affine: bool = True,
+ ):
+ super().__init__()
+ self.head_dim = dim_head
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ self.dropout = dropout
+ self.added_kv_proj_dim = added_kv_proj_dim
+
+ self.to_qkv = QKVParallelLinear(
+ hidden_size=query_dim,
+ head_size=self.head_dim,
+ total_num_heads=self.heads,
+ disable_tp=True,
+ bias=bias,
+ )
+
+ self.norm_q = RMSNorm(dim_head, eps=eps)
+ self.norm_k = RMSNorm(dim_head, eps=eps)
+
+ self.to_out = nn.ModuleList(
+ [ReplicatedLinear(self.inner_dim, self.out_dim, bias=out_bias), nn.Dropout(dropout)]
+ )
+
+ if added_kv_proj_dim is not None:
+ self.norm_added_q = RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = RMSNorm(dim_head, eps=eps)
+ self.add_kv_proj = QKVParallelLinear(
+ hidden_size=added_kv_proj_dim,
+ head_size=self.head_dim,
+ total_num_heads=self.heads,
+ disable_tp=True,
+ bias=added_proj_bias,
+ )
+ self.to_add_out = ReplicatedLinear(self.inner_dim, query_dim, bias=out_bias)
+
+ self.rope = RotaryEmbedding(is_neox_style=False)
+ self.attn = Attention(
+ num_heads=self.heads,
+ head_size=self.head_dim,
+ softmax_scale=1.0 / (self.head_dim**0.5),
+ causal=False,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor | None = None,
+ attention_mask: torch.Tensor | None = None,
+ image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
+ **kwargs,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ qkv, _ = self.to_qkv(hidden_states)
+ query, key, value = qkv.chunk(3, dim=-1)
+
+ encoder_query = encoder_key = encoder_value = None
+ if encoder_hidden_states is not None and self.added_kv_proj_dim is not None:
+ encoder_qkv, _ = self.add_kv_proj(encoder_hidden_states)
+ encoder_query, encoder_key, encoder_value = encoder_qkv.chunk(3, dim=-1)
+
+ query = query.unflatten(-1, (self.heads, -1))
+ key = key.unflatten(-1, (self.heads, -1))
+ value = value.unflatten(-1, (self.heads, -1))
+
+ query = self.norm_q(query)
+ key = self.norm_k(key)
+
+ if encoder_hidden_states is not None and self.added_kv_proj_dim is not None:
+ encoder_query = encoder_query.unflatten(-1, (self.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (self.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (self.heads, -1))
+
+ encoder_query = self.norm_added_q(encoder_query)
+ encoder_key = self.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ cos, sin = image_rotary_emb
+ cos = cos.to(query.dtype)
+ sin = sin.to(query.dtype)
+ query = self.rope(query, cos, sin)
+ key = self.rope(key, cos, sin)
+
+ attn_metadata = None
+ if attention_mask is not None:
+ if attention_mask.dim() == 3:
+ attention_mask = attention_mask.unsqueeze(1)
+ attn_metadata = AttentionMetadata(attn_mask=attention_mask)
+
+ hidden_states = self.attn(query, key, value, attn_metadata)
+ hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ context_len = encoder_hidden_states.shape[1]
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [context_len, hidden_states.shape[1] - context_len],
+ dim=1,
+ )
+ encoder_hidden_states, _ = self.to_add_out(encoder_hidden_states)
+
+ hidden_states, _ = self.to_out[0](hidden_states)
+ hidden_states = self.to_out[1](hidden_states)
+
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ return hidden_states
+
+
+class Flux2ParallelSelfAttention(nn.Module):
+ """
+ Parallel attention block that fuses QKV projections with MLP input projections.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ out_dim: int = None,
+ elementwise_affine: bool = True,
+ mlp_ratio: float = 4.0,
+ mlp_mult_factor: int = 2,
+ ):
+ super().__init__()
+ self.head_dim = dim_head
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ self.dropout = dropout
+
+ self.mlp_ratio = mlp_ratio
+ self.mlp_hidden_dim = int(query_dim * self.mlp_ratio)
+ self.mlp_mult_factor = mlp_mult_factor
+
+ self.to_qkv_mlp_proj = nn.Linear(
+ self.query_dim,
+ self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor,
+ bias=bias,
+ )
+ self.mlp_act_fn = Flux2SwiGLU()
+
+ self.norm_q = RMSNorm(dim_head, eps=eps)
+ self.norm_k = RMSNorm(dim_head, eps=eps)
+
+ self.to_out = nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)
+ self.rope = RotaryEmbedding(is_neox_style=False)
+ self.attn = Attention(
+ num_heads=self.heads,
+ head_size=self.head_dim,
+ softmax_scale=1.0 / (self.head_dim**0.5),
+ causal=False,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor | None = None,
+ image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ hidden_states = self.to_qkv_mlp_proj(hidden_states)
+ qkv, mlp_hidden_states = torch.split(
+ hidden_states,
+ [3 * self.inner_dim, self.mlp_hidden_dim * self.mlp_mult_factor],
+ dim=-1,
+ )
+
+ query, key, value = qkv.chunk(3, dim=-1)
+ query = query.unflatten(-1, (self.heads, -1))
+ key = key.unflatten(-1, (self.heads, -1))
+ value = value.unflatten(-1, (self.heads, -1))
+
+ query = self.norm_q(query)
+ key = self.norm_k(key)
+
+ if image_rotary_emb is not None:
+ cos, sin = image_rotary_emb
+ cos = cos.to(query.dtype)
+ sin = sin.to(query.dtype)
+ query = self.rope(query, cos, sin)
+ key = self.rope(key, cos, sin)
+
+ attn_metadata = None
+ if attention_mask is not None:
+ if attention_mask.dim() == 3:
+ attention_mask = attention_mask.unsqueeze(1)
+ attn_metadata = AttentionMetadata(attn_mask=attention_mask)
+
+ attn_output = self.attn(query, key, value, attn_metadata)
+ attn_output = attn_output.flatten(2, 3).to(query.dtype)
+
+ mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states)
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=-1)
+ return self.to_out(hidden_states)
+
+
+class Flux2SingleTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float = 3.0,
+ eps: float = 1e-6,
+ bias: bool = False,
+ ):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.attn = Flux2ParallelSelfAttention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=bias,
+ out_bias=bias,
+ eps=eps,
+ mlp_ratio=mlp_ratio,
+ mlp_mult_factor=2,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor | None,
+ temb_mod_params: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+ image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
+ joint_attention_kwargs: dict[str, Any] | None = None,
+ split_hidden_states: bool = False,
+ text_seq_len: int | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ if encoder_hidden_states is not None:
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ mod_shift, mod_scale, mod_gate = temb_mod_params
+
+ norm_hidden_states = self.norm(hidden_states)
+ norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
+
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ hidden_states = hidden_states + mod_gate * attn_output
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ if split_hidden_states:
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
+ return encoder_hidden_states, hidden_states
+ return hidden_states
+
+
+class Flux2TransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float = 3.0,
+ eps: float = 1e-6,
+ bias: bool = False,
+ ):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+
+ self.attn = Flux2Attention(
+ query_dim=dim,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=bias,
+ added_proj_bias=bias,
+ out_bias=bias,
+ eps=eps,
+ )
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
+
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb_mod_params_img: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
+ temb_mod_params_txt: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
+ image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
+ joint_attention_kwargs: dict[str, Any] | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ joint_attention_kwargs = joint_attention_kwargs or {}
+
+ (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
+ (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
+
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa
+
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
+ norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa
+
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ attn_output = gate_msa * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+ ff_output = self.ff(norm_hidden_states)
+ hidden_states = hidden_states + gate_mlp * ff_output
+
+ context_attn_output = c_gate_msa * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+class Flux2PosEmbed(nn.Module):
+ def __init__(self, theta: int, axes_dim: list[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ cos_out = []
+ sin_out = []
+ pos = ids.float()
+ is_mps = ids.device.type == "mps"
+ is_npu = ids.device.type == "npu"
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+ for i in range(len(self.axes_dim)):
+ freqs_cis = get_1d_rotary_pos_embed(
+ self.axes_dim[i],
+ pos[..., i],
+ theta=self.theta,
+ use_real=False,
+ freqs_dtype=freqs_dtype,
+ )
+ cos_out.append(freqs_cis.real)
+ sin_out.append(freqs_cis.imag)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+
+class Flux2TimestepGuidanceEmbeddings(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 256,
+ embedding_dim: int = 6144,
+ bias: bool = False,
+ guidance_embeds: bool = True,
+ ):
+ super().__init__()
+ self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(
+ in_channels=in_channels,
+ time_embed_dim=embedding_dim,
+ sample_proj_bias=bias,
+ )
+
+ if guidance_embeds:
+ self.guidance_embedder = TimestepEmbedding(
+ in_channels=in_channels,
+ time_embed_dim=embedding_dim,
+ sample_proj_bias=bias,
+ )
+ else:
+ self.guidance_embedder = None
+
+ def forward(self, timestep: torch.Tensor, guidance: torch.Tensor | None) -> torch.Tensor:
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype))
+
+ if guidance is not None and self.guidance_embedder is not None:
+ guidance_proj = self.time_proj(guidance)
+ guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype))
+ return timesteps_emb + guidance_emb
+ return timesteps_emb
+
+
+class Flux2Modulation(nn.Module):
+ def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
+ super().__init__()
+ self.mod_param_sets = mod_param_sets
+ self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
+ self.act_fn = nn.SiLU()
+
+ def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
+ mod = self.act_fn(temb)
+ mod = self.linear(mod)
+ if mod.ndim == 2:
+ mod = mod.unsqueeze(1)
+ mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
+ return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
+
+
+class Flux2Transformer2DModel(nn.Module):
+ """
+ The Transformer model introduced in Flux 2.
+ """
+
+ _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
+
+ def __init__(
+ self,
+ patch_size: int = 1,
+ in_channels: int = 128,
+ out_channels: int | None = None,
+ num_layers: int = 8,
+ num_single_layers: int = 48,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 48,
+ joint_attention_dim: int = 15360,
+ timestep_guidance_channels: int = 256,
+ mlp_ratio: float = 3.0,
+ axes_dims_rope: tuple[int, ...] = (32, 32, 32, 32),
+ rope_theta: int = 2000,
+ eps: float = 1e-6,
+ guidance_embeds: bool = True,
+ ):
+ super().__init__()
+ self.out_channels = out_channels or in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+ self.config = SimpleNamespace(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ out_channels=self.out_channels,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ attention_head_dim=attention_head_dim,
+ num_attention_heads=num_attention_heads,
+ joint_attention_dim=joint_attention_dim,
+ timestep_guidance_channels=timestep_guidance_channels,
+ mlp_ratio=mlp_ratio,
+ axes_dims_rope=axes_dims_rope,
+ rope_theta=rope_theta,
+ eps=eps,
+ guidance_embeds=guidance_embeds,
+ )
+
+ self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=list(axes_dims_rope))
+ self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
+ in_channels=timestep_guidance_channels,
+ embedding_dim=self.inner_dim,
+ bias=False,
+ guidance_embeds=guidance_embeds,
+ )
+
+ self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
+ self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
+ self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
+
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ Flux2TransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ mlp_ratio=mlp_ratio,
+ eps=eps,
+ bias=False,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ Flux2SingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ mlp_ratio=mlp_ratio,
+ eps=eps,
+ bias=False,
+ )
+ for _ in range(num_single_layers)
+ ]
+ )
+
+ self.norm_out = AdaLayerNormContinuous(
+ self.inner_dim,
+ self.inner_dim,
+ elementwise_affine=False,
+ eps=eps,
+ bias=False,
+ )
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return next(self.parameters()).dtype
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ img_ids: torch.Tensor,
+ txt_ids: torch.Tensor,
+ guidance: torch.Tensor | None = None,
+ joint_attention_kwargs: dict[str, Any] | None = None,
+ return_dict: bool = True,
+ ) -> torch.Tensor | Transformer2DModelOutput:
+ joint_attention_kwargs = joint_attention_kwargs or {}
+
+ num_txt_tokens = encoder_hidden_states.shape[1]
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+
+ temb = self.time_guidance_embed(timestep, guidance)
+
+ double_stream_mod_img = self.double_stream_modulation_img(temb)
+ double_stream_mod_txt = self.double_stream_modulation_txt(temb)
+ single_stream_mod = self.single_stream_modulation(temb)[0]
+
+ hidden_states = self.x_embedder(hidden_states)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if img_ids.ndim == 3:
+ img_ids = img_ids[0]
+ if txt_ids.ndim == 3:
+ txt_ids = txt_ids[0]
+
+ image_rotary_emb = self.pos_embed(img_ids)
+ text_rotary_emb = self.pos_embed(txt_ids)
+ concat_rotary_emb = (
+ torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
+ torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
+ )
+
+ for block in self.transformer_blocks:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb_mod_params_img=double_stream_mod_img,
+ temb_mod_params_txt=double_stream_mod_txt,
+ image_rotary_emb=concat_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ for block in self.single_transformer_blocks:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=None,
+ temb_mod_params=single_stream_mod,
+ image_rotary_emb=concat_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ hidden_states = hidden_states[:, num_txt_tokens:, ...]
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ (".to_qkv", ".to_q", "q"),
+ (".to_qkv", ".to_k", "k"),
+ (".to_qkv", ".to_v", "v"),
+ (".add_kv_proj", ".add_q_proj", "q"),
+ (".add_kv_proj", ".add_k_proj", "k"),
+ (".add_kv_proj", ".add_v_proj", "v"),
+ ]
+
+ params_dict = dict(self.named_parameters())
+
+ for name, buffer in self.named_buffers():
+ if name.endswith(".beta") or name.endswith(".eps"):
+ params_dict[name] = buffer
+
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ if "to_qkvkv_mlp_proj" in name:
+ name = name.replace("to_qkvkv_mlp_proj", "to_qkv_mlp_proj")
+ if "to_qkv_mlp_proj" in name:
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ continue
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py
new file mode 100644
index 00000000000..ba29e681c32
--- /dev/null
+++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py
@@ -0,0 +1,963 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+#
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import math
+import os
+from collections.abc import Callable, Iterable
+from typing import Any
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn as nn
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models.autoencoders.autoencoder_kl_flux2 import AutoencoderKLFlux2
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils.torch_utils import randn_tensor
+from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM
+from vllm.logger import init_logger
+from vllm.model_executor.models.utils import AutoWeightsLoader
+
+from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
+from vllm_omni.diffusion.distributed.utils import get_local_device
+from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
+from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import (
+ Flux2Transformer2DModel,
+)
+from vllm_omni.diffusion.models.interface import SupportImageInput
+from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
+from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific
+
+logger = init_logger(__name__)
+
+
+class Flux2ImageProcessor(VaeImageProcessor):
+ """Image processor to preprocess the reference image for Flux2 klein."""
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ vae_scale_factor: int = 16,
+ vae_latent_channels: int = 32,
+ do_normalize: bool = True,
+ do_convert_rgb: bool = True,
+ ):
+ super().__init__(
+ do_resize=do_resize,
+ vae_scale_factor=vae_scale_factor,
+ vae_latent_channels=vae_latent_channels,
+ do_normalize=do_normalize,
+ do_convert_rgb=do_convert_rgb,
+ )
+
+ @staticmethod
+ def check_image_input(
+ image: PIL.Image.Image,
+ max_aspect_ratio: int = 8,
+ min_side_length: int = 64,
+ max_area: int = 1024 * 1024,
+ ) -> PIL.Image.Image:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}")
+
+ width, height = image.size
+ if width < min_side_length or height < min_side_length:
+ raise ValueError(f"Image too small: {width}x{height}. Both dimensions must be at least {min_side_length}px")
+
+ aspect_ratio = max(width / height, height / width)
+ if aspect_ratio > max_aspect_ratio:
+ raise ValueError(
+ f"Aspect ratio too extreme: {width}x{height} (ratio: {aspect_ratio:.1f}:1). "
+ f"Maximum allowed ratio is {max_aspect_ratio}:1"
+ )
+
+ if width * height > max_area:
+ logger.warning("Image area exceeds recommended maximum; resizing will be applied.")
+
+ return image
+
+ @staticmethod
+ def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image:
+ image_width, image_height = image.size
+ scale = math.sqrt(target_area / (image_width * image_height))
+ width = int(image_width * scale)
+ height = int(image_height * scale)
+ return image.resize((width, height), PIL.Image.Resampling.LANCZOS)
+
+ @staticmethod
+ def _resize_if_exceeds_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image:
+ image_width, image_height = image.size
+ if image_width * image_height <= target_area:
+ return image
+ return Flux2ImageProcessor._resize_to_target_area(image, target_area)
+
+ def _resize_and_crop(self, image: PIL.Image.Image, width: int, height: int) -> PIL.Image.Image:
+ image_width, image_height = image.size
+ left = (image_width - width) // 2
+ top = (image_height - height) // 2
+ right = left + width
+ bottom = top + height
+ return image.crop((left, top, right, bottom))
+
+ @staticmethod
+ def concatenate_images(images: list[PIL.Image.Image]) -> PIL.Image.Image:
+ if len(images) == 1:
+ return images[0].copy()
+
+ images = [img.convert("RGB") if img.mode != "RGB" else img for img in images]
+ total_width = sum(img.width for img in images)
+ max_height = max(img.height for img in images)
+ background_color = (255, 255, 255)
+ new_img = PIL.Image.new("RGB", (total_width, max_height), background_color)
+
+ x_offset = 0
+ for img in images:
+ y_offset = (max_height - img.height) // 2
+ new_img.paste(img, (x_offset, y_offset))
+ x_offset += img.width
+
+ return new_img
+
+
+def get_flux2_klein_post_process_func(
+ od_config: OmniDiffusionConfig,
+):
+ model_name = od_config.model
+ if os.path.exists(model_name):
+ model_path = model_name
+ else:
+ model_path = download_weights_from_hf_specific(model_name, None, ["*"])
+
+ vae_config_path = os.path.join(model_path, "vae/config.json")
+ with open(vae_config_path) as f:
+ vae_config = json.load(f)
+ vae_scale_factor = 2 ** (len(vae_config["block_out_channels"]) - 1) if "block_out_channels" in vae_config else 8
+
+ image_processor = Flux2ImageProcessor(vae_scale_factor=vae_scale_factor * 2)
+
+ def post_process_func(images: torch.Tensor):
+ return image_processor.postprocess(images)
+
+ return post_process_func
+
+
+# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu
+def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
+ a1, b1 = 8.73809524e-05, 1.89833333
+ a2, b2 = 0.00016927, 0.45666666
+
+ if image_seq_len > 4300:
+ mu = a2 * image_seq_len + b2
+ return float(mu)
+
+ m_200 = a2 * image_seq_len + b2
+ m_10 = a1 * image_seq_len + b1
+
+ a = (m_200 - m_10) / 190.0
+ b = m_200 - 200.0 * a
+ mu = a * num_steps + b
+
+ return float(mu)
+
+
+class Flux2KleinPipeline(nn.Module, SupportImageInput):
+ """Flux2 klein pipeline for text-to-image generation."""
+
+ support_image_input = True
+
+ def __init__(
+ self,
+ *,
+ od_config: OmniDiffusionConfig,
+ prefix: str = "",
+ is_distilled: bool = False,
+ ):
+ super().__init__()
+ self.od_config = od_config
+ self.is_distilled = is_distilled
+ self.weights_sources = [
+ DiffusersPipelineLoader.ComponentSource(
+ model_or_path=od_config.model,
+ subfolder="transformer",
+ revision=None,
+ prefix="transformer.",
+ fall_back_to_pt=True,
+ )
+ ]
+
+ self._execution_device = get_local_device()
+ model = od_config.model
+ local_files_only = os.path.exists(model)
+
+ self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ model,
+ subfolder="scheduler",
+ local_files_only=local_files_only,
+ )
+ self.text_encoder = Qwen3ForCausalLM.from_pretrained(
+ model,
+ subfolder="text_encoder",
+ local_files_only=local_files_only,
+ )
+ self.tokenizer = Qwen2TokenizerFast.from_pretrained(
+ model,
+ subfolder="tokenizer",
+ local_files_only=local_files_only,
+ )
+ self.vae = AutoencoderKLFlux2.from_pretrained(
+ model,
+ subfolder="vae",
+ local_files_only=local_files_only,
+ ).to(self._execution_device)
+
+ transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, Flux2Transformer2DModel)
+ self.transformer = Flux2Transformer2DModel(**transformer_kwargs)
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = 512
+ self.default_sample_size = 128
+
+ self._guidance_scale = None
+ self._attention_kwargs = None
+ self._num_timesteps = None
+ self._current_timestep = None
+ self._interrupt = False
+
+ @staticmethod
+ def _get_qwen3_prompt_embeds(
+ text_encoder: Qwen3ForCausalLM,
+ tokenizer: Qwen2TokenizerFast,
+ prompt: str | list[str],
+ dtype: torch.dtype | None = None,
+ device: torch.device | None = None,
+ max_sequence_length: int = 512,
+ hidden_states_layers: list[int] = (9, 18, 27),
+ ):
+ dtype = text_encoder.dtype if dtype is None else dtype
+ device = text_encoder.device if device is None else device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ all_input_ids = []
+ all_attention_masks = []
+
+ for single_prompt in prompt:
+ messages = [{"role": "user", "content": single_prompt}]
+ text = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=False,
+ )
+ inputs = tokenizer(
+ text,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=max_sequence_length,
+ )
+
+ all_input_ids.append(inputs["input_ids"])
+ all_attention_masks.append(inputs["attention_mask"])
+
+ input_ids = torch.cat(all_input_ids, dim=0).to(device)
+ attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
+
+ # Forward pass through the model
+ output = text_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ use_cache=False,
+ )
+
+ # Only use outputs from intermediate layers and stack them
+ out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
+ out = out.to(dtype=dtype, device=device)
+
+ batch_size, num_channels, seq_len, hidden_dim = out.shape
+ prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
+
+ return prompt_embeds
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids
+ def _prepare_text_ids(
+ x: torch.Tensor, # (B, L, D) or (L, D)
+ t_coord: torch.Tensor | None = None,
+ ):
+ B, L, _ = x.shape
+ out_ids = []
+
+ for i in range(B):
+ t = torch.arange(1) if t_coord is None else t_coord[i]
+ h = torch.arange(1)
+ w = torch.arange(1)
+ seq_positions = torch.arange(L)
+
+ coords = torch.cartesian_prod(t, h, w, seq_positions)
+ out_ids.append(coords)
+
+ return torch.stack(out_ids)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids
+ def _prepare_latent_ids(
+ latents: torch.Tensor, # (B, C, H, W)
+ ):
+ r"""
+ Generates 4D position coordinates (T, H, W, L) for latent tensors.
+
+ Args:
+ latents (torch.Tensor):
+ Latent tensor of shape (B, C, H, W)
+
+ Returns:
+ torch.Tensor:
+ Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
+ H=[0..H-1], W=[0..W-1], L=0
+ """
+
+ batch_size, _, height, width = latents.shape
+
+ t = torch.arange(1) # [0] - time dimension
+ h = torch.arange(height)
+ w = torch.arange(width)
+ layer_ids = torch.arange(1) # [0] - layer dimension
+
+ # Create position IDs: (H*W, 4)
+ latent_ids = torch.cartesian_prod(t, h, w, layer_ids)
+
+ # Expand to batch: (B, H*W, 4)
+ latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
+
+ return latent_ids
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids
+ def _prepare_image_ids(
+ image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
+ scale: int = 10,
+ ):
+ r"""
+ Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
+
+ This function creates a unique coordinate for every pixel/patch across all input latent with different
+ dimensions.
+
+ Args:
+ image_latents (List[torch.Tensor]):
+ A list of image latent feature tensors, typically of shape (C, H, W).
+ scale (int, optional):
+ A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
+ latent is: 'scale + scale * i'. Defaults to 10.
+
+ Returns:
+ torch.Tensor:
+ The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
+ input latents.
+
+ Coordinate Components (Dimension 4):
+ - T (Time): The unique index indicating which latent image the coordinate belongs to.
+ - H (Height): The row index within that latent image.
+ - W (Width): The column index within that latent image.
+ - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
+ """
+
+ if not isinstance(image_latents, list):
+ raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
+
+ # create time offset for each reference image
+ t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
+ t_coords = [t.view(-1) for t in t_coords]
+
+ image_latent_ids = []
+ for x, t in zip(image_latents, t_coords):
+ x = x.squeeze(0)
+ _, height, width = x.shape
+
+ x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
+ image_latent_ids.append(x_ids)
+
+ image_latent_ids = torch.cat(image_latent_ids, dim=0)
+ image_latent_ids = image_latent_ids.unsqueeze(0)
+
+ return image_latent_ids
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents
+ def _patchify_latents(latents):
+ batch_size, num_channels_latents, height, width = latents.shape
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 1, 3, 5, 2, 4)
+ latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents
+ def _unpatchify_latents(latents):
+ batch_size, num_channels_latents, height, width = latents.shape
+ latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
+ latents = latents.permute(0, 1, 4, 2, 5, 3)
+ latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents
+ def _pack_latents(latents):
+ """
+ pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
+ """
+
+ batch_size, num_channels, height, width = latents.shape
+ latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
+ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
+ """
+ using position ids to scatter tokens into place
+ """
+ x_list = []
+ for data, pos in zip(x, x_ids):
+ _, ch = data.shape # noqa: F841
+ h_ids = pos[:, 1].to(torch.int64)
+ w_ids = pos[:, 2].to(torch.int64)
+
+ h = torch.max(h_ids) + 1
+ w = torch.max(w_ids) + 1
+
+ flat_ids = h_ids * w + w_ids
+
+ out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
+ out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
+
+ # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
+
+ out = out.view(h, w, ch).permute(2, 0, 1)
+ x_list.append(out)
+
+ return torch.stack(x_list, dim=0)
+
+ def encode_prompt(
+ self,
+ prompt: str | list[str],
+ device: torch.device | None = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: torch.Tensor | None = None,
+ max_sequence_length: int = 512,
+ text_encoder_out_layers: tuple[int, ...] = (9, 18, 27),
+ ):
+ device = device or self._execution_device
+
+ if prompt is None:
+ prompt = ""
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_qwen3_prompt_embeds(
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ prompt=prompt,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ hidden_states_layers=text_encoder_out_layers,
+ )
+
+ batch_size, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ text_ids = self._prepare_text_ids(prompt_embeds)
+ text_ids = text_ids.to(device)
+ return prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if image.ndim != 4:
+ raise ValueError(f"Expected image dims 4, got {image.ndim}.")
+
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
+ image_latents = self._patchify_latents(image_latents)
+
+ latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
+ latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
+ image_latents = (image_latents - latents_bn_mean) / latents_bn_std
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_latents_channels,
+ height,
+ width,
+ dtype,
+ device,
+ generator: torch.Generator,
+ latents: torch.Tensor | None = None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ latent_ids = self._prepare_latent_ids(latents)
+ latent_ids = latent_ids.to(device)
+
+ latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
+ return latents, latent_ids
+
+ # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents
+ def prepare_image_latents(
+ self,
+ images: list[torch.Tensor],
+ batch_size,
+ generator: torch.Generator,
+ device,
+ dtype,
+ ):
+ image_latents = []
+ for image in images:
+ image = image.to(device=device, dtype=dtype)
+ imagge_latent = self._encode_vae_image(image=image, generator=generator)
+ image_latents.append(imagge_latent) # (1, 128, 32, 32)
+
+ image_latent_ids = self._prepare_image_ids(image_latents)
+
+ # Pack each latent and concatenate
+ packed_latents = []
+ for latent in image_latents:
+ # latent: (1, 128, 32, 32)
+ packed = self._pack_latents(latent) # (1, 1024, 128)
+ packed = packed.squeeze(0) # (1024, 128) - remove batch dim
+ packed_latents.append(packed)
+
+ # Concatenate all reference tokens along sequence dimension
+ image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
+ image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
+
+ image_latents = image_latents.repeat(batch_size, 1, 1)
+ image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
+ image_latent_ids = image_latent_ids.to(device)
+
+ return image_latents, image_latent_ids
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ guidance_scale=None,
+ ):
+ if (
+ height is not None
+ and height % (self.vae_scale_factor * 2) != 0
+ or width is not None
+ and width % (self.vae_scale_factor * 2) != 0
+ ):
+ logger.warning(
+ "`height` and `width` have to be divisible by %s but are %s and %s. "
+ "Dimensions will be resized accordingly",
+ self.vae_scale_factor * 2,
+ height,
+ width,
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in ["latents", "prompt_embeds"] for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError("`callback_on_step_end_tensor_inputs` must be a subset of ['latents', 'prompt_embeds'].")
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if guidance_scale > 1.0 and self.is_distilled:
+ logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.")
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale is not None and self._guidance_scale > 1 and not self.is_distilled
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ def forward(
+ self,
+ req: OmniDiffusionRequest,
+ image: PIL.Image.Image | list[PIL.Image.Image] | None = None,
+ prompt: str | list[str] | None = None,
+ height: int | None = None,
+ width: int | None = None,
+ num_inference_steps: int = 50,
+ sigmas: list[float] | None = None,
+ guidance_scale: float | None = 4.0,
+ num_images_per_prompt: int = 1,
+ generator: torch.Generator | list[torch.Generator] | None = None,
+ latents: torch.Tensor | None = None,
+ prompt_embeds: torch.Tensor | None = None,
+ negative_prompt_embeds: torch.Tensor | None = None,
+ output_type: str | None = "pil",
+ return_dict: bool = True,
+ attention_kwargs: dict[str, Any] | None = None,
+ callback_on_step_end: Callable[[int, int, dict], None] | None = None,
+ callback_on_step_end_tensor_inputs: list[str] = ["latents"],
+ max_sequence_length: int = 512,
+ text_encoder_out_layers: tuple[int, ...] = (9, 18, 27),
+ ) -> DiffusionOutput:
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or list of these):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality. For step-wise distilled models,
+ `guidance_scale` is ignored.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Note that "" is used as the negative prompt in this pipeline.
+ If not provided, will be generated from "".
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+ text_encoder_out_layers (`Tuple[int]`):
+ Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ prompt = req.prompt if req.prompt is not None else prompt
+ image = req.pil_image if req.pil_image is not None else image
+ height = req.height or height
+ width = req.width or width
+ num_inference_steps = req.num_inference_steps or num_inference_steps
+ guidance_scale = req.guidance_scale if req.guidance_scale is not None else guidance_scale
+ generator = req.generator or generator
+ req_num_outputs = getattr(req, "num_outputs_per_prompt", None)
+ if req_num_outputs and req_num_outputs > 0:
+ num_images_per_prompt = req_num_outputs
+
+ if isinstance(req.prompt_embeds, torch.Tensor):
+ prompt_embeds = req.prompt_embeds
+ if isinstance(req.negative_prompt_embeds, torch.Tensor):
+ negative_prompt_embeds = req.negative_prompt_embeds
+
+ if req.max_sequence_length is not None:
+ max_sequence_length = req.max_sequence_length
+ if getattr(req, "text_encoder_out_layers", None) is not None:
+ text_encoder_out_layers = req.text_encoder_out_layers
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ prompt_embeds=prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ guidance_scale=guidance_scale,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. prepare text embeddings
+ prompt_embeds, text_ids = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ text_encoder_out_layers=text_encoder_out_layers,
+ )
+
+ if self.do_classifier_free_guidance:
+ negative_prompt = ""
+ if prompt is not None and isinstance(prompt, list):
+ negative_prompt = [negative_prompt] * len(prompt)
+ negative_prompt_embeds, negative_text_ids = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ text_encoder_out_layers=text_encoder_out_layers,
+ )
+
+ # 4. process images
+ if image is not None and not isinstance(image, list):
+ image = [image]
+
+ condition_images = None
+ if image is not None:
+ for img in image:
+ self.image_processor.check_image_input(img)
+
+ condition_images = []
+ for img in image:
+ image_width, image_height = img.size
+ if image_width * image_height > 1024 * 1024:
+ img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
+ image_width, image_height = img.size
+
+ multiple_of = self.vae_scale_factor * 2
+ image_width = (image_width // multiple_of) * multiple_of
+ image_height = (image_height // multiple_of) * multiple_of
+ img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
+ condition_images.append(img)
+ height = height or image_height
+ width = width or image_width
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 5. prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_ids = self.prepare_latents(
+ batch_size=batch_size * num_images_per_prompt,
+ num_latents_channels=num_channels_latents,
+ height=height,
+ width=width,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ image_latents = None
+ image_latent_ids = None
+ if condition_images is not None:
+ image_latents, image_latent_ids = self.prepare_image_latents(
+ images=condition_images,
+ batch_size=batch_size * num_images_per_prompt,
+ generator=generator,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+
+ # 6. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
+ sigmas = None
+ image_seq_len = latents.shape[1]
+ mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ self._num_timesteps = len(timesteps)
+
+ # 7. Denoising loop
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
+ self.scheduler.set_begin_index(0)
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ latent_model_input = latents.to(self.transformer.dtype)
+ latent_image_ids = latent_ids
+
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
+ latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=None,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ noise_pred = noise_pred[:, : latents.size(1) :]
+
+ if self.do_classifier_free_guidance:
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=None,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1) :]
+ noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
+
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype and torch.backends.mps.is_available():
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ self._current_timestep = None
+
+ latents = self._unpack_latents_with_ids(latents, latent_ids)
+
+ latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
+ latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
+ latents.device, latents.dtype
+ )
+ latents = latents * latents_bn_std + latents_bn_mean
+ latents = self._unpatchify_latents(latents)
+ if output_type == "latent":
+ image = latents
+ else:
+ if latents.dtype != self.vae.dtype:
+ latents = latents.to(self.vae.dtype)
+ image = self.vae.decode(latents, return_dict=False)[0]
+
+ return DiffusionOutput(output=image)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights)
diff --git a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py
index 09f7b17e133..615b9194af4 100644
--- a/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py
+++ b/vllm_omni/diffusion/models/glm_image/glm_image_transformer.py
@@ -8,8 +8,8 @@
import torch
import torch.nn as nn
from diffusers.models.attention import FeedForward
-from diffusers.models.transformers.transformer_glm_image import GlmImageCombinedTimestepSizeEmbeddings
from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.transformers.transformer_glm_image import GlmImageCombinedTimestepSizeEmbeddings
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -414,9 +414,10 @@ def forward(
query_img = query[:, text_seq_length:, :, :]
key_img = key[:, text_seq_length:, :, :]
from diffusers.models.embeddings import apply_rotary_emb
- query_img = apply_rotary_emb(query_img,image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
+
+ query_img = apply_rotary_emb(query_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
# key_img = self.rope(key_img, cos, sin)
- key_img = apply_rotary_emb(key_img,image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
+ key_img = apply_rotary_emb(key_img, image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2)
query = torch.cat([query[:, :text_seq_length, :, :], query_img], dim=1)
key = torch.cat([key[:, :text_seq_length, :, :], key_img], dim=1)
@@ -555,7 +556,7 @@ def __init__(
od_config: OmniDiffusionConfig,
):
super().__init__()
-
+
patch_size = od_config.tf_model_config.patch_size
in_channels = od_config.tf_model_config.in_channels
out_channels = od_config.tf_model_config.out_channels
@@ -565,8 +566,6 @@ def __init__(
condition_dim = od_config.tf_model_config.condition_dim
prior_vq_quantizer_codebook_size = od_config.tf_model_config.prior_vq_quantizer_codebook_size
text_embed_dim = od_config.tf_model_config.text_embed_dim
-
-
# Get num_layers from config if available
model_config = od_config.tf_model_config
diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py
index 6401b526a5f..87c7c2f73c2 100644
--- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py
+++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py
@@ -36,6 +36,7 @@
QwenImageTransformer2DModel,
)
from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.model_executor.model_loader.weight_utils import (
download_weights_from_hf_specific,
)
@@ -274,7 +275,8 @@ def __init__(
self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
self.device
)
- self.transformer = QwenImageTransformer2DModel(od_config=od_config)
+ transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel)
+ self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs)
self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py
index 53609c87757..901625e34b6 100644
--- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py
+++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit.py
@@ -38,6 +38,7 @@
QwenImageTransformer2DModel,
)
from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.model_executor.model_loader.weight_utils import (
download_weights_from_hf_specific,
)
@@ -231,7 +232,8 @@ def __init__(
self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
self.device
)
- self.transformer = QwenImageTransformer2DModel(od_config=od_config)
+ transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel)
+ self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs)
self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
self.processor = Qwen2VLProcessor.from_pretrained(
model, subfolder="processor", local_files_only=local_files_only
diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py
index 80e48112c71..6df502d48f0 100644
--- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py
+++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py
@@ -41,6 +41,7 @@
QwenImageTransformer2DModel,
)
from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.model_executor.model_loader.weight_utils import (
download_weights_from_hf_specific,
)
@@ -191,7 +192,9 @@ def __init__(
self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
self.device
)
- self.transformer = QwenImageTransformer2DModel(od_config=od_config)
+
+ transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel)
+ self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs)
self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
self.processor = Qwen2VLProcessor.from_pretrained(
model, subfolder="processor", local_files_only=local_files_only
diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py
index eb5e2aa987a..4642b3eb418 100644
--- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py
+++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_layered.py
@@ -37,6 +37,7 @@
QwenImageTransformer2DModel,
)
from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs
from vllm_omni.model_executor.model_loader.weight_utils import (
download_weights_from_hf_specific,
)
@@ -211,18 +212,8 @@ def __init__(
)
]
- use_additional_t_cond = od_config.tf_model_config.use_additional_t_cond
- zero_cond_t = od_config.tf_model_config.zero_cond_t
- use_layer3d_rope = od_config.tf_model_config.use_layer3d_rope
- guidance_embeds = od_config.tf_model_config.guidance_embeds
-
- self.transformer = QwenImageTransformer2DModel(
- od_config=od_config,
- use_additional_t_cond=use_additional_t_cond,
- zero_cond_t=zero_cond_t,
- use_layer3d_rope=use_layer3d_rope,
- guidance_embeds=guidance_embeds,
- )
+ transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel)
+ self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs)
# Pipeline configuration & processing parameters
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py
index 2aa9c29104d..cbf0b7e10ac 100644
--- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py
+++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py
@@ -8,7 +8,7 @@
import torch
import torch.nn as nn
-from diffusers.models.attention import FeedForward
+import torch.nn.functional as F
# TODO replace this with vLLM implementation
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
@@ -16,14 +16,18 @@
from diffusers.models.normalization import AdaLayerNormContinuous
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import QKVParallelLinear, ReplicatedLinear
+from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm_omni.diffusion.attention.backends.abstract import (
AttentionMetadata,
)
from vllm_omni.diffusion.attention.layer import Attention
-from vllm_omni.diffusion.attention.selector import get_attn_backend
+from vllm_omni.diffusion.attention.selector import _BACKENDS_SUPPORT_ATTENTION_MASK, get_attn_backend
from vllm_omni.diffusion.cache.base import CachedTransformer
from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.distributed.parallel_state import (
@@ -287,79 +291,133 @@ def _compute_video_freqs(self, frame, height, width, idx=0):
return freqs.clone().contiguous()
+class ColumnParallelApproxGELU(nn.Module):
+ def __init__(self, dim_in: int, dim_out: int, *, approximate: str, bias: bool = True):
+ super().__init__()
+ self.proj = ColumnParallelLinear(
+ dim_in,
+ dim_out,
+ bias=bias,
+ gather_output=False,
+ return_bias=False,
+ )
+ self.approximate = approximate
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ return F.gelu(x, approximate=self.approximate)
+
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int | None = None,
+ mult: int = 4,
+ activation_fn: str = "gelu-approximate",
+ inner_dim: int | None = None,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ assert activation_fn == "gelu-approximate", "Only gelu-approximate is supported."
+
+ inner_dim = inner_dim or int(dim * mult)
+ dim_out = dim_out or dim
+
+ layers: list[nn.Module] = [
+ ColumnParallelApproxGELU(dim, inner_dim, approximate="tanh", bias=bias),
+ nn.Identity(), # placeholder for weight loading
+ RowParallelLinear(
+ inner_dim,
+ dim_out,
+ input_is_parallel=True,
+ return_bias=False,
+ ),
+ ]
+
+ self.net = nn.ModuleList(layers)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
class QwenImageCrossAttention(nn.Module):
def __init__(
self,
dim: int, # query_dim
num_heads: int,
head_dim: int,
- window_size=(-1, -1),
- added_kv_proj_dim: int = None,
+ added_kv_proj_dim: int,
+ window_size: tuple[int, int] = (-1, -1),
out_bias: bool = True,
- qk_norm=True, # rmsnorm
- eps=1e-6,
- pre_only=False,
+ qk_norm: bool = True,
+ eps: float = 1e-6,
+ pre_only: bool = False,
context_pre_only: bool = False,
- parallel_attention=False,
- out_dim: int = None,
+ out_dim: int | None = None,
) -> None:
- assert dim % num_heads == 0
super().__init__()
+ assert dim % num_heads == 0
+
self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
+ self.head_dim = head_dim
+ self.total_num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
- self.parallel_attention = parallel_attention
- # layers
- # self.to_q = ReplicatedLinear(dim, dim)
- # self.to_k = ReplicatedLinear(dim, dim)
- # self.to_v = ReplicatedLinear(dim, dim)
self.to_qkv = QKVParallelLinear(
hidden_size=dim,
head_size=self.head_dim,
total_num_heads=num_heads,
- disable_tp=True,
)
+ self.query_num_heads = self.to_qkv.num_heads
+ self.kv_num_heads = self.to_qkv.num_kv_heads
+
self.norm_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
- self.inner_dim = out_dim if out_dim is not None else head_dim * num_heads
- self.inner_kv_dim = self.inner_dim
- if added_kv_proj_dim is not None:
- assert context_pre_only is not None
- # self.add_k_proj = ReplicatedLinear(added_kv_proj_dim, self.inner_kv_dim, bias=True)
- # self.add_v_proj = ReplicatedLinear(added_kv_proj_dim, self.inner_kv_dim, bias=True)
- # self.add_q_proj = ReplicatedLinear(
- # added_kv_proj_dim, self.inner_dim, bias=True
- # )
- self.add_kv_proj = QKVParallelLinear(
- added_kv_proj_dim,
- head_size=self.inner_kv_dim // self.num_heads,
- total_num_heads=self.num_heads,
- disable_tp=True,
- )
- if context_pre_only is not None and not context_pre_only:
- self.to_add_out = ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias)
- else:
- self.to_add_out = None
+ self.inner_dim = out_dim if out_dim is not None else head_dim * self.total_num_heads
- if not pre_only:
- self.to_out = nn.ModuleList([])
- self.to_out.append(ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias))
- else:
- self.to_out = None
+ assert context_pre_only is not None
+ self.add_kv_proj = QKVParallelLinear(
+ hidden_size=added_kv_proj_dim,
+ head_size=head_dim,
+ total_num_heads=num_heads,
+ )
+ self.add_query_num_heads = self.add_kv_proj.num_heads
+ self.add_kv_num_heads = self.add_kv_proj.num_kv_heads
+
+ assert not context_pre_only
+ self.to_add_out = RowParallelLinear(
+ self.inner_dim,
+ self.dim,
+ bias=out_bias,
+ input_is_parallel=True,
+ return_bias=False,
+ )
+
+ assert not pre_only
+ self.to_out = RowParallelLinear(
+ self.inner_dim,
+ self.dim,
+ bias=out_bias,
+ input_is_parallel=True,
+ return_bias=False,
+ )
self.norm_added_q = RMSNorm(head_dim, eps=eps)
self.norm_added_k = RMSNorm(head_dim, eps=eps)
self.attn = Attention(
- num_heads=num_heads,
+ num_heads=self.query_num_heads,
head_size=self.head_dim,
softmax_scale=1.0 / (self.head_dim**0.5),
causal=False,
+ num_kv_heads=self.kv_num_heads,
)
self.rope = RotaryEmbedding(is_neox_style=False)
@@ -377,61 +435,55 @@ def forward(
txt_freqs: torch.Tensor,
hidden_states_mask: torch.Tensor | None = None,
encoder_hidden_states_mask: torch.Tensor | None = None,
- ):
- # if mask is all true, set it to None
+ ) -> tuple[torch.Tensor, torch.Tensor]:
if hidden_states_mask is not None and hidden_states_mask.all():
hidden_states_mask = None
if encoder_hidden_states_mask is not None and encoder_hidden_states_mask.all():
encoder_hidden_states_mask = None
- seq_len_txt = encoder_hidden_states.shape[1]
- # Compute QKV for image stream (sample projections)
- qkv, _ = self.to_qkv(hidden_states)
- img_query, img_key, img_value = qkv.chunk(3, dim=-1)
+ img_qkv, _ = self.to_qkv(hidden_states)
+ q_size = self.query_num_heads * self.head_dim
+ kv_size = self.kv_num_heads * self.head_dim
+ img_query, img_key, img_value = img_qkv.split([q_size, kv_size, kv_size], dim=-1)
- # Compute QKV for text stream (context projections)
- qkv, _ = self.add_kv_proj(encoder_hidden_states)
- txt_query, txt_key, txt_value = qkv.chunk(3, dim=-1)
+ txt_qkv, _ = self.add_kv_proj(encoder_hidden_states)
+ add_q_size = self.add_query_num_heads * self.head_dim
+ add_kv_size = self.add_kv_num_heads * self.head_dim
+ txt_query, txt_key, txt_value = txt_qkv.split([add_q_size, add_kv_size, add_kv_size], dim=-1)
- # Reshape for multi-head attention
- img_query = img_query.unflatten(-1, (self.num_heads, -1))
- img_key = img_key.unflatten(-1, (self.num_heads, -1))
- img_value = img_value.unflatten(-1, (self.num_heads, -1))
+ img_query = img_query.unflatten(-1, (self.query_num_heads, self.head_dim))
+ img_key = img_key.unflatten(-1, (self.kv_num_heads, self.head_dim))
+ img_value = img_value.unflatten(-1, (self.kv_num_heads, self.head_dim))
- txt_query = txt_query.unflatten(-1, (self.num_heads, -1))
- txt_key = txt_key.unflatten(-1, (self.num_heads, -1))
- txt_value = txt_value.unflatten(-1, (self.num_heads, -1))
+ txt_query = txt_query.unflatten(-1, (self.add_query_num_heads, self.head_dim))
+ txt_key = txt_key.unflatten(-1, (self.add_kv_num_heads, self.head_dim))
+ txt_value = txt_value.unflatten(-1, (self.add_kv_num_heads, self.head_dim))
- # Apply QK normalization
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
- # Apply RoPE
img_cos = vid_freqs.real.to(img_query.dtype)
img_sin = vid_freqs.imag.to(img_query.dtype)
txt_cos = txt_freqs.real.to(txt_query.dtype)
txt_sin = txt_freqs.imag.to(txt_query.dtype)
+
img_query = self.rope(img_query, img_cos, img_sin)
img_key = self.rope(img_key, img_cos, img_sin)
txt_query = self.rope(txt_query, txt_cos, txt_sin)
txt_key = self.rope(txt_key, txt_cos, txt_sin)
- # Concatenate for joint attention
- # Order: [text, image]
+ seq_len_txt = encoder_hidden_states.shape[1]
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
- # Compute joint attention
if (
self.parallel_config is not None
and self.parallel_config.sequence_parallel_size > 1
and not get_forward_context().split_text_embed_in_sp
):
- # if using sequence parallel, but not splitting text embed,
- # we need to pass text embedding to attention layer as joint qkv
attn_metadata = AttentionMetadata(
joint_query=txt_query,
joint_key=txt_key,
@@ -443,22 +495,17 @@ def forward(
if encoder_hidden_states_mask is not None:
attn_metadata.joint_attn_mask = encoder_hidden_states_mask
- joint_hidden_states = self.attn(
- img_query,
- img_key,
- img_value,
- attn_metadata,
- )
+ joint_hidden_states = self.attn(img_query, img_key, img_value, attn_metadata)
else:
attn_metadata = None
if hidden_states_mask is not None or encoder_hidden_states_mask is not None:
- mask_list = []
+ mask_list: list[torch.Tensor] = []
if encoder_hidden_states_mask is not None:
mask_list.append(encoder_hidden_states_mask)
else:
mask_list.append(
torch.ones(
- [encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]],
+ encoder_hidden_states.shape[:2],
dtype=torch.bool,
device=encoder_hidden_states.device,
)
@@ -468,34 +515,22 @@ def forward(
else:
mask_list.append(
torch.ones(
- [hidden_states.shape[0], hidden_states.shape[1]],
+ hidden_states.shape[:2],
dtype=torch.bool,
device=hidden_states.device,
)
)
- joint_mask = (
- None if len(mask_list) == 0 else torch.cat(mask_list, dim=1) if len(mask_list) > 1 else mask_list[0]
- )
+ joint_mask = torch.cat(mask_list, dim=1) if len(mask_list) > 1 else mask_list[0]
attn_metadata = AttentionMetadata(attn_mask=joint_mask)
- joint_hidden_states = self.attn(
- joint_query,
- joint_key,
- joint_value,
- attn_metadata,
- )
- joint_hidden_states = joint_hidden_states.flatten(2, 3)
- joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
- # Split attention outputs back
- txt_attn_output = joint_hidden_states[:, :seq_len_txt, :] # Text part
- img_attn_output = joint_hidden_states[:, seq_len_txt:, :] # Image part
+ joint_hidden_states = self.attn(joint_query, joint_key, joint_value, attn_metadata)
- # Apply output projections
- img_attn_output, _ = self.to_out[0](img_attn_output)
- if len(self.to_out) > 1:
- (img_attn_output,) = self.to_out[1](img_attn_output) # dropout
+ joint_hidden_states = joint_hidden_states.flatten(2, 3).to(joint_query.dtype)
+ txt_attn_output = joint_hidden_states[:, :seq_len_txt, :]
+ img_attn_output = joint_hidden_states[:, seq_len_txt:, :]
- txt_attn_output, _ = self.to_add_out(txt_attn_output)
+ img_attn_output = self.to_out(img_attn_output)
+ txt_attn_output = self.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output
@@ -530,7 +565,7 @@ def __init__(
head_dim=attention_head_dim,
)
self.img_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps)
- self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+ self.img_mlp = FeedForward(dim=dim, dim_out=dim)
# Text processing modules
self.txt_mod = nn.Sequential(
@@ -540,7 +575,7 @@ def __init__(
self.txt_norm1 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps)
# Text doesn't need separate attention - it's handled by img_attn joint computation
self.txt_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps)
- self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+ self.txt_mlp = FeedForward(dim=dim, dim_out=dim)
self.zero_cond_t = zero_cond_t
@@ -604,7 +639,7 @@ def forward(
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
# Process image stream - norm1 + modulation
- img_modulated, img_gate1 = self.img_norm1(hidden_states, img_mod1)
+ img_modulated, img_gate1 = self.img_norm1(hidden_states, img_mod1, modulate_index)
# Process text stream - norm1 + modulation
txt_modulated, txt_gate1 = self.txt_norm1(encoder_hidden_states, txt_mod1)
@@ -632,7 +667,8 @@ def forward(
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
# Process image stream - norm2 + MLP
- img_modulated2, img_gate2 = self.img_norm2(hidden_states, img_mod2)
+ img_modulated2, img_gate2 = self.img_norm2(hidden_states, img_mod2, modulate_index)
+
img_mlp_output = self.img_mlp(img_modulated2)
hidden_states = hidden_states + img_gate2 * img_mlp_output
@@ -692,15 +728,13 @@ def __init__(
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 3584,
- guidance_embeds: bool = False, # TODO: this should probably be removed
+ guidance_embeds: bool = False,
axes_dims_rope: tuple[int, int, int] = (16, 56, 56),
zero_cond_t: bool = False,
use_additional_t_cond: bool = False,
use_layer3d_rope: bool = False,
):
super().__init__()
- model_config = od_config.tf_model_config
- num_layers = model_config.num_layers
self.parallel_config = od_config.parallel_config
self.in_channels = in_channels
self.out_channels = out_channels or in_channels
@@ -791,8 +825,8 @@ def forward(
sp_size = get_sequence_parallel_world_size()
if seq_len % sp_size != 0:
- # flash_attn, ring_attn, sage_attn do not support attention_mask
- if get_attn_backend(-1).get_name() != "SDPA" and get_attn_backend(-1).get_name() != "ASCEND":
+ # ring_attn, sage_attn do not support attention_mask
+ if get_attn_backend(-1).get_name() not in _BACKENDS_SUPPORT_ATTENTION_MASK:
raise ValueError(
f"When generating image shape that the sequence length is NOT divisible by sp_size={sp_size},"
f"cannot use {get_attn_backend(-1).get_name()} which does not support attention_mask."
@@ -893,6 +927,8 @@ def get_rotary_emb_chunk(freqs, padding=0):
if original_seq_len is not None:
output = output[:, :original_seq_len, :]
+ torch.cuda.empty_cache()
+
return Transformer2DModelOutput(sample=output)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
@@ -917,17 +953,22 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loaded_params: set[str] = set()
for name, loaded_weight in weights:
+ original_name = name
+ lookup_name = name
for param_name, weight_name, shard_id in stacked_params_mapping:
- if weight_name not in name:
+ if weight_name not in original_name:
continue
- name = name.replace(weight_name, param_name)
- param = params_dict[name]
+ lookup_name = original_name.replace(weight_name, param_name)
+ param = params_dict[lookup_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
- param = params_dict[name]
+ if lookup_name not in params_dict and ".to_out.0." in lookup_name:
+ lookup_name = lookup_name.replace(".to_out.0.", ".to_out.")
+ param = params_dict[lookup_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
- loaded_params.add(name)
+ loaded_params.add(original_name)
+ loaded_params.add(lookup_name)
return loaded_params
diff --git a/vllm_omni/diffusion/models/schedulers/__init__.py b/vllm_omni/diffusion/models/schedulers/__init__.py
new file mode 100644
index 00000000000..6f8df78ebf0
--- /dev/null
+++ b/vllm_omni/diffusion/models/schedulers/__init__.py
@@ -0,0 +1,10 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from vllm_omni.diffusion.models.schedulers.scheduling_flow_unipc_multistep import (
+ FlowUniPCMultistepScheduler,
+)
+
+__all__ = [
+ "FlowUniPCMultistepScheduler",
+]
diff --git a/vllm_omni/diffusion/models/schedulers/base.py b/vllm_omni/diffusion/models/schedulers/base.py
new file mode 100644
index 00000000000..bc9d87d7f55
--- /dev/null
+++ b/vllm_omni/diffusion/models/schedulers/base.py
@@ -0,0 +1,48 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# Adapted from https://github.com/hao-ai-lab/FastVideo
+# Originally from https://github.com/huggingface/diffusers
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+"""Base scheduler class for diffusion models."""
+
+from abc import ABC, abstractmethod
+
+import torch
+
+
+class BaseScheduler(ABC):
+ """
+ Abstract base class for schedulers.
+
+ Subclasses must define:
+ - timesteps: torch.Tensor
+ - order: int
+ - num_train_timesteps: int
+ """
+
+ timesteps: torch.Tensor
+ order: int
+ num_train_timesteps: int
+
+ def __init__(self):
+ required_attrs = ["timesteps", "order", "num_train_timesteps"]
+ for attr in required_attrs:
+ if not hasattr(self, attr):
+ raise AttributeError(
+ f"Subclass {self.__class__.__name__} must define `{attr}` before calling super().__init__()"
+ )
+
+ @abstractmethod
+ def set_shift(self, shift: float) -> None:
+ """Set the shift parameter for the scheduler."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def set_timesteps(self, *args, **kwargs) -> None:
+ """Set the timesteps for the scheduler."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def scale_model_input(self, sample: torch.Tensor, timestep: int | None = None) -> torch.Tensor:
+ """Scale the model input."""
+ raise NotImplementedError
diff --git a/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py b/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py
new file mode 100644
index 00000000000..3efe564bc61
--- /dev/null
+++ b/vllm_omni/diffusion/models/schedulers/scheduling_flow_unipc_multistep.py
@@ -0,0 +1,741 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# Adapted from https://github.com/hao-ai-lab/FastVideo
+# Originally from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
+# Convert unipc for flow matching
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+"""
+FlowUniPCMultistepScheduler - A training-free framework for fast sampling of flow-matching diffusion models.
+
+This scheduler implements the UniPC (Unified Predictor-Corrector) algorithm adapted for flow matching,
+providing faster convergence than simple Euler methods while maintaining quality.
+"""
+
+from __future__ import annotations
+
+import math
+from typing import Any
+
+import numpy as np
+import torch
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
+from diffusers.utils import deprecate
+
+from vllm_omni.diffusion.models.schedulers.base import BaseScheduler
+
+
+class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin, BaseScheduler):
+ """
+ `FlowUniPCMultistepScheduler` is a training-free framework designed for the fast sampling of
+ flow-matching diffusion models.
+
+ This scheduler implements the UniPC (Unified Predictor-Corrector) algorithm adapted for flow matching,
+ which can achieve the same quality as Euler methods in fewer steps (typically 20-30 steps vs 40-50).
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ solver_order (`int`, default `2`):
+ The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
+ due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
+ unconditional sampling.
+ prediction_type (`str`, defaults to "flow_prediction"):
+ Prediction type of the scheduler function; must be `flow_prediction` for this scheduler.
+ shift (`float`, defaults to 1.0):
+ The shift parameter for the noise schedule. For Wan2.2: use 5.0 for 720p, 12.0 for 480p.
+ use_dynamic_shifting (`bool`, defaults to False):
+ Whether to use dynamic shifting based on image resolution.
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding.
+ predict_x0 (`bool`, defaults to `True`):
+ Whether to use the updating algorithm on the predicted x0.
+ solver_type (`str`, default `bh2`):
+ Solver type for UniPC. Use `bh1` for unconditional sampling when steps < 10, `bh2` otherwise.
+ lower_order_final (`bool`, default `True`):
+ Whether to use lower-order solvers in the final steps. Stabilizes sampling for steps < 15.
+ disable_corrector (`list`, default `[]`):
+ Steps to disable the corrector to mitigate misalignment with large guidance scales.
+ timestep_spacing (`str`, defaults to `"linspace"`):
+ The way the timesteps should be scaled.
+ final_sigmas_type (`str`, defaults to `"zero"`):
+ The final `sigma` value for the noise schedule. Either `"zero"` or `"sigma_min"`.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ solver_order: int = 2,
+ prediction_type: str = "flow_prediction",
+ shift: float | None = 1.0,
+ use_dynamic_shifting: bool = False,
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ predict_x0: bool = True,
+ solver_type: str = "bh2",
+ lower_order_final: bool = True,
+ disable_corrector: tuple = (),
+ solver_p: SchedulerMixin | None = None,
+ timestep_spacing: str = "linspace",
+ steps_offset: int = 0,
+ final_sigmas_type: str | None = "zero",
+ **kwargs,
+ ):
+ if solver_type not in ["bh1", "bh2"]:
+ if solver_type in ["midpoint", "heun", "logrho"]:
+ self.register_to_config(solver_type="bh2")
+ else:
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
+
+ self.predict_x0 = predict_x0
+ self.num_inference_steps: int | None = None
+
+ # Initialize sigma schedule
+ alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
+ sigmas = 1.0 - alphas
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
+
+ if not use_dynamic_shifting:
+ # Apply timestep shifting based on shift parameter
+ assert shift is not None
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
+
+ self.sigmas = sigmas
+ self.timesteps = sigmas * num_train_timesteps
+ self.num_train_timesteps = num_train_timesteps
+
+ # State for multistep solver
+ self.model_outputs: list[torch.Tensor | None] = [None] * solver_order
+ self.timestep_list: list[Any | None] = [None] * solver_order
+ self.lower_order_nums = 0
+ self.disable_corrector = list(disable_corrector)
+ self.solver_p = solver_p
+ self.last_sample: torch.Tensor | None = None
+ self._step_index: int | None = None
+ self._begin_index: int | None = None
+ self.this_order: int = 1
+
+ # Move sigmas to CPU to reduce GPU/CPU communication
+ self.sigmas = self.sigmas.to("cpu")
+ self.sigma_min = self.sigmas[-1].item()
+ self.sigma_max = self.sigmas[0].item()
+
+ BaseScheduler.__init__(self)
+
+ @property
+ def step_index(self) -> int | None:
+ """The index counter for current timestep. Increases by 1 after each scheduler step."""
+ return self._step_index
+
+ @property
+ def begin_index(self) -> int | None:
+ """The index for the first timestep. Should be set from pipeline with `set_begin_index` method."""
+ return self._begin_index
+
+ def set_shift(self, shift: float) -> None:
+ """Set the shift parameter for the scheduler."""
+ self.config.shift = shift
+
+ def set_begin_index(self, begin_index: int = 0) -> None:
+ """
+ Sets the begin index for the scheduler. Run from pipeline before inference.
+
+ Args:
+ begin_index (`int`): The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int | None = None,
+ device: str | torch.device | None = None,
+ sigmas: list[float] | None = None,
+ mu: float | None = None,
+ shift: float | None = None,
+ ) -> None:
+ """
+ Sets the discrete timesteps used for the diffusion chain (run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ Total number of timesteps.
+ device (`str` or `torch.device`, *optional*):
+ The device to move timesteps to.
+ sigmas (`list[float]`, *optional*):
+ Custom sigma schedule.
+ mu (`float`, *optional*):
+ Parameter for dynamic shifting.
+ shift (`float`, *optional*):
+ Override shift parameter.
+ """
+ if self.config.use_dynamic_shifting and mu is None:
+ raise ValueError("Must pass a value for `mu` when `use_dynamic_shifting` is True")
+
+ if sigmas is None:
+ assert num_inference_steps is not None
+ sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1]
+
+ if self.config.use_dynamic_shifting:
+ assert mu is not None
+ sigmas = self.time_shift(mu, 1.0, sigmas)
+ else:
+ if shift is None:
+ shift = self.config.shift
+ assert isinstance(sigmas, np.ndarray)
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
+
+ if self.config.final_sigmas_type == "sigma_min":
+ sigma_last = self.sigma_min
+ elif self.config.final_sigmas_type == "zero":
+ sigma_last = 0
+ else:
+ raise ValueError(f"`final_sigmas_type` must be 'zero' or 'sigma_min', got {self.config.final_sigmas_type}")
+
+ timesteps = sigmas * self.config.num_train_timesteps
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
+
+ self.sigmas = torch.from_numpy(sigmas)
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
+
+ self.num_inference_steps = len(timesteps)
+
+ # Reset state
+ self.model_outputs = [None] * self.config.solver_order
+ self.timestep_list = [None] * self.config.solver_order
+ self.lower_order_nums = 0
+ self.last_sample = None
+
+ if self.solver_p:
+ self.solver_p.set_timesteps(self.num_inference_steps, device=device)
+
+ self._step_index = None
+ self._begin_index = None
+ self.sigmas = self.sigmas.to("cpu")
+
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
+ """
+ Dynamic thresholding to prevent pixel saturation.
+
+ From "Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding"
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, *remaining_dims = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float()
+
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
+ abs_sample = sample.abs()
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(s, min=1, max=self.config.sample_max_value)
+ s = s.unsqueeze(1)
+ sample = torch.clamp(sample, -s, s) / s
+
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def _sigma_to_t(self, sigma: torch.Tensor) -> torch.Tensor:
+ """Convert sigma to timestep."""
+ return sigma * self.config.num_train_timesteps
+
+ def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """Convert sigma to alpha and sigma_t for flow matching."""
+ return 1 - sigma, sigma
+
+ def time_shift(self, mu: float, sigma: float, t: np.ndarray) -> np.ndarray:
+ """Apply time shift transformation."""
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+
+ def convert_model_output(
+ self,
+ model_output: torch.Tensor,
+ *args,
+ sample: torch.Tensor | None = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Convert the model output to the format needed by the UniPC algorithm.
+
+ Args:
+ model_output (`torch.Tensor`): Direct output from the diffusion model.
+ sample (`torch.Tensor`): Current sample in the diffusion process.
+
+ Returns:
+ `torch.Tensor`: Converted model output.
+ """
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
+ if sample is None:
+ if len(args) > 1:
+ sample = args[1]
+ else:
+ raise ValueError("missing `sample` as a required keyword argument")
+ if timestep is not None:
+ deprecate(
+ "timesteps",
+ "1.0.0",
+ "Passing `timesteps` is deprecated and has no effect as model output conversion "
+ "is now handled via an internal counter `self.step_index`",
+ )
+
+ sigma = self.sigmas[self.step_index].to(sample.device)
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
+
+ if self.predict_x0:
+ if self.config.prediction_type == "flow_prediction":
+ sigma_t = sigma.to(sample.device)
+ x0_pred = sample - sigma_t * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be `flow_prediction` "
+ "for the FlowUniPCMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ x0_pred = self._threshold_sample(x0_pred)
+
+ return x0_pred
+ else:
+ if self.config.prediction_type == "flow_prediction":
+ sigma_t = sigma.to(sample.device)
+ epsilon = sample - (1 - sigma_t) * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be `flow_prediction` "
+ "for the FlowUniPCMultistepScheduler."
+ )
+
+ if self.config.thresholding:
+ sigma_t = sigma.to(sample.device)
+ x0_pred = sample - sigma_t * model_output
+ x0_pred = self._threshold_sample(x0_pred)
+ epsilon = model_output + x0_pred
+
+ return epsilon
+
+ def multistep_uni_p_bh_update(
+ self,
+ model_output: torch.Tensor,
+ *args,
+ sample: torch.Tensor | None = None,
+ order: int | None = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the UniP (B(h) version) predictor.
+
+ Args:
+ model_output (`torch.Tensor`): Direct output from the diffusion model.
+ sample (`torch.Tensor`): Current sample.
+ order (`int`): The order of UniP at this timestep.
+
+ Returns:
+ `torch.Tensor`: The sample tensor at the previous timestep.
+ """
+ prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
+ if sample is None:
+ if len(args) > 1:
+ sample = args[1]
+ else:
+ raise ValueError("missing `sample` as a required keyword argument")
+ if order is None:
+ if len(args) > 2:
+ order = args[2]
+ else:
+ raise ValueError("missing `order` as a required keyword argument")
+ if prev_timestep is not None:
+ deprecate(
+ "prev_timestep",
+ "1.0.0",
+ "Passing `prev_timestep` is deprecated and has no effect.",
+ )
+
+ model_output_list = self.model_outputs
+
+ s0 = self.timestep_list[-1]
+ m0 = model_output_list[-1]
+ x = sample
+
+ if self.solver_p:
+ x_t = self.solver_p.step(model_output, s0, x).prev_sample
+ return x_t
+
+ device = sample.device
+ sigma_t, sigma_s0 = (
+ self.sigmas[self.step_index + 1].to(device),
+ self.sigmas[self.step_index].to(device),
+ )
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
+
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
+
+ h = lambda_t - lambda_s0
+
+ rks = []
+ D1s: list[Any] | None = []
+ for i in range(1, order):
+ si = self.step_index - i
+ mi = model_output_list[-(i + 1)]
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si].to(device))
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
+ rk = (lambda_si - lambda_s0) / h
+ rks.append(rk)
+ assert mi is not None
+ D1s.append((mi - m0) / rk)
+
+ rks.append(1.0)
+ rks = torch.tensor(rks, device=device)
+
+ R = []
+ b = []
+
+ hh = -h if self.predict_x0 else h
+ h_phi_1 = torch.expm1(hh)
+ h_phi_k = h_phi_1 / hh - 1
+
+ factorial_i = 1
+
+ if self.config.solver_type == "bh1":
+ B_h = hh
+ elif self.config.solver_type == "bh2":
+ B_h = torch.expm1(hh)
+ else:
+ raise NotImplementedError()
+
+ for i in range(1, order + 1):
+ R.append(torch.pow(rks, i - 1))
+ b.append(h_phi_k * factorial_i / B_h)
+ factorial_i *= i + 1
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
+
+ R = torch.stack(R)
+ b = torch.tensor(b, device=device)
+
+ if D1s is not None and len(D1s) > 0:
+ D1s = torch.stack(D1s, dim=1)
+ if order == 2:
+ rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
+ else:
+ assert isinstance(R, torch.Tensor)
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
+ else:
+ D1s = None
+
+ if self.predict_x0:
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
+ if D1s is not None:
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
+ else:
+ pred_res = 0
+ x_t = x_t_ - alpha_t * B_h * pred_res
+ else:
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
+ if D1s is not None:
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
+ else:
+ pred_res = 0
+ x_t = x_t_ - sigma_t * B_h * pred_res
+
+ x_t = x_t.to(x.dtype)
+ return x_t
+
+ def multistep_uni_c_bh_update(
+ self,
+ this_model_output: torch.Tensor,
+ *args,
+ last_sample: torch.Tensor | None = None,
+ this_sample: torch.Tensor | None = None,
+ order: int | None = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ One step for the UniC (B(h) version) corrector.
+
+ Args:
+ this_model_output (`torch.Tensor`): Model outputs at `x_t`.
+ last_sample (`torch.Tensor`): Sample before the last predictor `x_{t-1}`.
+ this_sample (`torch.Tensor`): Sample after the last predictor `x_{t}`.
+ order (`int`): The order of UniC-p. Effective accuracy is `order + 1`.
+
+ Returns:
+ `torch.Tensor`: The corrected sample tensor.
+ """
+ this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
+ if last_sample is None:
+ if len(args) > 1:
+ last_sample = args[1]
+ else:
+ raise ValueError("missing `last_sample` as a required keyword argument")
+ if this_sample is None:
+ if len(args) > 2:
+ this_sample = args[2]
+ else:
+ raise ValueError("missing `this_sample` as a required keyword argument")
+ if order is None:
+ if len(args) > 3:
+ order = args[3]
+ else:
+ raise ValueError("missing `order` as a required keyword argument")
+ if this_timestep is not None:
+ deprecate(
+ "this_timestep",
+ "1.0.0",
+ "Passing `this_timestep` is deprecated and has no effect.",
+ )
+
+ model_output_list = self.model_outputs
+
+ m0 = model_output_list[-1]
+ x = last_sample
+ x_t = this_sample
+ model_t = this_model_output
+
+ device = this_sample.device
+ sigma_t, sigma_s0 = (
+ self.sigmas[self.step_index].to(device),
+ self.sigmas[self.step_index - 1].to(device),
+ )
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
+
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
+
+ h = lambda_t - lambda_s0
+
+ rks = []
+ D1s: list[Any] | None = []
+ for i in range(1, order):
+ si = self.step_index - (i + 1)
+ mi = model_output_list[-(i + 1)]
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si].to(device))
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
+ rk = (lambda_si - lambda_s0) / h
+ rks.append(rk)
+ assert mi is not None
+ D1s.append((mi - m0) / rk)
+
+ rks.append(1.0)
+ rks = torch.tensor(rks, device=device)
+
+ R = []
+ b = []
+
+ hh = -h if self.predict_x0 else h
+ h_phi_1 = torch.expm1(hh)
+ h_phi_k = h_phi_1 / hh - 1
+
+ factorial_i = 1
+
+ if self.config.solver_type == "bh1":
+ B_h = hh
+ elif self.config.solver_type == "bh2":
+ B_h = torch.expm1(hh)
+ else:
+ raise NotImplementedError()
+
+ for i in range(1, order + 1):
+ R.append(torch.pow(rks, i - 1))
+ b.append(h_phi_k * factorial_i / B_h)
+ factorial_i *= i + 1
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
+
+ R = torch.stack(R)
+ b = torch.tensor(b, device=device)
+
+ if D1s is not None and len(D1s) > 0:
+ D1s = torch.stack(D1s, dim=1)
+ else:
+ D1s = None
+
+ if order == 1:
+ rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
+ else:
+ rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
+
+ if self.predict_x0:
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
+ if D1s is not None:
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
+ else:
+ corr_res = 0
+ D1_t = model_t - m0
+ x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
+ else:
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
+ if D1s is not None:
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
+ else:
+ corr_res = 0
+ D1_t = model_t - m0
+ x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
+
+ x_t = x_t.to(x.dtype)
+ return x_t
+
+ def index_for_timestep(self, timestep: torch.Tensor, schedule_timesteps: torch.Tensor | None = None) -> int:
+ """Get the index for a given timestep."""
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+ pos = 1 if len(indices) > 1 else 0
+ step_index: int = indices[pos].item()
+
+ return step_index
+
+ def _init_step_index(self, timestep: torch.Tensor) -> None:
+ """Initialize the step_index counter for the scheduler."""
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ def step(
+ self,
+ model_output: torch.Tensor,
+ timestep: int | torch.Tensor,
+ sample: torch.Tensor,
+ return_dict: bool = True,
+ generator: torch.Generator | None = None,
+ ) -> SchedulerOutput | tuple:
+ """
+ Predict the sample from the previous timestep by reversing the SDE using multistep UniPC.
+
+ Args:
+ model_output (`torch.Tensor`): Direct output from the diffusion model.
+ timestep (`int`): Current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`): Current sample created by the diffusion process.
+ return_dict (`bool`): Whether to return a SchedulerOutput or tuple.
+
+ Returns:
+ `SchedulerOutput` or `tuple`: The sample tensor at the previous timestep.
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ use_corrector = (
+ self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
+ )
+
+ model_output_convert = self.convert_model_output(model_output, sample=sample)
+
+ if use_corrector:
+ sample = self.multistep_uni_c_bh_update(
+ this_model_output=model_output_convert,
+ last_sample=self.last_sample,
+ this_sample=sample,
+ order=self.this_order,
+ )
+
+ # Update model output history
+ for i in range(self.config.solver_order - 1):
+ self.model_outputs[i] = self.model_outputs[i + 1]
+ self.timestep_list[i] = self.timestep_list[i + 1]
+
+ self.model_outputs[-1] = model_output_convert
+ self.timestep_list[-1] = timestep
+
+ # Determine order for this step
+ if self.config.lower_order_final:
+ this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index)
+ else:
+ this_order = self.config.solver_order
+
+ self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
+ assert self.this_order > 0
+
+ self.last_sample = sample
+ prev_sample = self.multistep_uni_p_bh_update(
+ model_output=model_output,
+ sample=sample,
+ order=self.this_order,
+ )
+
+ if self.lower_order_nums < self.config.solver_order:
+ self.lower_order_nums += 1
+
+ assert self._step_index is not None
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input.
+
+ Args:
+ sample (`torch.Tensor`): The input sample.
+
+ Returns:
+ `torch.Tensor`: A scaled input sample (unchanged for this scheduler).
+ """
+ return sample
+
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.Tensor:
+ """
+ Add noise to the original samples.
+
+ Args:
+ original_samples (`torch.Tensor`): Original samples.
+ noise (`torch.Tensor`): Noise to add.
+ timesteps (`torch.IntTensor`): Timesteps for noise addition.
+
+ Returns:
+ `torch.Tensor`: Noisy samples.
+ """
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ if self.begin_index is None:
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
+ elif self.step_index is not None:
+ step_indices = [self.step_index] * timesteps.shape[0]
+ else:
+ step_indices = [self.begin_index] * timesteps.shape[0]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
+ return noisy_samples
+
+ def __len__(self) -> int:
+ return self.config.num_train_timesteps
diff --git a/vllm_omni/diffusion/models/sd3/sd3_transformer.py b/vllm_omni/diffusion/models/sd3/sd3_transformer.py
index a2afa8c0946..22a11741a53 100644
--- a/vllm_omni/diffusion/models/sd3/sd3_transformer.py
+++ b/vllm_omni/diffusion/models/sd3/sd3_transformer.py
@@ -102,8 +102,8 @@ def __init__(
else:
self.to_out = None
- self.norm_added_q = RMSNorm(head_dim, eps=eps)
- self.norm_added_k = RMSNorm(head_dim, eps=eps)
+ self.norm_added_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
+ self.norm_added_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
self.attn = Attention(
num_heads=num_heads,
@@ -341,8 +341,10 @@ def __init__(
self.pooled_projection_dim = model_config.pooled_projection_dim
self.joint_attention_dim = model_config.joint_attention_dim
self.patch_size = model_config.patch_size
- self.dual_attention_layers = model_config.dual_attention_layers
- self.qk_norm = model_config.qk_norm
+ self.dual_attention_layers = (
+ model_config.dual_attention_layers if hasattr(model_config, "dual_attention_layers") else ()
+ )
+ self.qk_norm = model_config.qk_norm if hasattr(model_config, "qk_norm") else ""
self.pos_embed_max_size = model_config.pos_embed_max_size
self.pos_embed = PatchEmbed(
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
index 61a858ca4d2..3b04e2a5a9b 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
@@ -9,7 +9,7 @@
import PIL.Image
import torch
-from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler
+from diffusers import AutoencoderKLWan
from diffusers.utils.torch_utils import randn_tensor
from torch import nn
from transformers import AutoTokenizer, UMT5EncoderModel
@@ -18,6 +18,7 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
+from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler
from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel
from vllm_omni.diffusion.request import OmniDiffusionRequest
@@ -244,12 +245,13 @@ def __init__(
else:
self.transformer_2 = None
- self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
- model, subfolder="scheduler", local_files_only=local_files_only
+ # Initialize UniPC scheduler
+ flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p
+ self.scheduler = FlowUniPCMultistepScheduler(
+ num_train_timesteps=1000,
+ shift=flow_shift,
+ prediction_type="flow_prediction",
)
- # Apply flow_shift if specified (12.0 for 480p, 5.0 for 720p recommended for Wan2.2)
- if od_config.flow_shift is not None:
- self.scheduler.config.flow_shift = od_config.flow_shift
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
index d72afc9ee84..ed0d6e4c6b6 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
@@ -9,7 +9,7 @@
import numpy as np
import PIL.Image
import torch
-from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler
+from diffusers import AutoencoderKLWan
from diffusers.utils.torch_utils import randn_tensor
from torch import nn
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
@@ -18,6 +18,8 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
+from vllm_omni.diffusion.models.interface import SupportImageInput
+from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
create_transformer_from_config,
load_transformer_config,
@@ -112,7 +114,7 @@ def pre_process_func(requests: list[OmniDiffusionRequest]) -> list[OmniDiffusion
return pre_process_func
-class Wan22I2VPipeline(nn.Module):
+class Wan22I2VPipeline(nn.Module, SupportImageInput):
"""
Wan2.2 Image-to-Video Pipeline.
@@ -198,12 +200,13 @@ def __init__(
else:
self.transformer_2 = None
- # Scheduler
- self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
- model, subfolder="scheduler", local_files_only=local_files_only
+ # Initialize UniPC scheduler
+ flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p
+ self.scheduler = FlowUniPCMultistepScheduler(
+ num_train_timesteps=1000,
+ shift=flow_shift,
+ prediction_type="flow_prediction",
)
- if od_config.flow_shift is not None:
- self.scheduler.config.flow_shift = od_config.flow_shift
# VAE scale factors
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if hasattr(self.vae, "config") else 4
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
index bee70a7a96b..6a9a6a6a0e9 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
@@ -22,7 +22,7 @@
import numpy as np
import PIL.Image
import torch
-from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler
+from diffusers import AutoencoderKLWan
from diffusers.utils.torch_utils import randn_tensor
from torch import nn
from transformers import AutoTokenizer, UMT5EncoderModel
@@ -31,6 +31,8 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
+from vllm_omni.diffusion.models.interface import SupportImageInput
+from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
create_transformer_from_config,
load_transformer_config,
@@ -102,7 +104,7 @@ def pre_process_func(requests: list[OmniDiffusionRequest]) -> list[OmniDiffusion
return pre_process_func
-class Wan22TI2VPipeline(nn.Module):
+class Wan22TI2VPipeline(nn.Module, SupportImageInput):
"""
Wan2.2 Text-Image-to-Video (TI2V) Pipeline.
@@ -156,12 +158,13 @@ def __init__(
transformer_config = load_transformer_config(model, "transformer", local_files_only)
self.transformer = create_transformer_from_config(transformer_config)
- # Scheduler
- self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
- model, subfolder="scheduler", local_files_only=local_files_only
+ # Initialize UniPC scheduler
+ flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 # default for 720p
+ self.scheduler = FlowUniPCMultistepScheduler(
+ num_train_timesteps=1000,
+ shift=flow_shift,
+ prediction_type="flow_prediction",
)
- if od_config.flow_shift is not None:
- self.scheduler.config.flow_shift = od_config.flow_shift
# VAE scale factors
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if hasattr(self.vae, "config") else 4
diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py
index 94cf889210f..99b0ea4c765 100644
--- a/vllm_omni/diffusion/registry.py
+++ b/vllm_omni/diffusion/registry.py
@@ -79,6 +79,11 @@
"pipeline_sd3",
"StableDiffusion3Pipeline",
),
+ "Flux2KleinPipeline": (
+ "flux2_klein",
+ "pipeline_flux2_klein",
+ "Flux2KleinPipeline",
+ ),
}
@@ -127,6 +132,7 @@ def initialize_model(
"BagelPipeline": "get_bagel_post_process_func",
"LongCatImageEditPipeline": "get_longcat_image_post_process_func",
"StableDiffusion3Pipeline": "get_sd3_image_post_process_func",
+ "Flux2KleinPipeline": "get_flux2_klein_post_process_func",
}
_DIFFUSION_PRE_PROCESS_FUNCS = {
diff --git a/vllm_omni/diffusion/utils/tf_utils.py b/vllm_omni/diffusion/utils/tf_utils.py
new file mode 100644
index 00000000000..44a78804452
--- /dev/null
+++ b/vllm_omni/diffusion/utils/tf_utils.py
@@ -0,0 +1,54 @@
+import inspect
+from typing import Any
+
+from vllm_omni.diffusion.data import TransformerConfig
+
+
+def get_transformer_config_kwargs(
+ tf_model_config: TransformerConfig, model_class: type[Any] | None = None
+) -> dict[str, Any]:
+ """
+ This function extracts parameters from a TransformerConfig instance and filters out internal
+ diffusers metadata keys (those starting with '_') that should not be passed to model initialization.
+ Also filters out parameters that are not accepted by the model's __init__ method (e.g., pooled_projection_dim
+ for QwenImageTransformer2DModel).
+
+ This uses inspect.signature to dynamically detect accepted parameters, making it general for any model class.
+ Similar to how diffusers' @register_to_config decorator works.
+
+ Args:
+ tf_model_config: TransformerConfig instance containing model parameters
+ model_class: Optional model class to inspect for accepted __init__ parameters.
+ If None, all non-internal parameters are returned (backward compatibility).
+
+ Returns:
+ dict: Filtered dictionary of parameters suitable for transformer model initialization
+ """
+ # Extract transformer config parameters, filtering out internal diffusers metadata
+ # TransformerConfig stores params in a 'params' dict, and we need to exclude
+ # internal keys like '_class_name' and '_diffusers_version'
+ tf_config_params = tf_model_config.to_dict()
+
+ # Filter out internal diffusers metadata keys that start with '_'
+ filtered_params = {k: v for k, v in tf_config_params.items() if not k.startswith("_")}
+
+ # If model_class is provided, use inspect.signature to get accepted parameters
+ if model_class is not None:
+ try:
+ # Get the signature of the model's __init__ method
+ sig = inspect.signature(model_class.__init__)
+ # Get all parameter names (excluding 'self' and special parameters)
+ accepted_params = {
+ name
+ for name, param in sig.parameters.items()
+ if name != "self" and param.kind != inspect.Parameter.VAR_KEYWORD # Exclude **kwargs
+ }
+
+ # Filter to only include parameters that are in the model's signature
+ filtered_params = {k: v for k, v in filtered_params.items() if k in accepted_params}
+ except (TypeError, AttributeError):
+ # If inspection fails, fall back to returning all non-internal params
+ # This maintains backward compatibility
+ pass
+
+ return filtered_params
diff --git a/vllm_omni/diffusion/worker/__init__.py b/vllm_omni/diffusion/worker/__init__.py
index dc3306dae3f..dfec4596bc2 100644
--- a/vllm_omni/diffusion/worker/__init__.py
+++ b/vllm_omni/diffusion/worker/__init__.py
@@ -2,6 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Worker classes for diffusion models."""
-from vllm_omni.diffusion.worker.gpu_worker import GPUWorker, WorkerProc
+from vllm_omni.diffusion.worker.gpu_diffusion_model_runner import GPUDiffusionModelRunner
+from vllm_omni.diffusion.worker.gpu_diffusion_worker import (
+ GPUDiffusionWorker,
+ WorkerProc,
+)
-__all__ = ["GPUWorker", "WorkerProc"]
+__all__ = [
+ "GPUDiffusionModelRunner",
+ "GPUDiffusionWorker",
+ "WorkerProc",
+]
diff --git a/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py b/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py
new file mode 100644
index 00000000000..8ffa12ed2cc
--- /dev/null
+++ b/vllm_omni/diffusion/worker/gpu_diffusion_model_runner.py
@@ -0,0 +1,165 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Diffusion Model Runner for vLLM-Omni.
+
+Handles model loading, compilation, caching, and execution of diffusion model
+forward passes. This follows the AR pattern where the Runner handles all
+model-related operations.
+"""
+
+from __future__ import annotations
+
+import time
+from collections.abc import Iterable
+from contextlib import nullcontext
+
+import torch
+from vllm.config import LoadConfig
+from vllm.logger import init_logger
+from vllm.utils.mem_utils import DeviceMemoryProfiler, GiB_bytes
+
+from vllm_omni.diffusion.cache.selector import get_cache_backend
+from vllm_omni.diffusion.compile import regionally_compile
+from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
+from vllm_omni.diffusion.forward_context import set_forward_context
+from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
+from vllm_omni.diffusion.offload import apply_offload_hooks
+from vllm_omni.diffusion.request import OmniDiffusionRequest
+
+logger = init_logger(__name__)
+
+
+class GPUDiffusionModelRunner:
+ """
+ Model runner that handles model loading and execution for diffusion models.
+
+ This class follows the AR pattern where the Runner handles all model-related
+ operations including loading, compilation, offloading, caching, and execution.
+ The Worker only handles infrastructure (device, distributed env).
+ """
+
+ def __init__(
+ self,
+ vllm_config,
+ od_config: OmniDiffusionConfig,
+ device: torch.device,
+ ):
+ """
+ Initialize the diffusion model runner.
+
+ Args:
+ vllm_config: vLLM configuration.
+ od_config: OmniDiffusion configuration.
+ device: The device to run on.
+ """
+ self.vllm_config = vllm_config
+ self.od_config = od_config
+ self.device = device
+ self.pipeline = None
+ self.cache_backend = None
+
+ def load_model(
+ self,
+ memory_pool_context_fn: callable | None = None,
+ ) -> None:
+ """
+ Load the diffusion model, apply compilation and offloading.
+
+ Args:
+ memory_pool_context_fn: Optional function that returns a context manager
+ for memory pool allocation (used for sleep mode).
+ """
+ load_device = "cpu" if self.od_config.enable_cpu_offload else str(self.device)
+
+ def get_memory_context():
+ if memory_pool_context_fn is not None:
+ return memory_pool_context_fn(tag="weights")
+ return nullcontext()
+
+ # Load model within forward context
+ with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config):
+ load_config = LoadConfig()
+ model_loader = DiffusersPipelineLoader(load_config)
+ time_before_load = time.perf_counter()
+
+ with get_memory_context():
+ with DeviceMemoryProfiler() as m:
+ self.pipeline = model_loader.load_model(
+ od_config=self.od_config,
+ load_device=load_device,
+ )
+ time_after_load = time.perf_counter()
+
+ logger.info(
+ "Model loading took %.4f GiB and %.6f seconds",
+ m.consumed_memory / GiB_bytes,
+ time_after_load - time_before_load,
+ )
+ logger.info("Model runner: Model loaded successfully.")
+
+ # Apply CPU offloading (DiT <-> encoders mutual exclusion)
+ if self.od_config.enable_cpu_offload:
+ for name in ["vae"]:
+ module = getattr(self.pipeline, name, None)
+ if module is None:
+ continue
+ try:
+ module.to(self.device, non_blocking=True)
+ except Exception as exc:
+ logger.debug("Failed to move %s to GPU: %s", name, exc)
+
+ apply_offload_hooks(self.pipeline, self.od_config, device=self.device)
+
+ # Apply torch.compile if not in eager mode
+ if not self.od_config.enforce_eager:
+ try:
+ self.pipeline.transformer = regionally_compile(
+ self.pipeline.transformer,
+ dynamic=True,
+ )
+ logger.info("Model runner: Model compiled with torch.compile.")
+ except Exception as e:
+ logger.warning(f"Model runner: torch.compile failed with error: {e}. Using eager mode.")
+
+ # Setup cache backend
+ self.cache_backend = get_cache_backend(self.od_config.cache_backend, self.od_config.cache_config)
+
+ if self.cache_backend is not None:
+ self.cache_backend.enable(self.pipeline)
+
+ logger.info("Model runner: Initialization complete.")
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ """Load weights into the pipeline."""
+ return self.pipeline.load_weights(weights)
+
+ @torch.inference_mode()
+ def execute_model(self, reqs: list[OmniDiffusionRequest]) -> DiffusionOutput:
+ """
+ Execute a forward pass for the given requests.
+
+ Args:
+ reqs: List of diffusion requests to process.
+
+ Returns:
+ DiffusionOutput with generated results.
+ """
+ assert self.pipeline is not None, "Model not loaded. Call load_model() first."
+ if not reqs or len(reqs) == 0:
+ raise ValueError("Cannot execute model with empty request list")
+
+ # TODO: dealing with first req for now
+ req = reqs[0]
+
+ if req.generator is None and req.seed is not None:
+ req.generator = torch.Generator(device=self.device).manual_seed(req.seed)
+
+ # Refresh cache context if needed
+ if self.cache_backend is not None and self.cache_backend.is_enabled():
+ self.cache_backend.refresh(self.pipeline, req.num_inference_steps)
+
+ with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config):
+ output = self.pipeline.forward(req)
+
+ return output
diff --git a/vllm_omni/diffusion/worker/gpu_worker.py b/vllm_omni/diffusion/worker/gpu_diffusion_worker.py
similarity index 65%
rename from vllm_omni/diffusion/worker/gpu_worker.py
rename to vllm_omni/diffusion/worker/gpu_diffusion_worker.py
index 99a718e389f..7fa1e3da4d5 100644
--- a/vllm_omni/diffusion/worker/gpu_worker.py
+++ b/vllm_omni/diffusion/worker/gpu_diffusion_worker.py
@@ -1,20 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Diffusion Worker for vLLM-Omni.
+
+Handles GPU infrastructure initialization and delegates model operations
+to GPUDiffusionModelRunner.
+"""
+
import multiprocessing as mp
import os
-import time
-from collections.abc import Iterable
from contextlib import AbstractContextManager, nullcontext
import torch
import zmq
-from vllm.config import LoadConfig, VllmConfig
+from vllm.config import VllmConfig
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.logger import init_logger
-from vllm.utils.mem_utils import DeviceMemoryProfiler, GiB_bytes
+from vllm.utils.mem_utils import GiB_bytes
-from vllm_omni.diffusion.cache.selector import get_cache_backend
-from vllm_omni.diffusion.compile import regionally_compile
from vllm_omni.diffusion.data import (
DiffusionOutput,
OmniDiffusionConfig,
@@ -25,16 +28,23 @@
initialize_model_parallel,
)
from vllm_omni.diffusion.forward_context import set_forward_context
-from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
-from vllm_omni.diffusion.offload import apply_offload_hooks
from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.diffusion.worker.gpu_diffusion_model_runner import GPUDiffusionModelRunner
logger = init_logger(__name__)
-class GPUWorker:
+class GPUDiffusionWorker:
"""
- A worker that executes the model on a single GPU.
+ A worker that manages GPU infrastructure and delegates to the model runner.
+
+ This class handles infrastructure initialization only:
+ - Device setup (CUDA device selection)
+ - Distributed environment (NCCL, model parallel)
+ - Memory management (sleep/wake)
+
+ All model-related operations (loading, compilation, execution) are
+ delegated to GPUDiffusionModelRunner.
"""
def __init__(
@@ -46,15 +56,17 @@ def __init__(
self.local_rank = local_rank
self.rank = rank
self.od_config = od_config
- self.pipeline = None
- self.device = None
+ self.device: torch.device | None = None
+ self.vllm_config: VllmConfig | None = None
+ self.model_runner: GPUDiffusionModelRunner | None = None
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
- self.init_device_and_model()
+ self.init_device()
- def init_device_and_model(self) -> None:
- """Initialize the device and load the model."""
+ def init_device(self) -> None:
+ """Initialize the device and distributed environment."""
world_size = self.od_config.num_gpus
rank = self.rank
+
# Set environment variables for distributed initialization
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(self.od_config.master_port)
@@ -62,19 +74,21 @@ def init_device_and_model(self) -> None:
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
+ # Setup device
self.device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(self.device)
- # hack
+ # Create vllm_config for parallel configuration
vllm_config = VllmConfig()
vllm_config.parallel_config.tensor_parallel_size = self.od_config.parallel_config.tensor_parallel_size
vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size
self.vllm_config = vllm_config
- load_device = "cpu" if self.od_config.enable_cpu_offload else str(self.device)
+ # Initialize distributed environment
with set_forward_context(vllm_config=vllm_config, omni_diffusion_config=self.od_config):
init_distributed_environment(world_size=world_size, rank=rank)
logger.info(f"Worker {self.rank}: Initialized device and distributed environment.")
+
parallel_config = self.od_config.parallel_config
initialize_model_parallel(
data_parallel_size=parallel_config.data_parallel_size,
@@ -86,107 +100,45 @@ def init_device_and_model(self) -> None:
pipeline_parallel_size=parallel_config.pipeline_parallel_size,
)
- load_config = LoadConfig()
- model_loader = DiffusersPipelineLoader(load_config)
- time_before_load = time.perf_counter()
- with self._maybe_get_memory_pool_context(tag="weights"):
- with DeviceMemoryProfiler() as m:
- self.pipeline = model_loader.load_model(
- od_config=self.od_config,
- load_device=load_device,
- )
- time_after_load = time.perf_counter()
-
- logger.info(
- "Model loading took %.4f GiB and %.6f seconds",
- m.consumed_memory / GiB_bytes,
- time_after_load - time_before_load,
+ # Create model runner and load model
+ self.model_runner = GPUDiffusionModelRunner(
+ vllm_config=self.vllm_config,
+ od_config=self.od_config,
+ device=self.device,
)
- logger.info(f"Worker {self.rank}: Model loaded successfully.")
-
- # Apply CPU offloading (DiT <-> encoders mutual exclusion)
- if self.od_config.enable_cpu_offload:
- for name in ["vae"]:
- module = getattr(self.pipeline, name, None)
- if module is None:
- continue
- try:
- module.to(self.device, non_blocking=True)
- except Exception as exc:
- logger.debug("Failed to move %s to GPU: %s", name, exc)
-
- apply_offload_hooks(self.pipeline, self.od_config, device=self.device)
-
- if not self.od_config.enforce_eager:
- try:
- self.pipeline.transformer = regionally_compile(
- self.pipeline.transformer,
- dynamic=True,
- )
- logger.info(f"Worker {self.rank}: Model compiled with torch.compile.")
- except Exception as e:
- logger.warning(f"Worker {self.rank}: torch.compile failed with error: {e}. Using eager mode.")
-
- # Setup cache backend based on type (both backends use enable()/reset() interface)
- self.cache_backend = get_cache_backend(self.od_config.cache_backend, self.od_config.cache_config)
-
- if self.cache_backend is not None:
- self.cache_backend.enable(self.pipeline)
+ self.model_runner.load_model(
+ memory_pool_context_fn=self._maybe_get_memory_pool_context,
+ )
+ logger.info(f"Worker {self.rank}: Initialization complete.")
def generate(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput:
- """
- Generate output for the given requests.
-
- Args:
- requests: List of diffusion requests
-
- Returns:
- DiffusionOutput with generated results
- """
+ """Generate output for the given requests."""
return self.execute_model(requests, self.od_config)
- @torch.inference_mode()
def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput:
- """
- Execute a forward pass.
- """
- assert self.pipeline is not None
- if not reqs or len(reqs) == 0:
- raise ValueError("Cannot execute model with empty request list")
- # TODO: dealing with first req for now
- req = reqs[0]
+ """Execute a forward pass by delegating to the model runner."""
+ assert self.model_runner is not None, "Model runner not initialized"
+ return self.model_runner.execute_model(reqs)
- if req.generator is None and req.seed is not None:
- req.generator = torch.Generator(device=self.device).manual_seed(req.seed)
-
- # Refresh cache context if needed
- if self.cache_backend is not None and self.cache_backend.is_enabled():
- self.cache_backend.refresh(self.pipeline, req.num_inference_steps)
- with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config):
- output = self.pipeline.forward(req)
- return output
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- return self.pipeline.load_weights(weights)
+ def load_weights(self, weights) -> set[str]:
+ """Load weights by delegating to the model runner."""
+ assert self.model_runner is not None, "Model runner not initialized"
+ return self.model_runner.load_weights(weights)
def sleep(self, level: int = 1) -> bool:
"""
- Put the worker to sleep. The worker should not process any requests.
- The caller should guarantee that no requests are being processed
- during the sleep period, before `wake_up` is called.
+ Put the worker to sleep, offloading model weights.
Args:
- level: The sleep level. Level 1 sleep will offload the model
- weights and discard the kv cache.
- Currently only support level 1.
+ level: Sleep level. Level 1 offloads weights, level 2 also saves buffers.
"""
from vllm.device_allocator.cumem import CuMemAllocator
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
# Save the buffers before level 2 sleep
- if level == 2:
- model = self.pipeline
+ if level == 2 and self.model_runner is not None:
+ model = self.model_runner.pipeline
self._sleep_saved_buffers = {name: buffer.cpu().clone() for name, buffer in model.named_buffers()}
allocator = CuMemAllocator.get_instance()
@@ -220,8 +172,8 @@ def wake_up(self, tags: list[str] | None = None) -> bool:
allocator.wake_up(tags)
# Restore the buffers after level 2 sleep
- if len(self._sleep_saved_buffers):
- model = self.pipeline
+ if len(self._sleep_saved_buffers) and self.model_runner is not None:
+ model = self.model_runner.pipeline
for name, buffer in model.named_buffers():
if name in self._sleep_saved_buffers:
buffer.data.copy_(self._sleep_saved_buffers[name].data)
@@ -229,6 +181,7 @@ def wake_up(self, tags: list[str] | None = None) -> bool:
return True
def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
+ """Get memory pool context for sleep mode support."""
if self.od_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator
@@ -240,6 +193,7 @@ def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
return nullcontext()
def shutdown(self) -> None:
+ """Shutdown the worker and cleanup distributed environment."""
destroy_distributed_env()
@@ -257,17 +211,14 @@ def __init__(
# Inter-process Communication
self.context = zmq.Context(io_threads=2)
- # Initialize MessageQueue reader from handle (unified for generation & RPC)
+ # Initialize MessageQueue reader from handle
self.mq = MessageQueue.create_from_handle(broadcast_handle, gpu_id)
self.result_mq = None
self.result_mq_handle = None
- # Setup result sender (only for rank 0 for now, or whoever needs to reply)
- # Assuming only rank 0 replies to scheduler as per original logic
+ # Setup result sender (only for rank 0)
if gpu_id == 0:
- # Create MessageQueue for results (1 writer -> 1 reader)
- # We assume the reader (SyncScheduler) will act as rank 0
self.result_mq = MessageQueue(n_reader=1, n_local_reader=1, local_reader_ranks=[0])
self.result_mq_handle = self.result_mq.export_handle()
logger.info(f"Worker {gpu_id} created result MessageQueue")
@@ -277,31 +228,25 @@ def __init__(
self.gpu_id = gpu_id
self._running = True
- def _create_worker(self, gpu_id: int, od_config: OmniDiffusionConfig) -> GPUWorker:
+ def _create_worker(self, gpu_id: int, od_config: OmniDiffusionConfig) -> GPUDiffusionWorker:
"""Create a worker instance. Override in subclasses for different worker types."""
- return GPUWorker(
+ return GPUDiffusionWorker(
local_rank=gpu_id,
rank=gpu_id,
od_config=od_config,
)
def return_result(self, output: DiffusionOutput):
- """
- replies to client, only on rank 0
- """
+ """Reply to client, only on rank 0."""
if self.result_mq is not None:
self.result_mq.enqueue(output)
def recv_message(self):
- """
- Receive unified messages (RPC requests, shutdown) from broadcast queue.
- Uses indefinite=True to block until a message arrives.
- """
+ """Receive messages from broadcast queue."""
return self.mq.dequeue(indefinite=True)
def execute_rpc(self, rpc_request: dict) -> tuple[object | None, bool]:
"""Execute an RPC request and indicate whether to reply."""
-
method = rpc_request["method"]
args = rpc_request.get("args", ())
kwargs = rpc_request.get("kwargs", {})
@@ -325,14 +270,11 @@ def execute_rpc(self, rpc_request: dict) -> tuple[object | None, bool]:
logger.error(f"Error executing RPC: {e}", exc_info=True)
return {"status": "error", "error": str(e)}, should_reply
- # TODO: queueing, cancellation
def worker_busy_loop(self) -> None:
- """Main busy loop for Multiprocessing Workers"""
-
+ """Main busy loop for Multiprocessing Workers."""
logger.info(f"Worker {self.gpu_id} ready to receive requests via shared memory")
while self._running:
- # Receive unified message (generation request, RPC request, or shutdown)
msg = None
try:
msg = self.recv_message()
@@ -349,7 +291,6 @@ def worker_busy_loop(self) -> None:
# Route message based on type
if isinstance(msg, dict) and msg.get("type") == "rpc":
- # Handle RPC request
try:
result, should_reply = self.execute_rpc(msg)
if should_reply:
@@ -360,13 +301,12 @@ def worker_busy_loop(self) -> None:
self.return_result({"status": "error", "error": str(e)})
elif isinstance(msg, dict) and msg.get("type") == "shutdown":
- # Handle shutdown message
logger.info("Worker %s: Received shutdown message", self.gpu_id)
self._running = False
continue
else:
- # Handle generation request (OmniDiffusionRequest list)
+ # Handle generation request
try:
output = self.worker.execute_model(msg, self.od_config)
except Exception as e:
@@ -379,17 +319,14 @@ def worker_busy_loop(self) -> None:
try:
self.return_result(output)
except zmq.ZMQError as e:
- # Reply failed; log and keep loop alive to accept future requests
logger.error(f"ZMQ error sending reply: {e}")
continue
logger.info("event loop terminated.")
try:
self.worker.shutdown()
- except Exception as exc: # pragma: no cover - best effort cleanup
+ except Exception as exc:
logger.warning("Worker %s: Shutdown encountered an error: %s", self.gpu_id, exc)
- # if self.result_sender is not None:
- # self.result_sender.close()
self.context.term()
@staticmethod
@@ -400,7 +337,6 @@ def worker_main(
broadcast_handle,
) -> None:
"""Worker initialization and execution loops."""
-
worker_proc = WorkerProc(
od_config,
gpu_id=rank,
diff --git a/vllm_omni/diffusion/worker/npu/npu_worker.py b/vllm_omni/diffusion/worker/npu/npu_worker.py
index bfeb0d914c9..446c29cae4d 100644
--- a/vllm_omni/diffusion/worker/npu/npu_worker.py
+++ b/vllm_omni/diffusion/worker/npu/npu_worker.py
@@ -3,6 +3,8 @@
import multiprocessing as mp
import os
import time
+from collections.abc import Iterable
+from contextlib import AbstractContextManager, nullcontext
import torch
from vllm.config import LoadConfig, VllmConfig
@@ -11,25 +13,42 @@
from vllm_omni.diffusion.cache.selector import get_cache_backend
from vllm_omni.diffusion.data import (
+ DiffusionOutput,
OmniDiffusionConfig,
)
from vllm_omni.diffusion.distributed.parallel_state import (
+ destroy_distributed_env,
init_distributed_environment,
initialize_model_parallel,
)
from vllm_omni.diffusion.forward_context import set_forward_context
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
-from vllm_omni.diffusion.worker.gpu_worker import GPUWorker, WorkerProc
+from vllm_omni.diffusion.request import OmniDiffusionRequest
+from vllm_omni.diffusion.worker.gpu_diffusion_worker import WorkerProc
logger = init_logger(__name__)
-class NPUWorker(GPUWorker):
+class NPUWorker:
"""
A worker that executes the model on a single NPU.
Inherits from GPUWorker and overrides device-specific initialization.
"""
+ def __init__(
+ self,
+ local_rank: int,
+ rank: int,
+ od_config: OmniDiffusionConfig,
+ ):
+ self.local_rank = local_rank
+ self.rank = rank
+ self.od_config = od_config
+ self.pipeline = None
+ self.device = None
+ self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
+ self.init_device_and_model()
+
def init_device_and_model(self) -> None:
"""Initialize the NPU device and load the model."""
world_size = self.od_config.num_gpus
@@ -86,6 +105,115 @@ def init_device_and_model(self) -> None:
if self.cache_backend is not None:
self.cache_backend.enable(self.pipeline)
+ def generate(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput:
+ """
+ Generate output for the given requests.
+
+ Args:
+ requests: List of diffusion requests
+
+ Returns:
+ DiffusionOutput with generated results
+ """
+ return self.execute_model(requests, self.od_config)
+
+ @torch.inference_mode()
+ def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput:
+ """
+ Execute a forward pass.
+ """
+ assert self.pipeline is not None
+ if not reqs or len(reqs) == 0:
+ raise ValueError("Cannot execute model with empty request list")
+ # TODO: dealing with first req for now
+ req = reqs[0]
+
+ if req.generator is None and req.seed is not None:
+ req.generator = torch.Generator(device=self.device).manual_seed(req.seed)
+
+ # Refresh cache context if needed
+ if self.cache_backend is not None and self.cache_backend.is_enabled():
+ self.cache_backend.refresh(self.pipeline, req.num_inference_steps)
+ with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config):
+ output = self.pipeline.forward(req)
+ return output
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ return self.pipeline.load_weights(weights)
+
+ def sleep(self, level: int = 1) -> bool:
+ """
+ Put the worker to sleep. The worker should not process any requests.
+ The caller should guarantee that no requests are being processed
+ during the sleep period, before `wake_up` is called.
+
+ Args:
+ level: The sleep level. Level 1 sleep will offload the model
+ weights and discard the kv cache.
+ Currently only support level 1.
+ """
+ from vllm.device_allocator.cumem import CuMemAllocator
+
+ free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
+
+ # Save the buffers before level 2 sleep
+ if level == 2:
+ model = self.pipeline
+ self._sleep_saved_buffers = {name: buffer.cpu().clone() for name, buffer in model.named_buffers()}
+
+ allocator = CuMemAllocator.get_instance()
+ allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
+ free_bytes_after_sleep, total = torch.cuda.mem_get_info()
+ freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
+ used_bytes = total - free_bytes_after_sleep
+ assert freed_bytes >= 0, "Memory usage increased after sleeping."
+ logger.info(
+ "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.",
+ freed_bytes / GiB_bytes,
+ used_bytes / GiB_bytes,
+ )
+ return True
+
+ def wake_up(self, tags: list[str] | None = None) -> bool:
+ """
+ Wake up the worker from sleep mode. See the sleep function
+ method for more details.
+
+ Args:
+ tags: An optional list of tags to reallocate the worker memory
+ for specific memory allocations. Values must be in
+ `("weights")`. If None, all memory is reallocated.
+ wake_up should be called with all tags (or None) before the
+ worker is used again.
+ """
+ from vllm.device_allocator.cumem import CuMemAllocator
+
+ allocator = CuMemAllocator.get_instance()
+ allocator.wake_up(tags)
+
+ # Restore the buffers after level 2 sleep
+ if len(self._sleep_saved_buffers):
+ model = self.pipeline
+ for name, buffer in model.named_buffers():
+ if name in self._sleep_saved_buffers:
+ buffer.data.copy_(self._sleep_saved_buffers[name].data)
+ self._sleep_saved_buffers = {}
+ return True
+
+ def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
+ if self.od_config.enable_sleep_mode:
+ from vllm.device_allocator.cumem import CuMemAllocator
+
+ allocator = CuMemAllocator.get_instance()
+ if tag == "weights":
+ assert allocator.get_current_usage() == 0, "Sleep mode can only be used for one instance per process."
+ return allocator.use_memory_pool(tag=tag)
+ else:
+ return nullcontext()
+
+ def shutdown(self) -> None:
+ destroy_distributed_env()
+
class NPUWorkerProc(WorkerProc):
"""Wrapper that runs one NPUWorker in a separate process."""
diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py
index 79dc25cf494..47d094e4163 100644
--- a/vllm_omni/engine/__init__.py
+++ b/vllm_omni/engine/__init__.py
@@ -9,11 +9,9 @@
import msgspec
import torch
from vllm.v1.engine import (
- EngineCoreEvent,
EngineCoreRequest,
- FinishReason,
- LogprobsLists,
- LogprobsTensors,
+ EngineCoreOutput,
+ EngineCoreOutputs,
SchedulerStats,
UtilityOutput,
)
@@ -79,64 +77,10 @@ class OmniEngineCoreRequest(EngineCoreRequest):
additional_information: AdditionalInformationPayload | None = None
-class OmniEngineCoreOutput(
- msgspec.Struct,
- array_like=True, # type: ignore[call-arg]
- omit_defaults=True, # type: ignore[call-arg]
- gc=False,
-): # type: ignore[call-arg]
- request_id: str
- new_token_ids: list[int]
+class OmniEngineCoreOutput(EngineCoreOutput):
+ pooling_output: dict[str, torch.Tensor] | None = None
- new_logprobs: LogprobsLists | None = None
- new_prompt_logprobs_tensors: LogprobsTensors | None = None
- pooling_output: dict[str, torch.Tensor] | None = None
- finish_reason: FinishReason | None = None
- stop_reason: int | str | None = None
- events: list[EngineCoreEvent] | None = None
- kv_transfer_params: dict[str, Any] | None = None
-
- trace_headers: Mapping[str, str] | None = None
- # The number of tokens with prefix cache hits.
- num_cached_tokens: int = 0
-
- # The number of NaNs in logits.
- # A value greater than 0 indicates that the output is corrupted.
- num_nans_in_logits: int = 0
-
- @property
- def finished(self) -> bool:
- return self.finish_reason is not None
-
-
-class OmniEngineCoreOutputs(
- msgspec.Struct,
- array_like=True, # type: ignore[call-arg]
- omit_defaults=True, # type: ignore[call-arg]
- gc=False,
-): # type: ignore[call-arg]
- # NOTE(Nick): We could consider ways to make this more compact,
- # e.g. columnwise layout
-
- engine_index: int = 0
-
- # [num_reqs]
- outputs: list[OmniEngineCoreOutput] = []
- scheduler_stats: SchedulerStats | None = None
- timestamp: float = 0.0
-
- utility_output: UtilityOutput | None = None
- finished_requests: set[str] | None = None
-
- # In DP case, used to signal that the current wave of requests
- # has finished and the engines are paused.
- wave_complete: int | None = None
- # In DP case, used to signal that a request was received for an
- # "old" wave, so the next wave needs to be started in other engines.
- start_wave: int | None = None
-
- def __post_init__(self):
- if self.timestamp == 0.0:
- self.timestamp = time.monotonic()
+class OmniEngineCoreOutputs(EngineCoreOutputs):
+ outputs: list[OmniEngineCoreOutput] = []
\ No newline at end of file
diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py
index 714b8dfcc53..6aef7fb162c 100644
--- a/vllm_omni/engine/output_processor.py
+++ b/vllm_omni/engine/output_processor.py
@@ -8,8 +8,6 @@
from vllm.sampling_params import RequestOutputKind
from vllm.tokenizers import TokenizerLike
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
-from vllm.v1.engine.detokenizer import IncrementalDetokenizer
-from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.engine.output_processor import OutputProcessor as VLLMOutputProcessor
from vllm.v1.engine.output_processor import OutputProcessorOutput, RequestOutputCollector, RequestState
from vllm.v1.engine.parallel_sampling import ParentRequest
diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py
index 1cd7f506319..3c275147fa0 100644
--- a/vllm_omni/entrypoints/async_omni.py
+++ b/vllm_omni/entrypoints/async_omni.py
@@ -131,19 +131,21 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st
ulysses_degree = kwargs.get("ulysses_degree") or 1
ring_degree = kwargs.get("ring_degree") or 1
sequence_parallel_size = kwargs.get("sequence_parallel_size")
+ tensor_parallel_size = kwargs.get("tensor_parallel_size") or 1
+ cfg_parallel_size = kwargs.get("cfg_parallel_size") or 1
if sequence_parallel_size is None:
sequence_parallel_size = ulysses_degree * ring_degree
- num_devices = sequence_parallel_size
+ num_devices = sequence_parallel_size * tensor_parallel_size * cfg_parallel_size
for i in range(1, num_devices):
devices += f",{i}"
parallel_config = DiffusionParallelConfig(
pipeline_parallel_size=1,
data_parallel_size=1,
- tensor_parallel_size=1,
+ tensor_parallel_size=tensor_parallel_size,
sequence_parallel_size=sequence_parallel_size,
ulysses_degree=ulysses_degree,
ring_degree=ring_degree,
- cfg_parallel_size=1,
+ cfg_parallel_size=cfg_parallel_size,
)
default_stage_cfg = [
{
@@ -161,6 +163,7 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st
"cache_backend": cache_backend,
"cache_config": cache_config,
"enable_cpu_offload": kwargs.get("enable_cpu_offload", False),
+ "enforce_eager": kwargs.get("enforce_eager", False),
},
"final_output": True,
"final_output_type": "image",
@@ -347,12 +350,6 @@ async def generate(self, *args: Any, **kwargs: dict[str, Any]) -> AsyncGenerator
result = await req_state.queue.get()
assert stage_id == req_state.stage_id
- req_id = result.get("request_id")
- if "error" in result:
- logger.error(
- f"[{self._name}] Stage {stage_id} error on request {req_id}: {result['error']}",
- )
- raise RuntimeError(result) # Request Finished due to error
req_id = result.get("request_id")
if "error" in result:
logger.error(
diff --git a/vllm_omni/entrypoints/chat_utils.py b/vllm_omni/entrypoints/chat_utils.py
index 6517ae666fe..0fdef5edbb7 100644
--- a/vllm_omni/entrypoints/chat_utils.py
+++ b/vllm_omni/entrypoints/chat_utils.py
@@ -22,7 +22,6 @@
_postprocess_messages,
_ToolParser,
)
-from vllm.transformers_utils.tokenizer import AnyTokenizer
class OmniAsyncMultiModalItemTracker(AsyncMultiModalItemTracker):
@@ -129,7 +128,6 @@ def _cleanup_file_sync(file_path: str) -> None:
def parse_chat_messages_futures(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
- tokenizer: AnyTokenizer,
content_format: _ChatTemplateContentFormat,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> tuple[
@@ -138,7 +136,7 @@ def parse_chat_messages_futures(
MultiModalUUIDDict | None,
]:
conversation: list[ConversationMessage] = []
- mm_tracker = OmniAsyncMultiModalItemTracker(model_config, tokenizer)
+ mm_tracker = OmniAsyncMultiModalItemTracker(model_config)
for msg in messages:
sub_messages = _parse_chat_message_content(
diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py
index 3b222c8179c..c3a37e3c82e 100644
--- a/vllm_omni/entrypoints/cli/serve.py
+++ b/vllm_omni/entrypoints/cli/serve.py
@@ -208,6 +208,9 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu
default=None,
help="Scheduler flow_shift for video models (e.g., 5.0 for 720p, 12.0 for 480p).",
)
+ omni_config_group.add_argument(
+ "--cfg-parallel-size", type=int, default=1, help="Number of GPUs for CFG parallel computation"
+ )
return serve_parser
diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py
index 74fe6a80376..8db4b11d94d 100644
--- a/vllm_omni/entrypoints/omni_llm.py
+++ b/vllm_omni/entrypoints/omni_llm.py
@@ -1,14 +1,15 @@
+from collections.abc import Callable
from typing import Any
import cloudpickle
from pydantic import ValidationError
from tqdm import tqdm
-from vllm.outputs import RequestOutput, PoolingRequestOutput
-from typing import Callable
+
# External library imports (vLLM)
from vllm.config import CompilationConfig, StructuredOutputsConfig, is_init_field
from vllm.entrypoints.llm import LLM
from vllm.logger import init_logger
+from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.usage.usage_lib import UsageContext
from vllm.utils.counter import Counter
@@ -193,9 +194,7 @@ def __del__(self) -> None: # best-effort
except Exception as e:
logger.debug("[Orchestrator] __del__ close() raised: %s", e, exc_info=True)
- def _run_engine(
- self, *, use_tqdm: bool | Callable[..., tqdm] = True
- ) -> list[RequestOutput | PoolingRequestOutput]:
+ def _run_engine(self, *, use_tqdm: bool | Callable[..., tqdm] = True) -> list[RequestOutput | PoolingRequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
@@ -223,14 +222,9 @@ def _run_engine(
assert output.prompt_token_ids is not None
total_in_toks += len(output.prompt_token_ids) * n
in_spd = total_in_toks / pbar.format_dict["elapsed"]
- total_out_toks += sum(
- len(stp.token_ids) for stp in output.outputs
- )
+ total_out_toks += sum(len(stp.token_ids) for stp in output.outputs)
out_spd = total_out_toks / pbar.format_dict["elapsed"]
- pbar.postfix = (
- f"est. speed input: {in_spd:.2f} toks/s, "
- f"output: {out_spd:.2f} toks/s"
- )
+ pbar.postfix = f"est. speed input: {in_spd:.2f} toks/s, output: {out_spd:.2f} toks/s"
pbar.update(n)
else:
pbar.update(1)
@@ -242,4 +236,4 @@ def _run_engine(
# Sort the outputs by the int part of request ID which is in format of 'int-uuid'.
# This is necessary because some requests may be finished earlier than
# its previous requests.
- return sorted(outputs, key=lambda x: int(x.request_id.split("-")[0]))
\ No newline at end of file
+ return sorted(outputs, key=lambda x: int(x.request_id.split("-")[0]))
diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py
index af6f60f0420..a2790cd06e5 100644
--- a/vllm_omni/entrypoints/omni_stage.py
+++ b/vllm_omni/entrypoints/omni_stage.py
@@ -474,12 +474,14 @@ def _stage_worker(
data_parallel_size = parallel_config.get("data_parallel_size", 1)
prefill_context_parallel_size = 1 # not used for diffusion
sequence_parallel_size = parallel_config.get("sequence_parallel_size", 1)
+ cfg_parallel_size = parallel_config.get("cfg_parallel_size", 1)
else:
tensor_parallel_size = engine_args.get("tensor_parallel_size", 1)
pipeline_parallel_size = engine_args.get("pipeline_parallel_size", 1)
data_parallel_size = engine_args.get("data_parallel_size", 1)
prefill_context_parallel_size = engine_args.get("prefill_context_parallel_size", 1)
sequence_parallel_size = 1 # not use in omni model
+ cfg_parallel_size = 1 # not used in omni model
# Calculate total number of devices needed for this stage
# For a single stage worker:
@@ -488,7 +490,8 @@ def _stage_worker(
# - DP: replicates model, but each replica uses TP devices
# - PCP: context parallelism, typically uses TP devices
# - SP: sequence parallelism, typically uses TP devices
- # The number of devices per stage is determined by TP * PP * DP * PCP * SP size
+ # - CFG: Classifier-Free Guidance parallelism for diffusion models
+ # The number of devices per stage is determined by TP * PP * DP * PCP * SP * CFG size
# (PP/DP/PCP are higher-level parallelism that don't add devices per stage)
num_devices_per_stage = (
tensor_parallel_size
@@ -496,6 +499,7 @@ def _stage_worker(
* data_parallel_size
* prefill_context_parallel_size
* sequence_parallel_size
+ * cfg_parallel_size
)
# Get physical device IDs from CUDA_VISIBLE_DEVICES
@@ -951,8 +955,6 @@ async def _stage_worker_async(
except Exception as e:
logger.warning("Device setup failed: %s", e)
- max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1)
- engine_args["max_num_seqs"] = max_batch_size
# Initialize OmniConnectors if configured to match sync worker behavior
connectors: dict[Any, Any] = {}
if connectors_config:
diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py
index 49958eb66f6..888d39085ac 100644
--- a/vllm_omni/entrypoints/openai/api_server.py
+++ b/vllm_omni/entrypoints/openai/api_server.py
@@ -1,3 +1,5 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import multiprocessing.forkserver as forkserver
import os
@@ -14,31 +16,48 @@
from fastapi import Depends, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.datastructures import State
-from vllm.config import VllmConfig
+from starlette.routing import Route
from vllm.engine.protocol import EngineClient
-from vllm.entrypoints.chat_utils import load_chat_template, resolve_hf_chat_template, resolve_mistral_chat_template
+from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.api_server import (
base,
build_app,
load_log_config,
- maybe_register_tokenizer_info_endpoint,
router,
setup_server,
- validate_json_request,
)
-from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse
-from vllm.entrypoints.openai.serving_models import BaseModelPath, LoRAModulePath, OpenAIServingModels
-from vllm.entrypoints.openai.tool_parsers import ToolParserManager
-
-# yapf conflicts with isort for this block
-# yapf: disable
-# yapf: enable
+from vllm.entrypoints.openai.orca_metrics import metrics_header
+from vllm.entrypoints.openai.protocol import (
+ ChatCompletionRequest,
+ ChatCompletionResponse,
+ ErrorResponse,
+)
+from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
+from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
+from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
+from vllm.entrypoints.openai.serving_transcription import (
+ OpenAIServingTranscription,
+ OpenAIServingTranslation,
+)
+from vllm.entrypoints.openai.utils import validate_json_request
+from vllm.entrypoints.pooling.classify.serving import ServingClassification
+from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
+from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
+from vllm.entrypoints.pooling.score.serving import ServingScores
+from vllm.entrypoints.serve.disagg.serving import ServingTokens
+from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
-from vllm.entrypoints.utils import load_aware_call, with_cancellation
+from vllm.entrypoints.utils import (
+ load_aware_call,
+ process_chat_template,
+ process_lora_modules,
+ with_cancellation,
+)
from vllm.logger import init_logger
-from vllm.tokenizers import MistralTokenizer
+from vllm.tasks import POOLING_TASKS
+from vllm.tool_parsers import ToolParserManager
from vllm.utils.system_utils import decorate_logs
from vllm_omni.entrypoints.async_omni import AsyncOmni
@@ -57,6 +76,29 @@
logger = init_logger(__name__)
+ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format"
+
+
+def _remove_route_from_router(router_obj, path: str, methods: set[str] | None = None):
+ """Remove a route from the router by path and optionally by methods.
+
+ This is needed because vllm's api_server registers routes when imported,
+ and we need to override some routes (like /v1/chat/completions) with
+ omni-specific implementations.
+ """
+ routes_to_remove = []
+ for route in router_obj.routes:
+ if isinstance(route, Route) and route.path == path:
+ if methods is None or (hasattr(route, "methods") and route.methods & methods):
+ routes_to_remove.append(route)
+
+ for route in routes_to_remove:
+ router_obj.routes.remove(route)
+
+
+# Remove vllm's /v1/chat/completions route so we can register our own omni version
+_remove_route_from_router(router, "/v1/chat/completions", {"POST"})
+
# Server entry points
@@ -88,6 +130,10 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None,
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
+ if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
+ from vllm.reasoning import ReasoningParserManager
+
+ ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)
# Load logging config for uvicorn if specified
log_config = load_log_config(args.log_config_file)
@@ -98,11 +144,11 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None,
args,
client_config=client_config,
) as engine_client:
- maybe_register_tokenizer_info_endpoint(args)
app = build_app(args)
+ await omni_init_app_state(engine_client, app.state, args)
+
vllm_config = await engine_client.get_vllm_config()
- await omni_init_app_state(engine_client, vllm_config, app.state, args)
# Check if pure diffusion mode (vllm_config will be None)
is_pure_diffusion = vllm_config is None
@@ -233,7 +279,6 @@ async def build_async_omni_from_stage_config(
async def omni_init_app_state(
engine_client: EngineClient,
- vllm_config: VllmConfig | None,
state: State,
args: Namespace,
) -> None:
@@ -246,10 +291,12 @@ async def omni_init_app_state(
Args:
engine_client: Engine client instance (AsyncOmni)
- vllm_config: vLLM configuration object (may be None for pure diffusion)
state: FastAPI application state object to initialize
args: Parsed command-line arguments
"""
+ # Get vllm_config from engine_client (following 0.14.0 pattern)
+ vllm_config = await engine_client.get_vllm_config()
+
# Detect if it's pure Diffusion mode (single stage and is Diffusion)
is_pure_diffusion = False
if hasattr(engine_client, "stage_configs") and engine_client.stage_configs:
@@ -273,6 +320,7 @@ async def omni_init_app_state(
base_model_paths = [BaseModelPath(name=name, model_path=args.model) for name in served_model_names]
state.engine_client = engine_client
state.log_stats = not args.disable_log_stats
+ state.args = args
# For omni models
state.stage_configs = engine_client.stage_configs if hasattr(engine_client, "stage_configs") else None
@@ -302,32 +350,18 @@ async def omni_init_app_state(
logger.warning("vllm_config is None, some features may not work correctly")
state.vllm_config = vllm_config
- if vllm_config is not None:
- _model_config = vllm_config.model_config
-
- resolved_chat_template = load_chat_template(args.chat_template)
- if resolved_chat_template is not None and vllm_config is not None:
- # Get the tokenizer to check official template
- tokenizer = await engine_client.get_tokenizer()
-
- if tokenizer is not None:
- if isinstance(tokenizer, MistralTokenizer):
- # The warning is logged in resolve_mistral_chat_template.
- resolved_chat_template = resolve_mistral_chat_template(chat_template=resolved_chat_template)
- else:
- hf_chat_template = resolve_hf_chat_template(
- tokenizer=tokenizer,
- chat_template=None,
- tools=None,
- model_config=vllm_config.model_config,
- )
- if hf_chat_template != resolved_chat_template:
- logger.warning(
- "Using supplied chat template: %s\nIt is different from official chat template '%s'. This discrepancy may lead to performance degradation.", # noqa: E501
- resolved_chat_template,
- args.model,
- )
+ # Get supported tasks
+ supported_tasks: set[str] = {"generate"}
+ if hasattr(engine_client, "get_supported_tasks"):
+ supported_tasks = set(await engine_client.get_supported_tasks())
+ logger.info("Supported tasks: %s", supported_tasks)
+
+ resolved_chat_template = await process_chat_template(
+ args.chat_template,
+ engine_client,
+ vllm_config.model_config if vllm_config is not None else None,
+ )
if args.tool_server == "demo":
tool_server: ToolServer | None = DemoToolServer()
@@ -340,23 +374,12 @@ async def omni_init_app_state(
tool_server = None
# Merge default_mm_loras into the static lora_modules
- default_mm_loras = {}
- if vllm_config is not None and vllm_config.lora_config is not None:
- default_mm_loras = vllm_config.lora_config.default_mm_loras
-
- lora_modules = args.lora_modules
- if default_mm_loras:
- default_mm_lora_paths = [
- LoRAModulePath(
- name=modality,
- path=lora_path,
- )
- for modality, lora_path in default_mm_loras.items()
- ]
- if args.lora_modules is None:
- lora_modules = default_mm_lora_paths
- else:
- lora_modules += default_mm_lora_paths
+ default_mm_loras = (
+ vllm_config.lora_config.default_mm_loras
+ if vllm_config is not None and vllm_config.lora_config is not None
+ else {}
+ )
+ lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
# Ensure input_processor, io_processor, and model_config exist for OpenAIServingModels compatibility
if (
@@ -415,24 +438,182 @@ async def omni_init_app_state(
lora_modules=lora_modules,
)
await state.openai_serving_models.init_static_loras()
- state.openai_serving_chat = OmniOpenAIServingChat(
+
+ state.openai_serving_responses = (
+ OpenAIServingResponses(
+ engine_client,
+ state.openai_serving_models,
+ request_logger=request_logger,
+ chat_template=resolved_chat_template,
+ chat_template_content_format=args.chat_template_content_format,
+ return_tokens_as_token_ids=args.return_tokens_as_token_ids,
+ enable_auto_tools=args.enable_auto_tool_choice,
+ tool_parser=args.tool_call_parser,
+ tool_server=tool_server,
+ reasoning_parser=args.structured_outputs_config.reasoning_parser,
+ enable_prompt_tokens_details=args.enable_prompt_tokens_details,
+ enable_force_include_usage=args.enable_force_include_usage,
+ enable_log_outputs=args.enable_log_outputs,
+ log_error_stack=args.log_error_stack,
+ )
+ if "generate" in supported_tasks
+ else None
+ )
+ state.openai_serving_chat = (
+ OmniOpenAIServingChat(
+ engine_client,
+ state.openai_serving_models,
+ args.response_role,
+ request_logger=request_logger,
+ chat_template=resolved_chat_template,
+ chat_template_content_format=args.chat_template_content_format,
+ default_chat_template_kwargs=args.default_chat_template_kwargs,
+ trust_request_chat_template=args.trust_request_chat_template,
+ return_tokens_as_token_ids=args.return_tokens_as_token_ids,
+ enable_auto_tools=args.enable_auto_tool_choice,
+ exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
+ tool_parser=args.tool_call_parser,
+ reasoning_parser=args.structured_outputs_config.reasoning_parser,
+ enable_prompt_tokens_details=args.enable_prompt_tokens_details,
+ enable_force_include_usage=args.enable_force_include_usage,
+ enable_log_outputs=args.enable_log_outputs,
+ enable_log_deltas=args.enable_log_deltas,
+ log_error_stack=args.log_error_stack,
+ )
+ if "generate" in supported_tasks
+ else None
+ )
+ # Warm up chat template processing to avoid first-request latency
+ if state.openai_serving_chat is not None:
+ await state.openai_serving_chat.warmup()
+
+ state.openai_serving_completion = (
+ OpenAIServingCompletion(
+ engine_client,
+ state.openai_serving_models,
+ request_logger=request_logger,
+ return_tokens_as_token_ids=args.return_tokens_as_token_ids,
+ enable_prompt_tokens_details=args.enable_prompt_tokens_details,
+ enable_force_include_usage=args.enable_force_include_usage,
+ log_error_stack=args.log_error_stack,
+ )
+ if "generate" in supported_tasks
+ else None
+ )
+ state.openai_serving_pooling = (
+ OpenAIServingPooling(
+ engine_client,
+ state.openai_serving_models,
+ supported_tasks=supported_tasks,
+ request_logger=request_logger,
+ chat_template=resolved_chat_template,
+ chat_template_content_format=args.chat_template_content_format,
+ trust_request_chat_template=args.trust_request_chat_template,
+ log_error_stack=args.log_error_stack,
+ )
+ if any(task in POOLING_TASKS for task in supported_tasks)
+ else None
+ )
+ state.openai_serving_embedding = (
+ OpenAIServingEmbedding(
+ engine_client,
+ state.openai_serving_models,
+ request_logger=request_logger,
+ chat_template=resolved_chat_template,
+ chat_template_content_format=args.chat_template_content_format,
+ trust_request_chat_template=args.trust_request_chat_template,
+ log_error_stack=args.log_error_stack,
+ )
+ if "embed" in supported_tasks
+ else None
+ )
+ state.openai_serving_classification = (
+ ServingClassification(
+ engine_client,
+ state.openai_serving_models,
+ request_logger=request_logger,
+ chat_template=resolved_chat_template,
+ chat_template_content_format=args.chat_template_content_format,
+ trust_request_chat_template=args.trust_request_chat_template,
+ log_error_stack=args.log_error_stack,
+ )
+ if "classify" in supported_tasks
+ else None
+ )
+ state.openai_serving_scores = (
+ ServingScores(
+ engine_client,
+ state.openai_serving_models,
+ request_logger=request_logger,
+ score_template=resolved_chat_template,
+ log_error_stack=args.log_error_stack,
+ )
+ if ("embed" in supported_tasks or "score" in supported_tasks)
+ else None
+ )
+ state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
state.openai_serving_models,
- args.response_role,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
- return_tokens_as_token_ids=args.return_tokens_as_token_ids,
- enable_auto_tools=args.enable_auto_tool_choice,
- exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
- tool_parser=args.tool_call_parser,
- reasoning_parser=args.structured_outputs_config.reasoning_parser,
- enable_prompt_tokens_details=args.enable_prompt_tokens_details,
- enable_force_include_usage=args.enable_force_include_usage,
- enable_log_outputs=args.enable_log_outputs,
log_error_stack=args.log_error_stack,
)
+ state.openai_serving_transcription = (
+ OpenAIServingTranscription(
+ engine_client,
+ state.openai_serving_models,
+ request_logger=request_logger,
+ log_error_stack=args.log_error_stack,
+ enable_force_include_usage=args.enable_force_include_usage,
+ )
+ if "transcription" in supported_tasks
+ else None
+ )
+ state.openai_serving_translation = (
+ OpenAIServingTranslation(
+ engine_client,
+ state.openai_serving_models,
+ request_logger=request_logger,
+ log_error_stack=args.log_error_stack,
+ enable_force_include_usage=args.enable_force_include_usage,
+ )
+ if "transcription" in supported_tasks
+ else None
+ )
+ state.anthropic_serving_messages = (
+ AnthropicServingMessages(
+ engine_client,
+ state.openai_serving_models,
+ args.response_role,
+ request_logger=request_logger,
+ chat_template=resolved_chat_template,
+ chat_template_content_format=args.chat_template_content_format,
+ return_tokens_as_token_ids=args.return_tokens_as_token_ids,
+ enable_auto_tools=args.enable_auto_tool_choice,
+ tool_parser=args.tool_call_parser,
+ reasoning_parser=args.structured_outputs_config.reasoning_parser,
+ enable_prompt_tokens_details=args.enable_prompt_tokens_details,
+ enable_force_include_usage=args.enable_force_include_usage,
+ )
+ if "generate" in supported_tasks
+ else None
+ )
+ state.serving_tokens = (
+ ServingTokens(
+ engine_client,
+ state.openai_serving_models,
+ request_logger=request_logger,
+ return_tokens_as_token_ids=args.return_tokens_as_token_ids,
+ log_error_stack=args.log_error_stack,
+ enable_prompt_tokens_details=args.enable_prompt_tokens_details,
+ enable_log_outputs=args.enable_log_outputs,
+ force_no_detokenize=args.tokens_only,
+ )
+ if "generate" in supported_tasks
+ else None
+ )
state.openai_serving_speech = OmniOpenAIServingSpeech(
engine_client, state.openai_serving_models, request_logger=request_logger
@@ -463,6 +644,7 @@ def Omnispeech(request: Request) -> OmniOpenAIServingSpeech | None:
@with_cancellation
@load_aware_call
async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request):
+ metrics_header_format = raw_request.headers.get(ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, "")
handler = Omnichat(raw_request)
if handler is None:
return base(raw_request).create_error_response(message="The model does not support Chat Completions API")
@@ -474,7 +656,8 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
if isinstance(generator, ErrorResponse):
return JSONResponse(
- content=generator.model_dump(), status_code=generator.code if hasattr(generator, "code") else 400
+ content=generator.model_dump(),
+ status_code=generator.error.code if generator.error else 400,
)
elif isinstance(generator, ChatCompletionResponse):
@@ -490,18 +673,27 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
try:
# Use serialize_as_any=True to bypass type checking
response_dict = generator.model_dump(mode="json", serialize_as_any=True, warnings="none")
- return JSONResponse(content=response_dict)
+ return JSONResponse(
+ content=response_dict,
+ headers=metrics_header(metrics_header_format),
+ )
except Exception:
# Fallback: convert to JSON string and parse back to avoid any serialization issues
try:
response_json = generator.model_dump_json(warnings="none", serialize_as_any=True)
response_dict = json_lib.loads(response_json)
- return JSONResponse(content=response_dict)
+ return JSONResponse(
+ content=response_dict,
+ headers=metrics_header(metrics_header_format),
+ )
except Exception:
# Last resort: regular dump with warnings suppressed
with warnings_module.catch_warnings():
warnings_module.filterwarnings("ignore", category=UserWarning)
- return JSONResponse(content=generator.model_dump(mode="json", warnings="none"))
+ return JSONResponse(
+ content=generator.model_dump(mode="json", warnings="none"),
+ headers=metrics_header(metrics_header_format),
+ )
return StreamingResponse(content=generator, media_type="text/event-stream")
diff --git a/vllm_omni/entrypoints/openai/protocol/chat_completion.py b/vllm_omni/entrypoints/openai/protocol/chat_completion.py
index 9f607624938..d0c83f56f8b 100644
--- a/vllm_omni/entrypoints/openai/protocol/chat_completion.py
+++ b/vllm_omni/entrypoints/openai/protocol/chat_completion.py
@@ -1,4 +1,6 @@
-from vllm.entrypoints.openai.protocol import ChatCompletionStreamResponse
+from vllm.entrypoints.openai.protocol import (
+ ChatCompletionStreamResponse,
+)
class OmniChatCompletionStreamResponse(ChatCompletionStreamResponse):
diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py
index 6fb9750ccc1..6320cb73a16 100644
--- a/vllm_omni/entrypoints/openai/serving_chat.py
+++ b/vllm_omni/entrypoints/openai/serving_chat.py
@@ -29,7 +29,10 @@
make_tool_call_id,
resolve_chat_template_content_format,
)
-from vllm.entrypoints.harmony_utils import get_streamable_parser_for_assistant, parse_chat_output
+from vllm.entrypoints.openai.parser.harmony_utils import (
+ get_streamable_parser_for_assistant,
+ parse_chat_output,
+)
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
@@ -41,29 +44,24 @@
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
+ ErrorInfo,
ErrorResponse,
FunctionCall,
FunctionDefinition,
PromptTokenUsageInfo,
RequestResponseMetadata,
+ ResponsesRequest,
ToolCall,
UsageInfo,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_engine import (
ChatLikeRequest,
- EngineTokensPrompt,
- RequestPrompt,
- ResponsesRequest,
- TextTokensPrompt,
clamp_prompt_logprobs,
- is_list_of,
)
-from vllm.entrypoints.openai.tool_parsers import ToolParser
-from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import should_include_usage
-from vllm.inputs.data import PromptType
+from vllm.inputs.data import PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
@@ -75,8 +73,10 @@
truncate_tool_call_ids,
validate_request_params,
)
+from vllm.tool_parsers import ToolParser
+from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.transformers_utils.tokenizer import AnyTokenizer
-from vllm.utils.collection_utils import as_list
+from vllm.utils.collection_utils import as_list, is_list_of
from vllm_omni.entrypoints.chat_utils import parse_chat_messages_futures
from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin
@@ -177,52 +177,68 @@ async def create_chat_completion(
truncate_tool_call_ids(request)
validate_request_params(request)
- if (
- request.tool_choice == "auto"
- and not (self.enable_auto_tools and tool_parser is not None)
- and not isinstance(tokenizer, MistralTokenizer)
- and not self.use_harmony
+ # Check if tool parsing is unavailable (common condition)
+ tool_parsing_unavailable = (
+ tool_parser is None and not isinstance(tokenizer, MistralTokenizer) and not self.use_harmony
+ )
+
+ # Validate tool_choice when tool parsing is required but unavailable
+ if tool_parsing_unavailable and request.tool_choice not in (
+ None,
+ "none",
):
- # for hf tokenizers, "auto" tools requires
- # --enable-auto-tool-choice and --tool-call-parser
- return self.create_error_response(
- '"auto" tool choice requires --enable-auto-tool-choice and --tool-call-parser to be set'
- )
+ if request.tool_choice == "auto" and not self.enable_auto_tools:
+ # for hf tokenizers, "auto" tools requires
+ # --enable-auto-tool-choice and --tool-call-parser
+ return self.create_error_response(
+ '"auto" tool choice requires --enable-auto-tool-choice and --tool-call-parser to be set'
+ )
+ elif request.tool_choice != "auto":
+ # "required" or named tool requires tool parser
+ return self.create_error_response(
+ f'tool_choice="{request.tool_choice}" requires --tool-call-parser to be set'
+ )
if request.tools is None or (request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none):
tool_dicts = None
else:
tool_dicts = [tool.model_dump() for tool in request.tools]
- # Common case.
- request_chat_template = request.chat_template
- chat_template_kwargs = request.chat_template_kwargs
- if not self.trust_request_chat_template and (
- request_chat_template is not None
- or (chat_template_kwargs and chat_template_kwargs.get("chat_template") is not None)
- ):
- return self.create_error_response(
- "Chat template is passed with request, but --trust-request-chat-template is not set. "
- "Refused request with untrusted chat template."
+ if not self.use_harmony:
+ error_check_ret = self._validate_chat_template(
+ request_chat_template=request.chat_template,
+ chat_template_kwargs=request.chat_template_kwargs,
+ trust_request_chat_template=self.trust_request_chat_template,
)
- (
- conversation,
- request_prompts,
- engine_prompts,
- ) = await self._preprocess_chat(
- request,
- tokenizer,
- request.messages,
- chat_template=request_chat_template or self.chat_template,
- chat_template_content_format=self.chat_template_content_format,
- add_generation_prompt=request.add_generation_prompt,
- continue_final_message=request.continue_final_message,
- tool_dicts=tool_dicts,
- documents=request.documents,
- chat_template_kwargs=request.chat_template_kwargs,
- tool_parser=tool_parser,
- add_special_tokens=request.add_special_tokens,
- )
+ if error_check_ret is not None:
+ return error_check_ret
+
+ chat_template_kwargs = request.chat_template_kwargs or {}
+ chat_template_kwargs.update(reasoning_effort=request.reasoning_effort)
+
+ (
+ conversation,
+ request_prompts,
+ engine_prompts,
+ ) = await self._preprocess_chat(
+ request,
+ tokenizer,
+ request.messages,
+ chat_template=request.chat_template or self.chat_template,
+ chat_template_content_format=self.chat_template_content_format,
+ add_generation_prompt=request.add_generation_prompt,
+ continue_final_message=request.continue_final_message,
+ tool_dicts=tool_dicts,
+ documents=getattr(request, "documents", None),
+ chat_template_kwargs=chat_template_kwargs,
+ default_chat_template_kwargs=self.default_chat_template_kwargs,
+ tool_parser=tool_parser,
+ add_special_tokens=request.add_special_tokens,
+ )
+ else:
+ should_include_tools = tool_dicts is not None
+ conversation, engine_prompts = self._make_request_with_harmony(request, should_include_tools)
+ request_prompts = [engine_prompt.get("prompt_token_ids", []) for engine_prompt in engine_prompts]
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
@@ -314,12 +330,13 @@ async def _preprocess_chat(
tool_dicts: list[dict[str, Any]] | None = None,
documents: list[dict[str, str]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
+ default_chat_template_kwargs: dict[str, Any] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
add_special_tokens: bool = False,
) -> tuple[
list[ConversationMessage],
- Sequence[RequestPrompt],
- list[EngineTokensPrompt],
+ Sequence[PromptType],
+ list[TokensPrompt],
]:
model_config = self.model_config
@@ -333,11 +350,17 @@ async def _preprocess_chat(
conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
messages,
model_config,
- tokenizer,
content_format=resolved_content_format,
mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None),
)
+ # Merge default_chat_template_kwargs with request-provided kwargs
+ # Request kwargs take precedence over defaults
+ merged_kwargs = self._prepare_extra_chat_template_kwargs(
+ chat_template_kwargs,
+ default_chat_template_kwargs,
+ )
+
_chat_template_kwargs: dict[str, Any] = dict(
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
@@ -345,7 +368,7 @@ async def _preprocess_chat(
tools=tool_dicts,
documents=documents,
)
- _chat_template_kwargs.update(chat_template_kwargs or {})
+ _chat_template_kwargs.update(merged_kwargs)
request_prompt: str | list[int]
@@ -388,7 +411,7 @@ async def _preprocess_chat(
"Prompt has to be a string",
"when the tokenizer is not initialised",
)
- prompt_inputs = TextTokensPrompt(prompt=request_prompt, prompt_token_ids=[1])
+ prompt_inputs = TokensPrompt(prompt=request_prompt, prompt_token_ids=[1])
elif isinstance(request_prompt, str):
prompt_inputs = await self._tokenize_prompt_input_async(
request,
@@ -399,20 +422,21 @@ async def _preprocess_chat(
else:
# For MistralTokenizer
assert is_list_of(request_prompt, int), "Prompt has to be either a string or a list of token ids"
- prompt_inputs = TextTokensPrompt(
+ prompt_inputs = TokensPrompt(
prompt=tokenizer.decode(request_prompt),
prompt_token_ids=request_prompt,
)
- engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"])
+ engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"])
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
engine_prompt["multi_modal_uuids"] = mm_uuids
- if request.mm_processor_kwargs is not None:
- engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
+ mm_processor_kwargs = getattr(request, "mm_processor_kwargs", None)
+ if mm_processor_kwargs is not None:
+ engine_prompt["mm_processor_kwargs"] = mm_processor_kwargs
if hasattr(request, "cache_salt") and request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
@@ -513,7 +537,7 @@ def _build_sampling_params_list_from_request(
def _log_inputs(
self,
request_id: str,
- inputs: RequestPrompt | PromptType,
+ inputs: PromptType,
params_list: list[SamplingParams] | None,
lora_request: LoRARequest | None,
) -> None:
@@ -599,9 +623,13 @@ async def chat_completion_stream_generator(
try:
if self.reasoning_parser:
+ chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
+ request.chat_template_kwargs,
+ self.default_chat_template_kwargs,
+ )
reasoning_parser = self.reasoning_parser(
tokenizer,
- chat_template_kwargs=request.chat_template_kwargs, # type: ignore
+ chat_template_kwargs=chat_template_kwargs, # type: ignore
)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
@@ -1052,10 +1080,9 @@ async def chat_completion_stream_generator(
# wasn't ready to send a token, then
# get the next token without streaming a chunk
if delta_message is None:
- if output.finish_reason is None:
+ if output.finish_reason is None and not request.return_token_ids:
continue
- else:
- delta_message = DeltaMessage()
+ delta_message = DeltaMessage()
# Log streaming delta if output logging is enabled
if self.enable_log_outputs and self.request_logger:
@@ -1438,13 +1465,20 @@ def _create_text_choice(
if self.reasoning_parser:
try:
- reasoning_parser = self.reasoning_parser(tokenizer)
+ chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
+ request.chat_template_kwargs,
+ self.default_chat_template_kwargs,
+ )
+ reasoning_parser = self.reasoning_parser(
+ tokenizer,
+ chat_template_kwargs=chat_template_kwargs, # type: ignore
+ )
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
return self.create_error_response(str(e))
# If the reasoning parser is enabled,
# tool calls are extracted exclusively from the content.
- reasoning_content, content = reasoning_parser.extract_reasoning_content(output.text, request=request)
+ reasoning_content, content = reasoning_parser.extract_reasoning(output.text, request=request)
if not request.include_reasoning:
reasoning_content = None
else:
@@ -2049,7 +2083,9 @@ def _create_error_response(
) -> ErrorResponse:
"""Create an error response following OpenAI error format."""
return ErrorResponse(
- message=message,
- type=err_type,
- code=status_code,
+ error=ErrorInfo(
+ message=message,
+ type=err_type,
+ code=status_code,
+ )
)
diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py
index 135b0e89ff2..eae3ea7afc4 100644
--- a/vllm_omni/entrypoints/utils.py
+++ b/vllm_omni/entrypoints/utils.py
@@ -195,6 +195,10 @@ def load_stage_configs_from_yaml(config_path: str, base_engine_args: dict | None
# Update base_engine_args with stage-specific engine_args if they exist
if hasattr(stage_arg, "engine_args") and stage_arg.engine_args is not None:
base_engine_args_tmp = OmegaConf.merge(base_engine_args_tmp, stage_arg.engine_args)
+ if hasattr(stage_arg, "runtime") and stage_arg.runtime is not None:
+ runtime_cfg = stage_arg.runtime
+ max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1)
+ base_engine_args_tmp["max_num_seqs"] = max_batch_size
stage_arg.engine_args = base_engine_args_tmp
return stage_args
diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py
index 37733e2909c..45a1447b3a7 100644
--- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py
+++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni.py
@@ -107,6 +107,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
)
if t2w_token_end_id:
self.model.set_suppress_start_id(t2w_token_end_id + 1)
+ self.requires_raw_input_tokens = True
elif self.model_stage == "code2wav":
self.thinker = None
@@ -125,6 +126,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._token2wav_conds: dict[str, torch.Tensor] = {}
self._token2wav_ref_mels: dict[str, torch.Tensor] = {}
self.model = self.token2wav
+ self.requires_raw_input_tokens = True
else:
raise ValueError("Invalid model stage")
@@ -248,14 +250,27 @@ def forward(
if is_npu():
# TODO: remove this hack when NPU supports batched inputs properly
thinker_input_ids = input_ids[0] if input_ids is not None and added_batch_dim else input_ids
- thinker_positions = positions[0] if positions.ndim > 1 else positions
+ # For MRoPE, positions shape is [3, num_tokens] (T/H/W), don't slice it
+ if positions.ndim == 2 and positions.shape[0] == 3:
+ thinker_positions = positions # MRoPE positions, keep as is
+ else:
+ thinker_positions = positions[0] if positions.ndim > 1 else positions
thinker_inputs_embeds = (
inputs_embeds[0] if inputs_embeds is not None and added_batch_dim else inputs_embeds
)
else:
- thinker_input_ids = input_ids
- thinker_positions = positions[0]
- thinker_inputs_embeds = inputs_embeds
+ # Squeeze back if we added batch dim earlier
+ thinker_input_ids = input_ids[0] if input_ids is not None and added_batch_dim else input_ids
+ # For MRoPE, positions shape is [3, num_tokens] (T/H/W), don't slice it
+ if positions.ndim == 2 and positions.shape[0] == 3:
+ thinker_positions = positions # MRoPE positions, keep as is
+ elif added_batch_dim:
+ thinker_positions = positions[0]
+ else:
+ thinker_positions = positions
+ thinker_inputs_embeds = (
+ inputs_embeds[0] if inputs_embeds is not None and added_batch_dim else inputs_embeds
+ )
# Run thinker
thinker_output = self.thinker(
@@ -288,10 +303,16 @@ def forward(
if not hasattr(self, "voice_type"):
self.voice_type = voice_type
+ # For MRoPE, positions shape is [3, num_tokens] (T/H/W), don't slice it
+ if positions.ndim == 2 and positions.shape[0] == 3:
+ talker_positions = positions # MRoPE positions, keep as is
+ else:
+ talker_positions = positions[0]
+
with torch.inference_mode():
talker_hidden = self.talker(
input_ids=input_ids,
- positions=positions[0],
+ positions=talker_positions,
inputs_embeds=inputs_embeds,
)
diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py
index 46175568e07..927bc552573 100644
--- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py
+++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py
@@ -9,7 +9,6 @@
# from vllm.attention import AttentionMetadata # unused import
from vllm.config import VllmConfig
from vllm.logger import init_logger
-from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from vllm.model_executor.models.qwen2_5_omni_thinker import (
Qwen2_5OmniThinkerDummyInputsBuilder,
@@ -68,13 +67,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
else:
self.config = config
- self.thinker_to_talker_proj = ColumnParallelLinear(
+ self.thinker_to_talker_proj = nn.Linear(
self.config.embedding_size,
self.config.hidden_size,
- bias=True,
- gather_output=True,
- skip_bias_add=False,
- quant_config=quant_config,
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
@@ -145,7 +140,7 @@ def forward(
input_ids = None
# projection
- inputs_embeds, _ = self.thinker_to_talker_proj(inputs_embeds)
+ inputs_embeds = self.thinker_to_talker_proj(inputs_embeds)
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
@@ -155,7 +150,18 @@ def forward(
def bad_word_processor(self, logits: torch.Tensor) -> torch.Tensor:
# suppress token IDs unsupported by token2wav
if self.suppress_start_id and self.suppress_start_id < logits.size(-1):
- logits[..., self.suppress_start_id : logits.size(-1)] = -1e9
+ # skip the end token id.
+ if hasattr(self.config, "tts_codec_end_token_id"):
+ end_id = int(getattr(self.config, "tts_codec_end_token_id"))
+ if self.suppress_start_id == end_id:
+ logits[..., end_id + 1 : logits.size(-1)] = -1e9
+ elif self.suppress_start_id < end_id:
+ logits[..., self.suppress_start_id : end_id] = -1e9
+ logits[..., end_id + 1 : logits.size(-1)] = -1e9
+ else:
+ logits[..., self.suppress_start_id : logits.size(-1)] = -1e9
+ else:
+ raise ValueError("config must have tts_codec_end_token_id attribute")
if hasattr(self.config, "tts_codec_start_token_id"):
bos_id = int(getattr(self.config, "tts_codec_start_token_id"))
diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py
index 71c5e8377ac..a2250b82e73 100644
--- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py
+++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_thinker.py
@@ -1,4 +1,4 @@
-"""Thin Omni wrapper: reuse upstream Qwen2.5-Omni thinker (v0.12) with minimal overrides."""
+"""Thin Omni wrapper: reuse upstream Qwen2.5-Omni thinker (v0.14) with minimal overrides."""
from collections.abc import Iterable
from typing import Any
@@ -12,6 +12,7 @@
Qwen2_5OmniAudioEncoder,
)
from vllm.config import VllmConfig
+from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings,
@@ -26,8 +27,6 @@
Qwen2_5OmniThinkerDummyInputsBuilder,
Qwen2_5OmniThinkerMultiModalProcessor,
Qwen2_5OmniThinkerProcessingInfo,
- get_llm_pos_ids_for_vision,
- split_list_into_ranges,
)
from vllm.model_executor.models.qwen2_5_omni_thinker import (
Qwen2_5OmniConditionalGenerationMixin as Qwen2_5OmniConditionalGenerationMixinBase,
@@ -46,7 +45,9 @@
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
+ split_list_into_ranges,
)
+from vllm.model_executor.models.vision import get_llm_pos_ids_for_vision
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
@@ -166,6 +167,43 @@ def _parse_and_validate_video_input(
video_grid_thw=video_grid_thw,
)
+ def _process_image_input(self, image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]:
+ if image_input["type"] == "image_embeds":
+ return image_input["image_embeds"].type(self.visual.dtype)
+
+ grid_thw = image_input["image_grid_thw"]
+ assert grid_thw.ndim == 2
+
+ pixel_values = image_input["pixel_values"].type(self.visual.dtype)
+ with set_forward_context(None, self.vllm_config):
+ image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
+ # Split concatenated embeddings for each image item.
+ merge_size = self.visual.spatial_merge_size
+ sizes = grid_thw.prod(-1) // merge_size // merge_size
+
+ return image_embeds.split(sizes.tolist())
+
+ def _process_video_input(
+ self,
+ video_input: Qwen2_5_VLVideoInputs,
+ video_hashes: list[str] = None,
+ cached_video_embeds: torch.Tensor = None,
+ ) -> torch.Tensor:
+ if video_input["type"] == "video_embeds":
+ return video_input["video_embeds"].type(self.visual.dtype)
+
+ grid_thw = video_input["video_grid_thw"]
+ assert grid_thw.ndim == 2
+
+ pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
+ with set_forward_context(None, self.vllm_config):
+ video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
+ # Split concatenated embeddings for each video item.
+ merge_size = self.visual.spatial_merge_size
+ sizes = grid_thw.prod(-1) // merge_size // merge_size
+
+ return video_embeds.split(sizes.tolist())
+
@MULTIMODAL_REGISTRY.register_processor(
Qwen2_5OmniThinkerMultiModalProcessor,
@@ -180,8 +218,6 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
SupportsMRoPE,
Qwen2_5OmniConditionalGenerationMixin,
):
- merge_by_field_config = True
-
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"thinker.lm_head.": "language_model.lm_head.",
@@ -250,6 +286,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
+ multimodal_config=multimodal_config,
)
else:
self.visual = None
diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py
index ce9819b3e68..e04010196a8 100644
--- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py
+++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_old.py
@@ -3,9 +3,6 @@
import torch
from torch import nn
from transformers import Qwen2Config
-from vllm.attention.backends.abstract import (
- AttentionType,
-)
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
@@ -15,7 +12,6 @@
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
-from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding
@@ -24,15 +20,14 @@
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
PPMissingLayer,
- WeightsMapper,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
from vllm.sequence import IntermediateTensors
-from vllm.v1.outputs import PoolerOutput, SamplerOutput
-from vllm.v1.pool.metadata import PoolingMetadata
+from vllm.v1.attention.backend import AttentionType
+from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
@@ -130,7 +125,6 @@ def __init__(
self.rotary_pos_emb = get_rope(
head_size=self.head_dim,
- rotary_dim=self.head_dim,
max_position=max_position,
is_neox_style=True,
rope_parameters={
@@ -460,66 +454,3 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
-
-
-class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
- packed_modules_mapping = {
- "qkv_proj": [
- "q_proj",
- "k_proj",
- "v_proj",
- ],
- "gate_up_proj": [
- "gate_proj",
- "up_proj",
- ],
- }
-
- hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
- config = vllm_config.model_config.hf_config
- quant_config = vllm_config.quant_config
- lora_config = vllm_config.lora_config
- pooler_config = vllm_config.model_config.pooler_config
-
- self.config = config
- self.lora_config = lora_config
-
- self.quant_config = quant_config
- self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
-
- # TODO: Replace this model class with as_embedding_model(
- # Qwen2ForCausalLM) after changing the default pooling method
- if pooler_config.pooling_type is None:
- logger.warning(
- "This embedding model will default to last-token pooling in "
- "an upcoming version. To avoid breaking changes, you should "
- 'pass `--override-pooler-config \'{"pooling_type": "MEAN"}\'`'
- " explicitly."
- )
-
- self._pooler = Pooler.from_config_with_defaults(
- pooler_config, pooling_type=PoolingType.MEAN, normalize=True, softmax=False
- )
-
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- intermediate_tensors: IntermediateTensors | None = None,
- ) -> torch.Tensor:
- return self.model(input_ids, positions, intermediate_tensors)
-
- def pooler(
- self,
- hidden_states: torch.Tensor,
- pooling_metadata: PoolingMetadata,
- ) -> PoolerOutput | None:
- return self._pooler(hidden_states, pooling_metadata)
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
- weights = self.hf_to_vllm_mapper.apply(weights)
- weights = ((name, data) for name, data in weights if not name.startswith("lm_head."))
- self.model.load_weights(weights)
diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
index c459289af34..17f16d1d929 100644
--- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
+++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
@@ -356,18 +356,24 @@ def forward(
elif self.model_stage == "code2wav":
# Extract codec codes from input
codes = []
- if input_ids is not None:
+ if input_ids.shape[0] % 16 == 0:
codes.append(input_ids.reshape(1, 16, -1))
-
else:
- # for profile, we use max length from inputs_embeds
- codes.append(
- torch.zeros(
- (1, 16, inputs_embeds.shape[1]),
- dtype=torch.long,
- device=inputs_embeds.device,
- )
+ logger.warning(
+ (
+ "Input_ids length: %s is not divisible by 16, padding "
+ "with zeros. This should only happen in warm up."
+ ),
+ input_ids.shape[0],
+ )
+ input_ids_flatten = input_ids.reshape(-1)
+ input_ids_flatten = torch.cat(
+ [
+ input_ids_flatten,
+ torch.zeros(16 - input_ids.shape[0] % 16, dtype=torch.long, device=input_ids.device),
+ ]
)
+ codes.append(input_ids_flatten.reshape(1, 16, -1))
# Generate audio from codec codes
audio_tensors = []
@@ -575,30 +581,22 @@ def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor,
if input_embeds is None and input_ids is not None:
input_embeds = self.talker.embed_input_ids(input_ids)
- text_step = torch.zeros(
- 1,
- self.talker_config.text_config.hidden_size,
- device=self._module_device(self.talker),
- dtype=torch.bfloat16,
- )
- last_talker_hidden = torch.zeros(
- 1,
- 1,
- self.talker_config.text_config.hidden_size,
- device=self._module_device(self.talker),
- dtype=torch.bfloat16,
- )
-
span_len = input_ids.shape[0]
if span_len > 1:
# prefill
input_ids, input_embeds, update_dict = self.talker_preprocess_prefill(input_ids, input_embeds, **info_dict)
+ code_predictor_codes = torch.zeros(
+ (input_embeds.shape[0], self.talker.num_code_groups),
+ device=self._module_device(self.talker),
+ dtype=torch.long,
+ )
+ update_dict["code_predictor_codes"] = code_predictor_codes
else:
last_talker_hidden, text_step, update_dict = self.talker_preprocess_decode(
input_ids, input_embeds, **info_dict
)
- update_dict["mtp_inputs"] = last_talker_hidden, text_step
+ update_dict["mtp_inputs"] = last_talker_hidden, text_step
return input_ids, input_embeds, update_dict
@@ -610,24 +608,19 @@ def talker_mtp(
text_step: torch.Tensor,
):
# TODO(Peiqi): not support intermediate_tensors now
- input_ids = safe_tensor_reshape(input_ids, (1, -1))
+ input_ids = safe_tensor_reshape(input_ids, (input_ids.shape[0], -1))
inputs_embeds = safe_tensor_reshape(input_embeds, (-1, self.talker_config.text_config.hidden_size))
- text_step = safe_tensor_reshape(text_step, (1, -1))
- last_talker_hidden = safe_tensor_reshape(last_talker_hidden, (1, 1, self.talker_config.text_config.hidden_size))
+ text_step = safe_tensor_reshape(text_step, (-1, self.talker_config.text_config.hidden_size))
+ last_talker_hidden = safe_tensor_reshape(
+ last_talker_hidden, (-1, 1, self.talker_config.text_config.hidden_size)
+ )
# for profiling
if inputs_embeds.shape[-1] == 2048:
inputs_embeds = self.text_projection(inputs_embeds)
- if inputs_embeds.shape[0] == 1:
- code_predictor_codes, summed_embeddings = self.talker.code_predictor_forward(
- input_ids, inputs_embeds.clone(), last_talker_hidden=last_talker_hidden
- )
- inputs_embeds = summed_embeddings.clone()
- else:
- code_predictor_codes = torch.zeros(
- (inputs_embeds.shape[0], self.talker.num_code_groups),
- device=self._module_device(self.talker),
- dtype=torch.long,
- )
+ code_predictor_codes, summed_embeddings = self.talker.code_predictor_forward(
+ input_ids, inputs_embeds.clone(), last_talker_hidden=last_talker_hidden
+ )
+ inputs_embeds = summed_embeddings.clone()
inputs_embeds = (inputs_embeds + text_step).reshape(-1, self.talker_config.text_config.hidden_size)
return inputs_embeds, code_predictor_codes.squeeze(-1)
@@ -850,7 +843,7 @@ def talker_preprocess_decode(self, input_ids: torch.Tensor, input_embeds: torch.
use_vec = q_tail[0:1, :]
new_q_tail = (
q_tail[1:, :].detach().to("cpu").contiguous()
- if q_tail.shape[1] > 1
+ if q_tail.shape[0] > 1
else self.tts_pad_embed.to(input_embeds.device, dtype=input_embeds.dtype)
)
text_step = use_vec.to(input_embeds.device, dtype=input_embeds.dtype)
diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py
index 4e8730eab52..2e872340da4 100644
--- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py
+++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py
@@ -111,7 +111,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
talker_config: Qwen3OmniMoeTalkerConfig = vllm_config.model_config.hf_config
talker_config.text_config.rope_parameters = talker_config.text_config.rope_scaling
- talker_config.text_config.rope_parameters["rope_theta"] = talker_config.text_config.rope_theta
+ talker_config.text_config.rope_parameters["rope_theta"] = talker_config.text_config.rope_parameters[
+ "rope_theta"
+ ]
self.quant_config = vllm_config.quant_config
self.prefix = prefix
self.vllm_config = vllm_config
@@ -234,13 +236,7 @@ def code_predictor_forward(
# Use the corresponding lm_head for this layer
logits = self.code_predictor.lm_head[layer_idx](hidden_state[:, -1:, :]) # [batch, 1, vocab_size]
- if len(pos_codes) > 1:
- input_ids_for_logits_processors = torch.cat(pos_codes[1:], dim=1).to(
- device=logits.device, dtype=torch.long
- )
- else:
- input_ids_for_logits_processors = self.empty_code
- logits = logits_processors(input_ids_for_logits_processors, logits.squeeze(0)).unsqueeze(0)
+ logits = logits_processors(None, logits[:, -1])
# Sample from the filtered distribution
probs = F.softmax(logits, dim=-1)
@@ -288,7 +284,7 @@ def code_predictor_forward(
all_summed_embeddings.append(pos_summed)
# Concatenate across positions: [batch, seq_len, hidden_size]
- summed_embeddings = torch.cat(all_summed_embeddings, dim=1)
+ summed_embeddings = torch.cat(all_summed_embeddings, dim=1).squeeze(1)
return result_codes, summed_embeddings
diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py
index 7f3320a82eb..2d479062eb2 100644
--- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py
+++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py
@@ -37,9 +37,6 @@
Qwen3OmniMoeConfig,
Qwen3OmniMoeThinkerConfig,
)
-from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
- Qwen3OmniMoeAudioEncoder,
-)
from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import (
Qwen3OmniMoeProcessor,
)
@@ -56,6 +53,7 @@
SupportsMultiModal,
SupportsPP,
)
+from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2_5_omni_thinker import (
Qwen2_5OmniAudioFeatureInputs,
Qwen2_5OmniThinkerDummyInputsBuilder,
@@ -69,6 +67,7 @@
from vllm.model_executor.models.qwen3_moe import Qwen3MoeModel as _Qwen3MoeLLMModel
from vllm.model_executor.models.qwen3_omni_moe_thinker import (
Qwen3Omni_VisionTransformer,
+ Qwen3OmniMoeAudioEncoder,
_get_feat_extract_output_lengths,
)
from vllm.model_executor.models.utils import (
@@ -248,11 +247,36 @@ def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray:
# https://github.com/huggingface/transformers/pull/41473
mm_kwargs = dict(mm_kwargs)
tok_kwargs = dict(tok_kwargs)
+ mm_kwargs["audio_kwargs"] = dict(mm_kwargs.get("audio_kwargs") or {})
+ mm_kwargs["text_kwargs"] = dict(mm_kwargs.get("text_kwargs") or {})
if Version(TRANSFORMERS_VERSION) < Version("4.58.0"):
+ # Extract audio_sample_rate before restructuring
+ audio_sample_rate = mm_kwargs.pop("audio_sample_rate", None)
+
# move truncation to audio_kwargs level to avoid conflict
# with tok_kwargs
- mm_kwargs["audio_kwargs"] = {"truncation": mm_kwargs.pop("truncation", False)}
- mm_kwargs["text_kwargs"] = {"truncation": tok_kwargs.pop("truncation", False)}
+ mm_kwargs["audio_kwargs"].setdefault("truncation", mm_kwargs.pop("truncation", False))
+ mm_kwargs["text_kwargs"].setdefault("truncation", tok_kwargs.pop("truncation", False))
+
+ # Validate and conditionally pass audio_sample_rate
+ # WhisperFeatureExtractor has a fixed sampling rate, and vLLM's
+ # audio loader already resamples audio to the target rate.
+ # Only pass the value if it matches to avoid unexpected behavior.
+ if audio_sample_rate is not None:
+ expected_sr = feature_extractor.sampling_rate
+ if audio_sample_rate != expected_sr:
+ logger.warning(
+ "[%s] audio_sample_rate mismatch: user provided %dHz "
+ "but model expects %dHz. Ignoring user value. "
+ "vLLM's audio loader already resampled to %dHz.",
+ self.__class__.__name__,
+ audio_sample_rate,
+ expected_sr,
+ expected_sr,
+ )
+ else:
+ # Sample rate matches, safe to pass
+ mm_kwargs["audio_kwargs"]["audio_sample_rate"] = audio_sample_rate
hf_inputs = super()._call_hf_processor(
prompt=prompt,
@@ -398,11 +422,11 @@ def _get_prompt_updates(
if audio_feature_lengths is None and feature_attention_mask is None:
audio_output_lengths = []
elif audio_feature_lengths is not None:
- _, audio_output_lens = _get_feat_extract_output_lengths(audio_feature_lengths)
+ audio_output_lens = _get_feat_extract_output_lengths(audio_feature_lengths)
audio_output_lengths = audio_output_lens.tolist()
elif feature_attention_mask is not None:
assert isinstance(feature_attention_mask, torch.Tensor)
- _, audio_output_lens = _get_feat_extract_output_lengths(feature_attention_mask.sum(-1))
+ audio_output_lens = _get_feat_extract_output_lengths(feature_attention_mask.sum(-1))
audio_output_lengths = audio_output_lens.tolist()
# number of audios read from video.
@@ -569,59 +593,21 @@ def _get_raw_input_ids(
class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMixin):
- def _parse_and_validate_audio_input(self, **kwargs: object) -> Qwen2_5OmniAudioFeatureInputs | None:
- input_audio_features = kwargs.pop("input_audio_features", None)
- audio_feature_lengths = kwargs.pop("audio_feature_lengths", None)
- feature_attention_mask = kwargs.pop("feature_attention_mask", None)
- if input_audio_features is None:
- return None
- if (
- input_audio_features is not None
- and isinstance(input_audio_features, torch.Tensor)
- and input_audio_features.ndim == 3
- ):
- # (batch_size, feature_dim, chunk_size) -> (feature_dim, batch_size * chunk_size)
- input_audio_features = input_audio_features.permute(1, 0, 2).flatten(1)
- elif input_audio_features is not None and isinstance(input_audio_features, list):
- input_audio_features = torch.cat(input_audio_features, dim=-1)
- if (
- audio_feature_lengths is not None
- and isinstance(audio_feature_lengths, torch.Tensor)
- and audio_feature_lengths.ndim == 2
- ):
- audio_feature_lengths = audio_feature_lengths.reshape(-1)
- elif audio_feature_lengths is not None and isinstance(audio_feature_lengths, list):
- audio_feature_lengths = torch.cat(audio_feature_lengths, dim=-1)
- if (
- feature_attention_mask is not None
- and isinstance(feature_attention_mask, torch.Tensor)
- and feature_attention_mask.ndim == 3
- ):
- feature_attention_mask = feature_attention_mask.reshape(-1, feature_attention_mask.shape[-1])
- elif feature_attention_mask is not None and isinstance(feature_attention_mask, list):
- for i in range(len(feature_attention_mask)):
- feature_attention_mask[i] = feature_attention_mask[i].reshape(-1)
- return Qwen2_5OmniAudioFeatureInputs(
- type="audio_features",
- input_features=input_audio_features,
- audio_feature_lengths=audio_feature_lengths,
- feature_attention_mask=feature_attention_mask,
- )
-
def _process_audio_input(
self,
audio_input: Qwen2_5OmniAudioFeatureInputs,
audio_hashes: list[str] | None = None,
cached_audio_features: torch.Tensor | None = None,
- ) -> torch.Tensor:
+ ) -> tuple[torch.Tensor, ...]:
input_features = audio_input["input_features"]
audio_feature_lengths = audio_input["audio_feature_lengths"]
- audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths(audio_feature_lengths)
+ audio_output_lengths = _get_feat_extract_output_lengths(audio_feature_lengths)
audio_outputs = self.audio_tower(
input_features.to(self.audio_tower.dtype),
feature_lens=audio_feature_lengths,
+ aftercnn_lens=audio_output_lengths,
)
audio_features = audio_outputs.last_hidden_state
return audio_features.split(audio_output_lengths.tolist())
@@ -639,8 +625,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
SupportsMRoPE,
Qwen3OmniMoeConditionalGenerationMixin,
):
- merge_by_field_config = True
-
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"thinker.lm_head.": "language_model.lm_head.",
@@ -649,6 +633,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
}
)
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
@@ -669,20 +665,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config = thinker_config
self.multimodal_config = multimodal_config
- # force "use_flash_attention_2=True" to audio tower to align
- # the results.
- if flash_attn is not None:
- audio_config = thinker_config.audio_config
- audio_config._attn_implementation_autoset = True
- audio_config._attn_implementation = "flash_attention_2"
- else:
- logger.warning(
- "flash_attn is not available, the model may not yield the "
- "exactly same result as the transformers implementation "
- "in the audio tower part."
- )
-
- self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config)
+ self.audio_tower = Qwen3OmniMoeAudioEncoder(
+ thinker_config.audio_config,
+ )
self.visual = Qwen3Omni_VisionTransformer(
vision_config=thinker_config.vision_config,
@@ -810,8 +795,6 @@ def embed_input_ids(
return inputs_embeds
deepstack_input_embeds = None
- # TODO (ywang96): support overlapping modalitiy embeddings so that
- # `use_audio_in_video` will work on V1.
# split the feat dim to obtain multi-scale visual feature
has_vision_embeddings = [
embeddings.shape[-1] != self.config.text_config.hidden_size for embeddings in multimodal_embeddings
@@ -1005,7 +988,7 @@ def get_mrope_input_positions(
bos_len = 1
llm_pos_ids_list.append(torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
- _, audio_len = _get_feat_extract_output_lengths(audio_feature_lengths[audio_idx])
+ audio_len = _get_feat_extract_output_lengths(audio_feature_lengths[audio_idx])
llm_pos_ids = torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
@@ -1077,7 +1060,7 @@ def get_mrope_input_positions(
llm_pos_ids_list.append(bos_block)
llm_pos_ids_list.append(bos_block)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
- _, audio_len = _get_feat_extract_output_lengths(audio_feature_lengths[audio_idx])
+ audio_len = _get_feat_extract_output_lengths(audio_feature_lengths[audio_idx])
audio_llm_pos_ids = torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
@@ -1121,3 +1104,13 @@ def get_mrope_input_positions(
mrope_position_delta = llm_positions.max() + 1 - seq_len
return llm_positions, mrope_position_delta
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ """
+ Get the module prefix in multimodal models
+ """
+ return MultiModelKeys.from_string_field(
+ language_model="language_model",
+ connector="visual.merger",
+ tower_model=["visual.", "audio_tower."],
+ )
diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml
index fc84f485ea4..e3d740ad580 100644
--- a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml
+++ b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml
@@ -76,6 +76,7 @@ stage_args:
trust_remote_code: true
enable_prefix_caching: false
max_num_batched_tokens: 32768
+ async_scheduling: false
engine_output_type: audio
engine_input_source: [1]
final_output: true
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml
index c63dc563815..d4de078231a 100644
--- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml
+++ b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml
@@ -16,7 +16,7 @@ stage_args:
worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
gpu_memory_utilization: 0.6
- enforce_eager: false
+ enforce_eager: true
trust_remote_code: true
engine_output_type: latent # Output hidden states for talker
distributed_executor_backend: "mp"
@@ -46,8 +46,8 @@ stage_args:
model_arch: Qwen3OmniMoeForConditionalGeneration
worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.3
- enforce_eager: false
+ gpu_memory_utilization: 0.35
+ enforce_eager: true
trust_remote_code: true
engine_output_type: latent # Output codec codes for code2wav
# tensor_parallel_size: 2
@@ -80,6 +80,7 @@ stage_args:
scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
enforce_eager: true
trust_remote_code: true
+ async_scheduling: false
enable_prefix_caching: false
engine_output_type: audio # Final output: audio waveform
gpu_memory_utilization: 0.1
diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
index 246ea2996e8..a1457a9750b 100644
--- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
+++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
@@ -160,7 +160,7 @@ def talker2code2wav(
# Process each talker output
for i, talker_output in enumerate(talker_outputs):
output = talker_output.outputs[0]
- seq_len = len(output.token_ids)
+ seq_len = len(output.token_ids) - 1
# Extract codec codes from talker output
# Expected shape: [8, seq_len] (8-layer RVQ codes)
codec_codes = (
diff --git a/vllm_omni/utils/platform_utils.py b/vllm_omni/utils/platform_utils.py
index 5f8259ab83d..fb47018e789 100644
--- a/vllm_omni/utils/platform_utils.py
+++ b/vllm_omni/utils/platform_utils.py
@@ -53,6 +53,6 @@ def get_diffusion_worker_class() -> type:
return NPUWorkerProc
else:
# Default to GPU worker for cuda and other devices
- from vllm_omni.diffusion.worker.gpu_worker import WorkerProc
+ from vllm_omni.diffusion.worker.gpu_diffusion_worker import WorkerProc
return WorkerProc
diff --git a/vllm_omni/worker/gpu_ar_model_runner.py b/vllm_omni/worker/gpu_ar_model_runner.py
index d4e7e195fe8..18c8bc761b1 100644
--- a/vllm_omni/worker/gpu_ar_model_runner.py
+++ b/vllm_omni/worker/gpu_ar_model_runner.py
@@ -12,10 +12,15 @@
import numpy as np
import torch
from vllm.config import CUDAGraphMode
+from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
+from vllm.distributed.kv_transfer import get_kv_transfer_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
+ RoutedExpertsCapturer,
+)
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
-from vllm.v1.outputs import AsyncModelRunnerOutput
+from vllm.v1.outputs import AsyncModelRunnerOutput, make_empty_encoder_model_runner_output
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import record_function_or_nullcontext
@@ -26,18 +31,12 @@
get_pp_group,
get_tp_group,
has_kv_transfer_group,
-
)
+from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
from vllm.v1.worker.utils import is_residual_scattered_for_sp
-from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
- RoutedExpertsCapturer,
-)
-from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
-from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
+
from vllm_omni.outputs import OmniModelRunnerOutput
from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner
-from vllm.v1.outputs import make_empty_encoder_model_runner_output
-from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
logger = init_logger(__name__)
@@ -91,10 +90,7 @@ def execute_model(
intermediate_tensors: IntermediateTensors | None = None,
) -> OmniModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors | None:
if self.execute_model_state is not None:
- raise RuntimeError(
- "State error: sample_tokens() must be called "
- "after execute_model() returns None."
- )
+ raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.")
if self.vllm_config.model_config.enable_return_routed_experts:
capturer = RoutedExpertsCapturer.get_instance()
@@ -104,9 +100,7 @@ def execute_model(
logger.error("RoutedExpertsCapturer not initialized.")
if scheduler_output.preempted_req_ids and has_kv_transfer_group():
- get_kv_transfer_group().handle_preemptions(
- scheduler_output.preempted_req_ids
- )
+ get_kv_transfer_group().handle_preemptions(scheduler_output.preempted_req_ids)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with (
@@ -126,8 +120,7 @@ def execute_model(
if not num_scheduled_tokens:
if (
- self.parallel_config.distributed_executor_backend
- == "external_launcher"
+ self.parallel_config.distributed_executor_backend == "external_launcher"
and self.parallel_config.data_parallel_size > 1
):
# this is a corner case when both external launcher
@@ -196,9 +189,7 @@ def execute_model(
)
num_tokens_padded = batch_desc.num_tokens
- num_reqs_padded = (
- batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
- )
+ num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
should_ubatch,
num_scheduled_tokens_np,
@@ -218,19 +209,17 @@ def execute_model(
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
- attn_metadata, spec_decode_common_attn_metadata = (
- self._build_attention_metadata(
- num_tokens=num_tokens_unpadded,
- num_tokens_padded=num_tokens_padded if pad_attn else None,
- num_reqs=num_reqs,
- num_reqs_padded=num_reqs_padded if pad_attn else None,
- max_query_len=max_num_scheduled_tokens,
- ubatch_slices=ubatch_slices_attn,
- logits_indices=logits_indices,
- use_spec_decode=use_spec_decode,
- num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
- cascade_attn_prefix_lens=cascade_attn_prefix_lens,
- )
+ attn_metadata, spec_decode_common_attn_metadata = self._build_attention_metadata(
+ num_tokens=num_tokens_unpadded,
+ num_tokens_padded=num_tokens_padded if pad_attn else None,
+ num_reqs=num_reqs,
+ num_reqs_padded=num_reqs_padded if pad_attn else None,
+ max_query_len=max_num_scheduled_tokens,
+ ubatch_slices=ubatch_slices_attn,
+ logits_indices=logits_indices,
+ use_spec_decode=use_spec_decode,
+ num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
+ cascade_attn_prefix_lens=cascade_attn_prefix_lens,
)
(
@@ -240,9 +229,7 @@ def execute_model(
intermediate_tensors,
model_kwargs,
ec_connector_output,
- ) = self._preprocess(
- scheduler_output, num_tokens_padded, intermediate_tensors
- )
+ ) = self._preprocess(scheduler_output, num_tokens_padded, intermediate_tensors)
# Set cudagraph mode to none if calc_kv_scales is true.
# KV scales calculation involves dynamic operations that are incompatible
@@ -329,9 +316,7 @@ def execute_model(
sample_hidden_states = hidden_states[logits_indices]
if not get_pp_group().is_last_rank:
all_gather_tensors = {
- "residual": not is_residual_scattered_for_sp(
- self.vllm_config, num_tokens_padded
- )
+ "residual": not is_residual_scattered_for_sp(self.vllm_config, num_tokens_padded)
}
get_pp_group().send_tensor_dict(
hidden_states.tensors,
@@ -408,9 +393,7 @@ def sample_tokens(
# Apply structured output bitmasks if present.
if grammar_output is not None:
- apply_grammar_bitmask(
- scheduler_output, grammar_output, self.input_batch, logits
- )
+ apply_grammar_bitmask(scheduler_output, grammar_output, self.input_batch, logits)
with record_function_or_nullcontext("gpu_model_runner: sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
@@ -450,23 +433,19 @@ def propose_draft_token_ids(sampled_token_ids):
propose_draft_token_ids(sampled_token_ids)
elif self.valid_sampled_token_count_event is not None:
assert spec_decode_common_attn_metadata is not None
- next_token_ids, valid_sampled_tokens_count = (
- self.drafter.prepare_next_token_ids_padded(
- spec_decode_common_attn_metadata,
- sampled_token_ids,
- self.requests,
- self.input_batch,
- self.discard_request_mask.gpu,
- )
- )
- self._copy_valid_sampled_token_count(
- next_token_ids, valid_sampled_tokens_count
+ next_token_ids, valid_sampled_tokens_count = self.drafter.prepare_next_token_ids_padded(
+ spec_decode_common_attn_metadata,
+ sampled_token_ids,
+ self.requests,
+ self.input_batch,
+ self.discard_request_mask.gpu,
)
+ self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
# Since we couldn't run the drafter,
# just use zeros for the draft tokens.
- self._draft_token_ids = torch.zeros(
- 1, device=self.device, dtype=torch.int32
- ).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
+ self._draft_token_ids = torch.zeros(1, device=self.device, dtype=torch.int32).expand(
+ len(self.input_batch.req_ids), self.num_spec_tokens
+ )
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
else:
propose_drafts_after_bookkeeping = input_fits_in_drafter
diff --git a/vllm_omni/worker/gpu_ar_worker.py b/vllm_omni/worker/gpu_ar_worker.py
index 599dea31f2f..d2dafce6877 100644
--- a/vllm_omni/worker/gpu_ar_worker.py
+++ b/vllm_omni/worker/gpu_ar_worker.py
@@ -3,19 +3,18 @@
import torch
from vllm.logger import init_logger
-from vllm.utils.torch_utils import set_random_seed
from vllm.platforms import current_platform
from vllm.utils.mem_utils import MemorySnapshot, format_gib
+from vllm.utils.torch_utils import set_random_seed
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_worker import Worker as GPUWorker
from vllm.v1.worker.gpu_worker import init_worker_distributed_environment
+from vllm.v1.worker.utils import request_memory
+from vllm.v1.worker.workspace import init_workspace_manager
from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner
-from vllm.v1.worker.workspace import init_workspace_manager
-from vllm.v1.worker.utils import request_memory
-from vllm.logger import init_logger
-logger = init_logger(__name__)
+logger = init_logger(__name__)
class GPUARWorker(GPUWorker):
@@ -31,8 +30,7 @@ def init_device(self):
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
parallel_config = self.parallel_config
if (
- parallel_config.distributed_executor_backend
- not in ("ray", "external_launcher")
+ parallel_config.distributed_executor_backend not in ("ray", "external_launcher")
and parallel_config.data_parallel_backend != "ray"
and parallel_config.nnodes_within_dp == 1
):
@@ -42,8 +40,7 @@ def init_device(self):
dp_local_rank = self.parallel_config.data_parallel_index
tp_pp_world_size = (
- self.parallel_config.pipeline_parallel_size
- * self.parallel_config.tensor_parallel_size
+ self.parallel_config.pipeline_parallel_size * self.parallel_config.tensor_parallel_size
)
# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
@@ -51,9 +48,7 @@ def init_device(self):
assert self.local_rank < torch.cuda.device_count(), (
f"DP adjusted local rank {self.local_rank} is out of bounds. "
)
- visible_device_count = (
- torch.cuda.device_count() if torch.cuda.is_available() else 0
- )
+ visible_device_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
assert self.parallel_config.local_world_size <= visible_device_count, (
f"local_world_size ({self.parallel_config.local_world_size}) must "
f"be less than or equal to the number of visible devices "
@@ -87,9 +82,7 @@ def init_device(self):
self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
self.requested_memory = request_memory(init_snapshot, self.cache_config)
logger.debug("worker init memory snapshot: %r", self.init_snapshot)
- logger.debug(
- "worker requested memory: %sGiB", format_gib(self.requested_memory)
- )
+ logger.debug("worker requested memory: %sGiB", format_gib(self.requested_memory))
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
@@ -102,4 +95,4 @@ def init_device(self):
if self.rank == 0:
# If usage stat is enabled, collect relevant info.
- report_usage_stats(self.vllm_config)
\ No newline at end of file
+ report_usage_stats(self.vllm_config)
diff --git a/vllm_omni/worker/gpu_generation_model_runner.py b/vllm_omni/worker/gpu_generation_model_runner.py
index 40b26e2ba50..011c74d2e83 100644
--- a/vllm_omni/worker/gpu_generation_model_runner.py
+++ b/vllm_omni/worker/gpu_generation_model_runner.py
@@ -5,16 +5,23 @@
"""
from __future__ import annotations
-from copy import copy
import gc
import logging
-from typing import Any
+from copy import copy
+
import numpy as np
import torch
from vllm.config import CUDAGraphMode
+from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
+from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
+from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
+ RoutedExpertsCapturer,
+)
+from vllm.model_executor.models.interfaces import supports_mm_encoder_only
from vllm.utils.math_utils import cdiv
-from vllm.v1.core.sched.output import SchedulerOutput, GrammarOutput
+from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
+from vllm.v1.outputs import AsyncModelRunnerOutput, make_empty_encoder_model_runner_output
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.utils import record_function_or_nullcontext
from vllm.v1.worker.gpu_model_runner import (
@@ -25,20 +32,13 @@
get_pp_group,
set_forward_context,
)
-from vllm.model_executor.models.interfaces import supports_mm_encoder_only
-from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
+from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
+
from vllm_omni.outputs import OmniModelRunnerOutput
-from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner
-from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
- RoutedExpertsCapturer,
-)
-from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
-from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
-from vllm.v1.outputs import make_empty_encoder_model_runner_output
-from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
-from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm_omni.worker.gpu_ar_model_runner import ExecuteModelState
+from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner
+
logger = logging.getLogger(__name__)
@@ -57,10 +57,7 @@ def execute_model(
intermediate_tensors: IntermediateTensors | None = None,
) -> OmniModelRunnerOutput | IntermediateTensors:
if self.execute_model_state is not None:
- raise RuntimeError(
- "State error: sample_tokens() must be called "
- "after execute_model() returns None."
- )
+ raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.")
if self.vllm_config.model_config.enable_return_routed_experts:
capturer = RoutedExpertsCapturer.get_instance()
@@ -70,10 +67,8 @@ def execute_model(
logger.error("RoutedExpertsCapturer not initialized.")
if scheduler_output.preempted_req_ids and has_kv_transfer_group():
- get_kv_transfer_group().handle_preemptions(
- scheduler_output.preempted_req_ids
- )
-
+ get_kv_transfer_group().handle_preemptions(scheduler_output.preempted_req_ids)
+
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with (
record_function_or_nullcontext("gpu_model_runner: preprocess"),
@@ -82,7 +77,7 @@ def execute_model(
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
return EMPTY_MODEL_RUNNER_OUTPUT
-
+
if has_ec_transfer() and get_ec_transfer().is_producer:
with self.maybe_get_ec_connector_output(
scheduler_output,
@@ -93,8 +88,7 @@ def execute_model(
if not num_scheduled_tokens:
if (
- self.parallel_config.distributed_executor_backend
- == "external_launcher"
+ self.parallel_config.distributed_executor_backend == "external_launcher"
and self.parallel_config.data_parallel_size > 1
):
# this is a corner case when both external launcher
@@ -107,7 +101,7 @@ def execute_model(
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
-
+
return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
if self.cache_config.kv_sharing_fast_prefill:
@@ -127,7 +121,7 @@ def execute_model(
scheduler_output,
num_scheduled_tokens_np,
)
-
+
cascade_attn_prefix_lens = None
# Disable cascade attention when using microbatching (DBO)
if self.cascade_attn_enabled and not self.parallel_config.use_ubatching:
@@ -137,7 +131,7 @@ def execute_model(
self.input_batch.num_computed_tokens_cpu[:num_reqs],
scheduler_output.num_common_prefix_blocks,
)
-
+
(
cudagraph_mode,
batch_desc,
@@ -163,9 +157,7 @@ def execute_model(
)
num_tokens_padded = batch_desc.num_tokens
- num_reqs_padded = (
- batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
- )
+ num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
should_ubatch,
num_scheduled_tokens_np,
@@ -185,19 +177,17 @@ def execute_model(
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
- attn_metadata, spec_decode_common_attn_metadata = (
- self._build_attention_metadata(
- num_tokens=num_tokens_unpadded,
- num_tokens_padded=num_tokens_padded if pad_attn else None,
- num_reqs=num_reqs,
- num_reqs_padded=num_reqs_padded if pad_attn else None,
- max_query_len=max_num_scheduled_tokens,
- ubatch_slices=ubatch_slices_attn,
- logits_indices=logits_indices,
- use_spec_decode=use_spec_decode,
- num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
- cascade_attn_prefix_lens=cascade_attn_prefix_lens,
- )
+ attn_metadata, spec_decode_common_attn_metadata = self._build_attention_metadata(
+ num_tokens=num_tokens_unpadded,
+ num_tokens_padded=num_tokens_padded if pad_attn else None,
+ num_reqs=num_reqs,
+ num_reqs_padded=num_reqs_padded if pad_attn else None,
+ max_query_len=max_num_scheduled_tokens,
+ ubatch_slices=ubatch_slices_attn,
+ logits_indices=logits_indices,
+ use_spec_decode=use_spec_decode,
+ num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
+ cascade_attn_prefix_lens=cascade_attn_prefix_lens,
)
(
@@ -260,12 +250,15 @@ def execute_model(
)
self.kv_connector_output = kv_connector_output
return None
-
+
@torch.inference_mode()
def sample_tokens(
self,
grammar_output: GrammarOutput | None = None,
) -> OmniModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
+ # NOTE: Even though the model is non-autoregressive, we still need
+ # this function to match the interface of the engine core.
+ # In this case, this function
kv_connector_output = self.kv_connector_output
self.kv_connector_output = None
@@ -329,9 +322,7 @@ def sample_tokens(
kv_connector_output=kv_connector_output,
num_nans_in_logits={},
cudagraph_stats=cudagraph_stats,
- ec_connector_output=ec_connector_output
- if self.supports_mm_inputs
- else None,
+ ec_connector_output=ec_connector_output if self.supports_mm_inputs else None,
)
if not self.use_async_scheduling:
@@ -433,10 +424,7 @@ def _dummy_run(
# mm encoder dummy run may need to add in the future.
return torch.tensor([]), torch.tensor([])
- assert (
- cudagraph_runtime_mode is None
- or cudagraph_runtime_mode.valid_runtime_modes()
- )
+ assert cudagraph_runtime_mode is None or cudagraph_runtime_mode.valid_runtime_modes()
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.separate_routine(). This means that we are using
@@ -497,8 +485,7 @@ def _dummy_run(
max_num_scheduled_tokens=max_query_len,
use_cascade_attn=False,
allow_microbatching=allow_microbatching,
- force_eager=is_profile
- or (cudagraph_runtime_mode == CUDAGraphMode.NONE),
+ force_eager=is_profile or (cudagraph_runtime_mode == CUDAGraphMode.NONE),
# `force_uniform_decode` is used for cudagraph capture; because for
# capturing mixed prefill-decode batches, we sometimes use
# num_tokens == num_reqs which looks like a uniform decode batch to the
@@ -520,9 +507,7 @@ def _dummy_run(
)
num_tokens_padded = batch_desc.num_tokens
- num_reqs_padded = (
- batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
- )
+ num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
should_ubatch,
num_scheduled_tokens,
@@ -601,17 +586,13 @@ def _dummy_run(
intermediate_tensors = None
else:
if self.intermediate_tensors is None:
- self.intermediate_tensors = (
- self.model.make_empty_intermediate_tensors(
- batch_size=self.max_num_tokens,
- dtype=self.model_config.dtype,
- device=self.device,
- )
+ self.intermediate_tensors = self.model.make_empty_intermediate_tensors(
+ batch_size=self.max_num_tokens,
+ dtype=self.model_config.dtype,
+ device=self.device,
)
- intermediate_tensors = self.sync_and_slice_intermediate_tensors(
- num_tokens_padded, None, False
- )
+ intermediate_tensors = self.sync_and_slice_intermediate_tensors(num_tokens_padded, None, False)
if ubatch_slices_padded is not None:
# Adjust values to reflect a single ubatch.
@@ -652,14 +633,8 @@ def _dummy_run(
# Therefore only use cudagraphs if the main model uses PIECEWISE
# NOTE(lucas): this is a hack, need to clean up.
use_cudagraphs = (
- (
- is_graph_capturing
- and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE
- )
- or (
- not is_graph_capturing
- and cudagraph_runtime_mode != CUDAGraphMode.NONE
- )
+ (is_graph_capturing and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE)
+ or (not is_graph_capturing and cudagraph_runtime_mode != CUDAGraphMode.NONE)
) and not self.speculative_config.enforce_eager
# Note(gnovack) - We need to disable cudagraphs for one of the two
diff --git a/vllm_omni/worker/gpu_generation_worker.py b/vllm_omni/worker/gpu_generation_worker.py
index 6a1a3039211..19f8ab84b99 100644
--- a/vllm_omni/worker/gpu_generation_worker.py
+++ b/vllm_omni/worker/gpu_generation_worker.py
@@ -2,17 +2,21 @@
import os
import torch
-from vllm.utils.torch_utils import set_random_seed
+from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.mem_utils import MemorySnapshot, format_gib
+from vllm.utils.torch_utils import set_random_seed
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_worker import Worker as GPUWorker
from vllm.v1.worker.gpu_worker import init_worker_distributed_environment
-from vllm.v1.worker.workspace import init_workspace_manager
from vllm.v1.worker.utils import request_memory
+from vllm.v1.worker.workspace import init_workspace_manager
+
from vllm_omni.worker.gpu_generation_model_runner import GPUGenerationModelRunner
-from vllm.logger import init_logger
+
logger = init_logger(__name__)
+
+
class GPUGenerationWorker(GPUWorker):
"""GPU Worker for Generation model (non-autoregressive waveform generation).
@@ -26,8 +30,7 @@ def init_device(self):
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
parallel_config = self.parallel_config
if (
- parallel_config.distributed_executor_backend
- not in ("ray", "external_launcher")
+ parallel_config.distributed_executor_backend not in ("ray", "external_launcher")
and parallel_config.data_parallel_backend != "ray"
and parallel_config.nnodes_within_dp == 1
):
@@ -37,8 +40,7 @@ def init_device(self):
dp_local_rank = self.parallel_config.data_parallel_index
tp_pp_world_size = (
- self.parallel_config.pipeline_parallel_size
- * self.parallel_config.tensor_parallel_size
+ self.parallel_config.pipeline_parallel_size * self.parallel_config.tensor_parallel_size
)
# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
@@ -46,9 +48,7 @@ def init_device(self):
assert self.local_rank < torch.cuda.device_count(), (
f"DP adjusted local rank {self.local_rank} is out of bounds. "
)
- visible_device_count = (
- torch.cuda.device_count() if torch.cuda.is_available() else 0
- )
+ visible_device_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
assert self.parallel_config.local_world_size <= visible_device_count, (
f"local_world_size ({self.parallel_config.local_world_size}) must "
f"be less than or equal to the number of visible devices "
@@ -82,9 +82,7 @@ def init_device(self):
self.init_snapshot = init_snapshot = MemorySnapshot(device=self.device)
self.requested_memory = request_memory(init_snapshot, self.cache_config)
logger.debug("worker init memory snapshot: %r", self.init_snapshot)
- logger.debug(
- "worker requested memory: %sGiB", format_gib(self.requested_memory)
- )
+ logger.debug("worker requested memory: %sGiB", format_gib(self.requested_memory))
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py
index 24d4ffd028e..06abd9f3a78 100644
--- a/vllm_omni/worker/gpu_model_runner.py
+++ b/vllm_omni/worker/gpu_model_runner.py
@@ -8,7 +8,7 @@
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
-from vllm.model_executor.models.interfaces import supports_mrope, supports_mm_encoder_only
+from vllm.model_executor.models.interfaces import supports_mm_encoder_only, supports_mrope
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
from vllm.sampling_params import SamplingType
from vllm.utils.import_utils import LazyLoader
@@ -17,6 +17,7 @@
from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm.v1.worker.gpu_model_runner import GPUModelRunner, IntermediateTensors, PerLayerAttnMetadata
from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
+
from vllm_omni.model_executor.models.output_templates import OmniOutput
if TYPE_CHECKING:
@@ -94,7 +95,7 @@ def _init_mrope_positions(self, req_state: CachedRequestState):
if use_audio_in_video_value is not None:
use_audio_in_video = bool(use_audio_in_video_value.item())
- if supports_mrope(self.model):
+ if supports_mrope(self.get_model()):
req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions(
req_state.prompt_token_ids,
mm_features=req_state.mm_features,
@@ -248,7 +249,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
except Exception as e:
logger.error(f"Error decoding additional information: {e}")
pass
-
+
if sampling_params and sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = (
self.input_batch.vocab_size
@@ -258,11 +259,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._init_mrope_positions(req_state)
-
+
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
if self.uses_xdrope_dim > 0:
self._init_xdrope_positions(req_state)
-
+
reqs_to_add.append(self.requests[req_id])
# Update the states of the running/resumed requests.
@@ -281,14 +282,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
resumed_from_preemption = req_id in req_data.resumed_req_ids
num_output_tokens = req_data.num_output_tokens[i]
req_index = self.input_batch.req_id_to_index.get(req_id)
-
+
if req_state.prev_num_draft_len and self.use_async_scheduling:
# prev_num_draft_len is used in async scheduling mode with
# spec decode. it indicates if need to update num_computed_tokens
# of the request. for example:
# fist step: num_computed_tokens = 0, spec_tokens = [],
# prev_num_draft_len = 0.
- # second step: num_computed_tokens = 100(prompt lenth),
+ # second step: num_computed_tokens = 100(prompt length),
# spec_tokens = [a,b], prev_num_draft_len = 0.
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
# prev_num_draft_len = 2.
@@ -305,7 +306,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
num_rejected = req_state.prev_num_draft_len - num_accepted
num_computed_tokens -= num_rejected
req_state.output_token_ids.extend([-1] * num_accepted)
-
+
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
@@ -327,12 +328,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# failure. Align the cached state.
del req_state.output_token_ids[num_output_tokens:]
if req_index is not None:
- end_idx = (
- self.input_batch.num_prompt_tokens[req_index]
- + num_output_tokens
- )
+ end_idx = self.input_batch.num_prompt_tokens[req_index] + num_output_tokens
self.input_batch.num_tokens_no_spec[req_index] = end_idx
-
+
# Update the block IDs.
if not resumed_from_preemption:
if new_block_ids is not None:
@@ -372,15 +370,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids)
- self.input_batch.token_ids_cpu[
- req_index, start_token_index:end_token_index
- ] = new_token_ids
+ self.input_batch.token_ids_cpu[req_index, start_token_index:end_token_index] = new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens)
-
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
for request in reqs_to_add:
@@ -457,10 +452,7 @@ def _dummy_run(
# mm encoder dummy run may need to add in the future.
return torch.tensor([]), torch.tensor([])
- assert (
- cudagraph_runtime_mode is None
- or cudagraph_runtime_mode.valid_runtime_modes()
- )
+ assert cudagraph_runtime_mode is None or cudagraph_runtime_mode.valid_runtime_modes()
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.separate_routine(). This means that we are using
@@ -521,8 +513,7 @@ def _dummy_run(
max_num_scheduled_tokens=max_query_len,
use_cascade_attn=False,
allow_microbatching=allow_microbatching,
- force_eager=is_profile
- or (cudagraph_runtime_mode == CUDAGraphMode.NONE),
+ force_eager=is_profile or (cudagraph_runtime_mode == CUDAGraphMode.NONE),
# `force_uniform_decode` is used for cudagraph capture; because for
# capturing mixed prefill-decode batches, we sometimes use
# num_tokens == num_reqs which looks like a uniform decode batch to the
@@ -544,9 +535,7 @@ def _dummy_run(
)
num_tokens_padded = batch_desc.num_tokens
- num_reqs_padded = (
- batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
- )
+ num_reqs_padded = batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
should_ubatch,
num_scheduled_tokens,
@@ -625,17 +614,13 @@ def _dummy_run(
intermediate_tensors = None
else:
if self.intermediate_tensors is None:
- self.intermediate_tensors = (
- self.model.make_empty_intermediate_tensors(
- batch_size=self.max_num_tokens,
- dtype=self.model_config.dtype,
- device=self.device,
- )
+ self.intermediate_tensors = self.model.make_empty_intermediate_tensors(
+ batch_size=self.max_num_tokens,
+ dtype=self.model_config.dtype,
+ device=self.device,
)
- intermediate_tensors = self.sync_and_slice_intermediate_tensors(
- num_tokens_padded, None, False
- )
+ intermediate_tensors = self.sync_and_slice_intermediate_tensors(num_tokens_padded, None, False)
if ubatch_slices_padded is not None:
# Adjust values to reflect a single ubatch.
@@ -657,6 +642,13 @@ def _dummy_run(
ubatch_slices=ubatch_slices_padded,
),
):
+ if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"):
+ outputs = self.talker_mtp(
+ self.talker_mtp_input_ids.gpu[:num_tokens_padded],
+ self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded],
+ self.last_talker_hidden.gpu[:num_tokens_padded],
+ self.text_step.gpu[:num_tokens_padded],
+ )
outputs = self.model(
input_ids=input_ids,
positions=positions,
@@ -676,14 +668,8 @@ def _dummy_run(
# Therefore only use cudagraphs if the main model uses PIECEWISE
# NOTE(lucas): this is a hack, need to clean up.
use_cudagraphs = (
- (
- is_graph_capturing
- and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE
- )
- or (
- not is_graph_capturing
- and cudagraph_runtime_mode != CUDAGraphMode.NONE
- )
+ (is_graph_capturing and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE)
+ or (not is_graph_capturing and cudagraph_runtime_mode != CUDAGraphMode.NONE)
) and not self.speculative_config.enforce_eager
# Note(gnovack) - We need to disable cudagraphs for one of the two
@@ -721,9 +707,7 @@ def _dummy_run(
self.eplb_step(is_dummy=True, is_profile=is_profile)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
- logit_indices_device = torch.from_numpy(logit_indices).to(
- self.device, non_blocking=True
- )
+ logit_indices_device = torch.from_numpy(logit_indices).to(self.device, non_blocking=True)
return hidden_states, hidden_states[logit_indices_device]
def _decode_and_store_request_payloads(self, scheduler_output: "SchedulerOutput") -> None:
@@ -988,6 +972,7 @@ def _preprocess(
if hasattr(self.model, "has_preprocess") and self.model.has_preprocess:
# Overlay custom prompt_embeds per request for the prompt portion;
# collect additional_information (tensor/list) for prefill portion only
+ decode_req_ids = []
for req_index, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests.get(req_id)
req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None
@@ -998,43 +983,17 @@ def _preprocess(
span_len = int(e) - int(s)
# call the custom process function
- try:
- req_input_ids, req_embeds, update_dict = self.model.preprocess(
- input_ids=input_ids[s:e], input_embeds=inputs_embeds[s:e], **req_infos
- )
- except Exception as e:
- logger.error(f"Error in preprocess for request {req_id}: {e}")
- import traceback
- traceback.print_exc()
- raise e
- #TODO: This is Model Specific Code, need to be generalized in the future ZTC
- # run talker mtp decode
- if hasattr(self.model, "talker_mtp"):
- _cudagraph_mode, batch_desc, _, _, _ = self._determine_batch_execution_and_padding(
- num_tokens=span_len,
- num_reqs=1,
- num_scheduled_tokens_np=num_scheduled_tokens_np[req_index],
- max_num_scheduled_tokens=1,
- force_eager=span_len > 1,
- use_cascade_attn=False,
- )
+ req_input_ids, req_embeds, update_dict = self.model.preprocess(
+ input_ids=input_ids[s:e], input_embeds=inputs_embeds[s:e], **req_infos
+ )
+ if hasattr(self.model, "talker_mtp") and span_len == 1:
last_talker_hidden, text_step = update_dict.pop("mtp_inputs")
- if _cudagraph_mode != CUDAGraphMode.NONE:
- self.talker_mtp_input_ids.gpu[:span_len].copy_(req_input_ids)
- self.talker_mtp_inputs_embeds.gpu[:span_len].copy_(req_embeds)
- self.last_talker_hidden.gpu[:span_len].copy_(last_talker_hidden)
- self.text_step.gpu[:span_len].copy_(text_step)
- req_input_ids = self.talker_mtp_input_ids.gpu[:span_len]
- req_embeds = self.talker_mtp_inputs_embeds.gpu[:span_len]
- last_talker_hidden = self.last_talker_hidden.gpu[:span_len]
- text_step = self.text_step.gpu[:span_len]
- with set_forward_context(
- None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc
- ):
- req_embeds, code_predictor_codes = self.talker_mtp(
- req_input_ids, req_embeds, last_talker_hidden, text_step
- )
- update_dict["code_predictor_codes"] = code_predictor_codes
+ decode_slice = slice(len(decode_req_ids), len(decode_req_ids) + 1)
+ self.talker_mtp_input_ids.gpu[decode_slice].copy_(req_input_ids)
+ self.talker_mtp_inputs_embeds.gpu[decode_slice].copy_(req_embeds)
+ self.last_talker_hidden.gpu[decode_slice].copy_(last_talker_hidden)
+ self.text_step.gpu[decode_slice].copy_(text_step)
+ decode_req_ids.append(req_id)
# TODO(Peiqi): the merge stage could move out from the critical path
self._merge_additional_information_update(req_id, update_dict)
@@ -1045,6 +1004,10 @@ def _preprocess(
if isinstance(req_input_ids, torch.Tensor) and req_input_ids.numel() == seg_len:
input_ids[s : s + seg_len] = req_input_ids
+ # run talker mtp decode
+ if hasattr(self.model, "talker_mtp"):
+ self._talker_mtp_forward(decode_req_ids, inputs_embeds)
+
return (
input_ids,
inputs_embeds,
@@ -1054,6 +1017,34 @@ def _preprocess(
ec_connector_output,
)
+ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Tensor) -> None:
+ decode_batch_size = len(decode_req_ids)
+ if decode_batch_size == 0:
+ return
+ _cudagraph_mode, batch_desc, _, _, _ = self._determine_batch_execution_and_padding(
+ num_tokens=decode_batch_size,
+ num_reqs=decode_batch_size,
+ num_scheduled_tokens_np=np.ones(decode_batch_size, dtype=np.int32),
+ max_num_scheduled_tokens=1,
+ use_cascade_attn=False,
+ )
+ req_input_ids = self.talker_mtp_input_ids.gpu[:decode_batch_size]
+ req_embeds = self.talker_mtp_inputs_embeds.gpu[:decode_batch_size]
+ last_talker_hidden = self.last_talker_hidden.gpu[:decode_batch_size]
+ text_step = self.text_step.gpu[:decode_batch_size]
+ with set_forward_context(
+ None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc
+ ):
+ req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step)
+ # update the inputs_embeds and code_predictor_codes
+ code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous()
+ for idx, req_id in enumerate(decode_req_ids):
+ req_index = self.input_batch.req_ids.index(req_id)
+ start_offset = int(self.query_start_loc.cpu[req_index])
+ inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1]
+ update_dict = {"code_predictor_codes": code_predictor_codes_cpu[idx : idx + 1]}
+ self._merge_additional_information_update(req_id, update_dict)
+
def _model_forward(
self,
input_ids: torch.Tensor | None = None,
diff --git a/vllm_omni/worker/npu/npu_model_runner.py b/vllm_omni/worker/npu/npu_model_runner.py
index 1083bbe40d4..264d5e3413b 100644
--- a/vllm_omni/worker/npu/npu_model_runner.py
+++ b/vllm_omni/worker/npu/npu_model_runner.py
@@ -64,7 +64,7 @@ def _init_mrope_positions(self, req_state: CachedRequestState):
if use_audio_in_video_value is not None:
use_audio_in_video = bool(use_audio_in_video_value.item())
- if supports_mrope(self.model):
+ if supports_mrope(self.get_model()):
req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions(
req_state.prompt_token_ids,
mm_features=req_state.mm_features,